跳转至

理解 PyTorch 中的 CrossEntropyLoss

在机器学习中,特别是处理分类问题时,损失函数是衡量模型预测与实际标签差异的关键。在 PyTorch 中,CrossEntropyLoss是一个常用的损失函数,用于分类问题。它首先通过 Softmax 函数计算对应类别的概率值,然后计算每个样本的负对数似然损失,最后对所有样本的损失值求平均。

本文将通过一个简单的例子来手动计算CrossEntropyLoss,并展示如何使用 PyTorch 实现这一过程。

创建样本和标签

首先,我们需要创建三个简单的样本的 logits 值和对应的标签:

Python
import torch

# 创建 logits(模型输出的原始分数)
logits = torch.tensor([[0.3, 0.1], [-0.4, 0.7], [-0.1, 0.8]])

print("原始 logits:")
print(logits)

# 创建标签
labels = torch.tensor([0, 0, 1])

print("labels:")
print(labels)
Text Only
原始logits:
tensor([[ 0.3000,  0.1000],
        [-0.4000,  0.7000],
        [-0.1000,  0.8000]])

labels:
tensor([0, 0, 1])

在这个例子中,我们有三个样本,每个样本对应两个类别的得分。标签表示每个样本的真实类别。

手动实现 Softmax 和负对数似然损失

接下来,我们将手动实现 Softmax 函数和负对数似然损失的计算过程:

Python
# 应用 Softmax 函数
exp_logits = torch.exp(logits)
sum_exp_logits = torch.sum(exp_logits, axis=1, keepdim=True)
probabilities = exp_logits / sum_exp_logits

# 打印 Softmax 概率
print("\nSoftmax 概率:")
print(probabilities)
Text Only
Softmax概率:
tensor([[0.5498, 0.4502],
        [0.2497, 0.7503],
        [0.2891, 0.7109]])
Python
# 初始化损失值累加器
loss_accumulator = 0.0

# 用 for 循环计算每个样本的损失
for i in range(len(labels)):
    # 获取第 i 个样本的真实类别的概率
    prob = probabilities[i, labels[i]]
    # 计算第 i 个样本的负对数似然损失
    loss = -torch.log(prob)
    # 打印第 i 个样本的损失
    print(f"\n样本{i}的损失值:")
    print(loss)
    # 累加损失
    loss_accumulator += loss

# 计算平均损失
loss = loss_accumulator / len(labels)

# 打印平均损失值
print("\n平均损失值:")
print(loss)
Text Only
样本0的损失值:
tensor(0.5981)

样本1的损失值:
tensor(1.3873)

样本2的损失值:
tensor(0.3412)

平均损失值:
tensor(0.7755)

使用 PyTorch 的 CrossEntropyLoss

为了验证我们的手动计算,我们可以使用 PyTorch 的CrossEntropyLoss来计算损失值,并与我们的手动计算结果进行比较:

Python
# 使用 PyTorch 的 CrossEntropyLoss
criterion = torch.nn.CrossEntropyLoss()
torch_loss = criterion(logits, labels)

print(f"PyTorch 计算的损失值:{torch_loss}")
Text Only
PyTorch计算的损失值: 0.775542676448822

可以发现,与手动计算的结果一致。

总结

PyTorch 的CrossEntropyLoss首先通过 Softmax 函数计算对应类别的概率值,然后计算每个样本的负对数似然损失,最后对所有样本的损失值求平均。

评论