优化模型参数
现在我们已经拥有了模型和数据,接下来该通过在数据上优化模型参数来训练、验证和测试我们的模型了。训练模型是一个迭代过程;在每次迭代中,模型会对输出结果进行预测,计算预测误差(损失),计算误差相对于参数的导数(正如我们在上一节中看到的),并使用梯度下降法优化这些参数。有关该过程的更详细讲解,可以查看3Blue1Brown关于反向传播的视频。
前置代码
import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import v2 training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]) ) test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]) ) train_dataloader = DataLoader(training_data, batch_size=64) test_dataloader = DataLoader(test_data, batch_size=64) class NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10), ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits model = NeuralNetwork()
超参数
超参数是可调整的参数,用于控制模型优化过程。不同的超参数值会影响模型训练和收敛速度
我们定义以下训练超参数:
迭代次数 - 在数据集上迭代的总次数
批次大小 - 在更新参数前,通过网络传播的数据样本数量
学习率 - 在每个批次/迭代中更新模型参数的幅度。较小的值会导致学习速度缓慢,而较大的值可能会在训练过程中产生不可预测的行为。
learning_rate = 1e-3 batch_size = 64 epochs = 5
优化循环
设置好超参数后,我们就可以通过优化循环来训练和优化模型。优化循环的每一次迭代称为一个迭代周期。
每个迭代周期包含两个主要部分:
训练循环 - 遍历训练数据集并尝试收敛到最优参数。
验证/测试循环 - 遍历测试数据集以检查模型性能是否有所提升。
让我们简单了解一下训练循环中使用的一些概念。
损失函数
当输入一些训练数据时,我们未经训练的网络很可能无法给出正确答案。损失函数用于衡量计算结果与目标值的差异程度,我们希望在训练过程中最小化该值。为了计算损失,我们使用给定数据样本的输入进行预测,并将其与真实数据标签值进行比较。
常见的损失函数包括用于回归任务的nn.MSELoss(均方误差)和用于分类任务的nn.NLLLoss(负对数似然)。nn.CrossEntropyLoss结合了nn.LogSoftmax和nn.NLLLoss。
我们将模型的输出对数概率传递给nn.CrossEntropyLoss,它会对对数概率进行归一化并计算预测误差。
# Initialize the loss function loss_fn = nn.CrossEntropyLoss()
优化器
优化是在每个训练步骤中调整模型参数以减少模型误差的过程。优化算法定义了该过程的执行方式(在本例中我们使用随机梯度下降)。所有优化逻辑都封装在优化器对象中。在这里,我们使用SGD优化器;此外,PyTorch中还提供了许多不同的优化器,如ADAM和RMSProp,它们适用于不同类型的模型和数据。
我们通过注册需要训练的模型参数并传入学习率超参数来初始化优化器。
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
在训练循环内,优化过程分为三个步骤:
调用optimizer.zero_grad()重置模型参数的梯度。梯度默认会累加;为了避免重复计算,我们在每次迭代中明确将其归零。
调用loss.backward()反向传播预测损失。PyTorch会存储损失相对于每个参数的梯度。
获取梯度后,调用optimizer.step()根据反向传播中收集的梯度调整参数。
完整实现
我们定义train_loop循环执行优化代码,test_loop评估模型在测试数据上的性能。
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
# Set the model to training mode - important for batch normalization and dropout layers
# Unnecessary in this situation but added for best practices
model.train()
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss, current = loss.item(), batch * batch_size + len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
# Set the model to evaluation mode - important for batch normalization and dropout layers
# Unnecessary in this situation but added for best practices
model.eval()
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
# Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
# also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")我们初始化损失函数和优化器,并将其传递给train_loop和test_loop。你可以增加迭代次数来跟踪模型性能的提升。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")Epoch 1 ------------------------------- loss: 2.318118 [ 64/60000] loss: 2.302672 [ 6464/60000] loss: 2.284030 [12864/60000] loss: 2.272507 [19264/60000] loss: 2.261528 [25664/60000] loss: 2.232846 [32064/60000] loss: 2.243716 [38464/60000] loss: 2.216156 [44864/60000] loss: 2.211433 [51264/60000] loss: 2.175991 [57664/60000] Test Error: Accuracy: 40.8%, Avg loss: 2.179781 Epoch 2 ------------------------------- loss: 2.188632 [ 64/60000] loss: 2.178872 [ 6464/60000] loss: 2.129834 [12864/60000] loss: 2.141780 [19264/60000] loss: 2.096759 [25664/60000] loss: 2.040408 [32064/60000] loss: 2.073470 [38464/60000] loss: 2.006684 [44864/60000] loss: 2.008528 [51264/60000] loss: 1.934280 [57664/60000] Test Error: Accuracy: 55.2%, Avg loss: 1.942932 Epoch 3 ------------------------------- loss: 1.969567 [ 64/60000] loss: 1.942245 [ 6464/60000] loss: 1.840353 [12864/60000] loss: 1.869442 [19264/60000] loss: 1.765993 [25664/60000] loss: 1.714671 [32064/60000] loss: 1.739229 [38464/60000] loss: 1.650178 [44864/60000] loss: 1.665230 [51264/60000] loss: 1.549294 [57664/60000] Test Error: Accuracy: 60.7%, Avg loss: 1.578091 Epoch 4 ------------------------------- loss: 1.641180 [ 64/60000] loss: 1.600254 [ 6464/60000] loss: 1.461095 [12864/60000] loss: 1.513520 [19264/60000] loss: 1.399615 [25664/60000] loss: 1.393271 [32064/60000] loss: 1.405040 [38464/60000] loss: 1.338134 [44864/60000] loss: 1.362529 [51264/60000] loss: 1.250924 [57664/60000] Test Error: Accuracy: 62.9%, Avg loss: 1.288266 Epoch 5 ------------------------------- loss: 1.366134 [ 64/60000] loss: 1.340259 [ 6464/60000] loss: 1.182602 [12864/60000] loss: 1.268120 [19264/60000] loss: 1.152727 [25664/60000] loss: 1.179421 [32064/60000] loss: 1.195093 [38464/60000] loss: 1.142752 [44864/60000] loss: 1.171334 [51264/60000] loss: 1.077597 [57664/60000] Test Error: Accuracy: 64.5%, Avg loss: 1.109038 Epoch 6 ------------------------------- loss: 1.181630 [ 64/60000] loss: 1.176610 [ 6464/60000] loss: 1.001204 [12864/60000] loss: 1.118861 [19264/60000] loss: 1.002068 [25664/60000] loss: 1.038551 [32064/60000] loss: 1.068349 [38464/60000] loss: 1.020661 [44864/60000] loss: 1.049896 [51264/60000] loss: 0.971519 [57664/60000] Test Error: Accuracy: 65.5%, Avg loss: 0.995844 Epoch 7 ------------------------------- loss: 1.056387 [ 64/60000] loss: 1.072315 [ 6464/60000] loss: 0.879060 [12864/60000] loss: 1.022047 [19264/60000] loss: 0.909181 [25664/60000] loss: 0.940961 [32064/60000] loss: 0.987317 [38464/60000] loss: 0.941837 [44864/60000] loss: 0.966922 [51264/60000] loss: 0.901267 [57664/60000] Test Error: Accuracy: 66.9%, Avg loss: 0.919692 Epoch 8 ------------------------------- loss: 0.966067 [ 64/60000] loss: 1.000780 [ 6464/60000] loss: 0.792635 [12864/60000] loss: 0.954458 [19264/60000] loss: 0.848325 [25664/60000] loss: 0.869870 [32064/60000] loss: 0.931330 [38464/60000] loss: 0.889131 [44864/60000] loss: 0.907144 [51264/60000] loss: 0.851083 [57664/60000] Test Error: Accuracy: 68.0%, Avg loss: 0.865171 Epoch 9 ------------------------------- loss: 0.897707 [ 64/60000] loss: 0.947478 [ 6464/60000] loss: 0.728783 [12864/60000] loss: 0.904286 [19264/60000] loss: 0.805542 [25664/60000] loss: 0.816085 [32064/60000] loss: 0.889305 [38464/60000] loss: 0.852287 [44864/60000] loss: 0.862027 [51264/60000] loss: 0.812590 [57664/60000] Test Error: Accuracy: 69.3%, Avg loss: 0.823883 Epoch 10 ------------------------------- loss: 0.843526 [ 64/60000] loss: 0.905047 [ 6464/60000] loss: 0.679433 [12864/60000] loss: 0.865432 [19264/60000] loss: 0.773194 [25664/60000] loss: 0.774396 [32064/60000] loss: 0.855781 [38464/60000] loss: 0.825071 [44864/60000] loss: 0.826822 [51264/60000] loss: 0.781875 [57664/60000] Test Error: Accuracy: 70.5%, Avg loss: 0.791268 Done!