• TensorRT 模型加密杂谈


    在大多数项目交付场景中,经常需要对部署模型进行加密。模型加密一方面可以防止泄密,一方面可以便于模型跟踪管理,防止混淆。

    由于博主使用的部署模型多为TensorRT格式,这里以TensorRT模型为例,讲解如何对模型进行加密、解密以及推理加密模型。

    代码仓库:https://github.com/laugh12321/TRTCrypto

    加密算法的选择和支持的库

    Crypto++ 是C/C++的加密算法库,基本上涵盖了市面上的各类加密解密算法,包括对称加密算法(AES等)和非对称加密算法(RSA等)。

    两种算法使用的场景不同,非对称加密算法一般应用于数字签名和密钥协商的场景下,而对称加密算法一般应用于纯数据加密场景,性能更优。在对模型的加密过程中使用对称加密算法。

    简易版本

    以AES-GCM加密模式为例,编写一个检测的加密、解密方法

    std::string Encrypt(const std::string &data, const CryptoPP::SecByteBlock &key, const CryptoPP::SecByteBlock &iv) {
        std::string cipher;
    
        try {
            CryptoPP::GCM::Encryption e;
            e.SetKeyWithIV(key, key.size(), iv, iv.size());
    
            CryptoPP::StringSource(data, true, 
                new CryptoPP::AuthenticatedEncryptionFilter(e,
                    new CryptoPP::StringSink(cipher)
                ) // StreamTransformationFilter
            ); // StringSource
        }
        catch(const CryptoPP::Exception& e) {
            std::cerr << e.what() << std::endl;
            exit(1);
        }
        return cipher;
    }
    
    std::string Decrypt(const std::string &cipher, const CryptoPP::SecByteBlock &key, const CryptoPP::SecByteBlock &iv) {
        std::string recovered;
    
        try {
            CryptoPP::GCM::Decryption d;
            d.SetKeyWithIV(key, key.size(), iv, iv.size());
    
            // The StreamTransformationFilter removes
            //  padding as required.
            CryptoPP::StringSource(cipher, true, 
                new CryptoPP::AuthenticatedDecryptionFilter(d,
                    new CryptoPP::StringSink(recovered)
                ) // StreamTransformationFilter
            ); // StringSource
    
        }
        catch(const CryptoPP::Exception& e) {
            std::cerr << e.what() << std::endl;
            exit(1);
        }
        return recovered;
    }
    

    上述代码,使用AES-GCM加密模式对数据进行加密、解密,其中key和iv为加密算法的参数,keySize为key的长度。

    加密流程:

    1. 初始化加密器,设置key和iv
    2. 读取文件内容并存储在字符串data中
    3. 使用加密器对data进行加密,加密后的内容存储在字符串cipher中

    解密流程:

    1. 初始化解密器,设置key和iv
    2. 读取加密后的文件内容并存储在字符串cipher中
    3. 使用解密器对cipher进行解密,解密后的内容存储在字符串recovered中

    转换为序列化格式

    推理加密模型的方法有两种,一种是将模型解密后保存为文件再进行推理,另一种是将模型解密后转换为序列化格式,再进行推理。
    很明显第一种方式比较鸡肋,因为每次推理都需要进行解密,而且解密后的模型文件也会暴露在外面,不安全。这里使用第二种方式,将模型解密后转换为序列化格式进行推理。这里给出一个简单的例子,将存储解密数据的字符串recovered进行序列化。

    std::vector<unsigned char> Convert2TRTengine(const std::string& data) {
        unsigned char* engine_data[1];
        engine_data[0] = new unsigned char[data.length() + 1];
        std::copy(data.begin(), data.end(), engine_data[0]);
        engine_data[0][data.length()] = '\0';
    
        // Convert char* array to vector
        std::vector<unsigned char> engineData(engine_data[0], engine_data[0] + data.length());
        // Clean up the memory
        delete* engine_data;
        return engineData;
    }
    

    上述代码的返回值可以直接作为TensorRT的推理引擎的输入。例如进行反序列化操作

     nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(engineData.data(), engineData.size());
    

    使用MAC地址作为密钥

    一般情况下,我们只想让客户在指定的机器上运行模型,这时候就需要使用机器的唯一标识作为密钥,这里使用MAC地址作为密钥。

    #ifdef _WIN32
    #include 
    #include 
    #pragma comment(lib, "IPHLPAPI.lib")
    #else
    #include 
    #include 
    #include 
    #endif
    
    
    std::string GetMACAddress() {
    #ifdef _WIN32
        IP_ADAPTER_INFO adapterInfo[16];
        DWORD bufferSize = sizeof(adapterInfo);
        DWORD result = GetAdaptersInfo(adapterInfo, &bufferSize);
    
        if (result == ERROR_BUFFER_OVERFLOW) {
            // Resize buffer and try again
            IP_ADAPTER_INFO *newBuffer = new IP_ADAPTER_INFO[bufferSize / sizeof(IP_ADAPTER_INFO)];
            result = GetAdaptersInfo(newBuffer, &bufferSize);
            if (result != ERROR_SUCCESS) {
                delete[] newBuffer;
                return "";
            }
            delete[] newBuffer;
        }
        else if (result != ERROR_SUCCESS) {
            return "";
        }
    
        for (PIP_ADAPTER_INFO adapter = adapterInfo; adapter; adapter = adapter->Next) {
            if (adapter->Type == MIB_IF_TYPE_ETHERNET) {
                char macAddress[18];
                snprintf(macAddress, sizeof(macAddress), "%.2X:%.2X:%.2X:%.2X:%.2X:%.2X",
                    adapter->Address[0], adapter->Address[1], adapter->Address[2],
                    adapter->Address[3], adapter->Address[4], adapter->Address[5]);
                return macAddress;
            }
        }
    #else
        struct ifaddrs *ifaddr, *ifa;
        if (getifaddrs(&ifaddr) == -1) {
            return "";
        }
    
        for (ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next) {
            if (ifa->ifa_addr == nullptr || ifa->ifa_addr->sa_family != AF_PACKET) {
                continue;
            }
            
            struct sockaddr_ll *s = reinterpret_cast<struct sockaddr_ll*>(ifa->ifa_addr);
            char macAddress[18];
            snprintf(macAddress, sizeof(macAddress), "%.2X:%.2X:%.2X:%.2X:%.2X:%.2X",
                s->sll_addr[0], s->sll_addr[1], s->sll_addr[2],
                s->sll_addr[3], s->sll_addr[4], s->sll_addr[5]);
            freeifaddrs(ifaddr);
            return macAddress;
        }
        freeifaddrs(ifaddr);
    #endif
    
        return "";
    }
    

    上述代码,使用了不同的API获取MAC地址,其中Windows使用GetAdaptersInfo函数,Linux使用getifaddrs函数。

    有了MAC地址,就可以将MAC地址转换为密钥,这里使用SHA256算法对MAC地址进行哈希,然后取前32个字节作为密钥。

    std::string GenerateAESKey(const std::string& macAddress) {
        CryptoPP::byte hash[CryptoPP::SHA256::DIGESTSIZE];
        CryptoPP::SHA256().CalculateDigest(hash, reinterpret_cast<const CryptoPP::byte*>(macAddress.c_str()), macAddress.length());
    
        CryptoPP::HexEncoder encoder;
        std::string encodedHash;
        encoder.Attach(new CryptoPP::StringSink(encodedHash));
        encoder.Put(hash, sizeof(hash));
        encoder.MessageEnd();
    
        return encodedHash.substr(0, 32); // AES-256 key length is 32 bytes
    }
    

    添加头部信息

    为了新的文件能够被区分和可迭代,除了加密后的数据外还添加了头部信息,比如为了判断该文件类型使用固定的魔数作为文件的开头;为了便于后面需求迭代写入版本号以示区别;为了能够在解密时判断是否采用了相同的密钥将加密时的密钥进行SHA256计算后存储;这三部分构成了目前加密后文件的头部信息。加密后的文件包含头部信息 + 密文信息。

    // Encrypt function with header information
    std::string EncryptWithHeader(const std::string &data, const CryptoPP::SecByteBlock &key, const CryptoPP::SecByteBlock &iv, const std::string &magicNumber, const std::string &version) {
        // Header format: [magic_number_len | magic_number | version_len | version | key_hash_length | key_hash | encrypted_data]
    
        std::string cipher;
    
        try {
            CryptoPP::GCM::Encryption e;
            e.SetKeyWithIV(key, key.size(), iv, iv.size());
    
            // Calculate SHA256 hash of the key
            std::string keyHash = CalculateSHA256(key);
    
            // Get the length
            const size_t magicNumberLength = magicNumber.length();
            const size_t versionLength = version.length();
            const size_t keyHashLength = keyHash.length();
    
            // Construct the header
            std::string header = Convert2String(magicNumberLength, MAGIC_NUMBER_LEN) + \
                magicNumber + Convert2String(versionLength, VERSION_NUMBER_LEN) + version + \
                Convert2String(keyHashLength, KEY_HASH_LENGTH_LEN) + keyHash;
    
            // Encrypt the data
            CryptoPP::StringSource s(data, true, 
                new CryptoPP::AuthenticatedEncryptionFilter(e,
                    new CryptoPP::StringSink(cipher)
                ) // StreamTransformationFilter
            ); // StringSource
    
            // Prepend the header to the cipher
            cipher = header + cipher;
        } catch(const CryptoPP::Exception& e) {
            std::cerr << e.what() << std::endl;
            exit(1);
        }
    
        return cipher;
    }
    
    // Decrypt function for header-aware encrypted data
    std::string DecryptWithHeader(const std::string &cipher, const CryptoPP::SecByteBlock &key, const CryptoPP::SecByteBlock &iv) {
        // Header format: [magic_number_len | magic_number | version_len | version | key_hash_length | key_hash | encrypted_data]
    
        std::string recovered;
    
        try {
            CryptoPP::GCM::Decryption d;
            d.SetKeyWithIV(key, key.size(), iv, iv.size());
    
            // Extract header information
            std::string magicNumberLengthStr = cipher.substr(0, MAGIC_NUMBER_LEN);
            size_t magicNumberLength = std::stoull(magicNumberLengthStr);
            std::string magicNumber = cipher.substr(MAGIC_NUMBER_LEN, magicNumberLength);
    
            std::string versionLengthStr = cipher.substr(MAGIC_NUMBER_LEN + magicNumberLength, VERSION_NUMBER_LEN);
            size_t versionLength = std::stoull(versionLengthStr);
            std::string version = cipher.substr(MAGIC_NUMBER_LEN + magicNumberLength + VERSION_NUMBER_LEN, versionLength);
    
            std::string keyHashLengthStr = cipher.substr(MAGIC_NUMBER_LEN + magicNumberLength + VERSION_NUMBER_LEN + versionLength, KEY_HASH_LENGTH_LEN);
            size_t keyHashLength = std::stoull(keyHashLengthStr);
            std::string keyHash = cipher.substr(MAGIC_NUMBER_LEN + magicNumberLength + VERSION_NUMBER_LEN + versionLength + KEY_HASH_LENGTH_LEN, keyHashLength);
    
            std::string encryptedData = cipher.substr(MAGIC_NUMBER_LEN + magicNumberLength + VERSION_NUMBER_LEN + versionLength + KEY_HASH_LENGTH_LEN + keyHashLength);
    
            // Verify the key using stored hash
            if (!VerifyKey(key, keyHash)) {
                std::cerr << "Key verification failed." << std::endl;
                exit(1);
            }
    
            // Decrypt the data
            // The StreamTransformationFilter removes
            //  padding as required.
            CryptoPP::StringSource(encryptedData, true, 
                new CryptoPP::AuthenticatedDecryptionFilter(d,
                    new CryptoPP::StringSink(recovered)
                ) // StreamTransformationFilter
            ); // StringSource
        } catch(const CryptoPP::Exception& e) {
            std::cerr << e.what() << std::endl;
            exit(1);
        }
    
        return recovered;
    }
    

    上述代码中,加密函数EncryptWithHeader中,首先计算密钥的SHA256哈希值,然后将魔数、版本号、密钥哈希值、密文依次拼接,然后使用AES-256-GCM算法加密,最后将头部信息和密文拼接返回。解密函数DecryptWithHeader中,首先从密文中提取出头部信息,然后使用密钥哈希值验证密钥是否正确,最后使用AES-256-GCM算法解密密文。

    参考资料

  • 相关阅读:
    HTTPS_SSL加密(HTTP终)
    Java进程假死排查记录
    What is UDS Service 0x10 - Diagnostic Session Control ?
    备战2023高企申报?先准备这些,通过率更高
    用友NC Cloud importhttpscer接口任意文件上传漏洞
    ThreadLocal源码解析以及常见面试题
    数字世界中的定位专家,手机号码归属地数据源下载!
    解决弹性布局父元素设置高自动换行,子元素均分高度问题(align-content: flex-start)
    【Windows】Windows Terminal教程
    JavaScript的ES6中的箭头函数的详细讲解
  • 原文地址:https://www.cnblogs.com/laugh12321/p/17617526.html