Pytorch——ImageFolder简单使用
一、ImageFolderImageFolder会将目录中的文件夹名自动转化成序列,每个文件夹下会存储相同的一个类别,文件夹名为类名;当DataLoader载入时,标签自动就是整数序列了;还是先查看一下用法:help(datasets.ImageFolder)class ImageFolder(DatasetFolder)函数构造:| ImageFolder(root, transform=Non
一、ImageFolder
ImageFolder会将目录中的文件夹名自动转化成序列,每个文件夹下会存储相同的一个类别,文件夹名为类名;当DataLoader载入时,标签自动就是整数序列了;
还是先查看一下用法:
help(datasets.ImageFolder)
class ImageFolder(DatasetFolder)
函数构造:
| ImageFolder(root, transform=None, target_transform=None, loader=<default_loader>, is_valid_file=None)
|
| A generic data loader where the images are arranged in this way: ::
| 数据集存放的形式:
| root/dog/xxx.png
| root/dog/xxy.png
| root/dog/xxz.png
|
| root/cat/123.png
| root/cat/nsdf3.png
| root/cat/asd932_.png
| -----------------------------------------------------------------------------------------------------------
|主要参数:
| root (string): Root directory path.
| transform (callable, optional): A function/transform that takes in an PIL image
| and returns a transformed version. E.g, ``transforms.RandomCrop``
| target_transform (callable, optional): A function/transform that takes in the
| target and transforms it.
| loader (callable, optional): A function to load an image given its path.
| is_valid_file (callable, optional): A function that takes path of an Image file
| and check if the file is a valid_file (used to check of corrupt files)
---------------------------------------------------------------------------------------------------------------
使用RMB_data数据集:一共两个类:1和100

from torchvision import transforms,utils
from torch.utils import data
from torchvision import datasets
import torch
import matplotlib.pyplot as plt
%matplotlib inline
my_trans = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
"""
ImageFolder会将目录中的文件夹名自动转化成序列,当DataLoader载入时,标签
自动就是整数序列了。
"""
train_data = datasets.ImageFolder('./data/RMB_data',transform=my_trans)
train_loader = data.DataLoader(train_data,batch_size=8,shuffle=True)
for i_batch,img in enumerate(train_loader):
if i_batch == 0:
# img代表一个列表:img[0]代表数据,img[1]代表标签
print(img[1])
fig = plt.figure()
# B x C x H x W 输入就是这种形式
grid = utils.make_grid(img[0],normalize=True)
# [C x H x W] 改变成 [H × W × C]
plt.imshow(grid.numpy().transpose((1,2,0)))
plt.show()
utils.save_image(grid,'test01.png')
break

再使用PIL加载保存的图像:
from PIL import Image
# 打开上一步保存的图像
Image.open('test01.png')

查看图像尺寸:
print(train_data[0][0].size())
print(train_data.classes) # 根据分的文件夹的名字来确定的类别
# print(train_data.imgs) # 返回从所有文件夹中得到的图片的路径以及其类别
print(train_data.class_to_idx)# 按顺序为这些类别定义索引为0,1...
更多推荐



所有评论(0)