资讯专栏INFORMATION COLUMN

mnist 机器学习入门笔记(一) 学习softmax模型

shengguo / 2382人阅读

摘要:首先需要添加一个新的占位符用于输入正确值计算交叉熵的表达式可以实现为现在我们知道我们需要我们的模型做什么啦,用来训练它是非常容易的。

学习softmax回归模型 一. 下载mnist数据集

新建一个download.py 代码如下:

"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gzip
import os
import tempfile

import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

运行则会自动下载mnist数据集

二. softmax模型参数

mnist.train.images(像素点):下载的mnist中的mnist.train包含60000万张简单的验证码图片, 我们将每张图片看作 28 * 28 = 784个像素点。如此我们可以用两个维度量表示整个数据集,维度一:图片序号, 维度二:像素点序号。 那么整个数据集最大像素点为[60000, 784]
mnist.train.labels(标签):接下来我们的任务是识别每张图片中的数字, 所以我们给每张图片设立一个标签, 标签值介于0~9之间(共十个值), 所以那个数据集的标签就可以做成两个维度, 维度一: 图片序号(60000), 维度二:标签值序号(10), 那么最大的标签可以表示为[60000, 10]

三. softmax数学推导

对于这里的数学推导,我就不过多说了。只能赞叹人类的智慧是伟大的, 然后简单分析下,不会数学推导的,我们可以这样来理解,分析一张图片的标签到底是数字几, 我们需要看图片中的每个像素点像数字几, 我们将每个像素点像某个标签的概率进行加权计算 w表示784个像素点中每个像素点更像数字几的加权, 然后再加上最终计算出的数字的干扰偏置量b即可。 大致理解和最终推导式

四. softmax实现

导入tensorflow
import tensorflow as tf
定义像素:
x = tf.placeholder(tf.float32, [None, 784])
x不是一个特定的值,而是一个占位符placeholder,我们在TensorFlow运行计算时输入这个值。
如此我们希望能输入任意数量的图片,所以在像素点参数中,第一个维度是无法确定的,所以我们用[None,784 ]来表示
定义w:
W = tf.Variable(tf.zeros([784,10]))
定义b:
b = tf.Variable(tf.zeros([10]))
实现softmax等式:
y = tf.nn.softmax(tf.matmul(x,W) + b)

五.训练模型

评估模型我们使用交叉熵作为成本函数。

首先需要添加一个新的占位符用于输入正确值:
y_ = tf.placeholder("float", [None,10])
计算交叉熵的表达式可以实现为:
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
现在我们知道我们需要我们的模型做什么啦,用TensorFlow来训练它是非常容易的。因为TensorFlow拥有一张描述你各个计算单元的图,它可以自动地使用反向传播算法(backpropagation algorithm)来有效地确定你的变量是如何影响你想要最小化的那个成本值的。然后,TensorFlow会用你选择的优化算法来不断地修改变量以降低成本。
这里我们使用梯度下降算法来计算梯度下降算法:
qtrain_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
我们已经设置好了我们的模型。在运行计算之前,我们需要添加一个操作来初始化我们创建的变量:
init = tf.initialize_all_variables()
启动我们的模型,并且初始化变量:
sess = tf.Session()
sess.run(init)
然后开始训练模型,这里我们让模型循环训练1000次!

for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
评估模型

tf.argmax 是一个非常有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。由于标签向量是由0,1组成,因此最大值1所在的索引位置就是类别标签,比如tf.argmax(y,1)返回的是模型对于任一输入x预测到的标签值,而 tf.argmax(y_,1) 代表正确的标签,我们可以用 tf.equal 来检测我们的预测是否真实标签匹配。
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
该函数返回单个实例等正确性,返回结果为bool值。所以我们需要把结果转化为浮点数然后再求取平均值。
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
最后我们获得正确率为:
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

文章版权归作者所有,未经允许请勿转载,若此文章存在违规行为,您可以联系管理员删除。

转载请注明本文地址:https://www.ucloud.cn/yun/40794.html

相关文章

  • 深度学习

    摘要:深度学习在过去的几年里取得了许多惊人的成果,均与息息相关。机器学习进阶笔记之一安装与入门是基于进行研发的第二代人工智能学习系统,被广泛用于语音识别或图像识别等多项机器深度学习领域。零基础入门深度学习长短时记忆网络。 多图|入门必看:万字长文带你轻松了解LSTM全貌 作者 | Edwin Chen编译 | AI100第一次接触长短期记忆神经网络(LSTM)时,我惊呆了。原来,LSTM是神...

    Vultr 评论0 收藏0
  • tensorflow入门与实战

    TensorFlow是一种流行的机器学习库,它提供了许多工具和技术,使得机器学习和深度学习变得更加容易。在这篇文章中,我们将介绍TensorFlow的入门和实战技术,帮助您开始使用这个强大的工具。 首先,让我们来了解一下TensorFlow的基础知识。TensorFlow是一个用于数值计算的开源软件库,它使用数据流图来表示数学运算。数据流图是一种图形表示法,它将数学运算表示为节点,将数据表示为边...

    _Zhao 评论0 收藏400

发表评论

0条评论

shengguo

|高级讲师

TA的文章

阅读更多
最新活动
阅读需要支付1元查看
<