版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领
文档简介
第10章航班乘客数预测第10章航班乘客数预测10.1PyTorch简介10.2安装PyTorch10.3导入相关库10.4PyTorch基础知识10.5读取数据10.6数据预处理10.7定义网络模型10.8定义损失函数和优化器10.9训练模型10.10测试模型第10章航班乘客数预测10.1PyTorch简介PyTorch是由Facebook开发,基于Torch开发,从并不常用的Lua语言转为Python语言开发的深度学习框架,可以用于构建深度神经网络。Pytorch是一个基于Python的科学计算库,它面向以下两种人群:希望将其代替Numpy来利用GPUs的威力;一个可以提供更加灵活和快速的深度学习研究平台。第10章航班乘客数预测10.2安装PyTorchPyTorch的安装可以直接查看官网教程,如下所示,官网地址:/get-started/locally/第10章航班乘客数预测10.2安装PyTorch第10章航班乘客数预测10.3导入相关库import
torchimport
torch.nn
as
nn
import
numpy
as
npimport
matplotlib.pyplot
as
pltplt.rcParams['font.sans-serif']=['simsun']
#设置加载的字体名plt.rcParams['axes.unicode_minus']=False
#解决保存图像是负号'-'显示为方块的问题第10章航班乘客数预测10.4PyTorch基础知识10.4.1张量(1)创建一个张量x=torch.Tensor([1,2,3])
#创建一个1维张量y=torch.Tensor([[1,2],[3,4]])
#创建一个2维张量z=torch.Tensor([[[1,2],[3,4]],[[5,6],[7,8]]])
#创建一个3维张量xyz输出结果:第10章航班乘客数预测10.4PyTorch基础知识10.4.1张量(2)张量的形状z.shape
#获取张量的形状z.size()
#获取张量的形状z.view(2,4)
#改变张量的形状,2行,4列z.reshape(1,8)
#改变张量的形状,1行,8列z.resize_(2,4)
#直接修改原始张量的形状,2行,4列输出结果:第10章航班乘客数预测10.4PyTorch基础知识10.4.2自动微分PyTorch提供了自动微分功能,可以自动计算梯度,这使得模型训练更加容易。我们使用torch.tensor()来定义张量,然后使用.backward()函数计算梯度。x=torch.tensor(2.0,requires_grad=True)#定义张量x,并将requires_grad设置为True,以便PyTorch跟踪它的计算历史y=x**2
#定义新的张量y,它是x的平方y.backward()#调用y.backward()来计算y相对于x的导数x.grad
#打印出结果为tensor(4.)第10章航班乘客数预测10.4PyTorch基础知识10.4.3神经网络PyTorch提供了torch.nn模块,可以帮助开发者更轻松地构建和训练神经网络模型。可以使用torch.nn.Module()类定义神经网络模型,然后使用torch.optim优化器进行训练。10.4.4数据加载PyTorch提供了torch.utils.data模块,可以帮助开发者更轻松地加载和处理数据。可以使用torch.utils.data.Dataset()类定义数据集,然后使用torch.utils.data,DataLoader()函数加载数据。10.4.5GPU加速PyTorch可以使用GPU加速,可以使用torch.cuda模块将张量和模型移动到GPU上运行。第10章航班乘客数预测10.5读取数据with
open("data\international-airline-passengers.csv","r",encoding="utf-8")asf:
next(f)#跳过第1行
data_csv=f.read()
#将文件内容读取到变量data中data=[row.split(',')forrowin
data_csv.split("\n")]#将字符串变量data_csv中的每一行按逗号分隔并返回一个列表。这个列表包含了每一行的元素。passengers=[int(each[1])foreachindata]#将列表变量data中的每个元素的第二个字符转换为整数并返回一个新的列表。Passengers#打印前10个月中每月的航班乘客数输出结果:第10章航班乘客数预测10.6数据预处理接下来,我们首先使用滑动窗口方法创建基于航班乘客数的时间序列数据。然后,将序列数据转换成满足模型输入要求的训练数据集和测试数据集。这样,我们就可以使用前2天的航班乘客数来预测第3天的航班乘客数。第10章航班乘客数预测10.7定义网络模型class
Net(nn.Module):
#初始化函数,定义网络结构
def
__init__(self):
#调用父类的初始化函数
super(Net,self).__init__()
#定义一个LSTM层,输入特征为1(只有乘客数),隐藏状态大小为32,层数为1,batch_first为True
self.lstm=nn.LSTM(input_size=1,hidden_size=32,num_layers=1,batch_first=True)
#定义一个线性层,将32*seq_len个输入特征映射到1个输出特征(预测下一月乘客数)
self.linear=nn.Linear(32*seq_len,1)
#前向传播函数
def
forward(self,input):
#将输入input输入到LSTM层中,得到输出结果lstm_out,隐藏状态h和单元状态c
lstm_out,(h,c)=self.lstm(input)
#将lstm_out进行reshape,变成一个形状为(-1,32*seq_len)的张量
x=lstm_out.reshape(-1,32*seq_len)
#将x输入到线性层中,得到输出pred
pred=self.linear(x)
#返回输出pred
return
pred第10章航班乘客数预测10.8定义损失函数和优化器model=Net()#定义一个Adam优化器,用于更新模型参数,学习率为0.003optimizer=torch.optim.Adam(model.parameters(),lr=0.003)#定义一个均方误差损失函数,用于计算模型预测值与真实值之间的误差loss_fun=nn.MSELoss()第10章航班乘客数预测10.9训练模型#将模型设置为训练模式model.train()#进行300轮训练for
epoch
in
range(300):
#将训练数据train_x输入到模型中,得到模型的输出output
output=model(train_x)
#计算模型输出output与训练标签train_y之间的均方误差损失
loss=loss_fun(output,train_y)
#将优化器的梯度清零
optimizer.zero_grad()
#反向传播计算梯度
loss.backward()
#使用优化器更新模型参数
optimizer.step()
#每20轮输出一次训练损失和测试损失
if
epoch%20==0
and
epoch>0:
#将测试数据test_x输入到模型中,得到模型的输出output
#计算模型输出output与测试标签test_y之间的均方误差损失
test_loss=loss_fun(model(test_x),test_y)
#输出当前轮数、训练损失和测试损失
print("epoch:{},loss:{},test_loss:{}".format(epoch,loss,test_loss))第10章航班乘客数预测10.9训练模型第10章航班乘客数预测10.10测试模型第10章航班乘客数预测10.10测试模型#将模型设置为评估模式model.eval()#构造预测结果result=X[0][:seq_len-1]
温馨提示
- 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
- 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
- 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
- 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
- 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
- 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
- 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。
最新文档
- 学校药品器材安全警示标识
- 实验室事故报告流程
- 电子产品生产资产管理指南
- 2024年艺人演艺事业发展规划3篇
- 油气开采挖机设备租赁合同
- 高铁工程预应力施工协议
- 轨道车物料成本优化
- 铁路建设临时用电服务合同
- 保险服务合同管理细则
- 体育场馆车辆管理规定
- 新疆维吾尔自治区巴音郭楞蒙古自治州2023-2024学年二年级上学期期末数学试卷
- 医院门窗工程施工方案与施工方法
- 短视频实习运营助理
- 2024年中化石油福建有限公司招聘笔试参考题库含答案解析
- 对加快推进新型工业化的认识及思考
- 移植后淋巴细胞增殖性疾病
- 风光储储能项目PCS舱、电池舱吊装方案
- 中医跟师总结论文3000字(通用3篇)
- 《军队征集和招录人员政治考核规定》
- 住宅小区视频监控清单及报价2020
- 电动三轮车监理细则
评论
0/150
提交评论