关于用RNN训练图形验证码损失不收敛的问题
代码:
import torch.optim
from torch import nn
import os
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader,Dataset
class LSTMRNN(nn.Module):
def __init__(self,input_size,hidden_size,output_size,num_layers):
super(LSTMRNN, self).__init__()
self.lstm = nn.LSTM(input_size,hidden_size,num_layers) # 输入变量x的特征数量,隐含层h的特征数量(即层中隐含单元的个数),隐含层的层数
self.linear = nn.Linear(hidden_size,output_size)
def forward(self,x):
x,_ = self.lstm(x)
s,b,h = x.shape
x = self.linear(x)
return x[-1,:,:]
class MyDataSet(Dataset):
def __init__(self,x,y):
self.len = x.shape[0]
self.x = x
self.y = y
def __getitem__(self, item):
return self.x[item],self.y[item]
def __len__(self):
return self.len
def getImage(img_path, img_name):
path = os.path.join(img_path, img_name)
label = img_name.split("_")[0]
img = Image.open(path)
captcha_array = np.array(img) # 向量形式
return label, captcha_array
def convert2gray(img):
if len(img.shape) > 2:
r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray
else:
return img
def text2vec(text):
text_len = len(text)
if text_len > max_captcha:
raise ValueError('验证码最长{}个字符'.format(max_captcha))
vector = np.zeros(max_captcha * len(char_set))
for i, ch in enumerate(text):
idx = i * len(char_set) + char_set.index(ch)
vector[idx] = 1
return vector
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
char_set = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
max_captcha = 4
file_names = os.listdir('./sample/sample/train')
lines = 5000
start_lines = np.random.randint(1, 5000)
trainx = np.zeros([lines, 60 * 180])
trainy = np.zeros([lines, max_captcha*len(char_set)])
for i in range(start_lines, start_lines+lines):
label, image_array = getImage('./sample/sample/train', file_names[i])
trainx[i-start_lines, :] = convert2gray(image_array).flatten() / 255
trainy[i-start_lines, :] = text2vec(label)
if os.path.exists('./rnnmodel1.pth'):
print("model load")
lstm_model = torch.load('./rnnmodel1.pth')
else :
print("model create")
lstm_model = LSTMRNN(trainx.shape[1], 128, trainy.shape[1], 64).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(lstm_model.parameters())
print("read finsh")
# for epoch in range(100):
# print("epoch={}".format(epoch))
# for i in range(lines):
# output = lstm_model(torch.Tensor(trainx[i].reshape([1, 10800])).unsqueeze(1).to(device))
# loss = criterion(output, torch.Tensor(trainy[i].reshape([1, 248])).to(device))
#
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# torch.save(lstm_model,'./rnnmodel1.pth')
pre_output = lstm_model(torch.Tensor(trainx[100].reshape([1,10800])).unsqueeze(1).to(device))
print(trainy[100])
print(pre_output)
其中读取了验证码图片,然后用平方误差作为损失函数进行收敛,但是一直都没有收敛