前言

想自己重写Dataset类,不通过torchvision.dataset.CIFAR10获取数据集。但是从官网下载的数据集是压缩包形式,直接解压无法得到图片和标签信息,因此参考博客将图片和标签读取出来。

下载数据集

首先可以通过pytorch下载CIFAR10数据集

#train.py

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np


#device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                         download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
                                           shuffle=False, num_workers=0)

# 10000张验证图片
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=4,
                                         shuffle=False, num_workers=0)
val_data_iter = iter(val_loader)
val_image, val_label = val_data_iter.next()
print(val_image.size())
print(train_set.class_to_idx)
classes = ('plane', 'car', 'bird', 'cat',
          'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


#显示图像,之前需把validate_loader中batch_size改为4
aaa = train_set.class_to_idx
cla_dict = dict((val, key) for key, val in aaa.items())
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    #plt.imshow(npimg)
    tt = np.transpose(npimg, (1, 2, 0))
    print(tt.shape)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

print(' '.join('%5s' % cla_dict[val_label[j].item()] for j in range(4)))
imshow(utils.make_grid(val_image))

数据集可视化

通过反序列化将数据读取出来

train

readDataTrain.py

import pickle
from imageio import imsave
import numpy as np


def load_file(filename):
    with open(filename, 'rb') as fo:
        data = pickle.load(fo, encoding='latin1')
    return data

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict



dic = load_file('data/cifar-10-batches-py/batches.meta')
labels_item = dic['label_names']

for k in range(1, 6):
    dic = unpickle("data/cifar-10-batches-py/data_batch_" + str(k))
    dict_image_data = dic[b'data']
    dict_image_labels = dic[b'labels']

    len = dict_image_data.shape[0]

    for i in range(len):
        id = len * (k - 1) + i + 1
        id = str(id).zfill(5)
        imgs = dict_image_data[i]
        labels = dict_image_labels[i]
        imgs_array = np.reshape(imgs, (3, 32, 32))
        imgs_array = imgs_array.transpose(1, 2, 0)
        imsave("data/cifar10/train/imges/" + id + '.jpg', imgs_array)
        with open("data/cifar10/train/labels/" + id + '.txt', 'w') as f:
            f.write(str(dict_image_labels[i]))


test

readDataTest.py

import pickle
from imageio import imsave
import numpy as np


def load_file(filename):
    with open(filename, 'rb') as fo:
        data = pickle.load(fo, encoding='latin1')
    return data

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict



dic = load_file('data/cifar-10-batches-py/batches.meta')
labels_item = dic['label_names']


dic = unpickle("data/cifar-10-batches-py/test_batch")
dict_image_data = dic[b'data']
dict_image_labels = dic[b'labels']

len = dict_image_data.shape[0]

for i in range(len):
    id = i + 1
    id = str(id).zfill(5)
    imgs = dict_image_data[i]
    labels = dict_image_labels[i]
    imgs_array = np.reshape(imgs, (3, 32, 32))
    imgs_array = imgs_array.transpose(1, 2, 0)
    imsave("data/cifar10/test/imges/" + id + '.jpg', imgs_array)
    with open("data/cifar10/test/labels/" + id + '.txt', 'w') as f:
        f.write(str(dict_image_labels[i]))


效果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
标签是0-9之间的数字
在这里插入图片描述
标签的对应关系

{'airplane': 0, 'automobile': 1, 'bird': 2, 
 'cat': 3, 'deer': 4, 'dog': 5, 
 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}

参考资料

手把手教你CIFAR数据集可视化

Logo

永洪科技,致力于打造全球领先的数据技术厂商,具备从数据应用方案咨询、BI、AIGC智能分析、数字孪生、数据资产、数据治理、数据实施的端到端大数据价值服务能力。

更多推荐