• 计算性能的提升之混合式编程(MXNet)


            平时的练习与调试基本都是使用命令式编程,但是计算性能比较差,而使用符号式编程可以提高性能,在 MXNet中的命令式编程和符号式编程的优缺点,这篇文章也有介绍这两者的一些基本区别和操作。

    一般来说我们是在开发的时候使用命令式编程,经过调试测试之后,我们转换成符号式编程,取两者的优点,这种编程的方式就叫做混合编程。 

    命令式编程:

    1. def add(a,b):
    2. return a+b
    3. def fancy_func(a,b,c,d):
    4. e=add(a,b)
    5. f=add(c,d)
    6. g=add(e,f)
    7. return g
    8. fancy_func(1,2,3,4)
    9. #10

    符号式编程

    1. def mul_str():
    2. return '''
    3. def mul(a,b):
    4. return a*b
    5. '''
    6. def fancy_func_str():
    7. return '''
    8. def fancy_func(a,b,c,d):
    9. e=mul(a,b)
    10. f=mul(c,d)
    11. g=mul(e,f)
    12. return g
    13. '''
    14. #这里可以看出,直接将函数使用注释符号给包裹起来,当作字符串返回
    15. def evoke_str():
    16. return mul_str()+fancy_func_str()+'''
    17. print(fancy_func(2,5,3,4))
    18. '''
    19. prog=evoke_str()
    20. print(prog)
    21. '''
    22. def mul(a,b):
    23. return a*b
    24. def fancy_func(a,b,c,d):
    25. e=mul(a,b)
    26. f=mul(c,d)
    27. g=mul(e,f)
    28. return g
    29. print(fancy_func(2,5,3,4))
    30. '''
    31. y=compile(prog,'','exec')#编译整个计算流程并运行
    32. exec(y)
    33. #120

    符号式编程三个步骤:

    1、定义计算的整个流程
    2、把流程编译成可执行的程序
    3、给定输入,调用编译好的程序执行

    从上面也可以看出,命令式编程的代码比较直观和易于调试,符号式编程除了在编译的时候可以做更多优化使得性能提升之外,还可以很方便的进行移植,也就是说可以将程序编译成与Python无关的格式,这样就不需要Python解释器了,提高性能。

    混合式编程的性能比较

    MXNet框架中,设计Gluon时将两种的优势结合在了一起,也就是说在开发和调试时使用命令式编程,当需要产品级别的计算性能和部署时,用户可以将大部分命令式编程转换成符号式的程序来执行。

    1. from mxnet import nd,sym
    2. from mxnet.gluon import nn
    3. import time
    4. def get_net():
    5. #HybridSequential是HybridBlock类的子类
    6. #跟Sequential一样,创建HybridSequential实例
    7. net=nn.HybridSequential()
    8. net.add(nn.Dense(256,activation='relu'),
    9. nn.Dense(128,activation='relu'),
    10. nn.Dense(2))
    11. net.initialize()
    12. return net
    13. x=nd.random.normal(shape=(1,512))
    14. net=get_net()
    15. net(x)
    16. #调用hybridize函数来编译和优化HybridSequential实例中串联的层
    17. #里面的层需要继承HybridBlock类才能被优化计算,比如Dense类属于HybridBlock的子类,如果是继承Block类将不被优化
    18. net.hybridize()
    19. net(x)
    20. '''
    21. [[ 0.09882921 -0.02765738]]
    22. '''

    比较两者性能:

    1. def benchmark(net,x):
    2. start=time.time()
    3. for i in range(5000):
    4. _=net(x)
    5. nd.waitall()#等待所有计算完成,方便计时
    6. return time.time()-start
    7. net=get_net()
    8. print('没有调用hybridize前需要的时间:%.4f 秒'%benchmark(net,x))
    9. net.hybridize()
    10. print('调用hybridize后需要的时间:%.4f 秒'%benchmark(net,x))
    11. '''
    12. 没有调用hybridize前需要的时间:1.1750 秒
    13. 调用hybridize后需要的时间:0.7565 秒
    14. '''

    可以看出调用了hybridize之后,性能提升比较明显。

    保存与调用模型参数

    我们可以使用export函数将符号式编程的架构与模型参数保存到硬盘:

    1. def benchmark(net,x):
    2. start=time.time()
    3. for i in range(10):
    4. _=net(x)
    5. net.export('tony_mlp')#或者放在循环外也可以
    6. nd.waitall()#等待所有计算完成,方便计时
    7. return time.time()-start
    8. net=get_net()
    9. net.hybridize()
    10. benchmark(net,x)

    将会生成两个类型的文件,如:tony_mlp-0000.params,tony_mlp-symbol.json,分别是模型的参数与符号式程序(网络的结构),这样的话很方便使用其他前端语言或在其他设备上部署训练好的模型,同时,由于部署时使用的是符号式程序,计算性能往往比命令式程序的性能更好。

    tony_mlp-symbol.json是字典类型:

    1. import json
    2. with open('tony_mlp-symbol.json','rb') as f:
    3. c=json.load(f)
    4. print(c)

    tony_mlp-0000.params:加载模型参数时出错:

    net.collect_params().load('tony_mlp-0000.params')
    AssertionError: Parameter 'dense48_weight' is missing in file 'tony_mlp-0000.params', which contains parameters: 'dense42_weight', 'dense42_bias', 'dense43_weight', 'dense43_bias', 'dense44_weight', 'dense44_bias'. Please make sure source and target networks have the same prefix.

    就是说需要确保原网络和目标网络的前缀要一样,但每次执行都会不一样,这个如何处理?

    net.load_params('tony_mlp-0000.params')
    AssertionError: restore_prefix is 'hybridsequential20_' but Parameters name 'dense60_weight' does not start with 'hybridsequential20_'

    这两种方式都是属于前缀不符合,后来一想这样单独加载是错误的,因为必须要和json文件搭配使用,首先定义程序的框架,然后往里面灌输已训练好的参数,这样就没有问题了。 

    1. import mxnet as mx
    2. from collections import namedtuple
    3. x=nd.random.normal(shape=(3,512))
    4. symnet=mx.symbol.load('tony_mlp-symbol.json')
    5. mod=mx.mod.Module(symbol=symnet,context=mx.cpu())
    6. mod.bind(data_shapes=[('data',(3,512))])
    7. mod.load_params('tony_mlp-0000.params')
    8. print(mod.data_names)
    9. print(mod.data_shapes)
    10. print(mod.output_names)
    11. print(mod.output_shapes)
    12. Batch=namedtuple('Batch',['data'])
    13. mod.forward(Batch([x]))
    14. out=mod.get_outputs()
    15. print('\n')
    16. print(out)
    17. '''
    18. ['data']
    19. [DataDesc[data,(3, 512),,NCHW]]
    20. ['dense44_fwd_output']
    21. [('dense44_fwd_output', (3, 2))]
    22. [
    23. [[0.03862947 0.07485762]
    24. [0.11289009 0.18212686]
    25. [0.10921334 0.17731467]]
    26. ]
    27. '''

    这就相当于加载json文件的整个网络架构,直接代替了前面代码定义的网络,再创建模型实例,对输入数据设定形状,然后加载训练好的网络参数,最后对输入数据进行计算即可。

    继承HybridBlock

    跟继承Block一样的继承HybridBlock类,区别就是forward修改为hybrid_forward函数

    1. class HybridNet(nn.HybridBlock):
    2. def __init__(self,**kwargs):
    3. super(HybridNet,self).__init__(**kwargs)
    4. self.hidden=nn.Dense(10)
    5. self.output=nn.Dense(2)
    6. def hybrid_forward(self,F,x):
    7. print('F:',F)
    8. print('x:',x)
    9. x=F.relu(self.hidden(x))
    10. print('hidden:',x)
    11. return self.output(x)
    12. net=HybridNet()
    13. net.initialize()
    14. x=nd.random.normal(shape=(1,4))
    15. net(x)
    16. '''
    17. F:
    18. x:
    19. [[ 1.7974477 0.19594945 -1.7376398 0.04734707]]
    20. hidden:
    21. [[0. 0.14281581 0.14206699 0.11347395 0. 0.
    22. 0. 0. 0.1772946 0. ]]
    23. '''

    可以看出F使用的是命令式编程ndarray类

    net.hybridize()
    net(x)

    调用hybridize()函数看下是什么情况:

    1. '''
    2. F:
    3. x:
    4. hidden:
    5. [[ 0.00760206 -0.01790646]]
    6. '''

    可以看到F使用的是符号式编程的symbol类,虽然输入数据是NDArray,但是在hybrid_forward函数里都变成了symbol类。也可以看到在hybrid_forward函数里定义的打印语句没有打印任何数据,这是因为在调用hybrid_forward函数后运行net(x)的时候,符号式编程已经得到,之后运行的net(x)的时候MXNet将不再访问Python代码,而是直接在C++后端执行符号式编程,这样也是性能提升的一个原因。当然这样也不便于调试了。

  • 相关阅读:
    Ampere ARM Server 内核版本更新
    React 19 的新增功能:Action Hooks
    基于粒子群算法(PSO)的路径规划问题研究 (Matlab代码实现)
    MySQL 事物四种隔离级别分析
    XUbuntu22.04之解决桌面突然放大,屏幕跟着鼠标移动问题(一百九十)
    分类之混淆矩阵(Confusion Matrix)
    贪心算法之——背包问题(nyoj106)
    01-01HTML
    asp.net+sqlserver+c#学生作品展示及评分系统
    leetcode每日一题寒假版-1805.字符串中不同整数的数目(easy)
  • 原文地址:https://blog.csdn.net/weixin_41896770/article/details/127400770