• CIFAR-10 数据转为图片-python


    """
    CIFAR-10 是 32X32 的彩色图片,共有10个类别,每个类别6000张图片,50000张训练图片(均分为5个batch),10000张测试图片(每个类别选1000张)
    将 CIFAR-10 转为 png
    """
    
    import os
    import pickle
    
    import numpy as np
    from imageio import imwrite
    
    # 数据存放的根目录
    base_dir = r'H:\DataStore'
    # cifar-10 存放位置
    data_dir = os.path.join(base_dir, 'cifar-10-batches-py')
    # 训练图片目录
    train_dir = os.path.join(base_dir, 'cifar-10-train-png')
    # 测试图片目录
    test_dir = os.path.join(base_dir, 'cifar-10-test-png')
    
    # 这里不进行训练图片的生成
    Train = False
    Test = True
    
    
    # 反序列化
    def unpickle(file_path):
        with open(file_path, 'rb') as f:
            _obj = pickle.load(f, encoding='bytes')
        return _obj
    
    
    # 目录不存在时创建一个
    def create_dir(dir_path):
        if not os.path.isdir(dir_path):
            os.makedirs(dir_path)
    
    
    def get_label_names():
        _label_names_obj = unpickle(os.path.join(data_dir, 'batches.meta'))
        return _label_names_obj[b'label_names']
    
    
    def save_images(i, obj, class_num, label_names, dir_path):
        # 通道是 红、绿、蓝
        # 一定要使用 b'' 的方式,因为 obj 是 bytes 编码的
        img = np.reshape(obj[b'data'][i], (3, 32, 32))
        # 保存为图片使用 (height, width, channel) 格式
        img = img.transpose(1, 2, 0)
        # 获取当前图片的类别下标 0-9
        label_idx = obj[b'labels'][i]
        # 获取当前图片的名称
        _label_name: str = label_names[label_idx].decode()
        train_dir_label_name_path = os.path.join(dir_path, _label_name)
        create_dir(train_dir_label_name_path)
        # 图片对应的类别数量+1
        class_num[label_idx] += 1
        _image_name = str(class_num[label_idx]) + '.png'
        image_path = os.path.join(train_dir_label_name_path, _image_name)
        # 写入图片
        imwrite(image_path, img)
    
    
    if __name__ == '__main__':
        _label_names = get_label_names()
        if Train:
            # 累计每个类别的数量
            train_class_num = [0] * 10
            for i in range(1, 6):
                data_batch_path = os.path.join(data_dir, 'data_batch_' + str(i))
                # k: data、labels
                train_batch_obj = unpickle(data_batch_path)
                print("{} is loading...".format(data_batch_path))
                # 每个batch中有10000张图片
                for j in range(0, 10000):
                    save_images(j, train_batch_obj, train_class_num, _label_names, train_dir)
            print('train loaded')
        if Test:
            test_class_num = [0] * 10
            test_data_path = os.path.join(data_dir, 'test_batch')
            test_obj = unpickle(test_data_path)
            for i in range(10000):
                save_images(i, test_obj, test_class_num, _label_names, test_dir)
            print('test loaded')
    
    
  • 相关阅读:
    io流笔记
    CSS3 中 transition 和 animation 的属性分别有哪些
    Springboot集成redis--不同环境切换
    聊一聊DTM子事务屏障功能之SQL Server版
    Android异步和线程
    阿曼市场最全开发攻略,看这一篇就够了
    css表单单选框、复选框、上传文字和隐藏字段、下拉菜单、文本域、字段集
    软件测试外包干了4年,感觉废了...
    KubeVela可持续测试应用部署之Mock基础设施
    深度学习——(生成模型)DDPM
  • 原文地址:https://blog.csdn.net/cnkeysky/article/details/139391578