pytorch数据处理
数据读取
数据集类的组成
发现一些pytorch实现的神经网络中,自定义的数据集类里经常实现了__getitem__和__len__两个魔法函数,这些与pytorch的数据处理Dataloader密切相关
- __getitem__:当实例对象使用[]运算符取值时,会调用这个方法里内容,相当于返回一条数据/样本
- __len__:返回样本的数量
eg:
from torch.utils import data
import os
from PIL import Image
import numpy as np
class DogCat(data.Dataset):
def __init__(self, root):
imgs = os.listdir(root)
# 所有图片的绝对路径
# 这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
self.imgs = [os.path.join(root, img) for img in imgs]
def __getitem__(self, index):
img_path = self.imgs[index]
# dog->1, cat->0
label = 1 if 'dog' in img_path.split('/')[-1] else 0
pil_img = Image.open(img_path)
array = np.asarray(pil_img)
data = t.from_numpy(array)
return data, label
def __len__(self):
return len(self.imgs)
调用:
dataset = DogCat('./data/dogcat/')
img, label = dataset[0] # 相当于调用了dataset.__getitem__(0)
数据集类的实现
继承torch.utils.data.Dataset
:复写上面提到的两个魔法函数即可(必须)
class MyDataset(Dataset): #继承Dataset
def __init__(self, root_dir, transform=None): #__init__是初始化该类的一些基础参数
self.root_dir = root_dir #文件目录
self.transform = transform #变换
self.images = os.listdir(self.root_dir)#目录里的所有文件
def __len__(self):#返回整个数据集的大小
return len(self.images)
def __getitem__(self,index):#根据索引index返回dataset[index]
pass
Dataset
一次调用__getitem__只返回一个样本,但在训练神经网络时最好是对一个batch的数据进行操作,同时需要打乱和并行加速等,因此使用pytorch提供的Dataloader
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
- dataset:加载的数据集(Dataset对象)
- batch_size:一个batch的大小
- shuffle:是否将数据打乱
- sampler:样本抽样
- num_workers:使用多进程加载的进程数,0代表不使用多进程
- collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
- pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
- drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃
设置好后自定义的数据集就可以使用torch.utils.data.DataLoader
加载
dataset = MyDataset('./data')
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
从dataloader中迭代取出数据和标签
dataiter = iter(dataloader)
imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, weight
或者
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
for i_batch, batch_data in enumerate(dataloader):
print(i_batch)#打印batch编号
print(batch_data['image'].size())
print(batch_data['label'])
使用Dataset和DataLoader建议
- 高负载的操作放在
__getitem__
中,如加载图片等 - dataset中应尽量只包含只读对象,避免修改任何可变对象,利用多线程进行操作
图像转换
在读取数据集时为了方便后续操作,通常会直接在调用时对图像进行一系列操作
pytorch的torchvision提供了很多视觉图像处理的工具,对PIL Image对象和Tensor对象的操作被封装在了transforms模块下
参考:原理+代码实现_torchvision transforms
对PIL Image的主要操作有:
Scale/Resize
:调整图片尺寸,长宽比保持不变,输入size可以为h、w序列([224, 224])或int型(224,会按长宽比调整最短边到224,长宽比不变)CenterCrop、RandomCrop、RandomResizedCrop
:裁剪图片- CenterCrop:以图象中心点为参考,按照给定size(同Scale中的size)进行裁剪
- RandomCrop:随机选取中心点裁剪
- RandomResizedCrop:先随机裁剪,再放缩图像为指定size
Pad
:填充,输入padding为序列(2或4个数:2个时左右填充第一个数,上下填充第二数;4个时填充左、上、右、下)或整型(上下左右填充相同)ToTensor
:将PIL Image对象转成Tensor,会自动将[0, 255]归一化至[0, 1]
对Tensor的主要操作有:
Normalize
:标准化,即$\frac{X-\mu}{\sigma}$ToPILImage
:将Tensor转换成PIL Image对象
对图片操作可以使用transforms.Compose将多个操作拼接在一起(类似nn.Sequential)
transform = transforms.Compose([
Scale([224, 224]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])