版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领
文档简介
PyTorch深度学习项目教程
猫狗图像分类IMAGECLASSIFICATION要点:监督学习、分类任务、多层感知机、数据增强
项目背景ProjectBackground要点:分类任务是深度学习的基本任务。问题:猫狗图像分类是一个经典的计算机视觉问题,目标是对给定的图像进行分类,判断图像中是猫还是狗。由于猫狗的相似性,将二者完全分开面临极大的挑战。解决:通过监督学习进行解决,即通过带标签的图像数据集来进行训练和评估模型。训练阶段,模型接收大量的猫狗图像作为输入,并对其进行学习,调整模型参数以最小化预测结果与真实标签之间的差异。评估阶段,使用另外的图像数据集对训练好的模型进行测试和验证其分类准确性。知识目标KnowledgeObjectives理解并应用PyTorch中DataSet类进行数据增强,以提高模型泛化能力学会使用torchstat工具包对模型的参数进行统计分析,以监控模型训练状态掌握logging工具包的使用,实现训练过程的日志记录与输出,便于模型调试与分析理解回归与分类任务的区别,并掌握分类任务中的关键算法学习和掌握Sigmoid与Softmax函数在不同分类任务中的应用能力目标AbilityGoals能够利用网络资源加载和处图像数据集,进行有效的数据预处理和增强。能够独立搭建并训练一个全连接神经网络模型,用于图像分类任务。能够运用日志记录和其他评估手段,对模型的性能进行监控和分析,提出改进策略素养目标ProfessionalAttainments培养项目化思维,养成项目开发的全局视角,合理规划项目进度增强自主学习能力,能够独立分析问题,寻找解决方案,并在项目中实践提升自我学习意识,通过本项目的学习,激发对监督学习数据集收集和构建的兴趣,主动寻找和应用更多的学习资源目录任务1准备猫狗数据集任务2设计图像分类全连接网络任务3训练图像分类网络任务4应用分类网络推理更多图片任务5认识深度学习的主要任务01任务1准备猫狗数据集1.1数据集获取
1.2数据整理及划分对收集到的图像进行预处理:可以使用图像编辑软件进行裁剪、缩放、旋转、去除噪声等预处理操作可参见OpenCV类的教材书籍一般将数据集分为训练集、验证集和测试集3类,根据数据样本数量可以按照8:1:1、8:2:0或7:2:1等比例进行划分。训练集:用于训练模型验证集:用于在训练过程中判断模型是否收敛测试集:用于评估模型的性能本项目,首先创建train、val和test文件夹,在对应文件夹内再创建cat、dog等具体类别,如图所示。之后在对应的类别文件夹内放置相应类别图片,即可完成数据集的构建1.3创建数据集类classDogCatDataset(Dataset):
def__init__(self,data_dir,transform=None):
“”“
分类任务的Dataset
:paramdata_dir:str,数据集所在路径
:paramtransform:torch.transform,数据预处理
”“”
self.label_name={“cat”:0,“dog”:1}#需要根据实际训练任务修改
self.data_info=self.get_img_info(data_dir)#data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform=transform
def__getitem__(self,index):
path_img,label=self.data_info[index]
img=Image.open(path_img).convert(‘RGB’)#0~255
ifself.transformisnotNone:
img=self.transform(img)#在这里做transform,转为tensor等等
returnimg,label
def__len__(self):
returnlen(self.data_info)
defget_img_info(self,data_dir):
data_info=list()
forroot,dirs,_inos.walk(data_dir):
#遍历类别
forsub_dirindirs:
img_names=os.listdir(os.path.join(root,sub_dir))
img_names=list(filter(lambdax:x.endswith(‘.jpg’),img_names))
#遍历图片
foriinrange(len(img_names)):
img_name=img_names[i]
path_img=os.path.join(root,sub_dir,img_name)
label=self.label_name[sub_dir]
data_info.append((path_img,int(label)))
returndata_info项目组织结构02任务2设计图像分类网络2分类MLPclassMLP(nn.Module):
def__init__(self,classes=2):#64x64x3
#使用super调用父类中的init函数对网络进行初始化
super(MLP,self).__init__()
#构建网络子模块,存储到MLP的module属性中
self.fc1=nn.Linear(64*64*3,32*32*3)
self.fc2=nn.Linear(32*32*3,16*16*3)
self.fc3=nn.Linear(16*16*3,256)
self.fc4=nn.Linear(256,128)
self.fc5=nn.Linear(128,classes)
#按照每层网络结构写处forward函数,即如何进行前向传播
defforward(self,x):
x=x.view(x.size(0),-1)#重置为[batchsize,一维数组]的形状
out=torch.relu(self.fc1(x))
out=torch.relu(self.fc2(out))
out=torch.relu(self.fc3(out))
out=torch.relu(self.fc4(out))
out=self.fc5(out)
returnout总参数个数:40,341,890torchstat模块统计:项目组织结构03任务3训练图像分类网络3.1训练日志importlogging
defget_logger(filename,verbosity=1,name=None):
level_dict={0:logging.DEBUG,1:logging.INFO,2:logging.WARNING}
formatter=logging.Formatter(
"[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s]%(message)s"
)
logger=logging.getLogger(name)
logger.setLevel(level_dict[verbosity])
fh=logging.FileHandler(filename,mode="a",encoding='utf-8')
fh.setFormatter(formatter)
logger.addHandler(fh)
sh=logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(sh)
returnlogger
if__name__=='__main__':
logger=get_logger('test.log')
("这是一条测试指令")项目组织结构记录训练过程,更好追溯问题3.2训练过程项目组织结构#参数设置MAX_EPOCH=30BATCH_SIZE=128LR=0.01log_interval=10val_interval=1device=torch.device("cuda:0"iftorch.cuda.is_available()else"cpu")训练初始化train_transform=transforms.Compose([transforms.Resize((64,64)),transforms.RandomCrop(64,padding=4),transforms.ToTensor(),transforms.Normalize(norm_mean,norm_std),])train_data=DogCatDataset(data_dir=train_dir,transform=train_transform)train_loader=DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)配置数据集net=MLP(classes=2)net.to(device)加载网络模型criterion=nn.CrossEntropyLoss()#损失函数设置optimizer=optim.SGD(net.parameters(),lr=LR,momentum=0.9)#优化器设置scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=100,gamma=0.1)#设置学习率下降策略
配置训练策略forepochinrange(MAX_EPOCH):...fori,datainenumerate(train_loader):
迭代训练3.2.1初始化参数初始化训练参数,由于数据量超过1万张图片,一般需要三位数的训练轮次才能让模型收敛,这里可以先预设一个小MAX_EPOCH值,便于观察程序模块是否配置正确:判断是否安装了英伟达GPU及其CUDA驱动:MAX_EPOCH=30
BATCH_SIZE=128
LR=0.01
log_interval=10
val_interval=1device=torch.device("cuda:0"iftorch.cuda.is_available()else"cpu")3.2.2配置数据集split_dir=os.path.join("data","dogs-vs-cats")
train_dir=os.path.join(split_dir,"train")
valid_dir=os.path.join(split_dir,"val")
norm_mean=[0.485,0.456,0.406]
norm_std=[0.229,0.224,0.225]
train_transform=transforms.Compose([
transforms.Resize((64,64)),
transforms.RandomCrop(64,padding=4),
transforms.ToTensor(),
transforms.Normalize(norm_mean,norm_std),
])
valid_transform=transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor(),
transforms.Normalize(norm_mean,norm_std),
])
#构建MyDataset实例
train_data=DogCatDataset(data_dir=train_dir,transform=train_transform)
valid_data=DogCatDataset(data_dir=valid_dir,transform=valid_transform)
#构建DataLoder
train_loader=DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)
valid_loader=DataLoader(dataset=valid_data,batch_size=BATCH_SIZE,shuffle=True)transforms.Compose:由Pytorch的torchvision库提供,包括图像缩放、裁剪、翻转、对比度变换等多种类型,充分利用模型的4000万个参数记住多样的数据变化,提高模型的表达能力DogCatDataSet:定义数据集DataLoader:主要作用是一次加载一部分数据到内存,防止数据一次性加载内存或显存容量不足,参数shuffle=True的意思为打乱数据加载顺序,shuffle的英文原意为洗牌3.2.3加载网络模型net=MLP(classes=2)
net.to(device)#GPU3.2.4配置训练策略1.配置损失函数criterion=nn.CrossEntropyLoss()#选择损失函数
3.2.4配置训练策略2.优化器optimizer=optim.SGD(net.parameters(),lr=LR,momentum=0.9)#选择优化器
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=100,gamma=0.1)#设置学习率下降策略在PyTorch中,优化器(Optimizer)是用于更新神经网络参数的工具。它根据计算得到的损失函数的梯度来调整模型的参数,以最小化损失函数并改善模型的性能。即优化器是一种特定的机器学习算法,通常用于在训练深度学习模型时调整权重和偏差。是用于更新神经网络参数以最小化某个损失函数的方法。它通过不断更新模型的参数来实现这一目的。
SGD(StochasticGradientDescent):Adam(AdaptiveMomentEstimation):通过维护模型的梯度和梯度平方的一阶动量和二阶动量,来调整模型的参数。Adam的优点是计算效率高,收敛速度快,缺点是需要调整超参数SGD的基本思想是,通过梯度下降的方法,不断调整模型的参数,使模型的损失函数最小化。SGD的优点是实现简单、效率高,缺点是收敛速度慢、容易陷入局部最小值。RMSprop(RootMeanSquarePropagation):是一种改进的随机梯度下降优化器,用于优化模型的参数。基本思想是,通过维护模型的梯度平方的指数加权平均,来调整模型的参数。RMSprop的优点是收敛速度快,缺点是计算复杂度高,需要调整超参数3.2.5迭代训练数据模型损失函数优化器损失函数Loss给到模型Autograd反向传播计算梯度,并输入到优化器更新模型参数Dataloader按照BatchSize加载数据迭代训练过程04任务4应用分类网络推理4推理path_img="data/dogs-vs-cats/val/dog/dog.11250.jpg"
label_name={0:"cat",1:"dog"}
norm_mean=[0.485,0.456,0.406]
norm_std=[0.229,0.224,0.225]
valid_transform=transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor(),
transforms.Normalize(norm_mean,norm_std),
])
device=torch.device("cuda:0"iftorch.cuda.is_available()else"cpu")
img=Image.open(path_img).convert('RGB')#0~255
img=valid_transform(img)#在这里做transform,转为tensor等等
model=torch.load('train_process/best.pth')
model.eval()
outputs=model(img.to(device).unsqueeze(0))
print(outputs)
_,predicted=torch.max(outputs.data,1)
confidence=torch.softmax(outputs,1).cpu().squeeze(0).detach().numpy()
print(confidence)
label=predicted.cpu().detach().numpy()[0]
print(label)
print(f'预测结果为{label_name[label]},置信度为{confidence[label]}’)项目组织结构4推理置信度confidence=torch.softmax(outputs,1).cpu().squeeze(0).detach().numpy()神经网络outputs输出结果不够直观,我们期望输出结果为有多大概率为猫,以及有多大概率为狗
05任务5认识深度学习的主要任务5深度学习的主要任务1.回归(Regression)回归任务的目标是预测连续数值的输出。深度学习在回归问题中可以学习到数据之间的复杂关系。例如,预测房价、股票价格、气温等连续变量的值。在回归任务中,我们可以使用不同的神经网络架构(如多层感知器、卷积神经网络)和损失函数(如均方误差、平均绝对误差)来训练模型。2.分类(Classification)分类任务的目标是将输入数据分为不同的类别。深度学习在分类问题中能够有效地从大量的特征中进行学习,自动提取数据中的有用信息,从而实现高准确率的分类。例如,图像分类、语音识别、文本分类等任务都
温馨提示
- 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
- 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
- 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
- 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
- 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
- 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
- 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。
最新文档
- 合成材料制造的竞争优势考核试卷
- 糖尿病诊疗指南2024版
- 健康饮食增加身体活力的食物建议考核试卷
- 南京信息工程大学《天气学原理与方法Ⅱ》2022-2023学年第一学期期末试卷
- 炼铁废水处理技术的新进展考核试卷
- 《东北三省经济增长省际间投入产出效应测度研究》
- 《农村事实孤儿负面情绪的小组工作介入》
- 小学课间操活动方案
- 日用化学产品的用户满意度调查报告和改进方案建议考核试卷
- 基于工程认证理念的“环境工程微生物学”多元混合式教学改革及效果
- 豆绿色时尚风送货单excel模板
- 新苏教版五年级上册科学全册教学课件(2022年春整理)
- 小学体育水平一《走与游戏》教学设计
- 秋日私语(完整精确版)克莱德曼(原版)钢琴双手简谱 钢琴谱
- 办公室室内装修工程技术规范
- 盐酸安全知识培训
- 万盛关于成立医疗设备公司组建方案(参考模板)
- 科技特派员工作调研报告
- 中波广播发送系统概述
- 县疾控中心中层干部竞聘上岗实施方案
- 急性心肌梗死精美PPt完整版
评论
0/150
提交评论