一、ImageFolder

ImageFolder会将目录中的文件夹名自动转化成序列,每个文件夹下会存储相同的一个类别,文件夹名为类名;当DataLoader载入时,标签自动就是整数序列了;

还是先查看一下用法:

help(datasets.ImageFolder)

class ImageFolder(DatasetFolder)
函数构造:
ImageFolder(root, transform=None, target_transform=None, loader=<default_loader>, is_valid_file=None)

|  A generic data loader where the images are arranged in this way: ::
数据集存放的形式:
|      root/dog/xxx.png
|      root/dog/xxy.png
|      root/dog/xxz.png

|      root/cat/123.png
|      root/cat/nsdf3.png
|      root/cat/asd932_.png
|  -----------------------------------------------------------------------------------------------------------
|主要参数:
|      root (string): Root directory path.
|     transform (callable, optional): A function/transform that  takes in an PIL image
|          and returns a transformed version. E.g, ``transforms.RandomCrop``
|      target_transform (callable, optional): A function/transform that takes in the
|          target and transforms it.
|      loader (callable, optional): A function to load an image given its path.
|     is_valid_file (callable, optional): A function that takes path of an Image file
|          and check if the file is a valid_file (used to check of corrupt files)
---------------------------------------------------------------------------------------------------------------

使用RMB_data数据集:一共两个类:1和100

from torchvision import transforms,utils
from torch.utils import data
from torchvision import datasets
import torch
import matplotlib.pyplot as plt
%matplotlib inline
my_trans = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

"""
    ImageFolder会将目录中的文件夹名自动转化成序列,当DataLoader载入时,标签
    自动就是整数序列了。
"""
train_data = datasets.ImageFolder('./data/RMB_data',transform=my_trans)
train_loader = data.DataLoader(train_data,batch_size=8,shuffle=True)
for i_batch,img in enumerate(train_loader):
    if i_batch == 0:
        # img代表一个列表:img[0]代表数据,img[1]代表标签
        print(img[1])
        fig = plt.figure()
        # B x C x H x W 输入就是这种形式
        grid = utils.make_grid(img[0],normalize=True)
        # [C x H x W] 改变成 [H × W × C]
        plt.imshow(grid.numpy().transpose((1,2,0)))
        plt.show()
        utils.save_image(grid,'test01.png')
    break

再使用PIL加载保存的图像:

from PIL import Image
# 打开上一步保存的图像
Image.open('test01.png')

 查看图像尺寸:

print(train_data[0][0].size())
print(train_data.classes) # 根据分的文件夹的名字来确定的类别
# print(train_data.imgs) # 返回从所有文件夹中得到的图片的路径以及其类别
print(train_data.class_to_idx)# 按顺序为这些类别定义索引为0,1...

 

Logo

讨论HarmonyOS开发技术,专注于API与组件、DevEco Studio、测试、元服务和应用上架分发等。

更多推荐