pytorch数据处理

数据读取

数据集类的组成

发现一些pytorch实现的神经网络中,自定义的数据集类里经常实现了__getitem__和__len__两个魔法函数,这些与pytorch的数据处理Dataloader密切相关

  • __getitem__:当实例对象使用[]运算符取值时,会调用这个方法里内容,相当于返回一条数据/样本
  • __len__:返回样本的数量

eg:

from torch.utils import data
import os
from PIL import  Image
import numpy as np

class DogCat(data.Dataset):
    def __init__(self, root):
        imgs = os.listdir(root)
        # 所有图片的绝对路径
        # 这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
        self.imgs = [os.path.join(root, img) for img in imgs]
    
    def __getitem__(self, index):
        img_path = self.imgs[index]
        # dog->1, cat->0
        label = 1 if 'dog' in img_path.split('/')[-1] else 0
        pil_img = Image.open(img_path)
        array = np.asarray(pil_img)
        data = t.from_numpy(array)
        return data, label
  
    def __len__(self):
        return len(self.imgs)

调用:

dataset = DogCat('./data/dogcat/')
img, label = dataset[0]    # 相当于调用了dataset.__getitem__(0)

数据集类的实现

继承torch.utils.data.Dataset:复写上面提到的两个魔法函数即可(必须)

class MyDataset(Dataset): #继承Dataset
    def __init__(self, root_dir, transform=None): #__init__是初始化该类的一些基础参数
        self.root_dir = root_dir   #文件目录
        self.transform = transform #变换
        self.images = os.listdir(self.root_dir)#目录里的所有文件
  
    def __len__(self):#返回整个数据集的大小
        return len(self.images)
  
    def __getitem__(self,index):#根据索引index返回dataset[index]
        pass

Dataset一次调用__getitem__只返回一个样本,但在训练神经网络时最好是对一个batch的数据进行操作,同时需要打乱和并行加速等,因此使用pytorch提供的Dataloader

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
  • dataset:加载的数据集(Dataset对象)
  • batch_size:一个batch的大小
  • shuffle:是否将数据打乱
  • sampler:样本抽样
  • num_workers:使用多进程加载的进程数,0代表不使用多进程
  • collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

设置好后自定义的数据集就可以使用torch.utils.data.DataLoader加载

dataset = MyDataset('./data')
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

从dataloader中迭代取出数据和标签

dataiter = iter(dataloader)
imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, weight

或者

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
for i_batch, batch_data in enumerate(dataloader):
    print(i_batch)#打印batch编号
    print(batch_data['image'].size())
    print(batch_data['label'])

使用Dataset和DataLoader建议

  • 高负载的操作放在__getitem__中,如加载图片等
  • dataset中应尽量只包含只读对象,避免修改任何可变对象,利用多线程进行操作

图像转换

在读取数据集时为了方便后续操作,通常会直接在调用时对图像进行一系列操作

pytorch的torchvision提供了很多视觉图像处理的工具,对PIL Image对象Tensor对象的操作被封装在了transforms模块

参考:原理+代码实现_torchvision transforms

对PIL Image的主要操作有:

  • Scale/Resize:调整图片尺寸,长宽比保持不变,输入size可以为h、w序列([224, 224])或int型(224,会按长宽比调整最短边到224,长宽比不变)
  • CenterCrop、RandomCrop、RandomResizedCrop:裁剪图片

    • CenterCrop:以图象中心点为参考,按照给定size(同Scale中的size)进行裁剪
    • RandomCrop:随机选取中心点裁剪
    • RandomResizedCrop:先随机裁剪,再放缩图像为指定size
  • Pad:填充,输入padding为序列(2或4个数:2个时左右填充第一个数,上下填充第二数;4个时填充左、上、右、下)或整型(上下左右填充相同)
  • ToTensor:将PIL Image对象转成Tensor,会自动将[0, 255]归一化至[0, 1]

对Tensor的主要操作有:

  • Normalize:标准化,即$\frac{X-\mu}{\sigma}$
  • ToPILImage:将Tensor转换成PIL Image对象

对图片操作可以使用transforms.Compose将多个操作拼接在一起(类似nn.Sequential)

transform = transforms.Compose([
    Scale([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])])
最后修改:2022 年 10 月 12 日
如果觉得我的文章对你有用,请随意赞赏