matplotlib
动态绘图——神经网络训练过程可视化¶
本文使用 matplotlib
实现动态绘图,可以用于查看神经网络训练过程的损失值和评估指标的变化情况。
本文部分代码参考了《动手学深度学习》的
utils.py
中的函数。
导入包¶
Python
import numpy as np
from IPython import display
from matplotlib import pyplot as plt
from matplotlib_inline import backend_inline
定义辅助函数和 Animator 类¶
Python
def use_svg_display():
"""Use the svg format to display a plot in Jupyter."""
backend_inline.set_matplotlib_formats("svg")
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
"""Set the axes for matplotlib."""
axes.set_xlabel(xlabel), axes.set_ylabel(ylabel)
axes.set_xscale(xscale), axes.set_yscale(yscale)
axes.set_xlim(xlim), axes.set_ylim(ylim)
if legend:
axes.legend(legend)
axes.grid()
class Animator:
"""For plotting data in animation."""
def __init__(
self,
xlabels=[None, None],
ylabels=[None, None],
legends=[None, None],
xlims=[None, None],
ylims=[None, None],
xscales=["linear", "linear"],
yscales=["linear", "linear"],
fmts=["c--", "m", "g--", "r"],
nrows=1,
ncols=2,
figsize=(10, 4),
):
# 以 svg 矢量图格式显示
use_svg_display()
# Incrementally plot multiple lines
if legends is None:
legends = [[], []]
self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)
# Use a lambda function to capture arguments
self.config_axes = lambda: (
set_axes(
self.axes[0],
xlabels[0],
ylabels[0],
xlims[0],
ylims[0],
xscales[0],
yscales[0],
legends[0],
),
set_axes(
self.axes[1],
xlabels[1],
ylabels[1],
xlims[1],
ylims[1],
xscales[1],
yscales[1],
legends[1],
),
)
self.X, self.Y, self.fmts = None, None, fmts
def add(self, x, y):
# Add multiple data points into the figure
if not hasattr(y, "__len__"):
y = [y]
n = len(y)
if not hasattr(x, "__len__"):
x = [x] * n
if not self.X:
self.X = [[] for _ in range(n)]
if not self.Y:
self.Y = [[] for _ in range(n)]
# 添加新数据
for i, (a, b) in enumerate(zip(x, y)):
if a is not None and b is not None:
self.X[i].append(a)
self.Y[i].append(b)
# 清除坐标轴上的所有内容
self.axes[0].cla()
self.axes[1].cla()
# 绘制图像
for i, (x, y, fmt) in enumerate(zip(self.X, self.Y, self.fmts)):
if i < 2:
self.axes[0].plot(x, y, fmt)
else:
self.axes[1].plot(x, y, fmt)
# 配置图形参数
self.config_axes()
# 显示图像
display.display(self.fig)
# 当输出内容有更新时,则将旧的输出内容全部清除
display.clear_output(wait=True)
创建示例数据,演示动态绘图过程¶
Python
n_epochs = 30
# 定义可视化实例
animator = Animator(
xlabels=["epoch", "epoch"],
legends=[["train loss", "valid loss"], ["train metric", "valid metric"]],
xlims=[[1, n_epochs], [1, n_epochs]],
ylims=[[0, 20], [0, 20]],
)
for step in range(1, n_epochs + 1):
train_loss, train_score = (-3 * np.log(step) + 15, 3 * np.log(step))
val_loss, val_score = (-2 * np.log(step) + 15, 2 * np.log(step))
# 动态地绘制损失值和评估指标的折线图
animator.add(step, (train_loss, val_loss, train_score, val_score))