本文是一个记录,用于记录自己如何快速上手 TF 的经历。
PS: 本文中的大部分概念基于莫烦的 MorvanZhou/Tensorflow-Tutorial
一、什么是 session
seesion 就是 TF 中的启动器,我们需要通过调用 session 来将数据传入整个计算图中,得到想要的结果,没有 session 就没有结果。
比如我们要计算:
$\begin{bmatrix}
- & 1. & 1. \
- & 1. & 1. \
\end{bmatrix} +
\begin{bmatrix} - & 2. & 3. \
- & 5. & 6. \
\end{bmatrix} =
\begin{bmatrix} - & 3. & 4. \
- & 6. & 7. \
\end{bmatrix}$
1 | import tensorflow as tf |
output:
Tensor(“Add_2:0”, shape=(2, 3), dtype=float32)
为什么没有得到我们想要的结果?
因为在我们写完上述代码之后,我们完成的是 “定义了一个可以运行的计算图”。如果我们需要启动这个计算图,我们还需要使用 tf.Session()。
1 | import tensorflow as tf |
二、什么是 variable
variable 就是变量的含义,通俗一点说就是“模型中的参数”,比如我们要训练一个线性模型 $y = k x + b$,那么 $x$ 和 $b$ 就是我们的 variable,值得注意的是,一旦定义了变量,在运行之前,要对变量进行初始化。
接下来,我们会定义一个变量 $x$,初始时候 $x=0$,之后每次运行都让 $x = x + 1$
1 | x = tf.Variable(0, name="x") |
三、什么是 placeholder
placeholder 翻译过来是占位符,通俗一点说就是“每次训练的过程中会输入到模型的数据”。我们在定义 placeholder 的时候需要将它的数据类型和数据格式一起传入。
比如下面会看到的 tf.float32 表示的就是每次训练扔进模型中的数据类型是 tf.float32。 而数据格式默认是 None 就是一维值也可以是多维,比如 [2,3] 表示2行、3列的数据格式。而 [None, 3] 表示列为3,行不定。
使用占位符的关键就是在 sess.run() 的时候使用 feed_dict 参数,将真实数据传入。
1 | import tensorflow as tf |
四、用 Neural Network 来拟合一个曲线(回归模型)
这里,我们想要做的是用 TF 来拟合一个非常简单的类似于 $y=x^2$ 的曲线。
用实现这个拟合的目标,我们会使用一个简单的两层神经网络来完成。这时候,我们会用到 tf.layers 这个模块。如果用 Pytorch 来类比,这就是基础版的 torch.nn,之所叫基础版本,是因为里面只给你简单的封装了一些常用的 layers, 不过也勉强够用了。
再介绍一下求梯度的过程, pytorch 中求梯度有两种方式 1: 用 backward() 求出梯度之后,手动对梯度进行处理;2:backward() 求出梯度之后,使用 step() 来自动为我们更新所有的梯度。
tf 中也是这样的,我们有两种方式: 1:optimizer.compute_gradients(loss) + optimizer.apply_gradients(); 2: optimizer.minimize(loss) 。
如果你不需要用 graident 做什么事,推荐使用第二种。
模拟数据如下图所示:
导入包加使用模拟数据
1 | import tensorflow as tf |
五、模型的存储和重载
很多时候,我们会需要使用预训练的模型,这样我们需要把模型重载。但是重载之前,我们得先学会如何保存模型。
tf 中模型的保存也比较简单,讲起来两步就能搞定。
1.实例化 Saver 类 (Saver 类是 tf 中用来存储和重载的类)
2.保存模型数据
代码形式如下:
1 | # 定义一个 Saver 类 |
我们保存完模型之后,如果下次要用这个模型,该怎么办呢? 同样使用 Saver 类就行了,代码如下:
1 | # 模型重新建立的过程 |
六、数据处理和分批
用过 Pytorch 的人都知道 Dataset 和 Dataloader 合起来用起来有多爽。那么 tf 有没有这么爽的东西呢?
有的,而且功能也差不多… 叫 tf.data.Dataset 和 tf.data.Iterator
先从一个简单的例子来看。
1 | # 创建一个Dataset对象 |
当我们需要封装的数据是比较简单的数据的时候,我们可以直接使用 tf.data.Dataset.from_tensor_slices() 将数据包装成 Dataset 类。然后调用 make_one_shot_iterator() 变成一个可以迭代器,最后通过 get_next() 帮助我们从迭代器中获取元素。
再从一个复杂一些的例子来看,比如现在我们需要从硬盘中读取图片信息。假设,我们知道文件夹下面放着的图片的信息,那么我们可以这样做。
1 | # 先建立一个常量保存所有图片的路径 |
最后再补充一些理论知识,看得懂就看,看不懂就随意吧…
tf.data.Dataset:表示一串元素(elements,tfrecors中的example),其中每个元素包含了一或多个Tensor对象。例如:在一个图片pipeline中,一个元素可以是单个训练样本,它们带有一个表示图片数据的tensors和一个label组成的pair。对于datasets其实理解为一个数据堆就行,我们可以在这个数据堆上进行多种操作,预处理、排序、batching等等。有两种不同的方式创建一个dataset:
- 创建一个source (例如:Dataset.from_tensor_slices()), 从一或多个tf.Tensor对象中构建一个dataset
- 应用一个transformation(例如:Dataset.batch()),从一或多个tf.data.Dataset对象上构建一个dataset
tf.data.Iterator:它提供了主要的方式来从一个dataset中抽取元素。通过Iterator.get_next() 返回的该操作会yields出Datasets中的下一个元素,作为输入pipeline和模型间的接口使用。
其实数据处理的过程有时候才是最花时间的过程,这里仅仅只是蜻蜓点水一般的略过了一下,还是得多看 API 文档
七、CNN 和 MNIST
终于到了用 CNN 来处理 MNIST 的环节了….