上一篇文章中已经通过将dsets.py实现将数据集封装加载,之后就可以通过建立了模型并编写training脚本实现模型的训练了。这一篇文章主要是对《pytorch深度学习实战》第11章内容做的笔记。
1、建立简单的卷积神经网络
2、编写训练函数
3、编写训练日志(训练和验证过程的loss,accuracy等)数据结构
4、使用tensorboard可视化训练信息。
原书代码的【code/p2_run_everything.ipynb】的cell2中,定义了一个通用的系统进程方式的调用方法。通过这种方法可以调用所有脚本中的函数。但个人认为还是挺麻烦的,一点都不人性化。建议不要把精力花在这部分代码上,知道代码是在干嘛就行。
- def run(app, *argv):
- argv = list(argv)
- argv.insert(0, '--num-workers=4') # <1> 使用4个核
- log.info("Running: {}({!r}).main()".format(app, argv))
-
- app_cls = importstr(*app.rsplit('.', 1)) # <2> # 动态加载库
- app_cls(argv).main() # 调用app类的main函数
-
- log.info("Finished: {}.{!r}).main()".format(app, argv))
使用示例:从p2ch11文件夹的training.py文件中importLunaTrainingApp类并调用其main函数,函数的输入参数是epochs=1。
run('p2ch11.training.LunaTrainingApp', '--epochs=1')
其中:
函数是为了实现动态调用各个库和库函数。类似于from 【pkg_name】 import 【func_name】的作用。通过importstr可以实现动态加载函数,而不用调用前用import声明。
函数用法:list = str.rsplit(sep, maxsplit)。可参考下面的文章。简单而言就是对字符【str】按照【sep】分隔符进行拆分,从字符右侧开始拆分,一共拆分【maxsplit】次。返回的是拆分结果是一个list。
Python实用语法之rsplit_明 总 有的博客-CSDN博客_python rsplit
在原书代码的【prepcache.py】文件中,使用了argparse库。argparse库是用来解决使用命令行执行函数时,让命令行能够解析我们输入的参数名称和参数值的问题。定义了参数解释器后,我们在命令行执行函数时,就可以像使用conda命令一样,用类似【conda --user xxx】一样的方式来执行函数了。
argparse库的具体用法可以参考以下文章:
argparse.ArgumentParser()的用法_无尽的沉默的博客-CSDN博客_argparse.argumentparser
简单用法如下:
- import argparse
-
- parser = argparse.ArgumentParser() # 创建一个参数解释器
- parser.add_argument("--arg1", type=int, help="一个整数", default=1) # 通过 --argName方式声明参数,为int类型
- parser.add_argument("--arg2", type=int, help="一个整数", default=2) # 通过 --argName方式声明参数,为int类型
-
- args = parser.parse_args() # 解析参数
-
- print("arg1 = {0}".format(args.arg1))
- print("arg2 = {0}".format(args.arg2))
使用命令行运行结果如下:
- (pytorch) E:\CT\code>python test2.py --arg1 1 --arg2 2
- arg1 = 1
- arg2 = 2
在原书代码的【prepcache.py】文件中,使用了@classmethod修饰器,这样就可以不实例化对象直接调用类内的函数。
书中在11章用的是简单的卷积堆叠+线性层的神经网络结果,没任何特别之处。其中线性层由于只是简单2分类(结节是否为肿瘤),所以只用了一个线性层。卷积和池化用的是3维的卷积和池化。
多GPU训练可通过nn.DataParallel(model)或DistributedParallel函数实现,前者较为简单,一般用在单机多卡场景,后者配置较为复杂,一般用在多台计算机的多卡场景。
一般开始训练时可以先尝试使用带动量的SGD,lr=0.001,momentum=0.9,不行再换其他优化器,如Adam。
在上一篇文章中的ct类介绍中,width_irc参数定义了每个在irc坐标系的尺寸大小。也是数据集输入到模型的input_size。
使用torchinfo库或者torchsummary库的summary函数都可以打印模型的参数信息。具体方法如下:
- from p2ch11.model import LunaModel
- import torchinfo # 安装命令conda install torchinfo
-
- model = LunaModel()
- torchinfo.summary(model, (1, 32, 48, 48), batch_dim=0,
- col_names = ('input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'), verbose = 1)
运行结果,即模型信息如下:
- =====================================================================================================================================================================
- Layer (type:depth-idx) Input Shape Output Shape Param # Kernel Shape Mult-Adds
- =====================================================================================================================================================================
- LunaModel [1, 1, 32, 48, 48] [1, 2] -- -- --
- ├─BatchNorm3d: 1-1 [1, 1, 32, 48, 48] [1, 1, 32, 48, 48] 2 -- 2
- ├─LunaBlock: 1-2 [1, 1, 32, 48, 48] [1, 8, 16, 24, 24] -- -- --
- │ └─Conv3d: 2-1 [1, 1, 32, 48, 48] [1, 8, 32, 48, 48] 224 [3, 3, 3] 16,515,072
- │ └─ReLU: 2-2 [1, 8, 32, 48, 48] [1, 8, 32, 48, 48] -- -- --
- │ └─Conv3d: 2-3 [1, 8, 32, 48, 48] [1, 8, 32, 48, 48] 1,736 [3, 3, 3] 127,991,808
- │ └─ReLU: 2-4 [1, 8, 32, 48, 48] [1, 8, 32, 48, 48] -- -- --
- │ └─MaxPool3d: 2-5 [1, 8, 32, 48, 48] [1, 8, 16, 24, 24] -- 2 --
- ├─LunaBlock: 1-3 [1, 8, 16, 24, 24] [1, 16, 8, 12, 12] -- -- --
- │ └─Conv3d: 2-6 [1, 8, 16, 24, 24] [1, 16, 16, 24, 24] 3,472 [3, 3, 3] 31,997,952
- │ └─ReLU: 2-7 [1, 16, 16, 24, 24] [1, 16, 16, 24, 24] -- -- --
- │ └─Conv3d: 2-8 [1, 16, 16, 24, 24] [1, 16, 16, 24, 24] 6,928 [3, 3, 3] 63,848,448
- │ └─ReLU: 2-9 [1, 16, 16, 24, 24] [1, 16, 16, 24, 24] -- -- --
- │ └─MaxPool3d: 2-10 [1, 16, 16, 24, 24] [1, 16, 8, 12, 12] -- 2 --
- ├─LunaBlock: 1-4 [1, 16, 8, 12, 12] [1, 32, 4, 6, 6] -- -- --
- │ └─Conv3d: 2-11 [1, 16, 8, 12, 12] [1, 32, 8, 12, 12] 13,856 [3, 3, 3] 15,962,112
- │ └─ReLU: 2-12 [1, 32, 8, 12, 12] [1, 32, 8, 12, 12] -- -- --
- │ └─Conv3d: 2-13 [1, 32, 8, 12, 12] [1, 32, 8, 12, 12] 27,680 [3, 3, 3] 31,887,360
- │ └─ReLU: 2-14 [1, 32, 8, 12, 12] [1, 32, 8, 12, 12] -- -- --
- │ └─MaxPool3d: 2-15 [1, 32, 8, 12, 12] [1, 32, 4, 6, 6] -- 2 --
- ├─LunaBlock: 1-5 [1, 32, 4, 6, 6] [1, 64, 2, 3, 3] -- -- --
- │ └─Conv3d: 2-16 [1, 32, 4, 6, 6] [1, 64, 4, 6, 6] 55,360 [3, 3, 3] 7,971,840
- │ └─ReLU: 2-17 [1, 64, 4, 6, 6] [1, 64, 4, 6, 6] -- -- --
- │ └─Conv3d: 2-18 [1, 64, 4, 6, 6] [1, 64, 4, 6, 6] 110,656 [3, 3, 3] 15,934,464
- │ └─ReLU: 2-19 [1, 64, 4, 6, 6] [1, 64, 4, 6, 6] -- -- --
- │ └─MaxPool3d: 2-20 [1, 64, 4, 6, 6] [1, 64, 2, 3, 3] -- 2 --
- ├─Linear: 1-6 [1, 1152] [1, 2] 2,306 -- 2,306
- ├─Softmax: 1-7 [1, 2] [1, 2] -- -- --
- =====================================================================================================================================================================
- Total params: 222,220
- Trainable params: 222,220
- Non-trainable params: 0
- Total mult-adds (M): 312.11
- =====================================================================================================================================================================
- Input size (MB): 0.29
- Forward/backward pass size (MB): 13.12
- Params size (MB): 0.89
- Estimated Total Size (MB): 14.31
- =====================================================================================================================================================================
-
- Process finished with exit code 0
训练开始前,需要对权重进行初始化,初始化方法是通用的,具体参照书中代码【model.py】的_init_weights函数。
- def _init_weights(self):
- for m in self.modules():
- if type(m) in {
- nn.Linear,
- nn.Conv3d,
- nn.Conv2d,
- nn.ConvTranspose2d,
- nn.ConvTranspose3d,
- }:
- nn.init.kaiming_normal_(
- m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
- )
- if m.bias is not None:
- fan_in, fan_out = \
- nn.init._calculate_fan_in_and_fan_out(m.weight.data)
- bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(m.bias, -bound, bound)
原书代码中,定义了enumerateWithEstimate函数来预计运行完某段代码所需的运行时间。其中关键是利用了yield关键字,使enumerateWithEstimate一次次的迭代加载数据集。关于yield的用法可参考下面的文章。
python中yield的用法详解——最简单,最清晰的解释_冯爽朗的博客-CSDN博客_python yield
总的来说,声明为yield关键子的函数func,调用时类似断点执行:
1.首次执行时,代码执行到yield关键字右侧部分代码,并返回右侧部分代码的结果,类似return。yield之后的代码不在执行。
2. 用next函数再次调用函数func时,函数func继续从yield之后的代码开始执行,直到碰到下一个yield;如果函数后续没有别的yield关键字,则函数运行到末尾后返回函数开头重新运行,直至碰到yield。
3. 每次用next函数调用func时,不断重复第2点的执行方式。
原书中,作者通过diskacache库,将第一次加载的数据集缓存到磁盘中,下次训练或者验证再加载数据的时候,可直接在磁盘缓存中加载,可节省极大部分数据加载和预处理的时间。具体diskache库用法可参考下面的文章:
https://blog.csdn.net/wxyczhyza/article/details/127773721
pytorch1.2之后已集成tensorboard,直接在util库调用即可。
- from torch.utils.tensorboard import SummaryWriter # 调用tensorboard的SummaryWriter,用于记录训练性能
-
- writer = SummaryWriter(file_path) # 实例化时指明记录文件的路径
- writer.add_scalar(tag, y_value, x_value) # 添加标量
- # writer.add_histogram() # 添加直方图
- # writer.add_image() # 添加图像
- writer.close() # 关闭文件引用
原书代码可根据下面文章的代码链接下载,这里贴下我自己注释过的代码吧:
代码如下:
- import math
-
- from torch import nn as nn
-
- from util.logconf import logging
-
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- # log.setLevel(logging.INFO)
- log.setLevel(logging.DEBUG)
-
-
- class LunaModel(nn.Module):
- def __init__(self, in_channels=1, conv_channels=8):
- super().__init__()
-
- self.tail_batchnorm = nn.BatchNorm3d(1)
-
- self.block1 = LunaBlock(in_channels, conv_channels)
- self.block2 = LunaBlock(conv_channels, conv_channels * 2)
- self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
- self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
-
- self.head_linear = nn.Linear(1152, 2)
- self.head_softmax = nn.Softmax(dim=1)
-
- self._init_weights()
-
- # see also https://github.com/pytorch/pytorch/issues/18182
- def _init_weights(self):
- for m in self.modules():
- if type(m) in {
- nn.Linear,
- nn.Conv3d,
- nn.Conv2d,
- nn.ConvTranspose2d,
- nn.ConvTranspose3d,
- }:
- nn.init.kaiming_normal_(
- m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
- )
- if m.bias is not None:
- fan_in, fan_out = \
- nn.init._calculate_fan_in_and_fan_out(m.weight.data)
- bound = 1 / math.sqrt(fan_out)
- nn.init.normal_(m.bias, -bound, bound)
-
-
-
- def forward(self, input_batch):
- bn_output = self.tail_batchnorm(input_batch)
-
- block_out = self.block1(bn_output)
- block_out = self.block2(block_out)
- block_out = self.block3(block_out)
- block_out = self.block4(block_out)
-
- conv_flat = block_out.view(
- block_out.size(0),
- -1,
- )
- linear_output = self.head_linear(conv_flat)
-
- return linear_output, self.head_softmax(linear_output)
-
-
- class LunaBlock(nn.Module):
- def __init__(self, in_channels, conv_channels):
- super().__init__()
-
- self.conv1 = nn.Conv3d(
- in_channels, conv_channels, kernel_size=3, padding=1, bias=True,
- )
- self.relu1 = nn.ReLU(inplace=True)
- self.conv2 = nn.Conv3d(
- conv_channels, conv_channels, kernel_size=3, padding=1, bias=True,
- )
- self.relu2 = nn.ReLU(inplace=True)
-
- self.maxpool = nn.MaxPool3d(2, 2)
-
- def forward(self, input_batch):
- block_out = self.conv1(input_batch)
- block_out = self.relu1(block_out)
- block_out = self.conv2(block_out)
- block_out = self.relu2(block_out)
-
- return self.maxpool(block_out)
函数位置:util\util.py
函数主要用了yield关键字,使enumerateWithEstimate函数变为一个迭代器生成器,不断的迭代加载数据集,并根据每次迭代的时间来预估加载完整个数据集所需要的总时间。
- # 函数实现预估加载完整个迭代器所需要的时间。具体原理:
- # step1:使用yield关键字,每次加载一部分数据集,统计这部分数据集的平均单个数据集的使用时间delta_t = 花费的时间/该部分数据集样本数
- # step2:根据迭代器长度,预估加载整个数据集所花时间 t_dataset = delta_t * 数据集长度
- def enumerateWithEstimate(
- iter, # 数据集的一个迭代器。函数目的就是统计加载完整个数据集所需要的时间。
- desc_str, # 打印log的时候的说明文本。自己随便定义就行。
- start_ndx=0, # 开始统计前跳过的统计此时。比如start_ndx=3,则意思是第1,2次统计不打印,第三次开始打印。
- print_ndx=4, # 相邻两次打印日志的统计次数间隔print_ndx = print_ndx * backoff,缺省的初始值为4
- backoff=None, # 相邻两次打印日志的统计次数间隔的倍数。print_ndx = print_ndx * backoff
- iter_len=None, # 迭代器的长度,不指定时,iter_len = len(iter)
- ):
- """
- In terms of behavior, `enumerateWithEstimate` is almost identical
- to the standard `enumerate` (the differences are things like how
- our function returns a generator, while `enumerate` returns a
- specialized `
`). - However, the side effects (logging, specifically) are what make the
- function interesting.
- :param iter: `iter` is the iterable that will be passed into
- `enumerate`. Required.
- :param desc_str: This is a human-readable string that describes
- what the loop is doing. The value is arbitrary, but should be
- kept reasonably short. Things like `"epoch 4 training"` or
- `"deleting temp files"` or similar would all make sense.
- :param start_ndx: This parameter defines how many iterations of the
- loop should be skipped before timing actually starts. Skipping
- a few iterations can be useful if there are startup costs like
- caching that are only paid early on, resulting in a skewed
- average when those early iterations dominate the average time
- per iteration.
- NOTE: Using `start_ndx` to skip some iterations makes the time
- spent performing those iterations not be included in the
- displayed duration. Please account for this if you use the
- displayed duration for anything formal.
- This parameter defaults to `0`.
- :param print_ndx: determines which loop interation that the timing
- logging will start on. The intent is that we don't start
- logging until we've given the loop a few iterations to let the
- average time-per-iteration a chance to stablize a bit. We
- require that `print_ndx` not be less than `start_ndx` times
- `backoff`, since `start_ndx` greater than `0` implies that the
- early N iterations are unstable from a timing perspective.
- `print_ndx` defaults to `4`.
- :param backoff: This is used to how many iterations to skip before
- logging again. Frequent logging is less interesting later on,
- so by default we double the gap between logging messages each
- time after the first.
- `backoff` defaults to `2` unless iter_len is > 1000, in which
- case it defaults to `4`.
- :param iter_len: Since we need to know the number of items to
- estimate when the loop will finish, that can be provided by
- passing in a value for `iter_len`. If a value isn't provided,
- then it will be set by using the value of `len(iter)`.
- :return:
- """
- if iter_len is None:
- iter_len = len(iter)
-
- if backoff is None:
- backoff = 2
- while backoff ** 7 < iter_len:
- backoff *= 2
-
- assert backoff >= 2
- while print_ndx < start_ndx * backoff:
- print_ndx *= backoff
-
- log.warning("{} ----/{}, starting".format(
- desc_str,
- iter_len,
- ))
- start_ts = time.time()
- for (current_ndx, item) in enumerate(iter):
- yield (current_ndx, item)
- if current_ndx == print_ndx:
- # ... <1> step1:计算若干隔数据集加载时间;step2:平均得到每个数据集加载时间;step3:乘以数据集长度得到预计加载所有数据的时间
- duration_sec = ((time.time() - start_ts)
- / (current_ndx - start_ndx + 1)
- * (iter_len-start_ndx)
- )
-
- done_dt = datetime.datetime.fromtimestamp(start_ts + duration_sec)
- done_td = datetime.timedelta(seconds=duration_sec)
-
- log.info("{} {:-4}/{}, done at {}, {}".format(
- desc_str,
- current_ndx,
- iter_len,
- str(done_dt).rsplit('.', 1)[0], # 运行了current_ndx次后,预估的加载完整个数据集后的系统时间
- str(done_td).rsplit('.', 1)[0], # 运行了current_ndx次后,预估的加载完整个数据集所需要的秒数
- ))
-
- print_ndx *= backoff
-
- if current_ndx + 1 == start_ndx:
- start_ts = time.time()
-
- log.warning("{} ----/{}, done at {}".format(
- desc_str,
- iter_len,
- str(datetime.datetime.now()).rsplit('.', 1)[0],
- ))
这个脚本用来尝试加载整个数据集,测试加载数据集所需要的时间。核心时调用enumerateWithEstimate函数。
- import argparse # 参数解释器
- import sys
-
- import numpy as np
-
- import torch.nn as nn
- from torch.autograd import Variable
- from torch.optim import SGD
- from torch.utils.data import DataLoader
-
- from util.util import enumerateWithEstimate
- from .dsets import LunaDataset
- from util.logconf import logging
- from .model import LunaModel
-
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- log.setLevel(logging.INFO)
- # log.setLevel(logging.DEBUG)
-
-
- class LunaPrepCacheApp:
- @classmethod
- def __init__(self, sys_argv=None):
- if sys_argv is None:
- sys_argv = sys.argv[1:]
-
- parser = argparse.ArgumentParser() # 命令行参数修饰器
- parser.add_argument('--batch-size', # 添加参数
- help='Batch size to use for training',
- default=1024,
- type=int,
- )
- parser.add_argument('--num-workers',
- help='Number of worker processes for background data loading',
- default=8,
- type=int,
- )
-
- self.cli_args = parser.parse_args(sys_argv) # 解释参数
-
- def main(self):
- log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
-
- self.prep_dl = DataLoader(
- LunaDataset(
- sortby_str='series_uid',
- ),
- batch_size=self.cli_args.batch_size,
- num_workers=self.cli_args.num_workers,
- )
-
- batch_iter = enumerateWithEstimate( # 尝试加载数据集,预估加载整个数据集所需时间
- self.prep_dl,
- "Stuffing cache",
- start_ndx=self.prep_dl.num_workers,
- )
- for _ in batch_iter:
- pass
-
-
- if __name__ == '__main__':
- LunaPrepCacheApp().main() # 对类的__init__函数使用了@classmethod修饰器,所以可以不需要实例化,直接调用类内函数
在jupyter运行方法可参考原书代码的【p2_run_everything.ipynb】的【chapter11-cell2】。具体运行方法:
step1:加载相关库和函数
step2:使用命令行形式调用LunaPrepCacheApp函数。
运行结果:
从下图可见,数据集中一个551065个样本,每个batch有1024个样本,一共539个batch,加载16个batch后,推算出加载完所有batch的时间要6个小时05分。
注释了部分代码,其中部分关于tensorboard的代码注释放到第六篇文章的笔记。训练结果及代码如下:
- import argparse
- import datetime
- import os
- import sys
-
- import numpy as np
-
- from torch.utils.tensorboard import SummaryWriter
-
- import torch
- import torch.nn as nn
- from torch.optim import SGD, Adam
- from torch.utils.data import DataLoader
-
- from util.util import enumerateWithEstimate
- from .dsets import LunaDataset
- from util.logconf import logging
- from .model import LunaModel
-
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- log.setLevel(logging.INFO)
- log.setLevel(logging.DEBUG)
-
- # Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
- # 将每个样本在训练时候的label、预测值、loss存在了一个矩阵,用于打印结果和tensorboard上显示
- # 矩阵第一行为label,第二行为预测值,第三行为loss值,每一列为一个样本
- METRICS_LABEL_NDX=0 # label的行索引
- METRICS_PRED_NDX=1 # 预测值行索引
- METRICS_LOSS_NDX=2 # loss值行索引
- METRICS_SIZE = 3 # 矩阵行数
-
- class LunaTrainingApp:
- def __init__(self, sys_argv=None):
- if sys_argv is None:
- sys_argv = sys.argv[1:]
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--num-workers',
- help='Number of worker processes for background data loading',
- default=6, # 使用的CPU核心数,我用的i5-12490f为6核
- type=int,
- )
- parser.add_argument('--batch-size',
- help='Batch size to use for training',
- default=24, # 每个batch样本数
- type=int,
- )
- parser.add_argument('--epochs',
- help='Number of epochs to train for',
- default=1, # 训练的代数
- type=int,
- )
-
- parser.add_argument('--tb-prefix',
- default='p2ch11',
- help="Data prefix to use for Tensorboard run. Defaults to chapter.",
- )
-
- parser.add_argument('comment',
- help="Comment suffix for Tensorboard run.",
- nargs='?',
- default='dwlpt',
- )
- self.cli_args = parser.parse_args(sys_argv)
- self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
-
- self.trn_writer = None
- self.val_writer = None
- self.totalTrainingSamples_count = 0
-
- self.use_cuda = torch.cuda.is_available()
- self.device = torch.device("cuda" if self.use_cuda else "cpu")
-
- self.model = self.initModel() # 将模型搬到cuda
- self.optimizer = self.initOptimizer() # 定义优化器
-
- def initModel(self):
- model = LunaModel()
- if self.use_cuda:
- log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
- if torch.cuda.device_count() > 1:
- model = nn.DataParallel(model) # 如果有多个gpu,分配多给GPU训练
- model = model.to(self.device)
- return model
-
- def initOptimizer(self):
- # 一般第一次训练用SGD看看效果,再选择其他优化器。比较常用参数为lr=0.001,momentum=0.99
- return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
- # return Adam(self.model.parameters())
-
- def initTrainDl(self):
- # 由于LunaDataset的getCtRawCandidate被diskcache修饰,所以第一次加载数据集时,需要从文件读取数据,
- # 同时数据处理后会缓存到磁盘,速度较慢;第二次开始,会直接从缓存加载,速度会较快。
- train_ds = LunaDataset(
- val_stride=10,
- isValSet_bool=False,
- )
-
- batch_size = self.cli_args.batch_size
- if self.use_cuda:
- batch_size *= torch.cuda.device_count()
-
- train_dl = DataLoader(
- train_ds,
- batch_size=batch_size,
- num_workers=self.cli_args.num_workers,
- pin_memory=self.use_cuda,
- )
-
- return train_dl
-
- def initValDl(self):
- val_ds = LunaDataset(
- val_stride=10,
- isValSet_bool=True,
- )
-
- batch_size = self.cli_args.batch_size
- if self.use_cuda:
- batch_size *= torch.cuda.device_count()
-
- val_dl = DataLoader(
- val_ds,
- batch_size=batch_size,
- num_workers=self.cli_args.num_workers,
- pin_memory=self.use_cuda,
- )
-
- return val_dl
-
- def initTensorboardWriters(self):
- if self.trn_writer is None:
- log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
-
- self.trn_writer = SummaryWriter(
- log_dir=log_dir + '-trn_cls-' + self.cli_args.comment)
- self.val_writer = SummaryWriter(
- log_dir=log_dir + '-val_cls-' + self.cli_args.comment)
-
-
- def main(self):
- log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
-
- train_dl = self.initTrainDl()
- val_dl = self.initValDl()
-
- for epoch_ndx in range(1, self.cli_args.epochs + 1):
-
- log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
- epoch_ndx,
- self.cli_args.epochs,
- len(train_dl),
- len(val_dl),
- self.cli_args.batch_size,
- (torch.cuda.device_count() if self.use_cuda else 1),
- ))
-
- trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
- self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
-
- valMetrics_t = self.doValidation(epoch_ndx, val_dl)
- self.logMetrics(epoch_ndx, 'val', valMetrics_t)
-
- if hasattr(self, 'trn_writer'):
- self.trn_writer.close()
- self.val_writer.close()
-
-
- def doTraining(self, epoch_ndx, train_dl):
- self.model.train()
- trnMetrics_g = torch.zeros(
- METRICS_SIZE,
- len(train_dl.dataset),
- device=self.device,
- )
-
- # batch_iter = enumerateWithEstimate(
- # train_dl,
- # "E{} Training".format(epoch_ndx),
- # start_ndx=train_dl.num_workers,
- # )
- for batch_ndx, batch_tup in enumerate(train_dl):
- self.optimizer.zero_grad()
-
- loss_var = self.computeBatchLoss(
- batch_ndx,
- batch_tup,
- train_dl.batch_size,
- trnMetrics_g
- )
-
- loss_var.backward()
- self.optimizer.step()
-
- # # This is for adding the model graph to TensorBoard.
- # if epoch_ndx == 1 and batch_ndx == 0:
- # with torch.no_grad():
- # model = LunaModel()
- # self.trn_writer.add_graph(model, batch_tup[0], verbose=True)
- # self.trn_writer.close()
-
- self.totalTrainingSamples_count += len(train_dl.dataset)
-
- return trnMetrics_g.to('cpu')
-
-
- def doValidation(self, epoch_ndx, val_dl):
- with torch.no_grad():
- self.model.eval()
- valMetrics_g = torch.zeros(
- METRICS_SIZE,
- len(val_dl.dataset),
- device=self.device,
- )
-
- batch_iter = enumerateWithEstimate(
- val_dl,
- "E{} Validation ".format(epoch_ndx),
- start_ndx=val_dl.num_workers,
- )
- for batch_ndx, batch_tup in batch_iter:
- self.computeBatchLoss(
- batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
-
- return valMetrics_g.to('cpu')
-
-
-
- def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
- input_t, label_t, _series_list, _center_list = batch_tup
-
- input_g = input_t.to(self.device, non_blocking=True)
- label_g = label_t.to(self.device, non_blocking=True)
-
- logits_g, probability_g = self.model(input_g)
-
- loss_func = nn.CrossEntropyLoss(reduction='none') # reduction=none时,将每个样本的loss返回
- loss_g = loss_func(
- logits_g,
- label_g[:,1],
- )
- start_ndx = batch_ndx * batch_size
- end_ndx = start_ndx + label_t.size(0)
-
- # 将训练结果存到矩阵
- metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = \
- label_g[:,1].detach()
- metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = \
- probability_g[:,1].detach()
- metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = \
- loss_g.detach()
-
- return loss_g.mean()
-
-
- def logMetrics(
- self,
- epoch_ndx,
- mode_str,
- metrics_t,
- classificationThreshold=0.5,
- ):
- self.initTensorboardWriters()
- log.info("E{} {}".format(
- epoch_ndx,
- type(self).__name__,
- ))
-
- negLabel_mask = metrics_t[METRICS_LABEL_NDX] <= classificationThreshold
- negPred_mask = metrics_t[METRICS_PRED_NDX] <= classificationThreshold
-
- posLabel_mask = ~negLabel_mask
- posPred_mask = ~negPred_mask
-
- neg_count = int(negLabel_mask.sum())
- pos_count = int(posLabel_mask.sum())
-
- neg_correct = int((negLabel_mask & negPred_mask).sum())
- pos_correct = int((posLabel_mask & posPred_mask).sum())
-
- metrics_dict = {}
- metrics_dict['loss/all'] = \
- metrics_t[METRICS_LOSS_NDX].mean()
- metrics_dict['loss/neg'] = \
- metrics_t[METRICS_LOSS_NDX, negLabel_mask].mean()
- metrics_dict['loss/pos'] = \
- metrics_t[METRICS_LOSS_NDX, posLabel_mask].mean()
-
- metrics_dict['correct/all'] = (pos_correct + neg_correct) \
- / np.float32(metrics_t.shape[1]) * 100
- metrics_dict['correct/neg'] = neg_correct / np.float32(neg_count) * 100
- metrics_dict['correct/pos'] = pos_correct / np.float32(pos_count) * 100
-
- log.info(
- ("E{} {:8} {loss/all:.4f} loss, "
- + "{correct/all:-5.1f}% correct, "
- ).format(
- epoch_ndx,
- mode_str,
- **metrics_dict,
- )
- )
- log.info(
- ("E{} {:8} {loss/neg:.4f} loss, "
- + "{correct/neg:-5.1f}% correct ({neg_correct:} of {neg_count:})"
- ).format(
- epoch_ndx,
- mode_str + '_neg',
- neg_correct=neg_correct,
- neg_count=neg_count,
- **metrics_dict,
- )
- )
- log.info(
- ("E{} {:8} {loss/pos:.4f} loss, "
- + "{correct/pos:-5.1f}% correct ({pos_correct:} of {pos_count:})"
- ).format(
- epoch_ndx,
- mode_str + '_pos',
- pos_correct=pos_correct,
- pos_count=pos_count,
- **metrics_dict,
- )
- )
-
- writer = getattr(self, mode_str + '_writer')
-
- for key, value in metrics_dict.items():
- writer.add_scalar(key, value, self.totalTrainingSamples_count)
-
- writer.add_pr_curve(
- 'pr',
- metrics_t[METRICS_LABEL_NDX],
- metrics_t[METRICS_PRED_NDX],
- self.totalTrainingSamples_count,
- )
-
- bins = [x/50.0 for x in range(51)]
-
- negHist_mask = negLabel_mask & (metrics_t[METRICS_PRED_NDX] > 0.01)
- posHist_mask = posLabel_mask & (metrics_t[METRICS_PRED_NDX] < 0.99)
-
- if negHist_mask.any():
- writer.add_histogram(
- 'is_neg',
- metrics_t[METRICS_PRED_NDX, negHist_mask],
- self.totalTrainingSamples_count,
- bins=bins,
- )
- if posHist_mask.any():
- writer.add_histogram(
- 'is_pos',
- metrics_t[METRICS_PRED_NDX, posHist_mask],
- self.totalTrainingSamples_count,
- bins=bins,
- )
-
- # score = 1 \
- # + metrics_dict['pr/f1_score'] \
- # - metrics_dict['loss/mal'] * 0.01 \
- # - metrics_dict['loss/all'] * 0.0001
- #
- # return score
-
- # def logModelMetrics(self, model):
- # writer = getattr(self, 'trn_writer')
- #
- # model = getattr(model, 'module', model)
- #
- # for name, param in model.named_parameters():
- # if param.requires_grad:
- # min_data = float(param.data.min())
- # max_data = float(param.data.max())
- # max_extent = max(abs(min_data), abs(max_data))
- #
- # # bins = [x/50*max_extent for x in range(-50, 51)]
- #
- # try:
- # writer.add_histogram(
- # name.rsplit('.', 1)[-1] + '/' + name,
- # param.data.cpu().numpy(),
- # # metrics_a[METRICS_PRED_NDX, negHist_mask],
- # self.totalTrainingSamples_count,
- # # bins=bins,
- # )
- # except Exception as e:
- # log.error([min_data, max_data])
- # raise
-
-
- if __name__ == '__main__':
- LunaTrainingApp().main()