• pytest学习-pytorch单元测试


    希望测试pytorch各种算子、block、网络等在不同硬件平台,不同软件版本下的计算误差、耗时、内存占用等指标.

    本文基于torch.testing._internal

    一.公共模块[common.py]

    import torch
    from torch import nn
    import math
    import torch.nn.functional as F
    import time
    import os
    import socket
    import sys
    from datetime import datetime
    import numpy as np
    import collections
    import math
    import json
    import copy
    import traceback
    import subprocess
    import unittest
    import torch
    import inspect
    from torch.testing._internal.common_utils import TestCase, run_tests,parametrize,instantiate_parametrized_tests
    from torch.testing._internal.common_distributed import MultiProcessTestCase
    import torch.distributed as dist
    
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    os.environ["RANDOM_SEED"] = "0" 
    
    device="cpu"
    device_type="cpu"
    device_name="cpu"
    
    try:
        if torch.cuda.is_available():     
            device_name=torch.cuda.get_device_name().replace(" ","")
            device="cuda:0"
            device_type="cuda"
            ccl_backend='nccl'
    except:
        pass
    
    host_name=socket.gethostname()    
    sdk_version=os.getenv("SDK_VERSION","")   						 #从环境变量中获取sdk版本号
    metric_data_root=os.getenv("TORCH_UT_METRICS_DATA","./ut_data")  #日志存放的目录
    device_count=torch.cuda.device_count()
    
    if not os.path.exists(metric_data_root):
        os.makedirs(metric_data_root)
    
    def device_warmup(device):
        '''设备warmup,确保设备已经正常工作,排除设备初始化的耗时'''
        left = torch.rand([128,512], dtype = torch.float16).to(device)
        right = torch.rand([512,128], dtype = torch.float16).to(device)
        out=torch.matmul(left,right)
        torch.cuda.synchronize()
    
    torch.manual_seed(1) 
    np.random.seed(1)
    
    def loop_decorator(loops,rank=0):
        '''循环装饰器,用于统计函数的执行时间,内存占用等'''
        def decorator(func):
            def wrapper(*args,**kwargs):
                latency=[]
                memory_allocated_t0=torch.cuda.memory_allocated(rank)
                for _ in range(loops):
                    input_copy=[x.clone() for x in args]
                    beg= datetime.now().timestamp() * 1e6
                    pred= func(*input_copy)
                    gt=kwargs["golden"]
                    torch.cuda.synchronize()
                    end=datetime.now().timestamp() * 1e6
                    mse = torch.mean(torch.pow(pred.cpu().float()- gt.cpu().float(), 2)).item()
                    latency.append(end-beg)
                memory_allocated_t1=torch.cuda.memory_allocated(rank)
                avg_latency=np.mean(latency[len(latency)//2:]).round(3)
                first_latency=latency[0]
                return { "first_latency":first_latency,"avg_latency":avg_latency,
                          "memory_allocated":memory_allocated_t1-memory_allocated_t0,
                          "mse":mse}
            return wrapper
        return decorator
    
    class TorchUtMetrics:
        '''用于统计测试结果,比较之前的最小值'''
        def __init__(self,ut_name,thresold=0.2,rank=0):
            self.ut_name=f"{ut_name}_{rank}"
            self.thresold=thresold
            self.rank=rank
            self.data={"ut_name":self.ut_name,"metrics":[]}
            self.metrics_path=os.path.join(metric_data_root,f"{self.ut_name}_{self.rank}.jon")
            try:
                with open(self.metrics_path,"r") as f:
                    self.data=json.loads(f.read())
            except:
                pass
    
        def __enter__(self):
            self.beg= datetime.now().timestamp() * 1e6
            return self
    
        def __exit__(self, exc_type, exc_val, exc_tb):        
            self.report()
            self.save_data()
    
        def save_data(self):
            with open(self.metrics_path,"w") as f:
                f.write(json.dumps(self.data,indent=4))
    
        def set_metrics(self,metrics):
            self.end=datetime.now().timestamp() * 1e6
            item=collections.OrderedDict()
            item["time"]=datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
            item["sdk_version"]=sdk_version
            item["device_name"]=device_name
            item["host_name"]=host_name
            item["metrics"]=metrics
            item["metrics"]["e2e_time"]=self.end-self.beg
            self.cur_item=item
            self.data["metrics"].append(self.cur_item)
    
        def get_metric_names(self):
            return self.data["metrics"][0]["metrics"].keys()
    
        def get_min_metric(self,metric_name,devicename=None):
            min_value=0
            min_value_index=-1
            for idx,item in enumerate(self.data["metrics"]):
                if devicename and (devicename!=item['device_name']):                
                    continue            
                val=float(item["metrics"][metric_name])
                if min_value_index==-1 or val<min_value:
                    min_value=val
                    min_value_index=idx
            return min_value,min_value_index
    
        def get_metric_info(self,index):
            metrics=self.data["metrics"][index]
            return f'{metrics["device_name"]}@{metrics["sdk_version"]}'
    
        def report(self):
            assert len(self.data["metrics"])>0
            for metric_name in self.get_metric_names():
                min_value,min_value_index=self.get_min_metric(metric_name)
                min_value_same_dev,min_value_index_same_dev=self.get_min_metric(metric_name,device_name)
                cur_value=float(self.cur_item["metrics"][metric_name])
                print(f"-------------------------------{metric_name}-------------------------------")
                print(f"{cur_value}#{device_name}@{sdk_version}")
                if min_value_index_same_dev>=0:
                    print(f"{min_value_same_dev}#{self.get_metric_info(min_value_index_same_dev)}")
                if min_value_index>=0:
                    print(f"{min_value}#{self.get_metric_info(min_value_index)}")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151

    二.普通算子测试[test_clone.py]

    from common import *
    class TestCaseClone(TestCase):
        #如果不满足条件,则跳过这个测试
        @unittest.skipIf(device_count>1, "Not enough devices") 
        def test_todo(self):
            print(".TODO")
    
        #框架会自动遍历以下参数组合
        @parametrize("shape", [(10240,20480),(128,256)])
        @parametrize("dtype", [torch.float16,torch.float32])
        def test_clone(self,shape,dtype):
            
            #让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量
            @loop_decorator(loops=5)
            def run(input_dev):
                output=input_dev.clone()
                return output
            
            #记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)
            with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2) as m:
                input_host=torch.ones(shape,dtype=dtype)*np.random.rand()
                input_dev=input_host.to(device)
                metrics=run(input_dev,golden=input_host.cpu())
                m.set_metrics(metrics)
                assert(metrics["mse"]==0)
            
    instantiate_parametrized_tests(TestCaseClone)
    
    if __name__ == "__main__":
        run_tests()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    三.集合通信测试[test_ccl.py]

    from common import *
    class TestCCL(MultiProcessTestCase):
        '''CCL测试用例'''
        def _create_process_group_vccl(self, world_size, store):
            dist.init_process_group(
                ccl_backend, world_size=world_size, rank=self.rank, store=store
            )        
            pg = dist.distributed_c10d._get_default_group()
            return pg
    
        def setUp(self):
            super().setUp()
            self._spawn_processes()
    
        def tearDown(self):
            super().tearDown()
            try:
                os.remove(self.file_name)
            except OSError:
                pass
    
        @property
        def world_size(self):
            return 4
          
        #框架会自动遍历以下参数组合
        @unittest.skipIf(device_count<4, "Not enough devices") 
        @parametrize("op",[dist.ReduceOp.SUM])
        @parametrize("shape", [(1024,8192)])
        @parametrize("dtype", [torch.int64])
        def test_allreduce(self,op,shape,dtype):
            if self.rank >= self.world_size:
                return
            
            store = dist.FileStore(self.file_name, self.world_size)
            pg = self._create_process_group_vccl(self.world_size, store)
            if not torch.distributed.is_initialized():
                return
        
            torch.cuda.set_device(self.rank)
            device = torch.device(device_type,self.rank)
            device_warmup(device)
            #让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量
            @loop_decorator(loops=5,rank=self.rank)
            def run(input_dev):
                dist.all_reduce(input_dev, op=op)
                return input_dev
            
            #记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)
            with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2,rank=self.rank) as m:
                input_host=torch.ones(shape,dtype=dtype)*(100+self.rank)
                gt=[torch.ones(shape,dtype=dtype)*(100+i) for i in range(self.world_size)]
                gt_=gt[0]
                for i in range(1,self.world_size):
                    gt_=gt_+gt[i]
                input_dev=input_host.to(device)
                metrics=run(input_dev,golden=gt_)
                m.set_metrics(metrics)
                assert(metrics["mse"]==0)
            dist.destroy_process_group(pg)
        
    instantiate_parametrized_tests(TestCCL)
    
    if __name__ == "__main__":
        run_tests()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65

    四.测试命令

    # 运行所有的测试
    pytest -v -s -p no:warnings --html=torch_report.html --self-contained-html --capture=sys ./
    
    # 运行某一个测试
    python3 test_clone.py -k "test_clone_shape_(128, 256)_float32"
    
    • 1
    • 2
    • 3
    • 4
    • 5

    五.测试报告

    在这里插入图片描述

  • 相关阅读:
    隐私计算头条周刊(9.4-9.10)
    高能直播,大咖云集!邀你共启BizDevOps探索之路。
    [react] 什么是虚拟dom?虚拟dom比操作原生dom要快吗?虚拟dom是如何转变成真实dom并渲染到页面的?
    【Vue五分钟】 五分钟了解Webpack底层原理以及脚手架工具分析
    一文初探 Go reflect 反射包
    2022-09-17青少年软件编程(C语言)等级考试试卷(五级)解析
    cmmlu数据处理
    搭建Hadoop集群 并实现hdfs上的crud操作
    【算法】深度搜索(DFS) 和 广度搜索(BFS)
    给定两个字符串S和T,返回S子序列等于T的不同子序列个数有多少个? 如果得到子序列A删除的位置与得到子序列B删除的位置不同,那么认为A和B就是不同的。
  • 原文地址:https://blog.csdn.net/m0_61864577/article/details/137888050