• 【mindspore】【训练】训练过程内存占用大


    问题描述:

    我目前在做pytorch reconet模型在mindspore上复现的工作,现在遇到了显存溢出的问题,而且显存占用是torch中的三倍以上,pytorch只需要7.6G显存,而mindspore 24G都溢出了

    在pytorch中,在训练初始时加载一次vgg模型,在每个batch中vgg当做一个特征提取工具,也不需要参与模型梯度回传,训练步骤大体如下

    model = ReCoNet().cuda()
    vgg = Vgg16().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    for epoch in range(n_epochs):
        for sample in traindata:
            optimizer.zero_grad()
            # Compute ReCoNet features and output
            reconet_input = preprocess_for_reconet(sample["frame"])
            feature_maps = model.encoder(reconet_input)
            output_frame = model.decoder(feature_maps)

            previous_reconet_input = preprocess_for_reconet(sample["previous_frame"])
            previous_feature_maps = model.encoder(previous_reconet_input)
            previous_output_frame = model.decoder(previous_feature_maps)

            # Compute VGG features
            vgg_input_frame = preprocess_for_vgg(sample["frame"])
            vgg_output_frame = preprocess_for_vgg(postprocess_reconet(output_frame))
            input_vgg_features = vgg(vgg_input_frame)
            output_vgg_features = vgg(vgg_output_frame)

            vgg_previous_input_frame = preprocess_for_vgg(sample["previous_frame"])
            vgg_previous_output_frame = preprocess_for_vgg(postprocess_reconet(previous_output_frame))
            previous_input_vgg_features = vgg(vgg_previous_input_frame)
            previous_output_vgg_features = vgg(vgg_previous_output_frame)

            loss = loss_func(...)
            loss.backward()
            optimizer.step()

    而在mindspore中,由于模型的loss函数比较复杂,无法通过传入一个loss_fn的方式,因此参考了教程中的自定义loss的方式,定义了一个包含loss的模型reconet_with_loss,并在construct中返回loss,loss计算过程与上面的pytorch过程一致,另外为了能在模型中使用vgg模型我把vgg作为一个初始化参数送入模型中,通过TrainOneStepCell来完成训练,代码如下

    model = RecoNet_with_Loss(reconet, vgg)
    optim = nn.Adam(reconet.trainable_params(), learning_rate=0.1, weight_decay=0.0)
    train_net = nn.TrainOneStepCell(model, optim)

    通过parameters_dict查看发现train_net参数量很多,是pytorch的几倍,而且包含了vgg的权重以及还包含大量moment的权重,不清楚这些是否占用了过多内存
    pytorch:


    mindspore:

    解答:

    首先,需要确认是哪一部分的内存占用格外的高,一般网络图不怎么占内存,占内存的操作主要集中在数据处理。

    1. 不带网络,单跑下数据处理看内存使用情况

    1. for data in dataset:
    2.     print("="*20)
    3.     for item in data:
    4.         print(item.shape)

    2. 如果是数据处理的问题,建议减小并行度或者看是否有操作导致内存泄露

    3. 如果不是数据问题,打桩一些怀疑的模块,看内存占用是否变小。

  • 相关阅读:
    TeamViewer 可信设备的信任管理
    Linux学习-内存管理
    【排序算法】计数排序(C语言)
    【InternLM实战营---第六节课笔记】
    【入门Flink】- 02Flink经典案例-WordCount
    掌动智能浅析Web自动化测试的重要性
    python多线程技术(Threading)
    MMU如何通过虚拟地址找到物理地址-下
    ref属性
    一种对数据库友好的GUID的变种使用方法
  • 原文地址:https://blog.csdn.net/weixin_45666880/article/details/125545075