收藏本站,收获最前沿的人工智能与编程资讯!!

PyTorch 2.12 ​转换 transform

技术文档 13℃ 0

数据并非总是以训练机器学习算法所需的最终处理形式呈现。我们使用转换对数据进行一些处理,使其适合用于训练。

所有 TorchVision 数据集都有两个参数——transform 用于修改特征,target_transform 用于修改标签——这两个参数接受包含转换逻辑的可调用对象。torchvision.transforms 模块提供了多种开箱即用的常用转换方法。

FashionMNIST 的特征为 PIL 图像格式,标签为整数。为了进行训练,我们需要将特征转换为归一化张量,将标签转换为独热编码张量。为了实现这些转换,我们使用 torchvision.transforms.v2 API 以及 torch.nn.functional.one_hot

import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision.transforms import v2

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
    target_transform=v2.Lambda(
        lambda y: F.one_hot(torch.tensor(y), num_classes=10).float()
    ),
)
  0%|          | 0.00/26.4M [00:00<?, ?B/s]
  0%|          | 65.5k/26.4M [00:00<01:11, 366kB/s]
  1%|          | 229k/26.4M [00:00<00:38, 687kB/s]
  3%|▎         | 885k/26.4M [00:00<00:12, 2.04MB/s]
 14%|█▎        | 3.57M/26.4M [00:00<00:03, 7.13MB/s]
 36%|███▌      | 9.47M/26.4M [00:00<00:01, 16.4MB/s]
 59%|█████▊    | 15.5M/26.4M [00:01<00:00, 22.1MB/s]
 80%|████████  | 21.2M/26.4M [00:01<00:00, 25.3MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 19.5MB/s]

  0%|          | 0.00/29.5k [00:00<?, ?B/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 336kB/s]

  0%|          | 0.00/4.42M [00:00<?, ?B/s]
  1%|▏         | 65.5k/4.42M [00:00<00:11, 371kB/s]
  5%|▌         | 229k/4.42M [00:00<00:05, 700kB/s]
 21%|██        | 918k/4.42M [00:00<00:01, 2.16MB/s]
 83%|████████▎ | 3.67M/4.42M [00:00<00:00, 7.47MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 6.26MB/s]

  0%|          | 0.00/5.15k [00:00<?, ?B/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 54.3MB/s]

ToImage() 和 ToDtype()

torchvision.transforms.v2 API 通过两步流程替代了传统的 ToTensor 转换。v2.ToImage 将 PIL 图像或 NumPy ndarray 转换为 torchvision.tv_tensors.Image 张量,而设置了 scale=True 的 v2.ToDtype 会将其转换为 float32 类型,并将像素强度值缩放到 [0., 1.] 范围内。

Lambda 转换

Lambda 转换可以应用任何用户自定义的 lambda 函数。在这里,我们使用 torch.nn.functional.one_hot 将整数标签转换为大小为 10(数据集中的标签数量)的独热编码张量,然后将其转换为 float 类型以匹配预期的数据类型。

target_transform = v2.Lambda(
    lambda y: F.one_hot(torch.tensor(y), num_classes=10).float()
)
标签: PyTorch 2.12

相关推荐