• MixtralForCausalLM DeepSpeed Inference节约HOST内存【最新的方案】


    本文演示了MixtralForCausalLM DeepSpeed Inference如果节约HOST内存
    方法:每个rank分别保存,并且使用accelerate的init_empty_weights
    增加的功能:

    • safetensors分块的存储与加载
    • 解决register_buffer persistent=False,参数初始化的问题

    一.效果

    运行方式HOST内存占用备注
    单卡推理13198 MB
    DS 4TP13246 MB/GPU
    DS 4TP 优化内存占用后369 MB/GPU直接加载到设备,更节约HOST内存

    二.特别说明

    • 1.MixtralRotaryEmbedding中self.register_buffer(“sin_cached”, emb.sin().to(dtype), persistent=False)
      因为persistent为False。所以不会保存到state_dict中,module.to_empty(device)也不会保留它的值
      只能在模型初始化之后保存出来,之后engine.moudle加载完权值之后再把这个buffer替换进去

    三.测试步骤

    1.创建Mixtral-8x7B配置文件(简化了)

    mkdir skip_init_demo
    cd skip_init_demo
    tee ./config.json <<-'EOF'
    {
      "architectures": [
        "MixtralForCausalLM"
      ],
      "attention_dropout": 0.0,
      "bos_token_id": 1,
      "eos_token_id": 2,
      "hidden_act": "silu",
      "hidden_size": 1024,
      "initializer_range": 0.02,
      "intermediate_size": 4096,
      "max_position_embeddings": 1024,
      "model_type": "mixtral",
      "num_attention_heads": 32,
      "num_experts_per_tok": 2,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "num_local_experts": 8,
      "output_router_logits": false,
      "rms_norm_eps": 1e-05,
      "rope_theta": 1000000.0,
      "router_aux_loss_coef": 0.02,
      "sliding_window": 128,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.36.0.dev0",
      "use_cache": true,
      "vocab_size": 32000
    }
    EOF
    

    2.生成随机模型,运行cpu float32推理,输出结果

    rm -rf Mixtral-8x7B
    tee gen_model.py <<-'EOF'
    import torch
    import os
    import time
    def main():
        torch.manual_seed(1)
        from transformers import MixtralForCausalLM, MixtralConfig
        config=MixtralConfig.from_pretrained("./config.json")
        model = MixtralForCausalLM(config).half()    
        model.eval()
        model.save_pretrained("./Mixtral-8x7B",safe_serialization=True)
        torch.manual_seed(2)
        input_tokens=torch.randint(0,32000,(1,128))
        model=model.float()
        output=model(input_tokens)
        output=output.logits.detach().reshape(-1).cpu().numpy()[:8]
        print(output)
    
    if __name__ == "__main__":
        main()
    EOF
    python gen_model.py
    du Mixtral-8x7B -lh
    

    输出

    6.3G    Mixtral-8x7B
    
    [-0.9623295  -0.36580455  0.767425    1.7021806  -0.17950581  0.36059803
     -0.49157432 -0.58618194]
    

    3.加载模型,cuda 单卡推理

    tee open_model.py <<-'EOF'
    import torch
    import os
    import psutil
    import time
    from transformers.modeling_utils import load_sharded_checkpoint,load_state_dict
    import json
    from safetensors import safe_open
    
    def get_mem_info():
        pid = os.getpid()
        current_process = psutil.Process(pid)
        memory_info = current_process.memory_info()
        print(f"RSS: {memory_info.rss / (1024 * 1024):.2f}MB VMS:{memory_info.vms / (1024 * 1024):.2f}MB")
    
    def main():
        from transformers import MixtralForCausalLM, MixtralConfig
        get_mem_info()
        config=MixtralConfig.from_pretrained("./config.json")
        model = MixtralForCausalLM(config).half()
        get_mem_info()
    
        with open("Mixtral-8x7B/model.safetensors.index.json", "r") as file:
            index_data = json.load(file)
    
        weight_files = index_data.get('weight_map', [])
        state_dict = {}
        for k,v in weight_files.items():
            weights_path = os.path.join("Mixtral-8x7B", v)
            with safe_open(weights_path, framework="pt") as f:
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)       
            
        model.load_state_dict(state_dict, strict=True)
        get_mem_info()
    
        model=model.to("cuda:0")
        torch.manual_seed(2)
        input_tokens=torch.randint(0,32000,(1,128)).to("cuda:0")
        output=model(input_tokens)
        output=output.logits.detach().reshape(-1).cpu().numpy()[:8]
        print(output)
        
    if __name__ == "__main__":
        main()
    EOF
    python open_model.py
    

    输出:

    RSS: 251.70MB VMS:3292.21MB
    RSS: 6697.91MB VMS:13695.17MB
    RSS: 13198.57MB VMS:26385.02MB
    
    [-0.9633789  -0.36450195  0.76708984  1.703125   -0.1772461   0.3581543
     -0.48901367 -0.5888672 ]
    
    

    4.DS 4 TP cuda 推理

    tee open_model.py <<-'EOF'
    import torch
    import os
    import psutil
    import time
    from transformers.modeling_utils import load_sharded_checkpoint,load_state_dict
    import deepspeed
    from deepspeed.accelerator import get_accelerator
    import json
    from safetensors import safe_open
    
    deepspeed.init_distributed(dist_backend='nccl')
    world_size = torch.distributed.get_world_size()
    local_rank=int(os.environ['LOCAL_RANK'])
    rank=torch.distributed.get_rank()
    
    def get_mem_info(prefix):
        pid = os.getpid()
        current_process = psutil.Process(pid)
        memory_info = current_process.memory_info()
        print(f"{prefix} RANK:{os.environ['LOCAL_RANK']} RSS: {memory_info.rss / (1024 * 1024):.2f}MB VMS:{memory_info.vms / (1024 * 1024):.2f}MB")
    
    def main():
        torch.set_num_threads(1)
        from transformers import MixtralForCausalLM, MixtralConfig
        get_mem_info("Init")
        config=MixtralConfig.from_pretrained("./config.json")
        model = MixtralForCausalLM(config).half()
        get_mem_info("ModelCreate")
        print("-----------------------")
    
        with open("Mixtral-8x7B/model.safetensors.index.json", "r") as file:
            index_data = json.load(file)
    
        weight_files = index_data.get('weight_map', [])
        state_dict = {}
        for k,v in weight_files.items():
            weights_path = os.path.join("Mixtral-8x7B", v)
            with safe_open(weights_path, framework="pt") as f:
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)
    
        model.load_state_dict(state_dict, strict=True)
        get_mem_info("LoadState")
        print("-----------------------")
        engine = deepspeed.init_inference(model,
                                          tensor_parallel={"tp_size": world_size},
                                          dtype=torch.float16,
                                          replace_with_kernel_inject=False)
        device=get_accelerator().current_device_name()
        print("device:",device)
        torch.manual_seed(2)
        input_tokens=torch.randint(0,32000,(1,128)).to(device)
        output=engine(input_tokens)
        output=output.logits.detach().reshape(-1).cpu().numpy()[:8]
        if rank==0:
            print(output)
        
    if __name__ == "__main__":
        main()
    EOF
    deepspeed --num_gpus=4 open_model.py
    

    输出:

    
    Init RANK:1 RSS: 270.02MB VMS:3414.44MB
    Init RANK:3 RSS: 270.43MB VMS:3414.45MB
    Init RANK:2 RSS: 270.22MB VMS:3414.45MB
    Init RANK:0 RSS: 270.38MB VMS:3486.45MB
    
    ModelCreate RANK:0 RSS: 6757.33MB VMS:9965.12MB
    ModelCreate RANK:3 RSS: 6727.30MB VMS:9862.06MB
    ModelCreate RANK:2 RSS: 6757.18MB VMS:9893.12MB
    ModelCreate RANK:1 RSS: 6756.99MB VMS:9893.12MB
    
    LoadState RANK:2 RSS: 13248.96MB VMS:22772.97MB
    LoadState RANK:0 RSS: 13245.91MB VMS:22616.97MB
    LoadState RANK:3 RSS: 13233.00MB VMS:22490.91MB
    LoadState RANK:1 RSS: 13246.22MB VMS:23240.97MB
    
    [-0.96240234 -0.36547852  0.7680664   1.703125   -0.17382812  0.359375
     -0.49169922 -0.5883789 ]
    
    

    5.分别保存DS 4TP每个rank上engine.module的权值

    tee open_model.py <<-'EOF'
    import torch
    import os
    import psutil
    import time
    from transformers.modeling_utils import load_sharded_checkpoint,load_state_dict
    import deepspeed
    from deepspeed.accelerator import get_accelerator
    import json
    from safetensors import safe_open
    from safetensors.torch import save_file, load_file
    
    deepspeed.init_distributed(dist_backend='nccl')
    world_size = torch.distributed.get_world_size()
    local_rank=int(os.environ['LOCAL_RANK'])
    rank=torch.distributed.get_rank()
    
    def get_mem_info(prefix):
        pid = os.getpid()
        current_process = psutil.Process(pid)
        memory_info = current_process.memory_info()
        print(f"{prefix} RANK:{os.environ['LOCAL_RANK']} RSS: {memory_info.rss / (1024 * 1024):.2f}MB VMS:{memory_info.vms / (1024 * 1024):.2f}MB")
    
    def save_state_dict(state_dict,save_dir):
        max_bytes_per_file = 1 * 1024 * 1024 * 1024  # 1GB
        # 计算每个 tensor 的大小并拆分 state_dict
        split_state_dicts = []
        current_state_dict = {}
        current_size = 0
        for param_name, param_tensor in state_dict.items():
            tensor_size = param_tensor.element_size() * param_tensor.nelement()
            # 如果当前 tensor 超过了文件大小,先保存已有 tensors
            if current_size + tensor_size > max_bytes_per_file:
                split_state_dicts.append(current_state_dict)
                current_state_dict = {}
                current_size = 0
            current_state_dict[param_name] = param_tensor
            current_size += tensor_size
    
        # 添加最后一个 state_dict
        if current_state_dict:
            split_state_dicts.append(current_state_dict)
    
        # 保存拆分后的 state_dicts 并生成索引文件
        os.makedirs(save_dir, exist_ok=True)
        index = {
            "metadata": {
                "total_parts": len(split_state_dicts)
            },
            "weight_map": []
        }
        for i, sd in enumerate(split_state_dicts):
            part_file = os.path.join(save_dir, f"model_part_{i}.safetensors")
            save_file(sd, part_file)
            index["weight_map"].append(f"model_part_{i}.safetensors")
    
        # 保存索引文件
        index_file = os.path.join(save_dir, "index.json")
        with open(index_file, 'w') as f:
            json.dump(index, f, indent=4)
    
    def main():
        from transformers import MixtralForCausalLM, MixtralConfig
        get_mem_info("Init")
        config=MixtralConfig.from_pretrained("./config.json")
        model = MixtralForCausalLM(config).half()
        get_mem_info("ModelCreate")
        print("-----------------------")
        with open("Mixtral-8x7B/model.safetensors.index.json", "r") as file:
            index_data = json.load(file)
    
        weight_files = index_data.get('weight_map', [])
        state_dict = {}
        for k,v in weight_files.items():
            weights_path = os.path.join("Mixtral-8x7B", v)
            with safe_open(weights_path, framework="pt") as f:
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)
                    
        model.load_state_dict(state_dict, strict=True)
        get_mem_info("LoadState")
        print("-----------------------")
        engine = deepspeed.init_inference(model,
                                          tensor_parallel={"tp_size": world_size},
                                          dtype=torch.float16,
                                          replace_with_kernel_inject=False)
        save_state_dict(engine.module.state_dict(), f"./Mixtral-8x7B-{local_rank}")
    if __name__ == "__main__":
        main()
    EOF
    deepspeed --num_gpus=4 open_model.py
    du Mixtral-8x7B-* -lh
    

    输出

    1.7G    Mixtral-8x7B-0
    1.7G    Mixtral-8x7B-1
    1.7G    Mixtral-8x7B-2
    1.7G    Mixtral-8x7B-3
    

    6.DS 4TP推理,init_empty_weights初始化模型,每个rank加载自己engine.module的权值

    tee open_model.py <<-'EOF'
    import torch
    import os
    import psutil
    import time
    from accelerate import init_empty_weights
    from transformers.modeling_utils import load_sharded_checkpoint,load_state_dict
    import deepspeed
    from deepspeed.accelerator import get_accelerator
    import json
    from safetensors import safe_open
    from safetensors.torch import save_file, load_file
    
    deepspeed.init_distributed(dist_backend='nccl')
    world_size = torch.distributed.get_world_size()
    local_rank=int(os.environ['LOCAL_RANK'])
    rank=torch.distributed.get_rank()
    
    def get_mem_info(prefix):
        pid = os.getpid()
        current_process = psutil.Process(pid)
        memory_info = current_process.memory_info()
        print(f"{prefix} RANK:{os.environ['LOCAL_RANK']} RSS: {memory_info.rss / (1024 * 1024):.2f}MB VMS:{memory_info.vms / (1024 * 1024):.2f}MB")
    
    def my_load_state_dict(model,save_dir):
        index_file = os.path.join(save_dir, "index.json")
        with open(index_file, "r") as file:
            index_data = json.load(file)
    
        weight_files = index_data.get('weight_map', [])
        state_dict = {}
        for v in weight_files:
            weights_path = os.path.join(save_dir, v)
            with safe_open(weights_path, framework="pt") as f:
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)
    
        model.load_state_dict(state_dict, strict=True)
    
    def main():
        from transformers import MixtralForCausalLM, MixtralConfig
        get_mem_info("Init")
        config=MixtralConfig.from_pretrained("./config.json")
        with init_empty_weights():
            model = MixtralForCausalLM(config).half()
        get_mem_info("ModelCreate")
        print("-----------------------")
        buffer_dict = {}
        for name, param in model.named_buffers():
            buffer_dict[name] = param
        
        engine = deepspeed.init_inference(model,
                                          tensor_parallel={"tp_size": world_size},
                                          dtype=torch.float16,
                                          replace_with_kernel_inject=False)
        my_load_state_dict(engine.module,f"./Mixtral-8x7B-{local_rank}")
    
        for name, param in engine.module.named_buffers():
            param.copy_(buffer_dict[name])
        
        get_mem_info("LoadState")
        device=get_accelerator().current_device_name()
        torch.manual_seed(2)
        input_tokens=torch.randint(0,32000,(1,128)).to(device)
        output=engine(input_tokens)
        output=output.logits.detach().reshape(-1).cpu().numpy()[:8]
        if rank==0:
            print(output)
    if __name__ == "__main__":
        main()
    EOF
    deepspeed --num_gpus=4 open_model.py
    

    输出

    
    Init RANK:1 RSS: 269.73MB VMS:3382.40MB
    Init RANK:2 RSS: 269.45MB VMS:3382.39MB
    Init RANK:3 RSS: 269.86MB VMS:3382.39MB
    Init RANK:0 RSS: 269.96MB VMS:3454.39MB
    
    ModelCreate RANK:1 RSS: 300.44MB VMS:17064.71MB
    ModelCreate RANK:0 RSS: 297.03MB VMS:17136.70MB
    ModelCreate RANK:2 RSS: 299.22MB VMS:17064.70MB
    ModelCreate RANK:3 RSS: 300.66MB VMS:17065.70MB
    
    LoadState RANK:0 RSS: 366.28MB VMS:20159.03MB
    LoadState RANK:3 RSS: 369.87MB VMS:20152.03MB
    LoadState RANK:2 RSS: 368.37MB VMS:20151.02MB
    LoadState RANK:1 RSS: 369.16MB VMS:20087.04MB
    
    [-0.96240234 -0.36547852  0.7680664   1.703125   -0.17382812  0.359375
     -0.49169922 -0.5883789 ]
    
    
  • 相关阅读:
    R语言R原生plot函数和lines函数的主要参数说明、解析(type、pch、cex、lty、lwd、col、xlab、ylab)
    15.镜像安全-镜像加密
    数组16—flat() :递归地将数组展平到指定的深度
    Redis分布式锁
    使用Java继承UDF类或GenericUDF类给Hive3.1.2编写UDF实现编码解码加密解密并运行在USDP大数据集群
    RK3568平台开发系列讲解(视频篇)摄像头采集视频的相关配置
    mysql之数据聚合
    LeetCode_位运算_困难_805.数组的均值分割
    【目标检测】Faster R-CNN算法实现
    2、ARM处理器概论
  • 原文地址:https://blog.csdn.net/m0_61864577/article/details/139744977