• 【OCR】基于Encoder-Decoder的文本识别





            原文链接为:《Robust Scene Text Recognition with Automatic Rectification》









    1. import os
    2. import random
    3. import numpy as np
    4. from PIL import Image
    5. import cv2
    6. import torch
    7. import torch.utils.data
    8. import torch.backends.cudnn as cudnn
    9. import torch.nn as nn
    10. import torch.nn.functional as F
    11. import torchvision.transforms as T
    12. from torch.autograd import Variable
    13. import collections
    14. import collections.abc
    15. cudnn.benchmark = True
    16. class configs():
    17. def __init__(self):
    18. #Data
    19. self.train_list = r'E:\code\OCR\crnn_seq2seq_ocr_pytorch-master\data\train_list.txt'
    20. self.eval_list = r'E:\code\OCR\crnn_seq2seq_ocr_pytorch-master\data\valid_list.txt'
    21. self.img_height = 32
    22. self.img_width = 280
    23. self.save_model_dir = 'seq_models'
    24. self.get_lexicon_dir = './lbl2id_map.txt'
    25. # self.lexicon = self.get_lexicon(lexicon_name=self.get_lexicon_dir)
    26. self.lexicon = "0123456789"
    27. self.all_chars = {v: k for k, v in enumerate(self.lexicon)}
    28. self.all_nums = {v: k for v, k in enumerate(self.lexicon)}
    29. self.class_num = len(self.lexicon)+2
    30. self.label_word_length = 4
    31. self.random_sample = True #是否数据随机
    32. self.teaching_forcing_prob = 0.5
    33. #train
    34. self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    35. self.batch_size = 64
    36. self.epoch = 31
    37. self.save_model_fre_epoch = 1
    38. self.hidden_size = 256 # 隐层数量
    39. self.learning_rate = 0.0001
    40. self.encoder = ''
    41. self.decoder = ''
    42. self.max_width = 71 #最长字长
    43. #test/infer
    44. self.test_img_paths = r'E:\code\OCR\new_ocr\captcha_datasets\test-data-1'
    45. self.test_encoder_path = r'E:\code\OCR\crnn_seq2seq_ocr_pytorch-master\model\encoder_30.pth'
    46. self.test_decoder_path = r'E:\code\OCR\crnn_seq2seq_ocr_pytorch-master\model\decoder_30.pth'
    47. self.istrain = False
    48. self.istest = True
    49. def get_lexicon(self,lexicon_name):
    50. '''
    51. #获取词表 lbl2id_map.txt',词表格式如下
    52. #0\t0\n
    53. #a\t1\n
    54. #...
    55. #z\t63\n
    56. :param lexicons_name:
    57. :return:
    58. '''
    59. lexicons = open(lexicon_name, 'r', encoding='utf-8').readlines()
    60. lexicons_str = ''.join(word[0].split('\t')[0] for word in lexicons)
    61. return lexicons_str
    62. cfg = configs()
    63. #数据
    64. class TextLineDataset(torch.utils.data.Dataset):
    65. def __init__(self, text_line_file=None, transform=None, target_transform=None):
    66. self.text_line_file = text_line_file
    67. with open(text_line_file) as fp:
    68. self.lines = fp.readlines()
    69. self.nSamples = len(self.lines)
    70. self.transform = transform
    71. self.target_transform = target_transform
    72. def __len__(self):
    73. return self.nSamples
    74. def __getitem__(self, index):
    75. assert index <= len(self), 'index range error'
    76. line_splits = self.lines[index].strip().split()
    77. img_path = line_splits[0]
    78. try:
    79. if 'train' in self.text_line_file:
    80. img = Image.open(img_path).convert('RGB')
    81. else:
    82. img = Image.open(img_path).convert('RGB')
    83. except IOError:
    84. print('Corrupted image for %d' % index)
    85. return self[index + 1]
    86. if self.transform is not None:
    87. img = self.transform(img)
    88. label = line_splits[1]
    89. if self.target_transform is not None:
    90. label = self.target_transform(label)
    91. return (img, label)
    92. class ResizeNormalize(object):
    93. def __init__(self, img_width, img_height):
    94. self.img_width = img_width
    95. self.img_height = img_height
    96. self.toTensor = T.ToTensor()
    97. def __call__(self, img):
    98. img = np.array(img)
    99. h, w, c = img.shape
    100. height = self.img_height
    101. width = int(w * height / h)
    102. if width >= self.img_width:
    103. img = cv2.resize(img, (self.img_width, self.img_height))
    104. else:
    105. img = cv2.resize(img, (width, height))
    106. img_pad = np.zeros((self.img_height, self.img_width, c), dtype=img.dtype)
    107. img_pad[:height, :width, :] = img
    108. img = img_pad
    109. img = Image.fromarray(img)
    110. img = self.toTensor(img)
    111. img.sub_(0.5).div_(0.5)
    112. return img
    113. class RandomSequentialSampler(torch.utils.data.sampler.Sampler):
    114. def __init__(self, data_source, batch_size):
    115. self.num_samples = len(data_source)
    116. self.batch_size = batch_size
    117. def __iter__(self):
    118. n_batches = len(self) // self.batch_size
    119. tail = len(self) % self.batch_size
    120. index = torch.LongTensor(len(self)).fill_(0)
    121. for i in range(n_batches):
    122. random_start = random.randint(0, len(self) - self.batch_size)
    123. batch_index = random_start + torch.arange(0, self.batch_size)
    124. index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index
    125. # deal with tail
    126. if tail:
    127. random_start = random.randint(0, len(self) - self.batch_size)
    128. tail_index = random_start + torch.arange(0, tail)
    129. index[(i + 1) * self.batch_size:] = tail_index
    130. return iter(index)
    131. def __len__(self):
    132. return self.num_samples
    133. class AlignCollate(object):
    134. def __init__(self, img_height=32, img_width=100):
    135. self.img_height = img_height
    136. self.img_width = img_width
    137. self.transform = ResizeNormalize(img_width=self.img_width, img_height=self.img_height)
    138. def __call__(self, batch):
    139. images, labels = zip(*batch)
    140. images = [self.transform(image) for image in images]
    141. images = torch.cat([t.unsqueeze(0) for t in images], 0)
    142. return images, labels
    143. def load_data(v, data):
    144. with torch.no_grad():
    145. v.resize_(data.size()).copy_(data)
    146. SOS_TOKEN = 0 # special token for start of sentence
    147. EOS_TOKEN = 1 # special token for end of sentence
    148. class ConvertBetweenStringAndLabel(object):
    149. """Convert between str and label.
    150. NOTE:
    151. Insert `EOS` to the alphabet for attention.
    152. Args:
    153. alphabet (str): set of the possible characters.
    154. ignore_case (bool, default=True): whether or not to ignore all of the case.
    155. """
    156. def __init__(self, alphabet):
    157. self.alphabet = alphabet
    158. self.dict = {}
    159. self.dict['SOS_TOKEN'] = SOS_TOKEN
    160. self.dict['EOS_TOKEN'] = EOS_TOKEN
    161. for i, item in enumerate(self.alphabet):
    162. self.dict[item] = i + 2
    163. def encode(self, text):
    164. """
    165. Args:
    166. text (str or list of str): texts to convert.
    167. Returns:
    168. torch.IntTensor targets:max_length × batch_size
    169. """
    170. if isinstance(text, str):
    171. text = [self.dict[item] if item in self.dict else 2 for item in text]
    172. elif isinstance(text, collections.abc.Iterable):
    173. text = [self.encode(s) for s in text]
    174. max_length = max([len(x) for x in text])
    175. nb = len(text)
    176. targets = torch.ones(nb, max_length + 2) * 2
    177. for i in range(nb):
    178. targets[i][0] = 0
    179. targets[i][1:len(text[i]) + 1] = text[i]
    180. targets[i][len(text[i]) + 1] = 1
    181. text = targets.transpose(0, 1).contiguous()
    182. text = text.long()
    183. return torch.LongTensor(text)
    184. def decode(self, t):
    185. """Decode encoded texts back into strs.
    186. Args:
    187. torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
    188. torch.IntTensor [n]: length of each text.
    189. Raises:
    190. AssertionError: when the texts and its length does not match.
    191. Returns:
    192. text (str or list of str): texts to convert.
    193. """
    194. texts = list(self.dict.keys())[list(self.dict.values()).index(t)]
    195. return texts
    196. converter = ConvertBetweenStringAndLabel(cfg.lexicon)
    197. #模型
    198. class CNN(nn.Module):
    199. def __init__(self, channel_size):
    200. super(CNN, self).__init__()
    201. self.cnn = nn.Sequential(
    202. nn.Conv2d(channel_size, 64, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2),
    203. nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2),
    204. nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
    205. nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d((2,2), (2,1), (0,1)),
    206. nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
    207. nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d((2,2), (2,1), (0,1)),
    208. nn.Conv2d(512, 512, 2, 1, 0), nn.BatchNorm2d(512), nn.ReLU(True))
    209. def forward(self, input):
    210. # [n, channel_size, 32, 280] -> [n, 512, 1, 71]
    211. conv = self.cnn(input)
    212. return conv
    213. class BidirectionalLSTM(nn.Module):
    214. def __init__(self, input_size, hidden_size, output_size):
    215. super(BidirectionalLSTM, self).__init__()
    216. self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True)
    217. self.embedding = nn.Linear(hidden_size * 2, output_size)
    218. def forward(self, input):
    219. recurrent, _ = self.rnn(input)
    220. T, b, h = recurrent.size()
    221. t_rec = recurrent.view(T * b, h)
    222. output = self.embedding(t_rec) # [T * b, output_size]
    223. output = output.view(T, b, -1)
    224. return output
    225. class AttnDecoderRNN(nn.Module):
    226. def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=71):
    227. super(AttnDecoderRNN, self).__init__()
    228. self.hidden_size = hidden_size
    229. self.output_size = output_size
    230. self.dropout_p = dropout_p
    231. self.max_length = max_length
    232. self.embedding = nn.Embedding(self.output_size, self.hidden_size)
    233. self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
    234. self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
    235. self.dropout = nn.Dropout(self.dropout_p)
    236. self.gru = nn.GRU(self.hidden_size, self.hidden_size)
    237. self.out = nn.Linear(self.hidden_size, self.output_size)
    238. def forward(self, input, hidden, encoder_outputs):
    239. embedded = self.embedding(input)
    240. embedded = self.dropout(embedded)
    241. attn_weights = F.softmax(self.attn(torch.cat((embedded, hidden[0]), 1)), dim=1)
    242. attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs.permute(1, 0, 2))
    243. output = torch.cat((embedded, attn_applied.squeeze(1)), 1)
    244. output = self.attn_combine(output).unsqueeze(0)
    245. output = F.relu(output)
    246. output, hidden = self.gru(output, hidden)
    247. output = F.log_softmax(self.out(output[0]), dim=1)
    248. return output, hidden, attn_weights
    249. def initHidden(self):
    250. return torch.zeros(1, 1, self.hidden_size, device=cfg.device)
    251. class Encoder(nn.Module):
    252. def __init__(self, channel_size, hidden_size):
    253. super(Encoder, self).__init__()
    254. self.cnn = CNN(channel_size)
    255. self.rnn = nn.Sequential(
    256. BidirectionalLSTM(512, hidden_size, hidden_size),
    257. BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
    258. def forward(self, input):
    259. # conv features
    260. conv = self.cnn(input)
    261. b, c, h, w = conv.size()
    262. assert h == 1, "the height of conv must be 1"
    263. # rnn feature
    264. conv = conv.squeeze(2) # [b, c, 1, w] -> [b, c, w]
    265. conv = conv.permute(2, 0, 1) # [b, c, w] -> [w, b, c]
    266. output = self.rnn(conv)
    267. return output
    268. class Decoder(nn.Module):
    269. def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=71):
    270. super(Decoder, self).__init__()
    271. self.hidden_size = hidden_size
    272. self.decoder = AttnDecoderRNN(hidden_size, output_size, dropout_p, max_length)
    273. def forward(self, input, hidden, encoder_outputs):
    274. return self.decoder(input, hidden, encoder_outputs)
    275. def initHidden(self, batch_size):
    276. result = Variable(torch.zeros(1, batch_size, self.hidden_size))
    277. return result
    278. #utils 功能函数
    279. #模型初始化
    280. def weights_init(model):
    281. # Official init from torch repo.
    282. for m in model.modules():
    283. if isinstance(m, nn.Conv2d):
    284. nn.init.kaiming_normal_(m.weight)
    285. elif isinstance(m, nn.BatchNorm2d):
    286. nn.init.constant_(m.weight, 1)
    287. nn.init.constant_(m.bias, 0)
    288. elif isinstance(m, nn.Linear):
    289. nn.init.constant_(m.bias, 0)
    290. #loss取平均
    291. class Averager(object):
    292. """Compute average for `torch.Variable` and `torch.Tensor`. """
    293. def __init__(self):
    294. self.reset()
    295. def add(self, v):
    296. if isinstance(v, Variable):
    297. count = v.data.numel()
    298. v = v.data.sum()
    299. elif isinstance(v, torch.Tensor):
    300. count = v.numel()
    301. v = v.sum()
    302. self.n_count += count
    303. self.sum += v
    304. def reset(self):
    305. self.n_count = 0
    306. self.sum = 0
    307. def val(self):
    308. res = 0
    309. if self.n_count != 0:
    310. res = self.sum / float(self.n_count)
    311. return res
    312. class ocr():
    313. def train(self):
    314. # create train dataset
    315. train_dataset = TextLineDataset(text_line_file=cfg.train_list, transform=None)
    316. sampler = RandomSequentialSampler(train_dataset, cfg.batch_size)
    317. train_loader = torch.utils.data.DataLoader(
    318. train_dataset, batch_size=cfg.batch_size, shuffle=False, sampler=sampler, num_workers=4,
    319. collate_fn=AlignCollate(img_height=cfg.img_height, img_width=cfg.img_width))
    320. # create test dataset
    321. test_dataset = TextLineDataset(text_line_file=cfg.eval_list,
    322. transform=ResizeNormalize(img_width=cfg.img_width,
    323. img_height=cfg.img_height))
    324. test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=1,
    325. num_workers=4)
    326. # create crnn/seq2seq/attention network
    327. encoder = Encoder(channel_size=3, hidden_size=cfg.hidden_size)
    328. # for prediction of an indefinite long sequence
    329. decoder = Decoder(hidden_size=cfg.hidden_size, output_size=cfg.class_num, dropout_p=0.1,
    330. max_length=cfg.max_width)
    331. encoder.apply(weights_init)
    332. decoder.apply(weights_init)
    333. # create input tensor
    334. image = torch.FloatTensor(cfg.batch_size, 3, cfg.img_height, cfg.img_width)
    335. text = torch.LongTensor(cfg.batch_size)
    336. criterion = torch.nn.NLLLoss()
    337. encoder.to(cfg.device)
    338. decoder.to(cfg.device)
    339. image = image.to(cfg.device)
    340. text = text.to(cfg.device)
    341. criterion = criterion.to(cfg.device)
    342. # optimizer
    343. encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=cfg.learning_rate, betas=(0.5, 0.999))
    344. decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=cfg.learning_rate, betas=(0.5, 0.999))
    345. # loss averager
    346. loss_avg = Averager()
    347. for epoch in range(cfg.epoch):
    348. train_iter = iter(train_loader)
    349. for i in range(len(train_loader)):
    350. cpu_images, cpu_texts = train_iter.next()
    351. batch_size = cpu_images.size(0)
    352. for encoder_param, decoder_param in zip(encoder.parameters(), decoder.parameters()):
    353. encoder_param.requires_grad = True
    354. decoder_param.requires_grad = True
    355. encoder.train()
    356. decoder.train()
    357. target_variable = converter.encode(cpu_texts)
    358. load_data(image, cpu_images)
    359. # CNN + BiLSTM
    360. encoder_outputs = encoder(image)
    361. target_variable = target_variable.cuda()
    362. # start decoder for SOS_TOKEN
    363. decoder_input = target_variable[SOS_TOKEN].cuda()
    364. decoder_hidden = decoder.initHidden(batch_size).cuda()
    365. loss = 0.0
    366. teach_forcing = True if random.random() > cfg.teaching_forcing_prob else False
    367. if teach_forcing:
    368. for di in range(1, target_variable.shape[0]):
    369. decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden,
    370. encoder_outputs)
    371. loss += criterion(decoder_output, target_variable[di])
    372. decoder_input = target_variable[di]
    373. else:
    374. for di in range(1, target_variable.shape[0]):
    375. decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden,
    376. encoder_outputs)
    377. loss += criterion(decoder_output, target_variable[di])
    378. topv, topi = decoder_output.data.topk(1)
    379. ni = topi.squeeze()
    380. decoder_input = ni
    381. encoder.zero_grad()
    382. decoder.zero_grad()
    383. loss.backward()
    384. encoder_optimizer.step()
    385. decoder_optimizer.step()
    386. loss_avg.add(loss)
    387. if i % 10 == 0:
    388. print(
    389. '[Epoch {0}/{1}] [Batch {2}/{3}] Loss: {4}'.format(epoch, cfg.epoch, i, len(train_loader),
    390. loss_avg.val()))
    391. loss_avg.reset()
    392. # save checkpoint
    393. torch.save(encoder.state_dict(), '{0}/encoder_{1}.pth'.format(cfg.save_model_dir, epoch))
    394. torch.save(decoder.state_dict(), '{0}/decoder_{1}.pth'.format(cfg.save_model_dir, epoch))
    395. def infer(self):
    396. encoder_name = cfg.test_encoder_path
    397. decoder_name = cfg.test_decoder_path
    398. correct = 0
    399. transformer = ResizeNormalize(img_width=cfg.img_width, img_height=cfg.img_height)
    400. for test_img_paths in os.listdir(cfg.test_img_paths):
    401. test_img_path = os.path.join(cfg.test_img_paths, test_img_paths)
    402. # image = Image.open(cfg.img_path).convert('RGB')
    403. image = Image.open(test_img_path).convert('RGB')
    404. image = transformer(image)
    405. image = image.to(cfg.device)
    406. image = image.view(1, *image.size())
    407. image = torch.autograd.Variable(image)
    408. encoder = Encoder(3, cfg.hidden_size)
    409. # no dropout during inference
    410. decoder = Decoder(cfg.hidden_size, cfg.class_num, dropout_p=0.0, max_length=cfg.max_width)
    411. encoder = encoder.to(cfg.device)
    412. decoder = decoder.to(cfg.device)
    413. # encoder.load_state_dict(torch.load(cfg.encoder, map_location=map_location))
    414. encoder.load_state_dict(torch.load(encoder_name, map_location='cuda'))
    415. # print('loading pretrained encoder models from {}.'.format(encoder_name))
    416. # decoder.load_state_dict(torch.load(cfg.decoder, map_location=map_location))
    417. decoder.load_state_dict(torch.load(decoder_name, map_location='cuda'))
    418. # print('loading pretrained decoder models from {}.'.format(decoder_name))
    419. encoder.eval()
    420. decoder.eval()
    421. encoder_out = encoder(image)
    422. max_length = 20
    423. decoder_input = torch.zeros(1).long()
    424. decoder_hidden = decoder.initHidden(1)
    425. decoder_input = decoder_input.to(cfg.device)
    426. decoder_hidden = decoder_hidden.to(cfg.device)
    427. words, prob = self.seq2seq_decode(encoder_out, decoder, decoder_input, decoder_hidden, max_length)
    428. # print('predict_string: {} => predict_probility: {}'.format(words, prob))
    429. if words == test_img_paths.replace('.png', '').split('_')[1]:
    430. correct += 1
    431. print("model" + '\t' + "|| acc: " + str(correct / len(os.listdir(cfg.test_img_paths))) + '\n')
    432. #解码推理
    433. def seq2seq_decode(self,encoder_out, decoder, decoder_input, decoder_hidden, max_length):
    434. decoded_words = []
    435. prob = 1.0
    436. for di in range(max_length):
    437. decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_out)
    438. probs = torch.exp(decoder_output)
    439. _, topi = decoder_output.data.topk(1)
    440. ni = topi.squeeze(1)
    441. decoder_input = ni
    442. prob *= probs[:, ni]
    443. if ni == EOS_TOKEN:
    444. break
    445. else:
    446. decoded_words.append(converter.decode(ni))
    447. words = ''.join(decoded_words)
    448. prob = prob.item()
    449. return words, prob
    450. if __name__ == '__main__':
    451. myocr = ocr()
    452. if cfg.istrain == True:
    453. myocr.train()
    454. if cfg.istest == True:
    455. myocr.infer()






  • 相关阅读:
    SQL Server 阻止了对组件 ‘Ole Automation Procedures‘ 的 过程‘sys.sp_OACreate‘ 的访问
    使用ScottPlot库在.NET WinForms中快速实现大型数据集的交互式显示
    [Web安全 网络安全]-Burp Suite抓包软件‘下载‘安装‘配置‘与‘使用‘
    【杂烩】TeX Live+TeXStudio
    iNFTnews | 创造者经济的未来在Web3世界中该去向何处?
    MASA Framework - DDD设计(1)
  • 原文地址:https://blog.csdn.net/weian4913/article/details/126228281