理解 PyTorch 中的 CrossEntropyLoss
¶
在机器学习中,特别是处理分类问题时,损失函数是衡量模型预测与实际标签差异的关键。在 PyTorch 中,CrossEntropyLoss
是一个常用的损失函数,用于分类问题。它首先通过 Softmax 函数计算对应类别的概率值,然后计算每个样本的负对数似然损失,最后对所有样本的损失值求平均。
本文将通过一个简单的例子来手动计算CrossEntropyLoss
,并展示如何使用 PyTorch 实现这一过程。
创建样本和标签¶
首先,我们需要创建三个简单的样本的 logits 值和对应的标签:
Python
import torch
# 创建 logits(模型输出的原始分数)
logits = torch.tensor([[0.3, 0.1], [-0.4, 0.7], [-0.1, 0.8]])
print("原始 logits:")
print(logits)
# 创建标签
labels = torch.tensor([0, 0, 1])
print("labels:")
print(labels)
Text Only
原始logits:
tensor([[ 0.3000, 0.1000],
[-0.4000, 0.7000],
[-0.1000, 0.8000]])
labels:
tensor([0, 0, 1])
在这个例子中,我们有三个样本,每个样本对应两个类别的得分。标签表示每个样本的真实类别。
手动实现 Softmax 和负对数似然损失¶
接下来,我们将手动实现 Softmax 函数和负对数似然损失的计算过程:
Python
# 应用 Softmax 函数
exp_logits = torch.exp(logits)
sum_exp_logits = torch.sum(exp_logits, axis=1, keepdim=True)
probabilities = exp_logits / sum_exp_logits
# 打印 Softmax 概率
print("\nSoftmax 概率:")
print(probabilities)
Python
# 初始化损失值累加器
loss_accumulator = 0.0
# 用 for 循环计算每个样本的损失
for i in range(len(labels)):
# 获取第 i 个样本的真实类别的概率
prob = probabilities[i, labels[i]]
# 计算第 i 个样本的负对数似然损失
loss = -torch.log(prob)
# 打印第 i 个样本的损失
print(f"\n样本{i}的损失值:")
print(loss)
# 累加损失
loss_accumulator += loss
# 计算平均损失
loss = loss_accumulator / len(labels)
# 打印平均损失值
print("\n平均损失值:")
print(loss)
Text Only
样本0的损失值:
tensor(0.5981)
样本1的损失值:
tensor(1.3873)
样本2的损失值:
tensor(0.3412)
平均损失值:
tensor(0.7755)
使用 PyTorch 的 CrossEntropyLoss
¶
为了验证我们的手动计算,我们可以使用 PyTorch 的CrossEntropyLoss
来计算损失值,并与我们的手动计算结果进行比较:
Python
# 使用 PyTorch 的 CrossEntropyLoss
criterion = torch.nn.CrossEntropyLoss()
torch_loss = criterion(logits, labels)
print(f"PyTorch 计算的损失值:{torch_loss}")
可以发现,与手动计算的结果一致。
总结¶
PyTorch 的CrossEntropyLoss
首先通过 Softmax 函数计算对应类别的概率值,然后计算每个样本的负对数似然损失,最后对所有样本的损失值求平均。