您现在的位置是:首页 > 技术教程 正文

补充d2l.torch库里面缺失train_ch3函数

admin 阅读: 2024-03-21
后台-插件-广告管理-内容页头部广告(手机)

在最新版本1.0.3,上 遇到d2l.torch库里面缺失train_ch3函数,下面是个人写的替代补充函数可以完全平替。
所有函数都放在util.py文件中

import torch.nn from d2l import torch as d2l from IPython import display class Accumulator: """ 在n个变量上累加 """ def __init__(self, n): self.data = [0.0] * n # 创建一个长度为 n 的列表,初始化所有元素为0.0。 def add(self, *args): # 累加 self.data = [a + float(b) for a, b in zip(self.data, args)] def reset(self): # 重置累加器的状态,将所有元素重置为0.0 self.data = [0.0] * len(self.data) def __getitem__(self, idx): # 获取所有数据 return self.data[idx] def accuracy(y_hat, y): """ 计算正确的数量 :param y_hat: :param y: :return: """ if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: y_hat = y_hat.argmax(axis=1) # 在每行中找到最大值的索引,以确定每个样本的预测类别 cmp = y_hat.type(y.dtype) == y return float(cmp.type(y.dtype).sum()) def evaluate_accuracy(net, data_iter): """ 计算指定数据集的精度 :param net: :param data_iter: :return: """ if isinstance(net, torch.nn.Module): net.eval() # 通常会关闭一些在训练时启用的行为 metric = Accumulator(2) with torch.no_grad(): for X, y in data_iter: metric.add(accuracy(net(X), y), y.numel()) return metric[0] / metric[1] class Animator: """ 在动画中绘制数据 """ def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None, xscale='linear', yscale='linear', fmts=('-', 'm--', 'g-', 'r:'), nrows=1, ncols=1, figsize=(3.5, 2.5)): # 增量的绘制多条线 if legend is None: legend = [] d2l.use_svg_display() self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize) if nrows * ncols == 1: self.axes = [self.axes, ] # 使用lambda函数捕获参数 self.config_axes = lambda: d2l.set_axes( self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend ) self.X, self.Y, self.fmts = None, None, fmts def add(self, x, y): """ 向图表中添加多个数据点 :param x: :param y: :return: """ if not hasattr(y, "__len__"): y = [y] n = len(y) if not hasattr(x, "__len__"): x = [x] * n if not self.X: self.X = [[] for _ in range(n)] if not self.Y: self.Y = [[] for _ in range(n)] for i, (a, b) in enumerate(zip(x, y)): if a is not None and b is not None: self.X[i].append(a) self.Y[i].append(b) self.axes[0].cla() for x, y, fmt in zip(self.X, self.Y, self.fmts): self.axes[0].plot(x, y, fmt) self.config_axes() display.display(self.fig) display.clear_output(wait=True) def train_epoch_ch3(net, train_iter, loss, updater): """ 训练模型一轮 :param net:是要训练的神经网络模型 :param train_iter:是训练数据的数据迭代器,用于遍历训练数据集 :param loss:是用于计算损失的损失函数 :param updater:是用于更新模型参数的优化器 :return: """ if isinstance(net, torch.nn.Module): # 用于检查一个对象是否属于指定的类(或类的子类)或数据类型。 net.train() # 训练损失总和, 训练准确总和, 样本数 metric = Accumulator(3) for X, y in train_iter: # 计算梯度并更新参数 y_hat = net(X) l = loss(y_hat, y) if isinstance(updater, torch.optim.Optimizer): # 用于检查一个对象是否属于指定的类(或类的子类)或数据类型。 # 使用pytorch内置的优化器和损失函数 updater.zero_grad() l.mean().backward() # 方法用于计算损失的平均值 updater.step() else: # 使用定制(自定义)的优化器和损失函数 l.sum().backward() updater(X.shape()) metric.add(float(l.sum()), accuracy(y_hat, y), y.numel()) # 返回训练损失和训练精度 return metric[0] / metric[2], metric[1] / metric[2] def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): """ 训练模型() :param net: :param train_iter: :param test_iter: :param loss: :param num_epochs: :param updater: :return: """ animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9], legend=['train loss', 'train acc', 'test acc']) for epoch in range(num_epochs): trans_metrics = train_epoch_ch3(net, train_iter, loss, updater) test_acc = evaluate_accuracy(net, test_iter) animator.add(epoch + 1, trans_metrics + (test_acc,)) train_loss, train_acc = trans_metrics print(trans_metrics) def predict_ch3(net, test_iter, n=6): """ 进行预测 :param net: :param test_iter: :param n: :return: """ global X, y for X, y in test_iter: break trues = d2l.get_fashion_mnist_labels(y) preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1)) titles = [true + "\n" + pred for true, pred in zip(trues, preds)] d2l.show_images( X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n] ) d2l.plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173

直接调用即可
在这里插入图片描述
找个位置放就行
在这里插入图片描述

标签:
声明

1.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源;2.本站的原创文章,请转载时务必注明文章作者和来源,不尊重原创的行为我们将追究责任;3.作者投稿可能会经我们编辑修改或补充。

在线投稿:投稿 站长QQ:1888636

后台-插件-广告管理-内容页尾部广告(手机)
关注我们

扫一扫关注我们,了解最新精彩内容

搜索
排行榜