深度学习与图像处理实战-模型评估及模型调优_第1页
深度学习与图像处理实战-模型评估及模型调优_第2页
深度学习与图像处理实战-模型评估及模型调优_第3页
深度学习与图像处理实战-模型评估及模型调优_第4页
深度学习与图像处理实战-模型评估及模型调优_第5页
已阅读5页,还剩45页未读 继续免费阅读

下载本文档

版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领

文档简介

模型评估及模型调优深度学习与图像处理实战知识要点5.1评估指标5.1.1准确率5.1.2查准率5.1.3召回率5.1.4F1值5.1.5ROC与AUC5.2数据集处理5.2.1数据集划分5.2.2数据增强5.3模型调优5.3.1回调函数5.3.2超参数调整5.3.3模型结构调整目录5.1评估指标当训练好一个模型之后,需要对模型进行评估,评估可以反映出模型的各种指标优劣。在了解评估指标之前,必须先了解什么是混淆矩阵(CoufusionMatrix)。混淆矩阵是评估模型结果的指标,属于模型评估的一部分。此外,混淆矩阵多用于判断分类器(Classifier)的优劣,适用于分类型的数据模型,如分类树(ClassificationTree)、逻辑回归(LogisticRegression)、线性判别分析(LinearDiscriminantAnalysis)等方法。5.1评估指标所谓混淆矩阵就是根据分类时预测结果与实际情况的对比做出的表格,如表5-1所示。其中Positive代表正类、Negative代表负类、Predicted代表预测结果、Actual代表实际情况。ConfusionMatrixPredictedPositiveNegativeActualPositiveTPFNNegativeFPTN表5-1混淆矩阵表5-1中的指标解释如下。TP表示TruePositive,即真正:将正类预测为正类的数量。FP表示FalsePositive,即假正:将负类预测为正类的数量,可以称为误报率。TN表示TrueNegative,即真负:将负类预测为负类的数量。FN表示FalseNegative,即假负:将正类预测为负类的数量,可以称为漏报率。5.1评估指标5.1.1准确率最简单、最常使用的评估指标就是准确率(Accuracy),它可以从某种意义上判断出一个模型是否有效。true=[0,1,0,1,0,1,0,1,1,1]pred=[0,1,0,1,0,0,0,1,0,1]accuracy=0forindex,valueinenumerate(pred):ifvalue==true[index]:accuracy+=1print('accuarcy:%.2f%%'%(accuracy/len(true)*100))使用Python来计算准确率的代码如下。在正、负样本不均衡的情况下,准确率作为评估指标是不合适的。5.1评估指标5.1.2查准率

查准率(Precision)又叫精确率,它表示被正确检索的样本数与被检索到的样本总数之比,简单地说查准率是识别正确的结果在所识别出的结果中所占的比例。true=[0,1,0,1,0,1,0,1,1,1]pred=[0,1,0,1,0,0,0,1,1]precision=0forindex,valueinenumerate(pred):ifvalue==true[index]:precision+=1print('precision:%.2f%%'%(precision/len(pred)*100))使用Python计算查准率的代码如下。5.1评估指标5.1.3召回率

召回率(Recall)又叫查全率,它表示被正确检索的样本数与应当被检索到的样本数之比。从概念上看,查准率和召回率是一对相互矛盾的指标,一般而言,查准率高时,召回率往往偏低;召回率高时,查准率往往偏低。true=[0,1,0,1,0,1,0,1,1,1]pred=[0,1,0,1,0,0,0,1,1,1]#应当被检索到的样本数index_1_num=str(true).count("1")recall=0forindex,valueinenumerate(pred):使用Python计算召z回率的代码如下。#被正确检索的样本数ifvalue==true[index]andvalue==1:recall+=1print('recall:%.2f%%'%(recall/index_1_num*100))5.1评估指标5.1.3召回率

召回率(Recall)又叫查全率,它表示被正确检索的样本数与应当被检索到的样本数之比。从概念上看,查准率和召回率是一对相互矛盾的指标,一般而言,查准率高时,召回率往往偏低;召回率高时,查准率往往偏低。true=[0,1,0,1,0,1,0,1,1,1]pred=[0,1,0,1,0,0,0,1,1,1]#应当被检索到的样本数index_1_num=str(true).count("1")recall=0forindex,valueinenumerate(pred):使用Python计算召z回率的代码如下。#被正确检索的样本数ifvalue==true[index]andvalue==1:recall+=1print('recall:%.2f%%'%(recall/index_1_num*100))5.1评估指标5.1.3召回率

召回率(Recall)又叫查全率,它表示被正确检索的样本数与应当被检索到的样本数之比。从概念上看,查准率和召回率是一对相互矛盾的指标,一般而言,查准率高时,召回率往往偏低;召回率高时,查准率往往偏低。true=[0,1,0,1,0,1,0,1,1,1]pred=[0,1,0,1,0,0,0,1,1,1]#应当被检索到的样本数index_1_num=str(true).count("1")recall=0forindex,valueinenumerate(pred):使用Python计算召z回率的代码如下。#被正确检索的样本数ifvalue==true[index]andvalue==1:recall+=1print('recall:%.2f%%'%(recall/index_1_num*100))5.1评估指标5.1.4F1值

F1值(F1Score)是统计学中用来衡量二分类模型精度的一种指标。它同时兼顾了分类模型的查准率和召回率。F1值可以看作模型查准率和召回率的一种调和平均,它的最大值是1,最小值是0。true=[0,1,0,1,0,1,0,1,1,1]pred=[0,1,0,1,0,0,0,1,1,1]#应当被检索到的样本数index_1_num=str(true).count("1")recall=0precision=0forindex,valueinenumerate(pred):使用Python计算F1值的代码如下:ifvalue==true[index]andvalue==1:recall+=1ifvalue==true[index]:precision+=1precision=precision/len(pred)*100recall=recall/index_1_num*100print('f1:%.2f%%'%((2*precision*recall)/(precision+recall)))5.1评估指标5.1.5ROC与AUC

受试者工作特征曲线(ReceiverOperatingCharacteristicCurve,ROC)源于军事领域,而后在医学领域应用甚广,其名称也正是来自医学领域。

ROC的x轴表示假阳性率(FalsePositiveRate,FPR),y轴表示真阳性率(TruePositiveRate,TPR),也就是召回率。ROC越陡,表示模型效果越好。

曲线下面积(AreaUnderCurve,AUC)表示ROC与坐标轴围成的面积,显然这个面积的数值不会大于1。又由于ROC一般都处于y=x这条直线的上方,因此AUC的取值范围为0.5~1。AUC越大表示模型效果越好。5.1评估指标#导入要用的库fromsklearn.metricsimportroc_curvefromsklearn.metricsimportroc_auc_scoreasAUCimportmatplot.pyplotasplttrue=[0,1,0,1,0,1,0,1,1,1]pred=[0,1,0,1,0,0,0,1,1,1]#利用roc_curve函数获得的FPR和recall都是一系列值FPR,recall,thresholds=roc_curve(true,pred)#计算AUCarea=AUC(test_prob["1"],test_prob["0"])使用Python绘制ROC和AUC的代码如下。#画图plt.figure()plt.plot(FPR,recall,label='ROCcurve(Auc=%0.2f)'%area)plt.xlabel('FalsePositiveRate')plt.ylabel('TruePositiveRate')plt.legend(loc="lowerright")plt.show()5.2数据集处理5.2.1数据集划分

对于需要解决的问题的样本数据,在建立模型的过程中,数据一般会被划分为以下几个部分。训练集(TrainSet):用训练集对算法或模型进行训练。验证集(ValidationSet):又称简单交叉验证集(Hold-outCrossValidationSet),利用验证集进行交叉验证,即评估几种算法或模型中哪一个最好,从而选择出最好的模型。—测试集(TestSet):最后利用测试集对模型进行测试,获取模型运行的无偏估计(对学习方法进行评估)。在小数据量的时代,如100、1000、10000的数据量,可以将数据集按照以下比例进行划分。

无验证集的情况:70%/30%。—有验证集的情况:60%/20%/20%。5.2数据集处理

测试集的主要目的是评估模型的效果,如在单个分类器中,在百万级别的数据中选择其中1000条数据足以评估单个模型的效果。数据量较大的情况可以按照以下比例进行划分。00万数据量:98%/1%/1%。超百万数据量:99.5%/0.25%/0.25%(或者99.5%/0.4%/0.1%)。接下来,使用程序进行数据集的划分。假设某一文件存放的数据如图5-1所示,表明了图片路径与类别。图片名称中包含了图片类别(1与0)。图5-1数据展示5.2数据集处理

使用Python对数据集进行划分,代码如下。importrandom#分为训练集与测试集split=0.8withopen('train.txt')asf:txt_data=f.readlines()#随机打乱random.shuffle(txt_data)train_len=int(len(txt_data)*split)#划分数据集train_txt=txt_data[:train_len]test_txt=txt_data[train_len:]print(train_txt)print(test_txt)

然后,以8∶2的比例划分训练集以及测试集,这里只是提供了一种参考的写法,还可以通过修改split参数进行其他比例的划分,或者增加新的参数划分为训练集、测试集、验证集等。5.2数据集处理5.2.2数据增强

3只小狗在我们看来是同一只小狗,但是由于小狗的位置被平移了,神经网络认为这是不同的小狗。

在训练的时候,经常会遇到数据不足的情况。例如,在一个任务中,数据集只包含几百张图片,而通过这几百张图片训练出来的模型很容易造成过拟合。然而神经网络内部的参数是巨大的,大量的参数要求人们提供更多的数据以达到更好的预测结果。5.2数据集处理

可以很容易地联想到扩充数据最简单的方法,就是通过平移、旋转、镜像等多种方式,对已有的图片进行数据扩充,也就是本小节要说到的数据增强。

一个神经网络如果能够将一个放在不同地方、不同光线、不同背景中的物体识别成功,就称这个神经网络具有不变性。这种不变性具体来说就是对物体的位移、视角、光线、大小、角度、亮度等一种或多种变换的不变性。而为了完善这种不变性,数据中存在这种图片变换就显得尤为重要。所以,数据增强并不只是在数据量少的情况下有用,在数据量较多的情况下也可以使模型得到更好的效果。5.2数据集处理

两种不同品牌的鼠标预测图片一般来说,深度学习算法会去寻找能够区分两个类别的最明显的特征。在数据集里两个品牌的鼠标的头部朝向就是最明显的特征。所以,应该怎么去避免类似的事情发生呢?这就需要减少数据集中不相关的特征。对于上述的鼠标分类数据集来说,一个简单的方案就是增加多种头部朝向的的图片。然后重新训练模型,这样会得到性能更好的模型。5.2数据集处理使用Python实现图片翻转的代码如下。01OPTION图片翻转importcv2path=r'cat.jpg'img=cv2.imread(path)#水平翻转img_hor=cv2.flip(img,1)#垂直翻转img_ver=cv2.flip(img,0)#显示图片cv2.imshow('img',img)cv2.imshow('img_hor',img_hor)cv2.imshow('img_ver',img_ver)cv2.waitKey(0)cv2.destroyWindow()程序运行结果如图所示,从左到右分别是原图、水平翻转、垂直翻转的结果。5.2数据集处理使用Python实现图片旋转的代码如下。02OPTION图片旋转importcv2#图片旋转defrotation(img,angle):h,w=img.shape[:2]#得到中心点center=(w//2,h//2)#获得围绕中心点旋转某个角度的矩阵M=cv2.getRotationMatrix2D(center,angle,1.0)#旋转图片rotated=cv2.warpAffine(img,M,(w,h))returnrotatedpath=r'D:\GPU_SY\Opencv\opencv_image\cat.jpg'img=cv2.imread(path)#显示图片cv2.imshow('img',img)cv2.imshow('rotation_90_img',rotation_90_img)cv2.imshow('rotation_45_img',rotation_45_img)cv2.waitKey(0)cv2.destroyWindow()原图、旋转90°、旋转45°的结果5.2数据集处理在一般的神经网络模型中,都会对输入的图片尺寸进行预设,如ImageNet的224×224、YOLOv3的416×416等。由于自然界中的图片并不都是这种比例的,因此需要对图片进行缩放。03OPTION等比缩放fromPILimportImageimportnumpyasnpdefletterbox_image(image,size):#获得图片宽、高iw,ih=image.size#获得要缩放的图片尺寸w,h=size#得到小的边长的比值scale=min(w/iw,h/ih)#得到新的边长

使用Python实现等比缩放的代码如下。nw=int(iw*scale)nh=int(ih*scale)#缩放图片image=image.resize((nw,nh),Image.BICUBIC)#新建一张“画布”,以RGB颜色(128,128,128)填充new_image=Image.new('RGB',size,(128,128,128))#将缩放后的图片放到“画布”中,并居中new_image.paste(image,((w-nw)//2,(h-nh)//2))returnnp.array(new_image)

5.2数据集处理

path=r'D:\GPU_SY\Opencv\opencv_image\cat.jpg'img=cv2.imread(path)let_img=letterbox_image(Image.fromarray(img),(224,224))resize_img=cv2.resize(img,(224,224))#显示图片cv2.imshow('img',img)cv2.imshow('resize_img',resize_img)cv2.imshow('let_img',let_img)cv2.waitKey(0)cv2.destroyWindow()等比缩放5.2数据集处理位移只涉及沿x轴或y轴方向(或两者)移动图片。在Python中可以引入随机数,并且通过这个随机数将使用了等比缩放的图片放置到“画布”的随机位置上,形成位移。使用Python实现图片位移的代码如下。04OPTION位移defrand(a=0,b=1):returnnp.random.rand()*(b-a)+adefletterbox_image(image,size):#获得图片宽、高iw,ih=image.size#获得要缩放的图片尺寸w,h=size#得到小的边长的比值scale=min(w/iw,h/ih)#将比值进行随机缩小scale=scale*rand(.25,1)

nw=int(iw*scale)nh=int(ih*scale)image=image.resize((nw,nh),Image.BICUBIC)#将图片放到“画布”的随机位置上dx=int(rand(0,w-nw))dy=int(rand(0,h-nh))new_image=Image.new('RGB',(w,h),(128,128,128))new_image.paste(image,(dx,dy))returnnp.array(new_image)5.2数据集处理颜色变换类的数据增强有很多种,如噪声、模糊、颜色扰动、填充、擦除等。所谓颜色扰动,就是在某一颜色空间内通过增加或减少某些颜色分量来进行数据增强。使用Python实现颜色增强的代码如下。05OPTION位移defhsv_(image,hue=.1,sat=1.5,val=1.5):hue=rand(-hue,hue)sat=rand(1,sat)ifrand()<.5else1/rand(1,sat)val=rand(1,val)ifrand()<.5else1/rand(1,val)#将颜色空间转换到HSVx=cv2.cvtColor(image,cv2.COLOR_BGR2HSV)/255.#使用HSV颜色空间进行颜色增强x[...,0]+=huex[...,0][x[...,0]>1]-=1x[...,0][x[...,0]<0]+=1x[...,1]*=satx[...,2]*=val5.2数据集处理

x[x>1]=1x[x<0]=0#转换成原本的颜色空间image_data=np.array(x*255.,dtype='uint8')image_data=cv2.cvtColor(image_data,cv2.COLOR_HSV2BGR)returnimage_datapath=r'D:\GPU_SY\Opencv\opencv_image\cat.jpg'img=cv2.imread(path)hsv_img=hsv_(img)#显示图片cv2.imshow('img',img)cv2.imshow('hsv_img',hsv_img)cv2.waitKey(0)cv2.destroyWindow()图片颜色增强5.2数据集处理使用Python实现多种组合数据增强的代码如下。06OPTION使用多种组合进行图片数据增强nh=int(ih*scale)image=Image.fromarray(image)image=image.resize((nw,nh),Image.BICUBIC)#随机位移dx=int(rand(0,w-nw))dy=int(rand(0,h-nh))new_image=Image.new('RGB',(w,h),(128,128,128))new_image.paste(image,(dx,dy))image=new_image#随机翻转flip=rand()<.5ifflip:image=image.transpose(Image.FLIP_LEFT_RIGHT)#随机HSV增强hue=rand(-hue,hue)sat=rand(1,sat)ifrand()<.5else1/rand(1,sat)val=rand(1,val)ifrand()<.5else1/rand(1,val)frommatplotlib.colorsimportrgb_to_hsvdefget_random_data(image,input_shape,hue=.1,sat=1.5,val=1.5):#获得图片宽、高ih,iw=image.shape[:2]#获得要缩放的图片尺寸w,h=input_shape#得到小的边长的比值scale=min(w/iw,h/ih)#将比值进行随机缩小scale=scale*rand(.25,1)nw=int(iw*scale)5.2数据集处理path=r'D:\GPU_SY\Opencv\opencv_image\cat.jpg'img=cv2.imread(path)random_img=get_random_data(img,(224,224))#显示图片cv2.imshow('img',img)cv2.imshow('rotation_45_img',random_img)cv2.waitKey(0)cv2.destroyWindow()x=rgb_to_hsv(np.array(image)/255.)x[...,0]+=huex[...,0][x[...,0]>1]-=1x[...,0][x[...,0]<0]+=1x[...,1]*=satx[...,2]*=valx[x>1]=1x[x<0]=0#转换成原本的颜色空间image_data=np.array(x*255.,dtype='uint8')image_data=cv2.cvtColor(image_data,cv2.COLOR_HSV2BGR)returnimage_data通过引入的多个随机数,可以使图片完成平移、色彩变换、缩放等操作,并且这些操作都是随机的,大大增加了可用的数据量。5.2数据集处理

数据增强在很大程度上解决了数据不足的问题,把数据送入模型训练之前进行增强。但是这里有两种做法:一种做法是事先执行所有的转换,这实质上会增加数据集的大小;另一种做法是在把数据送入模型之前,小批量地执行这些转换。第一种做法叫作线下增强,这种做法适用于较小的数据集。线下增强后会增加一定倍数的数据,这个倍数取决于转换的倍数。第二种做法叫作线上增强,这种做法适用于较大的数据集。因为计算机可能无法承受爆炸性增加的数据,所以在训练的同时使用CPU加载数据并进行数据增强是一种较好的做法。5.3模型调优

模型调优归根结底还是解决欠拟合(Underfitting)和过拟合(Overfitting)的问题。当涉及机器学习算法时,往往会面临过拟合和欠拟合的问题。欠拟合意味着对算法简化过多,以至于很难映射到数据上;过拟合则意味着算法过于复杂,它完美地适应了训练数据,但是很难普及。欠拟合、拟合、过拟合的效果5.3模型调优什么因素会导致过拟合和欠拟合的情况发生?首先是数据的分布是否足够均匀。还有一个因素是模型是否复杂度过高,把学习进行地太过彻底,将样本数据的所有特征几乎都学习进去了。这时模型学到了数据中过多的局部特征,噪声带来的过多的假特征造成模型的“泛化性”和识别准确率几乎达到最低点,于是用训练好的模型预测新的样本的时候会发现模型效果很差。解决过拟合要从以下两个方面入手。首先是限制模型的学习,使模型在学习特征时忽略部分特征,这样就可以降低模型学到局部特征和错误特征的概率,使得识别准确率得到优化;其次是在选择数据的时候要尽可能全面,并且符合实际情况。5.3模型调优通常根据不同的训练阶段,可分为3种调优方法。

开始训练前:可以预先分析数据集的特征,选择合适的函数以及优化器,适当地进行数据增强等。

开始训练:在这个阶段,也有很多调优方法。例如,动态调整学习率、自动保存最优模型以及提前停止训练等。训练结束后:通过对模型的评估,了解此时模型处于什么状态,可以调整模型结构或者调整训练参数等。

1

2

35.3模型调优5.3.1回调函数

由于在训练的过程中,无法对一些参数进行修改,因此一些深度学习框架往往会留出一些回调函数来对训练过程中的参数进行调整。下面将使用tf.keras中的回调函数解决训练过程中的模型保存、学习率调整和终止训练的问题。importtensorflowfilepath='model.h5'tensorflow.keras.callbacks.ModelCheckpoint(filepath,monitor='val_loss',verbose=0,save_best_only=False,save_weights_only=False,mode='auto',period=1)01OPTION自动保存模型使用Python实现自动保存模型的代码如下。5.3模型调优表5-2ModelCheckpoint函数的参数及其说明参数说明filepath模型保存的路径monitor需要监视的值verbose是否显示信息save_best_only只保存监视值最好的模型,默认为Falsesave_weights_only是否只保存权重(不保存结构)mode评判最佳模型的标准,如min、maxperiod每几个周期保存一次,如果设置保存最优模型则该参数不起作用,因为每个周期结束都会保存5.3模型调优tensorflow.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',factor=0.1,patience=10,verbose=0,mode='auto',epsilon=0.0001,cooldown=0,min_lr=0)02OPTION学习率动态调整使用Python实现学习率动态调整的代码如下。参数说明monitor需要监视的值factor学习率(1r)衰减因子,新学习率newlr=lr*factorpatience当过去patience个周期,被监视的值还没往更好的方向走时(一般来说是变小),则触发学习率衰减verbose是否显示信息mode评判最佳模型的标准,如min、maxepsilon阈值,用来确定是否进入检测值的“平原区”cooldown学习率下降后,会经过cooldown个训练轮数才重新进行正常操作min_lr学习率的下限表5-3

ReduceLROnPlateau函数的参数及其说明5.3模型调优tensorflow.keras.callbacks.EarlyStopping(monitor='val_loss',patience=0,verbose=0,mode='auto')03OPTION自动终止训练使用Python实现自动终止训练的代码如下。表5-4EarlyStopping函数的参数及其说明参数说明monitor需要监视的值patience当发现监视值相比上一个训练轮数没有下降时,经过patience个训练轮数后停止训练verbose是否显示信息mode评判标准,如min、max5.3模型调优fromtensorflow.kerasimportcallbackssave=callbacks.ModelCheckpoint('logs/epoch{epoch:03d}-val_loss{val_loss:.3f}.h5',monitor='val_loss',save_best_only=True,period=1)low_lr=callbacks.ReduceLROnPlateau(monitor='val_loss',factor=0.2,patience=5,min_lr=1e-6,verbose=1)eary_stop=callbacks.EarlyStopping(monitor='val_loss',patience=15,verbose=1,mode='auto')model.fit(x,y,batch_size=64,epochs=500,callbacks=[save,low_lr,eary_stop])使用Python组合回调函数并进行训练的代码如下。5.3模型调优5.3.2超参数调整

人工智能和机器学习领域的知名学者吴恩达很形象地使用动物和食物来命名训练一个模型的两种方法:熊猫法与鱼子酱法。01OPTION熊猫法先初始化一组超参数,然后每训练一段时间(如一天,表示为D1)就需要查看进展,观察其是否按照预想的方向发展,再进行一定的微调,接着继续训练,持续观察。如果发现偏离了方向,就需要立即对超参数进行调整。就这样,持续定期观察并进行调整,直到最后达到训练目标。5.3模型调优02OPTION鱼子酱法如果计算资源足够丰富,可以同时训练多个模型,就可以用鱼子酱法——使用多种超参数组合的模型进行训练,在训练结束后通过评估指标判断哪种模型效果最好。(2)迭代次数迭代次数也叫训练轮数,模型收敛即可停止迭代。一般可采用验证集作为停止迭代的条件。如果连续几轮模型损失都没有相应减少,则停止迭代。5.3模型调优

超参数是控制模型结构、功能、效率等的“调节旋钮”。通过“调节旋钮”,可以控制神经网络模型训练的基本方向。一般而言,常用的超参数有以下几种。(1)学习率学习率是最影响性能的超参数之一,如果只能调整一个超参数,那么最好的选择就是它。相对于其他超参数,学习率以一种更加复杂的方式控制着模型的有效容量。模型的有效容量是指模型拟合各种函数的能力。当学习率最优时,模型的有效容量最大。12(3)批大小批大小也叫batch_size,对于小数据量的模型,可以进行全量训练,这样能更准确地朝着极值所在的方向更新。但是对于大数据量模型,进行全量训练将导致内存溢出,因此需要选择一个较小的批大小。35.3模型调优5.1.4F1值

F1值(F1Score)是统计学中用来衡量二分类模型精度的一种指标。它同时兼顾了分类模型的查准率和召回率。F1值可以看作模型查准率和召回率的一种调和平均,它的最大值是1,最小值是0。true=[0,1,0,1,0,1,0,1,1,1]pred=[0,1,0,1,0,0,0,1,1,1]#应当被检索到的样本数index_1_num=str(true).count("1")recall=0precision=0forindex,valueinenumerate(pred):使用Python计算F1值的代码如下:ifvalue==true[index]andvalue==1:recall+=1ifvalue==true[index]:precision+=1precision=precision/len(pred)*100recall=recall/index_1_num*100print('f1:%.2f%%'%((2*precision*recall)/(precision+recall)))5.3模型调优5.3.3模型结构调整

为了演示微调模型结构,本小节重新搭建了一个简单的MNIST数据集分类的模型,具体代码如下所示。importnumpyasnpimportmatplotlib.pyplotaspltfromtensorflow.keras.modelsimportSequentialfromtensorflow.keras.layersimportDense,Activationfromtensorflow.keras.optimizersimportSGDfromtensorflow.kerasimportutilsfromtensorflow.keras.datasetsimportmnistimportmatplotlib.pyplotasplt#加载数据集,并划分为训练集和测试集两个部分(x_train,y_train),(x_test,y_test)=mnist.load_data()plt.imshow(x_train[0])print(x_train[0])#归一化y_train=utils.to_categorical(y_train)y_test=utils.to_categorical(y_test)x_train=x_train.reshape(-1,784)/255.x_test=x_test.reshape(-1,784)/255.#模型结构model=Sequential([Dense(units=256,input_dim=784,activation='relu'),Dense(units=128,activation='relu'),Dense(units=10,activation='softmax')])5.3模型调优sgd=SGD(lr=0.2)pile(optimizer=sgd,loss='mse',metrics=['acc'])model.fit(x_train,y_train,batch_size=32,epochs=10)#测试loss,acc=model.evaluate(x_test,y_test)print(loss,acc)

程序运行结果如图所示,这时候的训练准确率是97.04%,测试准确率是96.55%。3层全连接层运行结果5.3模型调优

由于模型结构比较简单,因此想要提升准确率,直接的办法是提升模型的复杂度。接下来修改模型结构,新增几层全连接层,代码如下。model=Sequential([Dense(units=256,input_dim=784,activation='relu'),Dense(units=192,activation='relu'),

温馨提示

  • 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
  • 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
  • 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
  • 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
  • 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
  • 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
  • 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。

评论

0/150

提交评论