torchvision的datasets
- 使用torchvision提供的数据集API,比较方便,
- 如果在pycharm中下载很慢,可以URL链接到迅雷中进行下载(有些URL链接在源码里)
- 代码如下:
import torchvision
train_set = torchvision.datasets.CIFAR10("./Dataset", train = True, download = True)
test_set = torchvision.datasets.CIFAR10("./Dataset", train = False, download = True)
- CIFAR10数据集的每个样本会输出一个元组,第一个元素是PIL格式的图片,第二个元素是该样本的标签,即class,代码如下:
import torchvision
train_set = torchvision.datasets.CIFAR10("./Dataset", train = True, download = True)
test_set = torchvision.datasets.CIFAR10("./Dataset", train = False, download = True)print(train_set[0])
print(train_set.classes)
img, target = train_set[0]
print(img)
print(target)
print(train_set.classes[target])
- 对数据集进行transforms变换
- 注意,只需要在调用数据集API时,填入变换对象即可,由于dataset_transforms是Compose类实例化后的对象,所以直接传入即可,代码如下:
import torchvision
from torch.utils.tensorboard import SummaryWriterdataset_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
])
train_set = torchvision.datasets.CIFAR10("./Dataset", train = True, transform=dataset_transforms, download = True)
test_set = torchvision.datasets.CIFAR10("./Dataset", train = False, transform=dataset_transforms, download = True)writer = SummaryWriter("logs")
for i in range(10):img, target = train_set[i] writer.add_image("train_set_img", img, i) writer.close()
- tensorboard的展示结果如下:

torchvision中的dataloader