数据并非总是以训练机器学习算法所需的最终处理形式呈现。我们使用转换对数据进行一些处理,使其适合用于训练。
所有 TorchVision 数据集都有两个参数——transform 用于修改特征,target_transform 用于修改标签——这两个参数接受包含转换逻辑的可调用对象。torchvision.transforms 模块提供了多种开箱即用的常用转换方法。
FashionMNIST 的特征为 PIL 图像格式,标签为整数。为了进行训练,我们需要将特征转换为归一化张量,将标签转换为独热编码张量。为了实现这些转换,我们使用 torchvision.transforms.v2 API 以及 torch.nn.functional.one_hot。
import torch import torch.nn.functional as F from torchvision import datasets from torchvision.transforms import v2 ds = datasets.FashionMNIST( root="data", train=True, download=True, transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]), target_transform=v2.Lambda( lambda y: F.one_hot(torch.tensor(y), num_classes=10).float() ), )
0%| | 0.00/26.4M [00:00<?, ?B/s] 0%| | 65.5k/26.4M [00:00<01:11, 366kB/s] 1%| | 229k/26.4M [00:00<00:38, 687kB/s] 3%|▎ | 885k/26.4M [00:00<00:12, 2.04MB/s] 14%|█▎ | 3.57M/26.4M [00:00<00:03, 7.13MB/s] 36%|███▌ | 9.47M/26.4M [00:00<00:01, 16.4MB/s] 59%|█████▊ | 15.5M/26.4M [00:01<00:00, 22.1MB/s] 80%|████████ | 21.2M/26.4M [00:01<00:00, 25.3MB/s] 100%|██████████| 26.4M/26.4M [00:01<00:00, 19.5MB/s] 0%| | 0.00/29.5k [00:00<?, ?B/s] 100%|██████████| 29.5k/29.5k [00:00<00:00, 336kB/s] 0%| | 0.00/4.42M [00:00<?, ?B/s] 1%|▏ | 65.5k/4.42M [00:00<00:11, 371kB/s] 5%|▌ | 229k/4.42M [00:00<00:05, 700kB/s] 21%|██ | 918k/4.42M [00:00<00:01, 2.16MB/s] 83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.47MB/s] 100%|██████████| 4.42M/4.42M [00:00<00:00, 6.26MB/s] 0%| | 0.00/5.15k [00:00<?, ?B/s] 100%|██████████| 5.15k/5.15k [00:00<00:00, 54.3MB/s]
ToImage() 和 ToDtype()
torchvision.transforms.v2 API 通过两步流程替代了传统的 ToTensor 转换。v2.ToImage 将 PIL 图像或 NumPy ndarray 转换为 torchvision.tv_tensors.Image 张量,而设置了 scale=True 的 v2.ToDtype 会将其转换为 float32 类型,并将像素强度值缩放到 [0., 1.] 范围内。
Lambda 转换
Lambda 转换可以应用任何用户自定义的 lambda 函数。在这里,我们使用 torch.nn.functional.one_hot 将整数标签转换为大小为 10(数据集中的标签数量)的独热编码张量,然后将其转换为 float 类型以匹配预期的数据类型。
target_transform = v2.Lambda( lambda y: F.one_hot(torch.tensor(y), num_classes=10).float() )