• 使用perming加速训练可预测的模型


    监督学习模型的训练流程

    perming是一个主要在支持CUDA加速的Windows操作系统上架构的机器学习算法,基于感知机模型来解决分布在欧式空间中线性不可分数据集的解决方案,是基于PyTorch中预定义的可调用函数,设计的一个面向大规模结构化数据集的通用监督学习器,v1.4.2之后支持检测验证损失的变化间隔并提前停止训练。

    pip install perming --upgrade
    pip install perming>=1.4.2
    
    • 1
    • 2

    数据清洗后的特征输入

    在常见的自动化机器学习管线中,一组原始结构化数据集是经过一系列函数式的数据清洗操作后,得到了固定特征维度的特征数据集,但是该特征数据集没有专用的线性不可分检测方式以及相应的线性可分空间指定,所以需要用户指定潜在的线性可分空间的大小以及一些组合的学习参数。以下是以perming.Box为例展开机器学习训练的案例:

    import numpy
    import pandas
    df = pandas.read_csv('../data/bitcoin_heist_data.csv')
    df = df.to_numpy()
    labels = df[:,-1] # input
    features = df[:,1:-1].astype(numpy.float64) # input
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    此处下载数据集

    加载perming并配置超参数

    import perming # v1.6.0
    main = perming.Box(8, 29, (60,), batch_size=256, activation='relu', inplace_on=True, solver='sgd', learning_rate_init=0.01)
    main.print_config()
    
    • 1
    • 2
    • 3
    MLP(
      (mlp): Sequential(
        (Linear0): Linear(in_features=8, out_features=60, bias=True)
        (Activation0): ReLU(inplace=True)
        (Linear1): Linear(in_features=60, out_features=29, bias=True)
      )
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    Out[1]: OrderedDict([('torch -v', '1.7.1+cu101'),
                 ('criterion', CrossEntropyLoss()),
                 ('batch_size', 256),
                 ('solver',
                  SGD (
                  Parameter Group 0
                      dampening: 0
                      lr: 0.01
                      momentum: 0
                      nesterov: False
                      weight_decay: 0
                  )),
                 ('lr_scheduler', None),
                 ('device', device(type='cuda'))])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    参考这里查看每个模型的参数文档

    从numpy.ndarray多线程加载数据集

    main.data_loader(features, labels, random_seed=0)
    # 参考main.data_loader.__doc__获取更多默认参数的信息
    
    • 1
    • 2

    训练阶段和加速验证

    main.train_val(num_epochs=1, interval=100, early_stop=True)
    # 参考`main.train_val.__doc__`获取更多默认参数的信息,例如tolerance, patience
    
    • 1
    • 2
    Epoch [1/1], Step [100/3277], Training Loss: 2.5657, Validation Loss: 2.5551
    Epoch [1/1], Step [200/3277], Training Loss: 1.8318, Validation Loss: 1.8269
    Epoch [1/1], Step [300/3277], Training Loss: 1.2668, Validation Loss: 1.2844
    Epoch [1/1], Step [400/3277], Training Loss: 0.9546, Validation Loss: 0.9302
    Epoch [1/1], Step [500/3277], Training Loss: 0.7440, Validation Loss: 0.7169
    Epoch [1/1], Step [600/3277], Training Loss: 0.5863, Validation Loss: 0.5889
    Epoch [1/1], Step [700/3277], Training Loss: 0.5062, Validation Loss: 0.5086
    Epoch [1/1], Step [800/3277], Training Loss: 0.3308, Validation Loss: 0.4563
    Epoch [1/1], Step [900/3277], Training Loss: 0.3079, Validation Loss: 0.4204
    Epoch [1/1], Step [1000/3277], Training Loss: 0.4298, Validation Loss: 0.3946
    Epoch [1/1], Step [1100/3277], Training Loss: 0.3918, Validation Loss: 0.3758
    Epoch [1/1], Step [1200/3277], Training Loss: 0.4366, Validation Loss: 0.3618
    Process stop at epoch [1/1] with patience 10 within tolerance 0.001
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    使用内置的返回项来预测评估模型

    main.test()
    # main.test中的默认参数只在一维标签列的分类问题中作用
    # 因为损失衡量函数广泛且众多,以torcheval中的策略为主
    
    • 1
    • 2
    • 3
    loss of Box on the 104960 test dataset: 0.3505959212779999.
    
    • 1
    Out[2]: OrderedDict([('problem', 'classification'),
                 ('accuracy', '95.99942835365853%'),
                 ('num_classes', 29),
                 ('column', ('label name', ('true numbers', 'total numbers'))),
                 ('labels',
                  {'montrealAPT': [100761, 104857],
                   'montrealComradeCircle': [100761, 104857],
                   'montrealCryptConsole': [100761, 104857],
                   'montrealCryptXXX': [100761, 104857],
                   'montrealCryptoLocker': [100761, 104857],
                   'montrealCryptoTorLocker2015': [100761, 104857],
                   'montrealDMALocker': [100761, 104857],
                   'montrealDMALockerv3': [100761, 104857],
                   'montrealEDA2': [100761, 104857],
                   'montrealFlyper': [100761, 104857],
                   'montrealGlobe': [100761, 104857],
                   'montrealGlobeImposter': [100761, 104857],
                   'montrealGlobev3': [100761, 104857],
                   'montrealJigSaw': [100761, 104857],
                   'montrealNoobCrypt': [100761, 104857],
                   'montrealRazy': [100761, 104857],
                   'montrealSam': [100761, 104857],
                   'montrealSamSam': [100761, 104857],
                   'montrealVenusLocker': [100761, 104857],
                   'montrealWannaCry': [100761, 104857],
                   'montrealXLocker': [100761, 104857],
                   'montrealXLockerv5.0': [100761, 104857],
                   'montrealXTPLocker': [100761, 104857],
                   'paduaCryptoWall': [100761, 104857],
                   'paduaJigsaw': [100761, 104857],
                   'paduaKeRanger': [100761, 104857],
                   'princetonCerber': [100761, 104857],
                   'princetonLocky': [100761, 104857],
                   'white': [100761, 104857]}),
                 ('loss',
                  {'train': 0.330683171749115,
                   'val': 0.3547004163265228,
                    'test': 0.3505959212779999}),
                 ('sorted',
                  [('montrealAPT', [100761, 104857]),
                   ('montrealComradeCircle', [100761, 104857]),
                   ('montrealCryptConsole', [100761, 104857]),
                   ('montrealCryptXXX', [100761, 104857]),
                   ('montrealCryptoLocker', [100761, 104857]),
                   ('montrealCryptoTorLocker2015', [100761, 104857]),
                   ('montrealDMALocker', [100761, 104857]),
                   ('montrealDMALockerv3', [100761, 104857]),
                   ('montrealEDA2', [100761, 104857]),
                   ('montrealFlyper', [100761, 104857]),
                   ('montrealGlobe', [100761, 104857]),
                   ('montrealGlobeImposter', [100761, 104857]),
                   ('montrealGlobev3', [100761, 104857]),
                   ('montrealJigSaw', [100761, 104857]),
                   ('montrealNoobCrypt', [100761, 104857]),
                   ('montrealRazy', [100761, 104857]),
                   ('montrealSam', [100761, 104857]),
                   ('montrealSamSam', [100761, 104857]),
                   ('montrealVenusLocker', [100761, 104857]),
                   ('montrealWannaCry', [100761, 104857]),
                   ('montrealXLocker', [100761, 104857]),
                   ('montrealXLockerv5.0', [100761, 104857]),
                   ('montrealXTPLocker', [100761, 104857]),
                   ('paduaCryptoWall', [100761, 104857]),
                   ('paduaJigsaw', [100761, 104857]),
                   ('paduaKeRanger', [100761, 104857]),
                   ('princetonCerber', [100761, 104857]),
                   ('princetonLocky', [100761, 104857]),
                   ('white', [100761, 104857])])])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68

    保存模型参数到本地

    main.save(con=False, dir='../models/bitcoin.ckpt')
    # 使用main.unique和main.indices来建立标签的双向转换
    
    • 1
    • 2

    加载模型参数到预训练算法

    main.load(con=False, dir='../models/bitcoin.ckpt')
    
    • 1

    加载模型后可以通过更改组合训练参数,例如优化器等来微调模型的训练。模型训练文件见Multi-classification Task.ipynb

    其他常用的模型初始化设置

    main = perming.Box(10, 3, (30,), batch_size=8, activation='relu', inplace_on=True, solver='sgd', criterion="MultiLabelSoftMarginLoss", learning_rate_init=0.01)
    # 用于解决多标签排序问题,在用户定义标签的双向转换之后,data_loader能检测划分数据集并封装
    
    • 1
    • 2

    使用如下访问该软件的测试和算法:

    git clone https://github.com/linjing-lab/easy-pytorch.git
    cd easy-pytorch/released_box
    
    • 1
    • 2
  • 相关阅读:
    使用 Dumpling 备份 TiDB 集群数据到兼容 S3 的存储
    [动态规划简单题] LeetCode 53. 最大子数组和
    圆锥折射作为偏振计量工具的模拟
    RK3399驱动开发 | 03 - WK2124串口芯片驱动调试
    IT企业做ISO20000 服务管理体系的好处
    css3 都有哪些新属性
    海康威视综合安防平台视频摄像头接入Java
    2003-2022年高铁列车信息数据
    FPGA实现SDI硬件解码UDP网络传输,送工程源码和QT上位机显示程序
    B站视频弹幕不挡住人脸效果
  • 原文地址:https://blog.csdn.net/linjing_zyq/article/details/133418875