• 技术干货|昇思MindSpore NLP模型迁移之Bert模型—文本匹配任务(二):训练和评估


    前言:

    我将会介绍如何使用MindSpore的Bert模型来做下游任务:lcqmc的文本匹配任务。

    主机环境:

    系统:ubuntu18

    GPU:3090

    MindSpore版本:1.3

    数据集:lcqmc

    lcqmc文本匹配任务的定义:

    哈工大文本匹配数据集,LCQMC 是哈尔滨工业大学在自然语言处理国际顶会 COLING2018 构建的问题语义匹配数据集,其目标是判断两个问题的语义是否相同。

    数据集中的字段分别如下:

    text_a, text_b, label。

    其中text_a和text_b为两个问题的文本。若两个问题的语义相同则label为1,否则为0。

    权重迁移PyTorch->MindSpore

    由于官网已经提供了微调好的权重信息,所以我们尝试直接转换权重进行预测。

    我们先要知道模型权重名称以及形状等,需要PyTorch与MindSpore模型一一对应。

    首先,我们将huggingface的bert-chinese-base的torch bin文件下载下来。

    接下来使用下面的函数将Torch权重参数文件转化为MindSpore权重参数文件

    1. def torch_to_ms(model, torch_model,save_path):
    2. """
    3. Updates mobilenetv2 model mindspore param's data from torch param's data.
    4. Args:
    5. model: mindspore model
    6. torch_model: torch model
    7. """
    8. print("start load")
    9. # load torch parameter and mindspore parameter
    10. torch_param_dict = torch_model
    11. ms_param_dict = model.parameters_dict()
    12. count = 0
    13. for ms_key in ms_param_dict.keys():
    14. ms_key_tmp = ms_key.split('.')
    15. if ms_key_tmp[0] == 'bert_embedding_lookup':
    16. count+=1
    17. update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.word_embeddings.weight', ms_key)
    18. elif ms_key_tmp[0] == 'bert_embedding_postprocessor':
    19. if ms_key_tmp[1] == "token_type_embedding":
    20. count+=1
    21. update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.token_type_embeddings.weight', ms_key)
    22. elif ms_key_tmp[1] == "full_position_embedding":
    23. count+=1
    24. update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.position_embeddings.weight',
    25. ms_key)
    26. elif ms_key_tmp[1] =="layernorm":
    27. if ms_key_tmp[2]=="gamma":
    28. count+=1
    29. update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.LayerNorm.weight',
    30. ms_key)
    31. else:
    32. count+=1
    33. update_torch_to_ms(torch_param_dict, ms_param_dict, 'embeddings.LayerNorm.bias',
    34. ms_key)
    35. elif ms_key_tmp[0] == "bert_encoder":
    36. if ms_key_tmp[3] == 'attention':
    37. par = ms_key_tmp[4].split('_')[0]
    38. count+=1
    39. update_torch_to_ms(torch_param_dict, ms_param_dict, 'encoder.layer.'+ms_key_tmp[2]+'.'+ms_key_tmp[3]+'.'
    40. +'self.'+par+'.'+ms_key_tmp[5],
    41. ms_key)
    42. elif ms_key_tmp[3] == 'attention_output':
    43. if ms_key_tmp[4] == 'dense':
    44. print(7)
    45. count+=1
    46. update_torch_to_ms(torch_param_dict, ms_param_dict,
    47. 'encoder.layer.' + ms_key_tmp[2] + '.attention.output.'+ms_key_tmp[4]+'.'+ms_key_tmp[5],
    48. ms_key)
    49. elif ms_key_tmp[4]=='layernorm':
    50. if ms_key_tmp[5]=='gamma':
    51. print(8)
    52. count+=1
    53. update_torch_to_ms(torch_param_dict, ms_param_dict,
    54. 'encoder.layer.' + ms_key_tmp[2] + '.attention.output.LayerNorm.weight',
    55. ms_key)
    56. else:
    57. count+=1
    58. update_torch_to_ms(torch_param_dict, ms_param_dict,
    59. 'encoder.layer.' + ms_key_tmp[2] + '.attention.output.LayerNorm.bias',
    60. ms_key)
    61. elif ms_key_tmp[3] == 'intermediate':
    62. count+=1
    63. update_torch_to_ms(torch_param_dict, ms_param_dict,
    64. 'encoder.layer.' + ms_key_tmp[2] + '.intermediate.dense.'+ms_key_tmp[4],
    65. ms_key)
    66. elif ms_key_tmp[3] == 'output':
    67. if ms_key_tmp[4] == 'dense':
    68. count+=1
    69. update_torch_to_ms(torch_param_dict, ms_param_dict,
    70. 'encoder.layer.' + ms_key_tmp[2] + '.output.dense.'+ms_key_tmp[5],
    71. ms_key)
    72. else:
    73. if ms_key_tmp[5] == 'gamma':
    74. count+=1
    75. update_torch_to_ms(torch_param_dict, ms_param_dict,
    76. 'encoder.layer.' + ms_key_tmp[2] + '.output.LayerNorm.weight',
    77. ms_key)
    78. else:
    79. count+=1
    80. update_torch_to_ms(torch_param_dict, ms_param_dict,
    81. 'encoder.layer.' + ms_key_tmp[2] + '.output.LayerNorm.bias',
    82. ms_key)
    83. if ms_key_tmp[0] == 'dense':
    84. if ms_key_tmp[1] == 'weight':
    85. count+=1
    86. update_torch_to_ms(torch_param_dict, ms_param_dict,
    87. 'pooler.dense.weight',
    88. ms_key)
    89. else:
    90. count+=1
    91. update_torch_to_ms(torch_param_dict, ms_param_dict,
    92. 'pooler.dense.bias',
    93. ms_key)
    94. save_checkpoint(model, save_path)
    95. print("finish load")
    96. def update_bn(torch_param_dict, ms_param_dict, ms_key, ms_key_tmp):
    97. """Updates mindspore batchnorm param's data from torch batchnorm param's data."""
    98. str_join = '.'
    99. if ms_key_tmp[-1] == "moving_mean":
    100. ms_key_tmp[-1] = "running_mean"
    101. torch_key = str_join.join(ms_key_tmp)
    102. update_torch_to_ms(torch_param_dict, ms_param_dict, torch_key, ms_key)
    103. elif ms_key_tmp[-1] == "moving_variance":
    104. ms_key_tmp[-1] = "running_var"
    105. torch_key = str_join.join(ms_key_tmp)
    106. update_torch_to_ms(torch_param_dict, ms_param_dict, torch_key, ms_key)
    107. elif ms_key_tmp[-1] == "gamma":
    108. ms_key_tmp[-1] = "weight"
    109. torch_key = str_join.join(ms_key_tmp)
    110. update_torch_to_ms(torch_param_dict, ms_param_dict, 'transformer.' + torch_key, ms_key)
    111. elif ms_key_tmp[-1] == "beta":
    112. ms_key_tmp[-1] = "bias"
    113. torch_key = str_join.join(ms_key_tmp)
    114. update_torch_to_ms(torch_param_dict, ms_param_dict, 'transformer.' + torch_key, ms_key)
    115. def update_torch_to_ms(torch_param_dict, ms_param_dict, torch_key, ms_key):
    116. """Updates mindspore param's data from torch param's data."""
    117. value = torch_param_dict[torch_key].cpu().numpy()
    118. value = Parameter(Tensor(value), name=ms_key)
    119. _update_param(ms_param_dict[ms_key], value)
    120. def _update_param(param, new_param):
    121. """Updates param's data from new_param's data."""
    122. if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor):
    123. if param.data.dtype != new_param.data.dtype:
    124. print("Failed to combine the net and the parameters for param %s.", param.name)
    125. msg = ("Net parameters {} type({}) different from parameter_dict's({})"
    126. .format(param.name, param.data.dtype, new_param.data.dtype))
    127. raise RuntimeError(msg)
    128. if param.data.shape != new_param.data.shape:
    129. if not _special_process_par(param, new_param):
    130. print("Failed to combine the net and the parameters for param %s.", param.name)
    131. msg = ("Net parameters {} shape({}) different from parameter_dict's({})"
    132. .format(param.name, param.data.shape, new_param.data.shape))
    133. raise RuntimeError(msg)
    134. return
    135. param.set_data(new_param.data)
    136. return
    137. if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):
    138. if param.data.shape != (1,) and param.data.shape != ():
    139. print("Failed to combine the net and the parameters for param %s.", param.name)
    140. msg = ("Net parameters {} shape({}) is not (1,), inconsistent with parameter_dict's(scalar)."
    141. .format(param.name, param.data.shape))
    142. raise RuntimeError(msg)
    143. param.set_data(initializer(new_param.data, param.data.shape, param.data.dtype))
    144. elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor):
    145. print("Failed to combine the net and the parameters for param %s.", param.name)
    146. msg = ("Net parameters {} type({}) different from parameter_dict's({})"
    147. .format(param.name, type(param.data), type(new_param.data)))
    148. raise RuntimeError(msg)
    149. else:
    150. param.set_data(type(param.data)(new_param.data))
    151. def _special_process_par(par, new_par):
    152. """
    153. Processes the special condition.
    154. Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor.
    155. """
    156. par_shape_len = len(par.data.shape)
    157. new_par_shape_len = len(new_par.data.shape)
    158. delta_len = new_par_shape_len - par_shape_len
    159. delta_i = 0
    160. for delta_i in range(delta_len):
    161. if new_par.data.shape[par_shape_len + delta_i] != 1:
    162. break
    163. if delta_i == delta_len - 1:
    164. new_val = new_par.data.asnumpy()
    165. new_val = new_val.reshape(par.data.shape)
    166. par.set_data(Tensor(new_val, par.data.dtype))
    167. return True
    168. return False

    实际应用案例如下:

    1. import BertConfig
    2. import BertModel as ms_bm
    3. import BertModel as tc_bm
    4. bert_config_file = "./model/test.yaml"
    5. bert_config = BertConfig.from_yaml_file(bert_config_file)
    6. model = ms_bm(bert_config, False)
    7. torch_model = tc_bm.from_pretrained("/content/model/bert_cn")
    8. torch_to_ms(model, torch_model.state_dict(),"./model/bert2.ckpt")

    这里名称一定要一一对应。如果后期改动了模型,也需要在检查一下这个转换函数是否能对应。

    下游任务:lcqmc文本匹配任务训练

    1、封装Bert为bert_embeding

    首先我们先将之前构建好的Bert再进行一步封装为bert_embeding

    1. # Copyright 2021 Huawei Technologies Co., Ltd
    2. #
    3. # Licensed under the Apache License, Version 2.0 (the "License");
    4. # you may not use this file except in compliance with the License.
    5. # You may obtain a copy of the License at
    6. #
    7. # http://www.apache.org/licenses/LICENSE-2.0
    8. #
    9. # Unless required by applicable law or agreed to in writing, software
    10. # distributed under the License is distributed on an "AS IS" BASIS,
    11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12. # See the License for the specific language governing permissions and
    13. # limitations under the License.
    14. # ============================================================================
    15. """Bert Embedding."""
    16. import logging
    17. from typing import Tuple
    18. import mindspore.nn as nn
    19. from mindspore import Tensor
    20. from mindspore.train.serialization import load_checkpoint, load_param_into_net
    21. import BertModel, BertConfig
    22. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    23. logger = logging.getLogger(__name__)
    24. class BertEmbedding(nn.Cell):
    25. """
    26. This is a class that loads pre-trained weight files into the model.
    27. """
    28. def __init__(self, bert_config: BertConfig, is_training: bool = False):
    29. super(BertEmbedding, self).__init__()
    30. self.bert = BertModel(bert_config, is_training)
    31. def init_bertmodel(self, bert):
    32. """
    33. Manual initialization BertModel
    34. """
    35. self.bert = bert
    36. def from_pretrain(self, ckpt_file):
    37. """
    38. Load the model parameters from checkpoint
    39. """
    40. param_dict = load_checkpoint(ckpt_file)
    41. load_param_into_net(self.bert, param_dict)
    42. def construct(self, input_ids: Tensor, token_type_ids: Tensor, input_mask: Tensor) -> Tuple[Tensor, Tensor]:
    43. """
    44. Returns the result of the model after loading the pre-training weights
    45. Args:
    46. input_ids (:class:`mindspore.tensor`):A vector containing the transformation of characters
    47. into corresponding ids.
    48. token_type_ids (:class:`mindspore.tensor`):A vector containing segemnt ids.
    49. input_mask (:class:`mindspore.tensor`):the mask for input_ids.
    50. Returns:
    51. sequence_output:the sequence output .
    52. pooled_output:the pooled output of first token:cls..
    53. """
    54. sequence_output, pooled_output, _ = self.bert(input_ids, token_type_ids, input_mask)
    55. return sequence_output, pooled_output

    2、下游任务:BertforSequenceClassification

    将Bert作为预训练模型,接着在Bert的基础上,取Bert的cls token的embeding作为输入,输入到全连接网络中,这就是BertforSequenceClassification

    1. # Copyright 2021 Huawei Technologies Co., Ltd
    2. #
    3. # Licensed under the Apache License, Version 2.0 (the "License");
    4. # you may not use this file except in compliance with the License.
    5. # You may obtain a copy of the License at
    6. #
    7. # http://www.apache.org/licenses/LICENSE-2.0
    8. #
    9. # Unless required by applicable law or agreed to in writing, software
    10. # distributed under the License is distributed on an "AS IS" BASIS,
    11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12. # See the License for the specific language governing permissions and
    13. # limitations under the License.
    14. # ============================================================================
    15. """Bert for Sequence Classification script."""
    16. import numpy as np
    17. import mindspore.nn as nn
    18. import mindspore.ops as ops
    19. import mindspore.common.dtype as mstype
    20. from mindspore.common.initializer import TruncatedNormal
    21. from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
    22. from mindspore.context import ParallelMode
    23. from mindspore.common.tensor import Tensor
    24. from mindspore.common.parameter import Parameter
    25. from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
    26. from mindspore.ops import operations as P
    27. from mindspore.ops import functional as F
    28. from mindspore.ops import composite as C
    29. from mindspore.ops import Squeeze
    30. from mindspore.communication.management import get_group_size
    31. from mindspore import context, load_checkpoint, load_param_into_net
    32. from mindspore.common.seed import _get_graph_seed
    33. from bert_embedding import BertEmbedding
    34. class BertforSequenceClassification(nn.Cell):
    35. """
    36. Train interface for classification finetuning task.
    37. Args:
    38. config (Class): Configuration for BertModel.
    39. is_training (bool): True for training mode. False for eval mode.
    40. num_labels (int): Number of label types.
    41. dropout_prob (float): The dropout probability for BertforSequenceClassification.
    42. multi_sample_dropout (int): Dropout times per step
    43. label_smooth (float): Label Smoothing Regularization
    44. """
    45. def __init__(self, config, is_training, num_labels, dropout_prob=0.0, multi_sample_dropout=1, label_smooth=1):
    46. super(BertforSequenceClassification, self).__init__()
    47. if not is_training:
    48. config.hidden_dropout_prob = 0.0
    49. config.hidden_probs_dropout_prob = 0.0
    50. self.bert = BertEmbedding(config, is_training)
    51. self.cast = P.Cast()
    52. self.weight_init = TruncatedNormal(config.initializer_range)
    53. self.softmax = nn.Softmax(axis=-1)
    54. self.dtype = config.dtype
    55. self.num_labels = num_labels
    56. self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
    57. has_bias=True).to_float(mstype.float32)
    58. self.dropout_list=[]
    59. for count in range(0, multi_sample_dropout):
    60. seed0, seed1 = _get_graph_seed(1, "dropout")
    61. self.dropout_list.append(ops.Dropout(1-dropout_prob, seed0, seed1))
    62. self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False, reduction="mean")
    63. self.squeeze = Squeeze(1)
    64. self.num_labels = num_labels
    65. self.is_training = is_training
    66. self.one_hot = nn.OneHot(depth=num_labels, axis=-1)
    67. self.label_smooth = label_smooth
    68. def from_pretrain(self, ckpt_file):
    69. """
    70. Load the model parameters from checkpoint
    71. """
    72. param_dict = load_checkpoint(ckpt_file)
    73. load_param_into_net(self, param_dict)
    74. def init_embedding(self, embedding):
    75. """
    76. Manual initialization Embedding
    77. """
    78. self.bert = embedding
    79. def construct(self, input_ids, input_mask, token_type_id, label_ids=0):
    80. """
    81. Classification task
    82. """
    83. _, pooled_output = self.bert(input_ids, token_type_id, input_mask)
    84. loss = None
    85. if self.is_training:
    86. onehot_label = self.one_hot(self.squeeze(label_ids))
    87. smooth_label = self.label_smooth * onehot_label + (1-self.label_smooth)/(self.num_labels-1) * (1-onehot_label)
    88. for dropout in self.dropout_list:
    89. cls, _ = dropout(pooled_output)
    90. logits = self.dense_1(cls)
    91. temp_loss = self.loss(logits, smooth_label)
    92. if loss == None:
    93. loss = temp_loss
    94. else:
    95. loss += temp_loss
    96. loss = loss/len(self.dropout_list)
    97. else:
    98. loss = self.dense_1(pooled_output)
    99. return loss
    100. class BertLearningRate(LearningRateSchedule):
    101. """
    102. Warmup-decay learning rate for Bert network.
    103. """
    104. def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
    105. super(BertLearningRate, self).__init__()
    106. self.warmup_flag = False
    107. if warmup_steps > 0:
    108. self.warmup_flag = True
    109. self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
    110. self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
    111. self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
    112. self.greater = P.Greater()
    113. self.one = Tensor(np.array([1.0]).astype(np.float32))
    114. self.cast = P.Cast()
    115. def construct(self, global_step):
    116. decay_lr = self.decay_lr(global_step)
    117. if self.warmup_flag:
    118. is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
    119. warmup_lr = self.warmup_lr(global_step)
    120. lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
    121. else:
    122. lr = decay_lr
    123. return lr
    124. class BertFinetuneCell(nn.Cell):
    125. """
    126. Especially defined for finetuning where only four inputs tensor are needed.
    127. Append an optimizer to the training network after that the construct
    128. function can be called to create the backward graph.
    129. Different from the builtin loss_scale wrapper cell, we apply grad_clip before the optimization.
    130. Args:
    131. network (Cell): The training network. Note that loss function should have been added.
    132. optimizer (Optimizer): Optimizer for updating the weights.
    133. scale_update_cell (Cell): Cell to do the loss scale. Default: None.
    134. """
    135. def __init__(self, network, optimizer, scale_update_cell=None):
    136. super(BertFinetuneCell, self).__init__(auto_prefix=False)
    137. self.network = network
    138. self.network.set_grad()
    139. self.weights = optimizer.parameters
    140. self.optimizer = optimizer
    141. self.grad = C.GradOperation(get_by_list=True,
    142. sens_param=True)
    143. self.reducer_flag = False
    144. self.allreduce = P.AllReduce()
    145. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
    146. if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
    147. self.reducer_flag = True
    148. self.grad_reducer = None
    149. if self.reducer_flag:
    150. mean = context.get_auto_parallel_context("gradients_mean")
    151. degree = get_group_size()
    152. self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
    153. self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
    154. self.cast = P.Cast()
    155. self.gpu_target = False
    156. if context.get_context("device_target") == "GPU":
    157. self.gpu_target = True
    158. self.float_status = P.FloatStatus()
    159. self.addn = P.AddN()
    160. self.reshape = P.Reshape()
    161. else:
    162. self.alloc_status = P.NPUAllocFloatStatus()
    163. self.get_status = P.NPUGetFloatStatus()
    164. self.clear_status = P.NPUClearFloatStatus()
    165. self.reduce_sum = P.ReduceSum(keep_dims=False)
    166. self.base = Tensor(1, mstype.float32)
    167. self.less_equal = P.LessEqual()
    168. self.hyper_map = C.HyperMap()
    169. self.loss_scale = None
    170. self.loss_scaling_manager = scale_update_cell
    171. if scale_update_cell:
    172. self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
    173. def construct(self,
    174. input_ids,
    175. input_mask,
    176. token_type_id,
    177. label_ids,
    178. sens=None):
    179. """Bert Finetune"""
    180. weights = self.weights
    181. init = False
    182. loss = self.network(input_ids,
    183. input_mask,
    184. token_type_id,
    185. label_ids)
    186. if sens is None:
    187. scaling_sens = self.loss_scale
    188. else:
    189. scaling_sens = sens
    190. if not self.gpu_target:
    191. init = self.alloc_status()
    192. init = F.depend(init, loss)
    193. clear_status = self.clear_status(init)
    194. scaling_sens = F.depend(scaling_sens, clear_status)
    195. grads = self.grad(self.network, weights)(input_ids,
    196. input_mask,
    197. token_type_id,
    198. label_ids,
    199. self.cast(scaling_sens,
    200. mstype.float32))
    201. self.optimizer(grads)
    202. return loss

    3、任务训练

    1. from mindspore.train.callback import Callback
    2. from mindspore.train.callback import TimeMonitor
    3. from mindspore.train import Model
    4. from mindspore.nn.optim import AdamWeightDecay
    5. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
    6. from mindspore import save_checkpoint, context, load_checkpoint, load_param_into_net
    7. from mindtext.modules.encoder.bert import BertConfig
    8. from bert import BertforSequenceClassification, BertLearningRate, BertFinetuneCell
    9. from bert_embedding import BertEmbedding
    10. import LCQMCDataset
    11. from mindspore.common.tensor import Tensor
    12. import time
    13. def get_ms_timestamp():
    14. t = time.time()
    15. return int(round(t * 1000))
    16. class LossCallBack(Callback):
    17. """
    18. Monitor the loss in training.
    19. If the loss is NAN or INF terminating training.
    20. Note:
    21. If per_print_times is 0 do not print loss.
    22. Args:
    23. per_print_times (int): Print loss every times. Default: 1.
    24. """
    25. def __init__(self, per_print_times=1, rank_ids=0):
    26. super(LossCallBack, self).__init__()
    27. if not isinstance(per_print_times, int) or per_print_times < 0:
    28. raise ValueError("print_step must be int and >= 0.")
    29. self._per_print_times = per_print_times
    30. self.rank_id = rank_ids
    31. self.time_stamp_first = get_ms_timestamp()
    32. def step_end(self, run_context):
    33. """Monitor the loss in training."""
    34. time_stamp_current = get_ms_timestamp()
    35. cb_params = run_context.original_args()
    36. print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - self.time_stamp_first,
    37. cb_params.cur_epoch_num,
    38. cb_params.cur_step_num,
    39. str(cb_params.net_outputs)))
    40. with open("./loss_{}.log".format(self.rank_id), "a+") as f:
    41. f.write("time: {}, epoch: {}, step: {}, loss: {}".format(
    42. time_stamp_current - self.time_stamp_first,
    43. cb_params.cur_epoch_num,
    44. cb_params.cur_step_num,
    45. str(cb_params.net_outputs.asnumpy())))
    46. f.write('\n')
    47. def train(train_data, bert, optimizer, save_path, epoch_num):
    48. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
    49. netwithgrads = BertFinetuneCell(bert, optimizer=optimizer, scale_update_cell=update_cell)
    50. callbacks = [TimeMonitor(train_data.get_dataset_size()), LossCallBack(train_data.get_dataset_size())]
    51. model = Model(netwithgrads)
    52. model.train(epoch_num, train_data, callbacks=callbacks, dataset_sink_mode=False)
    53. save_checkpoint(model.train_network.network, save_path)
    54. def main():
    55. #context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
    56. context.set_context(mode=0, device_target="GPU")
    57. #context.set_context(enable_graph_kernel=True)
    58. epoch_num = 6
    59. save_path = "./model/output/train_lcqmc2.ckpt"
    60. dataset = LCQMCDataset(paths='./dataset/lcqmc',
    61. tokenizer="./model",
    62. max_length=128,
    63. truncation_strategy=True,
    64. batch_size=32, columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'],
    65. test_columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'])
    66. ds = dataset.from_cache(batch_size=128,
    67. columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'],
    68. test_columns_list=['input_ids', 'attention_mask', 'token_type_ids'])
    69. train_data = ds['train']
    70. bert_config_file = "./model/test.yaml"
    71. bert_config = BertConfig.from_yaml_file(bert_config_file)
    72. model_path = "./model/bert_cn.ckpt"
    73. bert = BertforSequenceClassification(bert_config, True, num_labels=2, dropout_prob=0.1, multi_sample_dropout=5, label_smooth=0.9)
    74. eb = BertEmbedding(bert_config, True)
    75. eb.from_pretrain(model_path)
    76. bert.init_embedding(eb)
    77. lr_schedule = BertLearningRate(learning_rate=2e-5,
    78. end_learning_rate=2e-5 * 0 ,
    79. warmup_steps=int(train_data.get_dataset_size() * epoch_num * 0.1),
    80. decay_steps=train_data.get_dataset_size() * epoch_num,
    81. power=1.0)
    82. params = bert.trainable_params()
    83. optimizer = AdamWeightDecay(params, lr_schedule, eps=1e-8)
    84. train(train_data, bert, optimizer, save_path, epoch_num)
    85. if __name__ == "__main__":
    86. main()

    关键参数:

    bert_config = BertConfig.from_yaml_file(bert_config_file):读取Bert的配置参数

    eb.from_pretrain(model_path) :加载Bert的MindSpore权重文件 bert.init_embedding(eb):初始化加载的权重

    lr_schedule :学习率控制器

    optimizer:梯度优化器

    评估

    使用lcaqmc的测试集来作为评估训练,输出模型在测试集中的精确度

    1. from mindspore.nn import Accuracy
    2. from tqdm import tqdm
    3. from mindspore import context
    4. import BertforSequenceClassification
    5. import BertConfig
    6. import mindspore
    7. import LCQMCDataset
    8. def eval(eval_data, model):
    9. metirc = Accuracy('classification')
    10. metirc.clear()
    11. squeeze = mindspore.ops.Squeeze(1)
    12. for batch in tqdm(eval_data.create_dict_iterator(num_epochs=1), total=eval_data.get_dataset_size()):
    13. input_ids = batch['input_ids']
    14. token_type_id = batch['token_type_ids']
    15. input_mask = batch['attention_mask']
    16. label_ids = batch['label']
    17. inputs = {"input_ids": input_ids,
    18. "input_mask": input_mask,
    19. "token_type_id": token_type_id
    20. }
    21. output = model(**inputs)
    22. sm = mindspore.nn.Softmax(axis=-1)
    23. output = sm(output)
    24. #print(output)
    25. metirc.update(output, squeeze(label_ids))
    26. print(metirc.eval())
    27. def main():
    28. context.set_context(mode=0, device_target="GPU")
    29. dataset = LCQMCDataset(paths='./dataset/lcqmc',
    30. tokenizer="./model",
    31. max_length=128,
    32. truncation_strategy=True,
    33. batch_size=128, columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'],
    34. test_columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'])
    35. #ds = dataset()
    36. ds = dataset.from_cache(batch_size=128,
    37. columns_list=['input_ids', 'attention_mask', 'token_type_ids', 'label'],
    38. test_columns_list=['input_ids', 'attention_mask', 'token_type_ids','label'])
    39. eval_data = ds['test']
    40. bert_config_file = "./model/test.yaml"
    41. bert_config = BertConfig.from_yaml_file(bert_config_file)
    42. bert = BertforSequenceClassification(bert_config, is_training=False, num_labels=2, dropout_prob=0.0)
    43. model_path = "./model/output/train_lcqmc2.ckpt"
    44. bert.from_pretrain(model_path)
    45. eval(eval_data, bert)
    46. if __name__ == "__main__":
    47. main()

    结果

    模型在对应的数据集的验证集和验证集精确度

  • 相关阅读:
    Asp.net MVC中文件夹中的控制器如何跳转到根目录的控制器中?
    Vue 源码解读(5)—— 全局 API
    【LeetCode】恢复二叉搜索树 [M](Morris遍历)
    【reverse】新160个CrackMe之116-REM-KeyGenME#10——脱壳、去背景音乐、识别反调试
    修改YOLOv5的模型结构第二弹
    axios+vue 请求时如何携带cookie
    Apache HBase
    Vue学习第20天——Vue中常用的ajax请求库(axios与vue-rouserce)
    redis的性能管理和雪崩
    Web自动化Selenium-键盘操作
  • 原文地址:https://blog.csdn.net/Kenji_Shinji/article/details/125478886