"""
CIFAR-10 是 32X32 的彩色图片,共有10个类别,每个类别6000张图片,50000张训练图片(均分为5个batch),10000张测试图片(每个类别选1000张)
将 CIFAR-10 转为 png
"""
import os
import pickle
import numpy as np
from imageio import imwrite
base_dir = r'H:\DataStore'
data_dir = os.path.join(base_dir, 'cifar-10-batches-py')
train_dir = os.path.join(base_dir, 'cifar-10-train-png')
test_dir = os.path.join(base_dir, 'cifar-10-test-png')
Train = False
Test = True
def unpickle(file_path):
with open(file_path, 'rb') as f:
_obj = pickle.load(f, encoding='bytes')
return _obj
def create_dir(dir_path):
if not os.path.isdir(dir_path):
os.makedirs(dir_path)
def get_label_names():
_label_names_obj = unpickle(os.path.join(data_dir, 'batches.meta'))
return _label_names_obj[b'label_names']
def save_images(i, obj, class_num, label_names, dir_path):
img = np.reshape(obj[b'data'][i], (3, 32, 32))
img = img.transpose(1, 2, 0)
label_idx = obj[b'labels'][i]
_label_name: str = label_names[label_idx].decode()
train_dir_label_name_path = os.path.join(dir_path, _label_name)
create_dir(train_dir_label_name_path)
class_num[label_idx] += 1
_image_name = str(class_num[label_idx]) + '.png'
image_path = os.path.join(train_dir_label_name_path, _image_name)
imwrite(image_path, img)
if __name__ == '__main__':
_label_names = get_label_names()
if Train:
train_class_num = [0] * 10
for i in range(1, 6):
data_batch_path = os.path.join(data_dir, 'data_batch_' + str(i))
train_batch_obj = unpickle(data_batch_path)
print("{} is loading...".format(data_batch_path))
for j in range(0, 10000):
save_images(j, train_batch_obj, train_class_num, _label_names, train_dir)
print('train loaded')
if Test:
test_class_num = [0] * 10
test_data_path = os.path.join(data_dir, 'test_batch')
test_obj = unpickle(test_data_path)
for i in range(10000):
save_images(i, test_obj, test_class_num, _label_names, test_dir)
print('test loaded')