版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领
文档简介
第10章航班乘客数预测第10章航班乘客数预测10.1PyTorch简介10.2安装PyTorch10.3导入相关库10.4PyTorch基础知识10.5读取数据10.6数据预处理10.7定义网络模型10.8定义损失函数和优化器10.9训练模型10.10测试模型第10章航班乘客数预测10.1PyTorch简介PyTorch是由Facebook开发,基于Torch开发,从并不常用的Lua语言转为Python语言开发的深度学习框架,可以用于构建深度神经网络。Pytorch是一个基于Python的科学计算库,它面向以下两种人群:希望将其代替Numpy来利用GPUs的威力;一个可以提供更加灵活和快速的深度学习研究平台。第10章航班乘客数预测10.2安装PyTorchPyTorch的安装可以直接查看官网教程,如下所示,官网地址:/get-started/locally/第10章航班乘客数预测10.2安装PyTorch第10章航班乘客数预测10.3导入相关库import
torchimport
torch.nn
as
nn
import
numpy
as
npimport
matplotlib.pyplot
as
pltplt.rcParams['font.sans-serif']=['simsun']
#设置加载的字体名plt.rcParams['axes.unicode_minus']=False
#解决保存图像是负号'-'显示为方块的问题第10章航班乘客数预测10.4PyTorch基础知识10.4.1张量(1)创建一个张量x=torch.Tensor([1,2,3])
#创建一个1维张量y=torch.Tensor([[1,2],[3,4]])
#创建一个2维张量z=torch.Tensor([[[1,2],[3,4]],[[5,6],[7,8]]])
#创建一个3维张量xyz输出结果:第10章航班乘客数预测10.4PyTorch基础知识10.4.1张量(2)张量的形状z.shape
#获取张量的形状z.size()
#获取张量的形状z.view(2,4)
#改变张量的形状,2行,4列z.reshape(1,8)
#改变张量的形状,1行,8列z.resize_(2,4)
#直接修改原始张量的形状,2行,4列输出结果:第10章航班乘客数预测10.4PyTorch基础知识10.4.2自动微分PyTorch提供了自动微分功能,可以自动计算梯度,这使得模型训练更加容易。我们使用torch.tensor()来定义张量,然后使用.backward()函数计算梯度。x=torch.tensor(2.0,requires_grad=True)#定义张量x,并将requires_grad设置为True,以便PyTorch跟踪它的计算历史y=x**2
#定义新的张量y,它是x的平方y.backward()#调用y.backward()来计算y相对于x的导数x.grad
#打印出结果为tensor(4.)第10章航班乘客数预测10.4PyTorch基础知识10.4.3神经网络PyTorch提供了torch.nn模块,可以帮助开发者更轻松地构建和训练神经网络模型。可以使用torch.nn.Module()类定义神经网络模型,然后使用torch.optim优化器进行训练。10.4.4数据加载PyTorch提供了torch.utils.data模块,可以帮助开发者更轻松地加载和处理数据。可以使用torch.utils.data.Dataset()类定义数据集,然后使用torch.utils.data,DataLoader()函数加载数据。10.4.5GPU加速PyTorch可以使用GPU加速,可以使用torch.cuda模块将张量和模型移动到GPU上运行。第10章航班乘客数预测10.5读取数据with
open("data\international-airline-passengers.csv","r",encoding="utf-8")asf:
next(f)#跳过第1行
data_csv=f.read()
#将文件内容读取到变量data中data=[row.split(',')forrowin
data_csv.split("\n")]#将字符串变量data_csv中的每一行按逗号分隔并返回一个列表。这个列表包含了每一行的元素。passengers=[int(each[1])foreachindata]#将列表变量data中的每个元素的第二个字符转换为整数并返回一个新的列表。Passengers#打印前10个月中每月的航班乘客数输出结果:第10章航班乘客数预测10.6数据预处理接下来,我们首先使用滑动窗口方法创建基于航班乘客数的时间序列数据。然后,将序列数据转换成满足模型输入要求的训练数据集和测试数据集。这样,我们就可以使用前2天的航班乘客数来预测第3天的航班乘客数。第10章航班乘客数预测10.7定义网络模型class
Net(nn.Module):
#初始化函数,定义网络结构
def
__init__(self):
#调用父类的初始化函数
super(Net,self).__init__()
#定义一个LSTM层,输入特征为1(只有乘客数),隐藏状态大小为32,层数为1,batch_first为True
self.lstm=nn.LSTM(input_size=1,hidden_size=32,num_layers=1,batch_first=True)
#定义一个线性层,将32*seq_len个输入特征映射到1个输出特征(预测下一月乘客数)
self.linear=nn.Linear(32*seq_len,1)
#前向传播函数
def
forward(self,input):
#将输入input输入到LSTM层中,得到输出结果lstm_out,隐藏状态h和单元状态c
lstm_out,(h,c)=self.lstm(input)
#将lstm_out进行reshape,变成一个形状为(-1,32*seq_len)的张量
x=lstm_out.reshape(-1,32*seq_len)
#将x输入到线性层中,得到输出pred
pred=self.linear(x)
#返回输出pred
return
pred第10章航班乘客数预测10.8定义损失函数和优化器model=Net()#定义一个Adam优化器,用于更新模型参数,学习率为0.003optimizer=torch.optim.Adam(model.parameters(),lr=0.003)#定义一个均方误差损失函数,用于计算模型预测值与真实值之间的误差loss_fun=nn.MSELoss()第10章航班乘客数预测10.9训练模型#将模型设置为训练模式model.train()#进行300轮训练for
epoch
in
range(300):
#将训练数据train_x输入到模型中,得到模型的输出output
output=model(train_x)
#计算模型输出output与训练标签train_y之间的均方误差损失
loss=loss_fun(output,train_y)
#将优化器的梯度清零
optimizer.zero_grad()
#反向传播计算梯度
loss.backward()
#使用优化器更新模型参数
optimizer.step()
#每20轮输出一次训练损失和测试损失
if
epoch%20==0
and
epoch>0:
#将测试数据test_x输入到模型中,得到模型的输出output
#计算模型输出output与测试标签test_y之间的均方误差损失
test_loss=loss_fun(model(test_x),test_y)
#输出当前轮数、训练损失和测试损失
print("epoch:{},loss:{},test_loss:{}".format(epoch,loss,test_loss))第10章航班乘客数预测10.9训练模型第10章航班乘客数预测10.10测试模型第10章航班乘客数预测10.10测试模型#将模型设置为评估模式model.eval()#构造预测结果result=X[0][:seq_len-1]
温馨提示
- 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
- 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
- 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
- 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
- 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
- 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
- 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。
最新文档
- 2026年兴业银行南昌分行社会招聘参考考试试题附答案解析
- 2026河北雄安人才服务有限公司商业招商岗招聘1人参考考试题库附答案解析
- 2026青海西宁市应急管理局招聘安全生产实操考评员备考考试题库附答案解析
- 2026山东临沂市市直部分医疗卫生事业单位招聘医疗后勤岗位工作人员9人参考考试试题附答案解析
- 2026中建三局三公司校园招聘备考考试试题附答案解析
- 2026西藏山南加查县文旅局公益性岗位招聘1人备考考试题库附答案解析
- 2026上半年云南事业单位联考特殊教育职业学院招聘6人备考考试试题附答案解析
- 2026年保山市昌宁县财政局招聘公益性岗位人员(5人)参考考试题库附答案解析
- 广电局安全生产制度
- 学习生产车间管理制度
- GB/T 9706.266-2025医用电气设备第2-66部分:助听器及助听器系统的基本安全和基本性能专用要求
- 2026年企业级云服务器采购合同
- 2026广西桂林医科大学人才招聘27人备考题库(第一批)及参考答案详解一套
- 2026年度黑龙江省生态环境厅所属事业单位公开招聘工作人员57人备考题库及答案详解一套
- 2025安徽省中煤三建国际公司机关工作人员内部竞聘31人笔试历年参考题库附带答案详解
- 2026国家国防科技工业局所属事业单位第一批招聘62人笔试参考题库及答案解析
- 北京2025年北京教育科学研究院公开招聘笔试历年参考题库附带答案详解
- 2025至2030中国谷氨酸和味精行业深度研究及发展前景投资评估分析
- 人教版高二化学上册期末真题试题题库试题附答案完整版
- 生产样品合同范本
- 2025职业技能培训学校自查报告范文(3篇)
评论
0/150
提交评论