• transformer算法嵌入Embedding示例


    本示例的目的,是希望把sku训练好的embedding值嵌入到transformer算法中,从而提高transformer在销量预测算法中的准确性。

    一、训练数据格式说明

    1、embedding训练的数据格式示例:

    133657,本田#第八代雅阁,1816,4

    字段1表示:sku_id

    字段2表示:车型 # 款式

    字段3表示:车型 # 款式对应的序号id

    字段4表示:sku_id对应的类目信息

    2、销量预测训练的数据格式示例:

    0053#031188,0_0_0_0_0_0_1_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_1_0_0_0_0_0_0_0_0_0_0_0_1_0_0_0_0_0_1_0_0_0_0_0_0_0_0_1_0_0_0_0_0_0_0_0_0_0_0_1_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0_0
    字段1表示:门店code # sku_id

    字段2表示:历史N周每周的销量值

    二、Embedding Model阶段

    1、输入阶段

    (1)<sku_id, car_id>的组合关系,一个sku可以对应多个car_id

    (2)构建sku_car_matrix,其中car_id对应位置标记为1,其余位置标记为0

    (3)sku_id到sku_car_matrix中dim=0维的序号值的映射,方便后续sku_id根据序号值直接取到sku的embedding值,并持久化该映射关系

    (4)将sku_car_matrix装入DataSet中,其中__getitem()__为<train_data[index], train_data[index]>的组合,因为我们通过sku_car_matrix -> embedding -> sku_car_matrix从而获取embedding值。

    (5)将DataSet的值装入dataloader中,并设置shuffle为False,默认情况下DataLoader是会将数据打乱的。

    2、模型训练阶段

    (1)encoder - decoder网络架构

    (2)loss值

    decoder阶段输出的值与label值计算MseLoss值

    (3)每一批次迭代时会将中间值embedding保存下来,汇总所有迭代的embedding值便得到每一次epoch的embedding值。根据train_loss值,取最少值时的embedding值为最佳embedding值,并持久化该值

    (4)测试embedding效果的方法,将几个sku查看两两间的embedding的欧氏距离,理论上相近的sku欧氏距离值更小。

    (5)代码实现如下(embedding_model_train.py文件):

    1. import os
    2. import numpy as np
    3. import pandas as pd
    4. from torch.utils.data import Dataset,DataLoader
    5. import torch
    6. import torch.nn as nn
    7. import logging
    8. from tqdm import trange
    9. import transformer_utils
    10. logger = logging.getLogger('Transformer.Embedding')
    11. class EmbeddingTrainDataset(Dataset):
    12. def __init__(self, matrix_data):
    13. self.train_data = matrix_data
    14. self.train_len = len(matrix_data)
    15. def __len__(self):
    16. return self.train_len
    17. def __getitem__(self, index):
    18. return self.train_data[index], self.train_data[index]
    19. class AutoEncoder(nn.Module):
    20. def __init__(self, input_dim, embedding_dim):
    21. super(AutoEncoder, self).__init__()
    22. self.encoder = nn.Sequential(
    23. nn.Linear(input_dim, input_dim // 2),
    24. nn.Tanh(),
    25. nn.Linear(input_dim // 2, input_dim // 4),
    26. nn.Tanh(),
    27. nn.Linear(input_dim // 4, embedding_dim),
    28. )
    29. self.decoder = nn.Sequential(
    30. nn.Linear(embedding_dim, input_dim // 4),
    31. nn.Tanh(),
    32. nn.Linear(input_dim // 4, input_dim // 2),
    33. nn.Tanh(),
    34. nn.Linear(input_dim // 2, input_dim),
    35. )
    36. def forward(self, x):
    37. encoded = self.encoder(x)
    38. decoded = self.decoder(encoded)
    39. return encoded, decoded
    40. if __name__ == '__main__':
    41. embedding_dim = 100
    42. epochs = 10000
    43. lr = 0.001
    44. gamma = 0.95
    45. batch_size = 1000
    46. transformer_utils.set_logger(os.path.join(os.getcwd(), 'train.log'))
    47. data_frame = pd.read_csv(os.path.join(os.getcwd(), 'data', 'abs_sku_to_Car_classfication_onehot_detail.csv'), header=None,
    48. names=['sku_code', 'car_model', 'car_id', 'cat_id'], dtype={0: str, 1: str, 2: int, 3: int})
    49. sku_code_set = set(data_frame['sku_code'].drop_duplicates())
    50. sku2idx_dict = {}
    51. for i, sku_code in enumerate(sku_code_set):
    52. sku2idx_dict[sku_code] = i
    53. car_id_num = max(data_frame['car_id'])
    54. sku_code_num = len(sku_code_set)
    55. sku_code_car_matrix = np.zeros((sku_code_num, car_id_num), dtype='float32')
    56. np.save(os.path.join(os.getcwd(), 'data', 'sku2idx_dict'), sku2idx_dict)
    57. for i in trange(len(data_frame)):
    58. sku_code = data_frame.loc[i, 'sku_code']
    59. car_id = data_frame.loc[i, 'car_id']
    60. sku_code_idx = sku2idx_dict[sku_code]
    61. sku_code_car_matrix[sku_code_idx, car_id - 1] = 1
    62. train_set = EmbeddingTrainDataset(sku_code_car_matrix)
    63. train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=0, shuffle=False)
    64. device = "cuda" if torch.cuda.is_available() else "cpu"
    65. autoencoder_model = AutoEncoder(car_id_num, embedding_dim).to(device)
    66. criterion = nn.MSELoss()
    67. optimizer = torch.optim.AdamW(autoencoder_model.parameters(), lr=lr)
    68. train_loss_summary = np.zeros(epochs)
    69. best_evaluate_loss = 100.0
    70. for epoch in trange(epochs):
    71. train_total_loss = 0
    72. sku_encoder_embedding = np.zeros((sku_code_num, embedding_dim), dtype='float32')
    73. train_loader_len = len(train_loader)
    74. for i, (x_input, x_label) in enumerate(train_loader):
    75. x_input = x_input.to(device)
    76. x_label = x_label.to(device)
    77. encoded, decoded = autoencoder_model(x_input)
    78. loss = criterion(decoded, x_label)
    79. optimizer.zero_grad()
    80. loss.backward()
    81. optimizer.step()
    82. train_total_loss += loss.item()
    83. sku_encoder_embedding[(i * batch_size) : (i * batch_size + x_input.shape[0])] = encoded.detach().to('cpu').numpy()
    84. train_avg_loss = train_total_loss / train_loader_len
    85. logger.info(f'epoch: {epoch + 1}, train_loss: {train_avg_loss}')
    86. is_best = False
    87. if train_avg_loss < best_evaluate_loss:
    88. is_best = True
    89. best_evaluate_loss = train_avg_loss
    90. np.save(os.path.join(os.getcwd(), 'data', 'sku2embedding'), sku_encoder_embedding)
    91. logger.info(f'best embedding at: {epoch + 1}')
    92. if epoch >= 10: # 太前面的去掉,免得影响后面曲线的可观测性
    93. train_loss_summary[epoch] = train_avg_loss
    94. if epoch % 10 == 1:
    95. transformer_utils.plot_all_epoch(train_loss_summary, train_loss_summary, epoch, 'embedding_train_loss_summary.png')
    96. print('finish!')

     

    三、Embedding嵌入Transformer预测

    1、transformer预测数据预处理代码(transformer_preprocess_data.py文件):

    1. import os
    2. import numpy as np
    3. import pandas as pd
    4. from tqdm import trange
    5. # 数据格式转换成标准格式
    6. def normalize_data_format(data):
    7. data_sale_list_series = data['sale_info'].apply(lambda row: list(map(float, row.split("_"))))
    8. data_frame = pd.DataFrame(item for item in data_sale_list_series)
    9. data_frame = pd.concat((data['warehouse_sku'], data_frame), axis=1)
    10. data_frame = data_frame.transpose()
    11. return data_frame
    12. # 平滑过大的值
    13. def smooth_big_value(data_frame):
    14. columns_len = len(data_frame.columns)
    15. print(">>>>smooth_big_value")
    16. for i in trange(columns_len):
    17. values = data_frame.iloc[1:,i]
    18. value_mean = np.mean(values[values > 0])
    19. value_std = np.std(values[values > 0], ddof=1)
    20. value_std = value_std if value_std > 0 else 0
    21. values_new = np.round(np.where(values > value_mean + 3 * value_std, value_mean + 3 * value_std, values).astype(float))
    22. values_new = np.array(values_new, dtype=np.int).astype(str)
    23. data_frame.iloc[1:, i] = values_new
    24. return data_frame
    25. # 获取列名和id之间的映射关系
    26. def gen_col2series(columns):
    27. columns = columns.values[0,:]
    28. id2series_dict = {}
    29. series2id_dict = {}
    30. j = 0
    31. for i, column in enumerate(columns):
    32. id2series_dict[i] = column
    33. if series2id_dict.get(column) is None:
    34. series2id_dict[column] = j
    35. j += 1
    36. return id2series_dict, series2id_dict
    37. # 每列的最大值
    38. def gen_series2maxValue(data_frame):
    39. series_max_value = np.max(data_frame[1:], axis=0)
    40. series2maxValue = series_max_value.to_dict()
    41. return series2maxValue
    42. # 处理数据
    43. def prep_data(data, series2maxValue):
    44. num_series = data.shape[1]
    45. time_len = data.shape[0]
    46. windows_per_series = np.full((num_series), (time_len - backcast_len))
    47. total_windows = np.sum(windows_per_series)
    48. x_input = np.zeros((total_windows, backcast_len, 1 + 2), dtype='float32') # sale_info + series_info + max_value
    49. label = np.zeros((total_windows, backcast_len), dtype='float32')
    50. print(">>>>prep_data")
    51. count = 0
    52. zero_count = 0
    53. for series_idx in trange(num_series):
    54. for i in range(windows_per_series[series_idx]):
    55. x_input_data = data[i : i + backcast_len, series_idx]
    56. x_input_series = series_idx
    57. label_data = data[i + 1 : i + backcast_len + 1, series_idx]
    58. if np.max(x_input_data) > 0:
    59. x_input[count, :, 0] = x_input_data
    60. x_input[count, :, 1] = x_input_series
    61. x_input[count, :, 2] = series2maxValue.get(series_idx)
    62. label[count] = label_data
    63. x_input[count, :, 0] = x_input[count, :, 0] / series2maxValue.get(series_idx)
    64. label[count] = label[count] / series2maxValue.get(series_idx)
    65. count += 1
    66. elif np.max(label_data) == 0 and zero_count < 2000 and np.random.choice([0,1], p=[0.6, 0.4]) > 0:
    67. x_input[count, :, 0] = x_input_data
    68. x_input[count, :, 1] = x_input_series
    69. x_input[count, :, 2] = 0
    70. label[count] = label_data
    71. zero_count += 1
    72. count += 1
    73. x_input = x_input[:count]
    74. label = label[:count]
    75. return x_input, label
    76. # 切分测试集、验证集
    77. def split_train_test_data(x_input, label, train_ratio=0.8):
    78. x_len = x_input.shape[0]
    79. shuffle_idx = np.random.permutation(x_len)
    80. train_x_len = int(x_len * train_ratio)
    81. train_shuffle_idx = shuffle_idx[:train_x_len]
    82. test_shuffle_idx = shuffle_idx[train_x_len:]
    83. train_x_input = x_input[train_shuffle_idx]
    84. train_label = label[train_shuffle_idx]
    85. test_x_input = x_input[test_shuffle_idx]
    86. test_label = label[test_shuffle_idx]
    87. return train_x_input, train_label, test_x_input, test_label
    88. if __name__ == '__main__':
    89. backcast_len = 12
    90. train_val_num = 110
    91. data_frame = pd.read_csv(os.path.join(os.getcwd(), 'data', 'ads_hub_sale_num_detail_simple.csv'), header=None, names=['warehouse_sku', 'sale_info'])
    92. data_frame = normalize_data_format(data_frame)
    93. data_frame = data_frame[:train_val_num]
    94. data_frame = smooth_big_value(data_frame)
    95. id2series, series2id = gen_col2series(data_frame)
    96. series2maxValue = gen_series2maxValue(data_frame)
    97. x_input, label = prep_data(data_frame.values[1:].astype('float'), series2maxValue)
    98. train_x_input, train_label, test_x_input, test_label = split_train_test_data(x_input, label)
    99. np.save(os.path.join(os.getcwd(), 'data', 'train_data'), train_x_input)
    100. np.save(os.path.join(os.getcwd(), 'data', 'train_label'), train_label)
    101. np.save(os.path.join(os.getcwd(), 'data', 'test_data'), test_x_input)
    102. np.save(os.path.join(os.getcwd(), 'data', 'test_label'), test_label)
    103. np.save(os.path.join(os.getcwd(), 'data', 'series_max_value'), series2maxValue)
    104. np.save(os.path.join(os.getcwd(), 'data', 'series2id'), series2id)
    105. np.save(os.path.join(os.getcwd(), 'data', 'id2series'), id2series)
    106. print('finish!')

    2、dataloader加载数据的代码实现(transformer_dataloader.py文件):

    1. import logging
    2. import os
    3. import numpy as np
    4. from torch.utils.data import Dataset
    5. logger = logging.getLogger('Transformer.Data')
    6. class TrainDataset(Dataset):
    7. def __init__(self, data_path):
    8. self.data = np.load(os.path.join(data_path, 'data', 'train_data.npy'))
    9. self.label = np.load(os.path.join(data_path, 'data', 'train_label.npy'))
    10. self.id2series_dict = np.load(os.path.join(data_path, 'data', 'id2series.npy')).item()
    11. self.sku2idx_dict = np.load(os.path.join(data_path, 'data', 'sku2idx_dict.npy')).item()
    12. self.sku2embedding = np.load(os.path.join(data_path, 'data', 'sku2embedding.npy'))
    13. self.sku_embedding_avg = self.sku2embedding.mean(axis=0)
    14. self.train_len = self.data.shape[0]
    15. logger.info(f'train_len:{self.train_len}')
    16. logger.info('building datasets from train_data.npy')
    17. def __len__(self):
    18. return self.train_len
    19. def __getitem__(self, index):
    20. series_idx = int(self.data[index,0,-2])
    21. series = self.id2series_dict.get(series_idx)
    22. sku_code = series.split('#')[1]
    23. sku_idx = self.sku2idx_dict.get(sku_code)
    24. if sku_idx is None:
    25. sku_embedding = self.sku_embedding_avg
    26. else:
    27. sku_embedding = self.sku2embedding[sku_idx]
    28. return (self.data[index,:,:-2], series_idx, sku_embedding, self.label[index])
    29. class TestDataset(Dataset):
    30. def __init__(self, data_path):
    31. self.data = np.load(os.path.join(data_path, 'data', 'test_data.npy'))
    32. self.label = np.load(os.path.join(data_path, 'data', 'test_label.npy'))
    33. self.id2series_dict = np.load(os.path.join(data_path, 'data', 'id2series.npy')).item()
    34. self.sku2idx_dict = np.load(os.path.join(data_path, 'data', 'sku2idx_dict.npy')).item()
    35. self.sku2embedding = np.load(os.path.join(data_path, 'data', 'sku2embedding.npy'))
    36. self.sku_embedding_avg = self.sku2embedding.mean(axis=0)
    37. self.test_len = self.data.shape[0]
    38. logger.info(f'test_len:{self.test_len}')
    39. logger.info('building datasets from test_data.npy')
    40. def __len__(self):
    41. return self.test_len
    42. def __getitem__(self, index):
    43. series_idx = int(self.data[index, 0, -2])
    44. series = self.id2series_dict.get(series_idx)
    45. sku_code = series.split('#')[1]
    46. sku_idx = self.sku2idx_dict.get(sku_code)
    47. if sku_idx is None:
    48. sku_embedding = self.sku_embedding_avg
    49. else:
    50. sku_embedding = self.sku2embedding[sku_idx]
    51. return (self.data[index,:,:-2], series_idx, sku_embedding, self.data[index,0,-1], self.label[index])

    (1)id2series表示sku历史销量到sku_id的映射

    (2)sku2idx表示sku_id到embedding序号的映射

    (3)sku2embedding表示embedding序号到embedding值的映射

    (4)在dataset的__getitem__()函数中,先找到sku_id信息,并继续找到embedding值

    3、使用embedding值的流程

    代码实现如下(transformer_train.py文件):

    1. import os
    2. import numpy as np
    3. import torch
    4. import torch.nn as nn
    5. import math
    6. import time
    7. import transformer_utils
    8. from transformer_dataloader import TrainDataset,TestDataset
    9. from torch.utils.data import DataLoader
    10. import logging
    11. logger = logging.getLogger('Transformer.Train')
    12. class PositionalEncoding(nn.Module):
    13. def __init__(self, d_model, max_len=5000):
    14. super(PositionalEncoding, self).__init__()
    15. pe = torch.zeros(max_len, d_model)
    16. position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    17. div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    18. pe[:, 0::2] = torch.sin(position * div_term)
    19. pe[:, 1::2] = torch.cos(position * div_term)
    20. pe = pe.unsqueeze(0).transpose(0, 1)
    21. self.register_buffer('pe', pe)
    22. def forward(self, x, embedding):
    23. return x + self.pe[:x.size(0), :] + embedding
    24. class TransAm(nn.Module):
    25. def __init__(self, feature_size=100, num_layers=1, dropout=0.1):
    26. super(TransAm, self).__init__()
    27. self.model_type = 'Transformer'
    28. self.src_mask = None
    29. self.pos_encoder = PositionalEncoding(feature_size)
    30. self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_size, nhead=10, dropout=dropout)
    31. self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
    32. self.decoder = nn.Linear(feature_size, 1)
    33. self.init_weights()
    34. def init_weights(self):
    35. initrange = 0.1
    36. self.decoder.bias.data.zero_()
    37. self.decoder.weight.data.uniform_(-initrange, initrange)
    38. def _generate_square_subsequent_mask(self, sz):
    39. mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0,1)
    40. mask = mask.float().masked_fill(mask==0, float('-inf')).masked_fill(mask == 1, float(0.0))
    41. return mask
    42. def forward(self, src, pre_embedding):
    43. if self.src_mask is None or self.src_mask.shape[0] == len(src):
    44. device = src.device
    45. mask = self._generate_square_subsequent_mask(len(src)).to(device)
    46. self.src_mask = mask
    47. src = self.pos_encoder(src, pre_embedding)
    48. output = self.transformer_encoder(src, self.src_mask)
    49. output = self.decoder(output)
    50. return output
    51. def evaluate(model, test_loader):
    52. test_total_loss = 0
    53. test_total_with_max_loss = 0
    54. model.eval()
    55. test_loader_len = len(test_loader)
    56. for i, (test_batch, idx, embedding_batch, max_value, labels) in enumerate(test_loader):
    57. test_batch = test_batch.permute(1, 0, 2).to(device)
    58. labels = labels.permute(1, 0).to(device)
    59. embedding_batch = torch.unsqueeze(embedding_batch, dim=0).to(device)
    60. test_output = transformer_model(test_batch, embedding_batch)
    61. test_output = torch.squeeze(test_output)
    62. test_output[test_output < 0] = 0
    63. test_labels = labels[-1]
    64. test_output = test_output[-1]
    65. test_loss = criterion(test_output, test_labels)
    66. test_total_loss += test_loss.item()
    67. max_value = max_value.to(device)
    68. test_with_max_labels = test_labels * max_value
    69. test_with_max_output = test_output * max_value
    70. test_with_max_loss = criterion(test_with_max_output, test_with_max_labels)
    71. test_total_with_max_loss += test_with_max_loss
    72. test_avg_loss = test_total_loss / test_loader_len
    73. test_with_max_avg_loss = test_total_with_max_loss / test_loader_len
    74. return test_avg_loss, test_with_max_avg_loss
    75. if __name__ == '__main__':
    76. transformer_utils.set_logger(os.path.join(os.getcwd(), 'train.log'))
    77. json_path = os.path.join(os.getcwd(), 'params.json')
    78. params = transformer_utils.Params(json_path)
    79. lr = params.lr
    80. epochs = params.epochs
    81. feature_size = params.feature_size
    82. gamma = params.gamma
    83. device = torch.device(params.mode)
    84. input_window = params.input_window
    85. feature_size = params.feature_size
    86. train_set = TrainDataset(os.getcwd())
    87. test_set = TestDataset(os.getcwd())
    88. train_loader = DataLoader(train_set, batch_size=params.train_batch_size, num_workers=0)
    89. test_loader = DataLoader(test_set, batch_size=params.test_batch_size, num_workers=0)
    90. transformer_model = TransAm(feature_size=feature_size).to(device)
    91. criterion = nn.MSELoss()
    92. optimizer = torch.optim.AdamW(transformer_model.parameters(), lr=lr)
    93. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=gamma)
    94. train_loss_summary = np.zeros(epochs)
    95. test_loss_summary = np.zeros(epochs)
    96. best_evaluate_loss = 100.0
    97. for epoch in range(1, epochs + 1):
    98. epoch_start_time = time.time()
    99. train_loader_len = len(train_loader)
    100. train_total_loss = 0
    101. transformer_model.train()
    102. for i, (train_batch, idx, embedding_batch, label_batch) in enumerate(train_loader):
    103. optimizer.zero_grad()
    104. train_batch = train_batch.permute(1, 0, 2).to(device)
    105. label_batch = label_batch.permute(1, 0).to(device)
    106. embedding_batch = torch.unsqueeze(embedding_batch, dim=0).to(device)
    107. output = transformer_model(train_batch, embedding_batch)
    108. output = torch.squeeze(output)
    109. loss = criterion(output, label_batch)
    110. loss.backward()
    111. optimizer.step()
    112. train_total_loss += loss.item()
    113. train_avg_loss = train_total_loss / train_loader_len
    114. test_avg_loss, test_with_max_avg_loss = evaluate(transformer_model, test_loader)
    115. logger.info(f'epoch: {epoch}, train_loss: {train_avg_loss}, test_loss: {test_avg_loss}, test_max_loss: {test_with_max_avg_loss}')
    116. is_best = False
    117. if test_avg_loss < best_evaluate_loss:
    118. is_best = True
    119. best_evaluate_loss = test_avg_loss
    120. transformer_utils.save_checkpoint({'epoch': epoch,
    121. 'state_dict': transformer_model.state_dict(),
    122. 'optim_dict': optimizer.state_dict()},
    123. is_best,
    124. epoch=epoch)
    125. train_loss_summary[epoch] = train_avg_loss
    126. test_loss_summary[epoch] = test_avg_loss
    127. if epoch % 20 == 1:
    128. transformer_utils.plot_all_epoch(test_loss_summary, test_loss_summary, epoch, 'train_test_loss_summary.png')
    129. print('finish!')

    (1)读取train_loader的embedding_batch,它的shape为[1200, 100],1200为batch_size,100为embedding_size

    (2)torch.unsqueeze(embedding_batch, dim=0)操作,使得embedding_batch的shape为[1, 1200, 100],分别对应sequence_length, batch_size, embedding_size,也对应transformer的输入的shape需求

    (3)positionalEncoding中,最后输出为:x + self.pe[:x.size(0), :] + embedding

    其中,x的shape为[12, 1200, 1],因为历史销量为12周,embedding_size为1,即只有销量值。

    self.pe[:x.size(0), :]的shape为[12, 1, 100],因为对应一个batch中的各个值,所以pe是通用的,因此通过torch的广播机制,两者相加后的shape变为[12, 1200, 100]。

    embedding的shape为[1, 1200, 100],因为sku的embedding是固有属性,不随时间(seq_length)发生变化,所以和上一步值通过广播机制后相加,最终维度还是[12, 1200, 100]。

    4、工具类文件的代码实现(transformer_utils.py):

    1. import logging
    2. import os
    3. import torch
    4. import json
    5. import numpy as np
    6. from tqdm import tqdm
    7. import matplotlib.pyplot as plt
    8. logger = logging.getLogger('Transformer.Utils')
    9. class Params:
    10. '''
    11. class that loads hyperparameters from a json file
    12. Example:
    13. params = Params(json_path)
    14. print(params.learning_rate)
    15. '''
    16. def __init__(self, json_path):
    17. with open(json_path) as f:
    18. params = json.load(f)
    19. self.__dict__.update(params)
    20. def set_logger(log_path):
    21. '''Set the logger to log info in terminal and file `log_path`.
    22. In general, it is useful to have a logger so that every output to the terminal is saved
    23. in a permanent file. Here we save it to `model_dir/train.log`.
    24. Example:
    25. logging.info('Starting training...')
    26. Args:
    27. log_path: (string) where to log
    28. '''
    29. _logger = logging.getLogger('Transformer')
    30. _logger.setLevel(logging.INFO)
    31. fmt = logging.Formatter('[%(asctime)s] %(name)s: %(message)s', '%H:%M:%S')
    32. class TqdmHandler(logging.StreamHandler):
    33. def __init__(self, formatter):
    34. logging.StreamHandler.__init__(self)
    35. self.setFormatter(formatter)
    36. def emit(self, record):
    37. msg = self.format(record)
    38. tqdm.write(msg)
    39. file_handler = logging.FileHandler(log_path)
    40. file_handler.setFormatter(fmt)
    41. _logger.addHandler(file_handler)
    42. _logger.addHandler(TqdmHandler(fmt))
    43. def save_checkpoint(state, is_best, epoch, save_checkpoint=False, ins_name=-1):
    44. '''Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves
    45. checkpoint + 'best.pth.tar'
    46. Args:
    47. state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict
    48. is_best: (bool) True if it is the best model seen till now
    49. checkpoint: (string) folder where parameters are to be saved
    50. ins_name: (int) instance index
    51. '''
    52. if save_checkpoint:
    53. if ins_name == -1:
    54. filepath = os.path.join('transformer-training-checkpoint', f'epoch_{epoch}.pth.tar')
    55. else:
    56. filepath = os.path.join('transformer-training-checkpoint', f'epoch_{epoch}_ins_{ins_name}.pth.tar')
    57. if not os.path.exists('transformer-training-checkpoint'):
    58. logger.info(f'Checkpoint Directory does not exist! Making directory transformer-training-checkpoint')
    59. os.mkdir('transformer-training-checkpoint')
    60. torch.save(state, filepath)
    61. logger.info(f'Checkpoint saved to {filepath}')
    62. if is_best:
    63. torch.save(state, os.path.join(os.getcwd(), 'base_model', 'best.pth.tar'))
    64. logger.info('Best checkpoint saved to best.pth.tar')
    65. def plot_all_epoch(train_loss_summary, test_loss_summary, num_samples, png_name):
    66. x = np.arange(start=1, stop=num_samples + 1)
    67. f = plt.figure()
    68. plt.plot(x, train_loss_summary[:num_samples], label='train_loss', linestyle='--')
    69. plt.plot(x, test_loss_summary[:num_samples], label='test_loss', linestyle='-')
    70. f.savefig(os.path.join('base_model', png_name))
    71. plt.close()

    5、配置文件(params.json)代码:

    1. {
    2. "train_batch_size": 1200,
    3. "test_batch_size":100,
    4. "lr": 0.005,
    5. "epochs": 1000,
    6. "feature_size": 100,
    7. "gamma": 0.95,
    8. "input_window": 12,
    9. "mode": "cuda"
    10. }

  • 相关阅读:
    大环境之下软件测试行业趋势能否上升?
    机器学习/人工智能的笔试面试题目——CNN相关问题总结
    对表单的操作说明》
    用于无功补偿的固定电容晶闸管控制反应器研究(Simulink)
    ssm+vue的培训机构运营管理系统(有报告)。Javaee项目,ssm vue前后端分离项目。
    是使用local_setup.bash 还是 setup.bash
    微信小程序生成海报工具Painter
    JavaScript数据类型有哪些?
    球谐函数在环境光照中的使用原理
    61张图,图解Spring事务,拆解底层源码
  • 原文地址:https://blog.csdn.net/benben044/article/details/125476272