• 【Tensorflow深度学习】实现手写字体识别、预测实战(附源码和数据集 超详细)


    需要源码和数据集请点赞关注收藏后评论区留言私信~~~

    一、数据集简介

    下面用到的数据集基于IAM数据集的英文手写字体自动识别应用,IAM数据库主要包含手写的英文文本,可用于训练和测试手写文本识别以及执行作者的识别和验证,该数据库在ICDAR1999首次发布,并据此开发了基于隐马尔可夫模型的手写句子识别系统,并于ICPR2000发布,IAM包含不受约束的手写文本,以300dpi的分辨率扫描并保存为具有256级灰度的PNG图像,IAM手写数据库目前最新的版本为3.0,其主要结构如下

    约700位作家贡献笔迹样本

    超过1500页扫描文本

    约6000个独立标记的句子

    超过一万行独立标记的文本

    超过十万个独立标记的空间

    展示如下 有许多张手写照片 

     

     

    二、实现步骤 

    1:数据清洗

    删除文件中备注说明以及错误结果,统计正确笔迹图形的数量,最后将整理后的数据进行随机无序化处理

    2:样本分类

    接下来对数据进行分类 按照8:1:1的比例将样本数据集分为三类数据集,分别是训练数据集 验证数据集和测试数据集,针对训练数据集进行训练可以获得模型,而测试数据集主要用于测试模型的有效性

    3:实现字符和数字映射

    利用Tensorflow库的Keras包的StringLookup函数实现从字符到数字的映射 主要参数说明如下

    max_tokens:单词大小的最大值

    num_oov_indices:out of vocabulary的大小

    mask_token:表示屏蔽输入的大小

    oov_token:仅当invert为True时使用 OOV索引的返回值 默认为UNK

    4:进行卷积变化 

    通过Conv2D函数实现二维卷积变换 主要参数说明如下

    filters:整数值 代表输出空间的维度

    kernel_size:一个整数或元组列表 指定卷积窗口的高度和宽度

    strides:一个整数或元组列表 指定卷积沿高度和宽度的步幅

    padding:输出图像的填充方式

    activation:激活函数

    三、效果展示 

    读取部分手写样本的真实文本信息如下

    训练结束后 得到训练模型 导入测试手写文本数据 进行手写笔迹预测 部分结果如下

     

    四、结果总结 

    观察预测结果可知,基于均值池化以及训练过程预警极值,大部分的英文字符能够得到准确的预测判定,训练的精度持续得到改善,损失值控制在比较合理的区间内,没有发生预测准确度连续多次无法改进的场景,模型稳定性较好

    五、代码

    部分代码如下 需要全部代码请点赞关注收藏后评论区留言私信~~~

    1. from tensorflow.keras.layers.experimental.preprocessing import StringLookup
    2. from tensorflow import keras
    3. import matplotlib.pyplot as plt
    4. import tensorflow as tf
    5. import numpy as np
    6. import os
    7. plt.rcParams['font.family'] = ['Microsoft YaHei']
    8. np.random.seed(0)
    9. tf.random.set_seed(0)
    10. # ## 切分数据
    11. # In[ ]:
    12. corpus_read = open("data/words.txt", "r").readlines()
    13. corpus = []
    14. length_corpus=0
    15. for word in corpus_read:
    16. if lit(" ")[1] == "ok"):
    17. corpus.append(word)
    18. np.random.shuffle(corpus)
    19. length_corpus=len(corpus)
    20. print(length_corpus)
    21. corpus[400:405]
    22. # 划分数据,按照 80:10:10 比例分配给训练:有效:测试 数据
    23. # In[ ]:
    24. train_flag = int(0.8 * len(corpus))
    25. test_flag = int(0.9 * len(corpus))
    26. train_data = corpus[:train_flag]
    27. validation_data = corpus[train_flag:test_flag]
    28. test_data = corpus[test_flag:]
    29. train_data_len=len(train_data)
    30. validation_data_len=len(validation_data)
    31. test_data_len=len(test_data)
    32. print("训练样本大小:", train_data_len)
    33. print("验证样本大小:", validation_data_len)
    34. print("测试样本大小:",test_data_len )
    35. # In[ ]:
    36. image_direct = "data\images"
    37. def retrieve_image_info(data):
    38. image_location = []
    39. sample = []
    40. for (i, corpus_row) in enumerate(data):
    41. corpus_strip = corpus_row.strip()
    42. corpus_strip = corpus_strip.split(" ")
    43. image_name = corpus_strip[0]
    44. leve1 = image_name.split("-")[0]
    45. leve2 = image_name.split("-")[1]
    46. image_location_detail = os.path.join(
    47. image_direct, leve1, leve1 + "-" + leve2, image_name + ".png"
    48. )
    49. if os.path.getsize(image_location_detail) >0 :
    50. image_location.append(image_location_detail)
    51. sample.append(corpus_row.split("\n")[0])
    52. print("手写图像路径:",image_location[0],"手写文本信息:",sample[0])
    53. return image_location, sample
    54. train_image, train_tag = retrieve_image_info(train_data)
    55. validation_image, validation_tag = retrieve_image_info(validation_data)
    56. test_image, test_tag = retrieve_image_info(test_data)
    57. # In[ ]:
    58. # 查找训练数据词汇最大长度
    59. train_tag_extract = []
    60. vocab = set()
    61. max_len = 0
    62. for tag in train_tag:
    63. tag = tag.split(" ")[-1].strip()
    64. for i in tag:
    65. vocab.add(i)
    66. max_len = max(max_len, len(tag))
    67. train_tag_extract.append(tag)
    68. print("最大长度: ", max_len)
    69. print("单词大小: ", len(vocab))
    70. print("单词内容: ", vocab)
    71. train_tag_extract[40:45]
    72. # In[ ]:
    73. print(train_tag[50:54])
    74. print(validation_tag[10:14])
    75. print(test_tag[80:84])
    76. def extract_tag_info(tags):
    77. extract_tag = []
    78. for tag in tags:
    79. tag = tag.split(" ")[-1].strip()
    80. extract_tag.append(tag)
    81. return extract_tag
    82. train_tag_tune = extract_tag_info(train_tag)
    83. validation_tag_tune = extract_tag_info(validation_tag)
    84. test_tag_tune = extract_tag_info(test_tag)
    85. print(train_tag_tune[50:54])
    86. print(validation_tag_tune[10:14])
    87. print(test_tag_tune[80:84])
    88. # In[ ]:
    89. AUTOTUNE = tf.data.AUTOTUNE
    90. # 映射单词到数字
    91. string_to_no = StringLookup(vocabulary=list(vocab), invert=False)
    92. # 映射数字到单词
    93. no_map_string = StringLookup(
    94. vocabulary=string_to_no.get_vocabulary(), invert=True)
    95. # In[ ]:
    96. def distortion_free_resize(image, img_size):
    97. w, h = img_size
    98. image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True, antialias=False, name=None)
    99. # 计算填充区域大小
    100. pad_height = h - tf.shape(image)[0]
    101. pad_width = w - tf.shape(image)[1]
    102. if pad_height % 2 != 0:
    103. height = pad_height // 2
    104. pad_height_top = height + 1
    105. pad_height_bottom = height
    106. else:
    107. pad_height_top = pad_height_bottom = pad_height // 2
    108. if pad_width % 2 != 0:
    109. width = pad_width // 2
    110. pad_width_left = width + 1
    111. pad_width_right = width
    112. else:
    113. pad_width_left = pad_width_right = pad_width // 2
    114. image = tf.pad(
    115. image,
    116. paddings=[
    117. [pad_height_top, pad_height_bottom],
    118. [pad_width_left, pad_width_right],
    119. [0, 0],
    120. ],
    121. )
    122. image = tf.transpose(image, perm=[1, 0, 2])
    123. image = tf.image.flip_left_right(image)
    124. return image
    125. # In[ ]:
    126. batch_size = 64
    127. padding_token = 99
    128. image_width = 128
    129. image_height = 32
    130. def preprocess_image(image_path, img_size=(image_width, image_height)):
    131. image = tf.io.read_file(image_path)
    132. image = tf.image.decode_png(image, 1)
    133. image = distortion_free_resize(image, img_size)
    134. image = tf.cast(image, tf.float32) / 255.0
    135. return image
    136. def vectorize_tag(tag):
    137. tag = string_to_no(tf.strings.unicode_split(tag, input_encoding="UTF-8"))
    138. length = tf.shape(tag)[0]
    139. pad_amount = max_len - length
    140. tag = tf.pad(tag, paddings=[[0, pad_amount]], constant_values=padding_token)
    141. return tag
    142. def process_images_tags(image_path, tag):
    143. image = preprocess_image(image_path)
    144. tag = vectorize_tag(tag)
    145. return {"image": image, "tag": tag}
    146. def prepare_dataset(image_paths, tags):
    147. dataset = tf.data.Dataset.from_tensor_slices((image_paths, tags)).map(
    148. process_images_tags, num_parallel_calls=AUTOTUNE
    149. )
    150. return dataset.batch(batch_size).cache().prefetch(AUTOTUNE)
    151. # In[ ]:
    152. train_final = prepare_dataset(train_image, train_tag_extract )
    153. validation_final = prepare_dataset(validation_image, validation_tag_tune )
    154. test_final = prepare_dataset(test_image, test_tag_tune )
    155. print(train_final.take(1))
    156. print(train_final)
    157. # In[ ]:
    158. plt.rcParams['font.family'] = ['Microsoft YaHei']
    159. for data in train_final.take(1):
    160. images, tags = data["image"], data["tag"]
    161. _, ax = plt.subplots(4, 4, figsize=(15, 8))
    162. for i in range(16):
    163. img = images[i]
    164. img = tf.image.flip_left_right(img)
    165. img = tf.transpose(img, perm=[1, 0, 2])
    166. img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
    167. img = img[:, :, 0]
    168. tag = tags[i]
    169. indices = tf.gather(tag, tf.where(tf.math.not_equal(tag, padding_token)))
    170. tag = tf.strings.reduce_join(no_map_string(indices))
    171. tag = tag.numpy().decode("utf-8")
    172. ax[i // 4, i % 4].imshow(img)
    173. ax[i // 4, i % 4].set_title(u"真实文本:%s"%tag)
    174. ax[i // 4, i % 4].axis("on")
    175. plt.show()
    176. # In[ ]:
    177. class CTCLoss(keras.layers.Layer):
    178. def call(self, y_true, y_pred):
    179. batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
    180. input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
    181. tag_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
    182. input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
    183. tag_length = tag_length * tf.ones(shape=(batch_len, 1), dtype="int64")
    184. loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, tag_length)
    185. self.add_loss(loss)
    186. return loss
    187. def generate_model():
    188. # Inputs to the model
    189. input_img = keras.Input(shape=(image_width, image_height, 1), name="image")
    190. tags = keras.layers.Input(name="tag", shape=(None,))
    191. # First conv block.
    192. t = keras.layers.Conv2D(
    193. filters=32,
    194. kernel_size=(3, 3),
    195. activation="relu",
    196. kernel_initializer="he_normal",
    197. padding="same",
    198. name="ConvolutionLayer1")(input_img)
    199. t = keras.layers.AveragePooling2D((2, 2), name="AveragePooling_one")(t)
    200. # Second conv block.
    201. t = keras.layers.Conv2D(
    202. filters=64,
    203. kernel_size=(3, 3),
    204. activation="relu",
    205. kernel_initializer="he_normal",
    206. padding="same",
    207. name="ConvolutionLayer2")(t)
    208. t = keras.layers.AveragePooling2D((2, 2), name="AveragePooling_two")(t)
    209. #re_shape = (t,[(image_width // 4), -1])
    210. #tf.dtypes.cast(t, tf.int32)
    211. re_shape = ((image_width // 4), (image_height // 4) * 64)
    212. t = keras.layers.Reshape(target_shape=re_shape, name="reshape")(t)
    213. t = keras.layers.Dense(64, activation="relu", name="denseone",use_bias=False,
    214. kernel_initializer='glorot_uniform',
    215. bias_initializer='zeros')(t)
    216. t = keras.layers.Dropout(0.4)(t)
    217. # RNNs.
    218. t = keras.layers.Bidirectional(
    219. keras.layers.LSTM(128, return_sequences=True, dropout=0.4)
    220. )(t)
    221. t = keras.layers.Bidirectional(
    222. keras.layers.LSTM(64, return_sequences=True, dropout=0.4)
    223. )(t)
    224. t = keras.layers.Dense(
    225. len(string_to_no.get_vocabulary())+2, activation="softmax", name="densetwo"
    226. )(t)
    227. # Add CTC layer for calculating CTC loss at each step.
    228. output = CTCLoss(name="ctc_loss")(tags, t)
    229. # Define the model.
    230. model = keras.models.Model(
    231. inputs=[input_img, tags], outputs=output, name="handwriting"
    232. )
    233. # Optimizer.
    234. # Compile the model and return.
    235. model.compile(optimizer=keras.optimizers.Adam())
    236. return model
    237. # Get the model.
    238. model = generate_model()
    239. model.summary()
    240. # In[ ]:
    241. validation_images = []
    242. validation_tags = []
    243. for batch in validation_final:
    244. validation_images.append(batch["image"])
    245. validation_tags.append(batch["tag"])
    246. # In[ ]:
    247. #epochs = 20
    248. model = generate_model()
    249. prediction_model = keras.models.Model(
    250. model.get_layer(name="image").input, model.get_layer(name="densetwo").output)
    251. #edit_distance_callback = EarlyStoppingAtLoss()
    252. epochs = 60
    253. early_stopping_patience = 10
    254. # Add early stopping
    255. early_stopping = keras.callbacks.EarlyStopping(
    256. monitor="val_loss", patience=early_stopping_patience, restore_best_weights=True
    257. )
    258. # Train the model.
    259. history = model.fit(
    260. train_final,
    261. validation_data=validation_final,
    262. epochs=60,callbacks=[early_stopping]
    263. )
    264. # ## Inference
    265. # In[ ]:
    266. plt.rcParams['font.family'] = ['Microsoft YaHei']
    267. # A utility function to decode the output of the network.
    268. def handwriting_prediction(pred):
    269. input_len = np.ones(pred.shape[0]) * pred.shape[1]
    270. = []
    271. for j in results:
    272. j = tf.gather(j, tf.where(tf.math.not_equal(j, -1)))
    273. j = tf.strings.reduce_join(no_map_string(j)).numpy().decode("utf-8")
    274. output_text.append(j)
    275. return output_text
    276. # Let's check results on some test samples.
    277. for test in test_final.take(1):
    278. test_images = test["image"]
    279. _, ax = plt.subplots(4, 4, figsize=(15, 8))
    280. predit = prediction_model.predict(test_images)
    281. predit_text = handwriting_prediction(predit)
    282. for k in range(16):
    283. img = test_images[k]
    284. img = tf.image.flip_left_right(img)
    285. img = tf.transpose(img, perm=[1, 0, 2])
    286. img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
    287. img = img[:, :, 0]
    288. title = f"预测结果: {predit_text[k]}"
    289. # In[ ]:

    创作不易 觉得有帮助请点赞关注收藏~~~

  • 相关阅读:
    【Java基础】Java容器相关知识小结
    国内机器人编程赛事大全介绍
    win10系统 C++环境 安装编译GRPC
    Google Earth Engine APP(GEE)—设定中国区域的一个夜间灯光时序分析app
    ElementPlus主题色修改
    JAVA毕业设计100—基于Java+Springboot+Vue的WMS仓库管理系统+移动端微信小程序(源码+数据库+部署视频)
    基于单片机的推箱子游戏仿真设计(#0013)
    Windows环境下ADB调试——安装adb
    深度学习模型试跑(十四):Bytetrack(vs2019 训练+trt推理部署)
    Numpy手撸softmax regression
  • 原文地址:https://blog.csdn.net/jiebaoshayebuhui/article/details/128192284