Pytorch深度学习者必备技能教程四_第1页
Pytorch深度学习者必备技能教程四_第2页
Pytorch深度学习者必备技能教程四_第3页
全文预览已结束

下载本文档

版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领

文档简介

1、通过Pytorch搭建的神经网络来描述一个简单的分类问题 搭建思路描述:? Torch生成数据集、预处理?搭建神经网络?损失函数与优化函数选择?模型训练以及图像绘制?结果可视化展示搭建神经网络1 .导入需要的包与模块:上一篇文章我们已经讲了导包注意的问题了,现在直接导入搭建该神经 网络所需要的包即可from torch.nn import functional as Ffrom torch.autogradimport Variablefrom matplotlib import pyplot as pltfrom torch.nn import Linear, CrossEntropyLos

2、sfrom torch.optim import SGDimport torch.nn as nnimport torch2 .Torch生成数据集、预处理:今天搭建的神经网络主要是要对提供数据集进行分类,所以这里需要构 造两种类型的数据。n_data = torch.ones(100, 2)x_0 = torch.normal(2 * n_data, 1) #第一个参数:平均值 第二个参数:方差y_0 = torch.zeros(100)x_1 = torch.normal(-2 * n_data, 1)y_1 = torch.ones(100)x = torch.cat(x_0, x_1)

3、, O).type(torch.FloatTensor) #第一个参数:输入数据第二个参数:指定隹足y = torch.cat(y_0, y_1), ).type(torch.FloatTensor)x, y = Variable(X) , Variable(Y)3 .搭建神经网络:搭建一个全连接神经网络,函数的选择原因和上篇文章一样。在_init_函数里搭建一个隐藏层与输出层。forward 函数主要负责前向传播,并得到神经网络的输出值classNet(Module):def_init_(self, n_feature, n_hidden, n_output):super(Net, self

4、)._init_()self.hidden = Linear(n_feature, n_hidden)self.pred = Linear(n_hidden, n_output)def forward (self, x):x = F.relu(self.hidden(x)x = self.pred(x)return xnet = Net(n_feature=2, n_hidden=10, n_output=2)4 .损失函数与优化函数选择:在机器学习中(特别是分类模型),模型训练时,通常都是使用交叉嫡 (Cross-Entropy )作为损失进行最小化对于分类问题,我们一般选用交叉 嫡作为损失

5、函数。优化函数的选择,和上篇文章讲的同理,这里就选SGD即可满足要求。optimizer= SGD(net.parameters(), lr=0.02)loss_func= CrossEntropyLoss()5 .模型训练以及图像绘制此处我们对模型训练1000次,一次来保证最终的参数符合我们的要求,并将最终的预测结果展示出来。plt.ion()for step in range(1000):y_pred = net (x)loss = loss_func(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()if step

6、 % 5 = 0:plt.cla()pred = torch.max(y_pred, 1)1 #返回行最大值对应的列索引pred_y = pred.data.numpy().sequeeze()target_y = y.data.numpy()plt.scatter(x.data.numpy:, 0, x.data.numpy():, 1, c=pred_y, s=100, lw=0, cmap="RdYlGn")accuracy = float(pred_y = target_y).astype(int).sum() / float(target_y.size) # # 求预测值与真实值之间的田柳一一plt.text(1.5, -4, "Accuracy=%.2f" % accuracy,fontdict="size": 20, "color": "red")plt.pause(0.1)plt.ioff()plt.savefig("自己保存图片的路径")plt.show()6 .结果可视化展示:-4-202由绘制的图可以看出来,对于我们自己生成的数据,神经网

温馨提示

  • 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
  • 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
  • 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
  • 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
  • 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
  • 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
  • 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。

评论

0/150

提交评论