版权说明:本文档由用户提供并上传,收益归属内容提供方,若内容存在侵权,请进行举报或认领
文档简介
张明副教授人工智能原理:基于Python语言和TensorFlow第四章:TensorFlow运作方式数据的准备和下载图表构建与推理损失与训练状态检查与可视化状态检查与可视化评估模型评估图表的构建与输出4.1:数据的准备和下载以MNIST为例。4.1:数据的准备和下载1.提供Python源代码用于自动下载和安装这个数据集。********************************************************************************importtensorflow.examples.tutorials.mnist.input_datamnist=input_data.read_data_sets("MNIST_data/",one_hot=True)********************************************************************************2.确认并解压缩文件********************************************************************************data_sets=input_data.read_data_sets(FLAGS.train_dir,FLAGS.fake_data)********************************************************************************第四章:TensorFlow运作方式数据的准备和下载图表构建与推理损失与训练状态检查与可视化状态检查与可视化评估模型评估图表的构建与输出4.2:图表构建与推理1图表构建2推理运行mnist.py文件了,经过三阶段的模式函数操作:inference(),loss()和training(),图表就构建完成了。(1)inference()——尽可能地构建好图表,满足促使神经网络向前反馈并做出预测的要求。(2)loss()——往inference图表中添加生成损失(loss)所需要的操作(ops)。(3)training()——往损失图表中添加计算并应用梯度(gradients)所需的操作。在run_training()这个函数的一开始,是一个Python语言中的with命令,这个命令表明所有已经构建的操作都要与默认的tf.Graph全局实例关联起来,代码如下所示。withtf.Graph().as_default():tf.Graph实例是一系列可以作为整体执行的操作。TensorFlow的大部分场景只需要依赖默认图表一个实例即可。4.2:图表构建与推理1图表构建2推理inference()函数会尽可能地构建图表,做到返回包含了预测结果(outputprediction)的张量。4.3:损失与训练1损失2训练1:损失loss()函数通过添加所需的损失操作,进一步构建图表。首先,labels_placeholer中的值将被编码为一个含有1-hotvalues的张量。例如,如果类标识符为3,那么该值就会被转换为[0,0,0,1,0,0,0,0,0,0],代码如下所示。**************************************************************************************batch_size=tf.size(labels)labels=tf.expand_dims(labels,1)indices=tf.expand_dims(tf.range(0,batch_size,1),1)concated=tf.concat(1,[indices,labels])onehot_labels=tf.sparse_to_dense(concated,tf.pack([batch_size,NUM_CLASSES]),1.0,0.0)**************************************************************************************添加一个tf.nn.softmax_cross_entropy_with_logits操作,用来比较inference()函数与1-hot标签所输出的logits张量,代码如下所示。**************************************************************************************cross_entropy=tf.nn.softmax_cross_entropy_with_logits(logits,onehot_labels,name='xentropy')**************************************************************************************使用tf.reduce_mean函数,计算batch维度(第一维度)下交叉熵(crossentropy)的平均值,将该值作为总损失,代码如下所示**************************************************************************************loss=tf.reduce_mean(cross_entropy,name='xentropy_mean')**************************************************************************************最后,返回包含了损失值的张量。4.3:损失与训练1损失2训练2:训练TensorFlow有大量内置的优化算法。这个例子中,我们用梯度下降(gradientdescent)法让交叉熵下降,步长为0.01。函数从loss()函数中获取损失张量,将其交给tf.scalar_summary,后者在SummaryWriter配合使用时,可以向事件文件(eventsfile)中生成汇总值(summaryvalues)。**************************************************************************************tf.scalar_summary(,loss)**************************************************************************************返回的train_step操作对象,在运行时会使用梯度下降来更新参数**************************************************************************************foriinrange(1000):batch=mnist.train.next_batch(50)train_step.run(feed_dict={x:batch[0],y_:batch[1]})**************************************************************************************实例化一个tf.train.GradientDescentOptimizer,负责按照所要求的学习效率(learningrate)应用梯度下降法,代码如下所示。**************************************************************************************optimizer=tf.train.GradientDescentOptimizer(FLAGS.learning_rate)**************************************************************************************train_op操作**************************************************************************************global_step=tf.Variable(0,name='global_step',trainable=False)train_op=optimizer.minimize(loss,global_step=global_step)**************************************************************************************程序返回包含了训练操作(trainingop)输出结果的张量。4.4:状态检查与可视化1
状态检查2状态可视化在运行sess.run()函数时,要在代码中明确其需要获取的两个值:[train_op,loss],代码如下所示。**************************************************************************************forstepinxrange(FLAGS.max_steps):feed_dict=fill_feed_dict(data_sets.train,images_placeholder,labels_placeholder)_,loss_value=sess.run([train_op,loss],feed_dict=feed_dict)**************************************************************************************假设训练一切正常,没有出现NaN,训练循环会每隔100个训练步骤打印一行简单的状态文本,告知用户当前的训练状态,代码如下所示。**************************************************************************************ifstep%100==0:print'Step%d:loss=%.2f(%.3fsec)'%(step,loss_value,duration)**************************************************************************************4.4:状态检查与可视化1
状态检查2状态可视化为了释放TensorBoard所使用的事件文件(eventsfile),所有的即时数据都要在图表构建阶段合并至一个操作(op)中,代码如下所示。**************************************************************************************summary_op=tf.merge_all_summaries()**************************************************************************************为实例化一个tf.train.SummaryWriter,用于写入包含了图表本身和即时数据具体值的事件文件,代码如下所示。**************************************************************************************summary_writer=tf.train.SummaryWriter(FLAGS.train_dir,graph_def=sess.graph_def)**************************************************************************************每次运行summary_op时,都会往事件文件中写入最新的即时数据,函数的输出会传入事件文件读写器(writer)的add_summary()函数,代码如下所示。**************************************************************************************summary_str=sess.run(summary_op,feed_dict=feed_dict)summary_writer.add_summary(summary_str,step)**************************************************************************************事件文件写入完毕之后,可以在训练文件夹中打开一个TensorBoard,查看即时数据的情况,如图所示。TensorBoard中汇总数据(summarydata)的大体生命周期如下。首先,创建你想汇总数据的TensorFlow图,然后选择你想在哪个节点进行汇总(summary)操作。然后,你可以执行合并命令,它会依据特定步骤将所有数据生成一个序列化的Summaryprotobuf对象。最后,为了将汇总数据写入磁盘,需要将汇总的protobuf对象传递给tf.train.SummaryWriter。4.5:评估模型我们的模型性能如何呢?4.6:评估图表的构建与输出1评估图表的构建2评估图表的输出调用get_data(train=False)函数,抓取测试数据集。**************************************************************************************test_all_images,test_all_labels=get_data(train=False)**************************************************************************************在进入训练循环之前,我们应该先调用mnist.py文件中的evaluation函数,传入的logits和标签参数要与loss函数的一致。这样做是为了先构建Eval操作。**************************************************************************************eval_correct=mnist.evaluation(logits,labels_placeholder)**************************************************************************************把K的值设置为1,也就是只有在预测是真的标签时,才判定它是正确的把K的值设置为1,也就是只有在预测是真的标签时,才判定它是正确的。************************************************************************************eval_correct=tf.nn.in_top_k(logits,labels,1)************************************************************************************4.6:评估图表的构建与输出1评估图表的构建2评估图表的输出用给定的数据集评估模型**********************************************************************
温馨提示
- 1. 本站所有资源如无特殊说明,都需要本地电脑安装OFFICE2007和PDF阅读器。图纸软件为CAD,CAXA,PROE,UG,SolidWorks等.压缩文件请下载最新的WinRAR软件解压。
- 2. 本站的文档不包含任何第三方提供的附件图纸等,如果需要附件,请联系上传者。文件的所有权益归上传用户所有。
- 3. 本站RAR压缩包中若带图纸,网页内容里面会有图纸预览,若没有图纸预览就没有图纸。
- 4. 未经权益所有人同意不得将文件中的内容挪作商业或盈利用途。
- 5. 人人文库网仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对用户上传分享的文档内容本身不做任何修改或编辑,并不能对任何下载内容负责。
- 6. 下载文件中如有侵权或不适当内容,请与我们联系,我们立即纠正。
- 7. 本站不保证下载资源的准确性、安全性和完整性, 同时也不承担用户因使用这些下载资源对自己和他人造成任何形式的伤害或损失。
最新文档
- 2022-2027年中国贝伐珠单抗行业市场全景评估及发展战略规划报告
- 箱变基础预制施工方案
- 保险服务品牌转让居间服务
- 4S店装修贷款协议
- 杭州市乒乓球馆租赁合同
- 湖北医药学院药护学院《新媒体运营与管理》2023-2024学年第一学期期末试卷
- 2025年度铁路轨道施工安全监管合同范本2篇
- 2025年度食品加工合同:原材料供应商与食品加工厂(2025版)3篇
- 2025年度融资租赁合同的租金支付方式3篇
- 2025年度运动服饰品牌授权与销售合同3篇
- 【大学课件】微型计算机系统
- (主城一诊)重庆市2025年高2025届高三学业质量调研抽测 (第一次)英语试卷(含答案)
- 2025关于标准房屋装修合同的范本
- 中国建材集团有限公司招聘笔试冲刺题2025
- 2024年马克思主义基本原理知识竞赛试题70题(附答案)
- 2024年湖北省中考物理真题含解析
- 荔枝病虫害防治技术规程
- 资金借贷还款协议
- 《实验性研究》课件
- 中国革命战争的战略问题(全文)
- 《阻燃材料与技术》课件全套 颜龙 第1讲 绪论 -第11讲 阻燃性能测试方法及分析技术
评论
0/150
提交评论