版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领
文档简介
第10章航班乘客数预测第10章航班乘客数预测10.1PyTorch简介10.2安装PyTorch10.3导入相关库10.4PyTorch基础知识10.5读取数据10.6数据预处理10.7定义网络模型10.8定义损失函数和优化器10.9训练模型10.10测试模型第10章航班乘客数预测10.1PyTorch简介PyTorch是由Facebook开发,基于Torch开发,从并不常用的Lua语言转为Python语言开发的深度学习框架,可以用于构建深度神经网络。Pytorch是一个基于Python的科学计算库,它面向以下两种人群:希望将其代替Numpy来利用GPUs的威力;一个可以提供更加灵活和快速的深度学习研究平台。第10章航班乘客数预测10.2安装PyTorchPyTorch的安装可以直接查看官网教程,如下所示,官网地址:/get-started/locally/第10章航班乘客数预测10.2安装PyTorch第10章航班乘客数预测10.3导入相关库import
torchimport
torch.nn
as
nn
import
numpy
as
npimport
matplotlib.pyplot
as
pltplt.rcParams['font.sans-serif']=['simsun']
#设置加载的字体名plt.rcParams['axes.unicode_minus']=False
#解决保存图像是负号'-'显示为方块的问题第10章航班乘客数预测10.4PyTorch基础知识10.4.1张量(1)创建一个张量x=torch.Tensor([1,2,3])
#创建一个1维张量y=torch.Tensor([[1,2],[3,4]])
#创建一个2维张量z=torch.Tensor([[[1,2],[3,4]],[[5,6],[7,8]]])
#创建一个3维张量xyz输出结果:第10章航班乘客数预测10.4PyTorch基础知识10.4.1张量(2)张量的形状z.shape
#获取张量的形状z.size()
#获取张量的形状z.view(2,4)
#改变张量的形状,2行,4列z.reshape(1,8)
#改变张量的形状,1行,8列z.resize_(2,4)
#直接修改原始张量的形状,2行,4列输出结果:第10章航班乘客数预测10.4PyTorch基础知识10.4.2自动微分PyTorch提供了自动微分功能,可以自动计算梯度,这使得模型训练更加容易。我们使用torch.tensor()来定义张量,然后使用.backward()函数计算梯度。x=torch.tensor(2.0,requires_grad=True)#定义张量x,并将requires_grad设置为True,以便PyTorch跟踪它的计算历史y=x**2
#定义新的张量y,它是x的平方y.backward()#调用y.backward()来计算y相对于x的导数x.grad
#打印出结果为tensor(4.)第10章航班乘客数预测10.4PyTorch基础知识10.4.3神经网络PyTorch提供了torch.nn模块,可以帮助开发者更轻松地构建和训练神经网络模型。可以使用torch.nn.Module()类定义神经网络模型,然后使用torch.optim优化器进行训练。10.4.4数据加载PyTorch提供了torch.utils.data模块,可以帮助开发者更轻松地加载和处理数据。可以使用torch.utils.data.Dataset()类定义数据集,然后使用torch.utils.data,DataLoader()函数加载数据。10.4.5GPU加速PyTorch可以使用GPU加速,可以使用torch.cuda模块将张量和模型移动到GPU上运行。第10章航班乘客数预测10.5读取数据with
open("data\international-airline-passengers.csv","r",encoding="utf-8")asf:
next(f)#跳过第1行
data_csv=f.read()
#将文件内容读取到变量data中data=[row.split(',')forrowin
data_csv.split("\n")]#将字符串变量data_csv中的每一行按逗号分隔并返回一个列表。这个列表包含了每一行的元素。passengers=[int(each[1])foreachindata]#将列表变量data中的每个元素的第二个字符转换为整数并返回一个新的列表。Passengers#打印前10个月中每月的航班乘客数输出结果:第10章航班乘客数预测10.6数据预处理接下来,我们首先使用滑动窗口方法创建基于航班乘客数的时间序列数据。然后,将序列数据转换成满足模型输入要求的训练数据集和测试数据集。这样,我们就可以使用前2天的航班乘客数来预测第3天的航班乘客数。第10章航班乘客数预测10.7定义网络模型class
Net(nn.Module):
#初始化函数,定义网络结构
def
__init__(self):
#调用父类的初始化函数
super(Net,self).__init__()
#定义一个LSTM层,输入特征为1(只有乘客数),隐藏状态大小为32,层数为1,batch_first为True
self.lstm=nn.LSTM(input_size=1,hidden_size=32,num_layers=1,batch_first=True)
#定义一个线性层,将32*seq_len个输入特征映射到1个输出特征(预测下一月乘客数)
self.linear=nn.Linear(32*seq_len,1)
#前向传播函数
def
forward(self,input):
#将输入input输入到LSTM层中,得到输出结果lstm_out,隐藏状态h和单元状态c
lstm_out,(h,c)=self.lstm(input)
#将lstm_out进行reshape,变成一个形状为(-1,32*seq_len)的张量
x=lstm_out.reshape(-1,32*seq_len)
#将x输入到线性层中,得到输出pred
pred=self.linear(x)
#返回输出pred
return
pred第10章航班乘客数预测10.8定义损失函数和优化器model=Net()#定义一个Adam优化器,用于更新模型参数,学习率为0.003optimizer=torch.optim.Adam(model.parameters(),lr=0.003)#定义一个均方误差损失函数,用于计算模型预测值与真实值之间的误差loss_fun=nn.MSELoss()第10章航班乘客数预测10.9训练模型#将模型设置为训练模式model.train()#进行300轮训练for
epoch
in
range(300):
#将训练数据train_x输入到模型中,得到模型的输出output
output=model(train_x)
#计算模型输出output与训练标签train_y之间的均方误差损失
loss=loss_fun(output,train_y)
#将优化器的梯度清零
optimizer.zero_grad()
#反向传播计算梯度
loss.backward()
#使用优化器更新模型参数
optimizer.step()
#每20轮输出一次训练损失和测试损失
if
epoch%20==0
and
epoch>0:
#将测试数据test_x输入到模型中,得到模型的输出output
#计算模型输出output与测试标签test_y之间的均方误差损失
test_loss=loss_fun(model(test_x),test_y)
#输出当前轮数、训练损失和测试损失
print("epoch:{},loss:{},test_loss:{}".format(epoch,loss,test_loss))第10章航班乘客数预测10.9训练模型第10章航班乘客数预测10.10测试模型第10章航班乘客数预测10.10测试模型#将模型设置为评估模式model.eval()#构造预测结果result=X[0][:seq_len-1]
温馨提示
- 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
- 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
- 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
- 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
- 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
- 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
- 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。
最新文档
- 新能源技术研发进展报告
- 个人网站搭建与维护实战指南
- 客户服务流程优化商洽通知函(6篇范文)
- 健康成长课:关注身心健康茁壮成长小学主题班会课件
- 预防校园欺凌筑就友善成长小学低年级主题班会课件
- 朔黄铁路小半径曲线改造工程施工2标混凝土搅拌站项目水土保持方案报告表
- 远离危险行为守护安全校园小学主题班会课件
- 2026年节日活动商业合作函4篇
- 重要客户信息泄露事情责任追究与修复预案
- 供应商交货地址变更商洽函5篇
- 2026银行遴选面试题及答案
- 2026乌鲁木齐城市轨道集团招聘(191人)笔试参考题库及答案详解
- 厂房设备搬迁改造项目合同文本
- 华中科技大学2026年强基计划校考(面试+体育测试)模拟试题及答案解析
- 2026年人教版高一第二学期地理期末普通高中统考试卷(附答案可下载)
- 2026贵州毕节黔西市粮油购销有限公司面向社会公开招聘工作人员3人考试模拟试题及答案详解
- 华为BTS3900基站维护手册
- 某塑料包装厂质量管理体系细则
- 四川省成都市高新区2024-2025学年七下期末数学试卷(原卷版)
- 2026年职业病防治知识考试试题(含答案)
- 2026年国家能源集团河南公司校园招聘笔试参考题库及答案解析
评论
0/150
提交评论