1. 准备
  2. 训练
  3. 验证
  4. 测试

参考 https://github.com/tensorflow/models/tree/master/research/slim

使用TensorFlow-Slim进行图像分类

准备

  1. 安装TensorFlow

    参考 https://www.tensorflow.org/install/

    如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本

    1
    2
    wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
    pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl

  2. 下载TF-slim图像模型库

    1
    2
    cd $WORKSPACE
    git clone https://github.com/tensorflow/models/

  3. 准备数据

    有不少公开数据集,这里以官网提供的Flowers为例。

    官网提供了下载和转换数据的代码,为了理解代码并能使用自己的数据,这里参考官方提供的代码进行修改。

    1
    2
    3
    cd $WORKSPACE/data
    wget http://download.tensorflow.org/example_images/flower_photos.tgz
    tar zxf flower_photos.tgz

    数据集文件夹结构如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    flower_photos
    ├── daisy
    │   ├── 100080576_f52e8ee070_n.jpg
    │   └── ...
    ├── dandelion
    ├── LICENSE.txt
    ├── roses
    ├── sunflowers
    └── tulips

    由于实际情况中我们自己的数据集并不一定把图片按类别放在不同的文件夹里,故我们生成list.txt来表示图片路径与标签的关系。

    Python代码:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    import os

    class_names_to_ids = {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
    data_dir = 'flower_photos/'
    output_path = 'list.txt'

    fd = open(output_path, 'w')
    for class_name in class_names_to_ids.keys():
    images_list = os.listdir(data_dir + class_name)
    for image_name in images_list:
    fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name]))

    fd.close()

    为了方便后期查看label标签,也可以定义labels.txt

    1
    2
    3
    4
    5
    daisy
    dandelion
    roses
    sunflowers
    tulips

    随机生成训练集与验证集:

    Python代码:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    import random

    _NUM_VALIDATION = 350
    _RANDOM_SEED = 0
    list_path = 'list.txt'
    train_list_path = 'list_train.txt'
    val_list_path = 'list_val.txt'

    fd = open(list_path)
    lines = fd.readlines()
    fd.close()
    random.seed(_RANDOM_SEED)
    random.shuffle(lines)

    fd = open(train_list_path, 'w')
    for line in lines[_NUM_VALIDATION:]:
    fd.write(line)

    fd.close()
    fd = open(val_list_path, 'w')
    for line in lines[:_NUM_VALIDATION]:
    fd.write(line)

    fd.close()

    生成TFRecord数据:

    Python代码:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    import sys
    sys.path.insert(0, '../models/slim/')
    from datasets import dataset_utils
    import math
    import os
    import tensorflow as tf

    def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5):
    fd = open(list_path)
    lines = [line.split() for line in fd]
    fd.close()
    num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))
    with tf.Graph().as_default():
    decode_jpeg_data = tf.placeholder(dtype=tf.string)
    decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)
    with tf.Session('') as sess:
    for shard_id in range(_NUM_SHARDS):
    output_path = os.path.join(output_dir,
    'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS))
    tfrecord_writer = tf.python_io.TFRecordWriter(output_path)
    start_ndx = shard_id * num_per_shard
    end_ndx = min((shard_id + 1) * num_per_shard, len(lines))
    for i in range(start_ndx, end_ndx):
    sys.stdout.write('\r>> Converting image {}/{} shard {}'.format(
    i + 1, len(lines), shard_id))
    sys.stdout.flush()
    image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]), 'rb').read()
    image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data})
    height, width = image.shape[0], image.shape[1]
    example = dataset_utils.image_to_tfexample(
    image_data, b'jpg', height, width, int(lines[i][1]))
    tfrecord_writer.write(example.SerializeToString())
    tfrecord_writer.close()
    sys.stdout.write('\n')
    sys.stdout.flush()

    os.system('mkdir -p train')
    convert_dataset('list_train.txt', 'flower_photos', 'train/')
    os.system('mkdir -p val')
    convert_dataset('list_val.txt', 'flower_photos', 'val/')

    得到的文件夹结构如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    data
    ├── flower_photos
    ├── labels.txt
    ├── list_train.txt
    ├── list.txt
    ├── list_val.txt
    ├── train
    │   ├── data_00000-of-00005.tfrecord
    │   ├── ...
    │   └── data_00004-of-00005.tfrecord
    └── val
    ├── data_00000-of-00005.tfrecord
    ├── ...
    └── data_00004-of-00005.tfrecord

  4. (可选)下载模型

    官方提供了不少预训练模型,这里以Inception-ResNet-v2以例。

    1
    2
    3
    cd $WORKSPACE/checkpoints
    wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
    tar zxf inception_resnet_v2_2016_08_30.tar.gz

训练

  1. 读入数据

    官方提供了读入Flowers数据集的代码models/slim/datasets/flowers.py,同样这里也是参考并修改成能读入上面定义的通用数据集。

    把下面代码写入models/slim/datasets/dataset_classification.py

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    import os
    import tensorflow as tf
    slim = tf.contrib.slim

    def get_dataset(dataset_dir, num_samples, num_classes, labels_to_names_path=None, file_pattern='*.tfrecord'):
    file_pattern = os.path.join(dataset_dir, file_pattern)
    keys_to_features = {
    'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
    'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
    'image/class/label': tf.FixedLenFeature(
    [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
    }
    items_to_handlers = {
    'image': slim.tfexample_decoder.Image(),
    'label': slim.tfexample_decoder.Tensor('image/class/label'),
    }
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
    items_to_descriptions = {
    'image': 'A color image of varying size.',
    'label': 'A single integer between 0 and ' + str(num_classes - 1),
    }
    labels_to_names = None
    if labels_to_names_path is not None:
    fd = open(labels_to_names_path)
    labels_to_names = {i : line.strip() for i, line in enumerate(fd)}
    fd.close()
    return slim.dataset.Dataset(
    data_sources=file_pattern,
    reader=tf.TFRecordReader,
    decoder=decoder,
    num_samples=num_samples,
    items_to_descriptions=items_to_descriptions,
    num_classes=num_classes,
    labels_to_names=labels_to_names)

  2. 构建模型

    官方提供了许多模型在models/slim/nets/

    如需要自定义模型,则参考官方提供的模型并放在对应的文件夹即可。

  3. 开始训练

    官方提供了训练脚本,如果使用官方的数据读入和处理,可使用以下方式开始训练。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    cd $WORKSPACE/models/slim
    CUDA_VISIBLE_DEVICES="0" python train_image_classifier.py \
    --train_dir=train_logs \
    --dataset_name=flowers \
    --dataset_split_name=train \
    --dataset_dir=../../data/flowers \
    --model_name=inception_resnet_v2 \
    --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \
    --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
    --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
    --max_number_of_steps=1000 \
    --batch_size=32 \
    --learning_rate=0.01 \
    --learning_rate_decay_type=fixed \
    --save_interval_secs=60 \
    --save_summaries_secs=60 \
    --log_every_n_steps=10 \
    --optimizer=rmsprop \
    --weight_decay=0.00004

    不fine-tune把--checkpoint_path, --checkpoint_exclude_scopes--trainable_scopes删掉。

    fine-tune所有层把--checkpoint_exclude_scopes--trainable_scopes删掉。

    如果只使用CPU则加上--clone_on_cpu=True

    其它参数可删掉用默认值或自行修改。

    使用自己的数据则需要修改models/slim/train_image_classifier.py


    1
    from datasets import dataset_factory

    修改为

    1
    from datasets import dataset_classification


    1
    2
    dataset = dataset_factory.get_dataset(
    FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    修改为

    1
    2
    dataset = dataset_classification.get_dataset(
    FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)


    1
    2
    tf.app.flags.DEFINE_string(
    'dataset_dir', None, 'The directory where the dataset files are stored.')

    后加入

    1
    2
    3
    4
    5
    6
    7
    8
    tf.app.flags.DEFINE_integer(
    'num_samples', 3320, 'Number of samples.')

    tf.app.flags.DEFINE_integer(
    'num_classes', 5, 'Number of classes.')

    tf.app.flags.DEFINE_string(
    'labels_to_names_path', None, 'Label names file path.')

    训练时执行以下命令即可:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    cd $WORKSPACE/models/slim
    python train_image_classifier.py \
    --train_dir=train_logs \
    --dataset_dir=../../data/train \
    --num_samples=3320 \
    --num_classes=5 \
    --labels_to_names_path=../../data/labels.txt \
    --model_name=inception_resnet_v2 \
    --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \
    --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
    --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits

  4. 可视化log

    可一边训练一边可视化训练的log,可看到Loss趋势。

    1
    tensorboard --logdir train_logs/

验证

官方提供了验证脚本。

1
2
3
4
5
6
7
python eval_image_classifier.py \
--checkpoint_path=train_logs \
--eval_dir=eval_logs \
--dataset_name=flowers \
--dataset_split_name=validation \
--dataset_dir=../../data/flowers \
--model_name=inception_resnet_v2

同样,如果是使用自己的数据集,则需要修改models/slim/eval_image_classifier.py


1
from datasets import dataset_factory

修改为

1
from datasets import dataset_classification


1
2
dataset = dataset_factory.get_dataset(
FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

修改为

1
2
dataset = dataset_classification.get_dataset(
FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)


1
2
tf.app.flags.DEFINE_string(
'dataset_dir', None, 'The directory where the dataset files are stored.')

后加入

1
2
3
4
5
6
7
8
tf.app.flags.DEFINE_integer(
'num_samples', 350, 'Number of samples.')

tf.app.flags.DEFINE_integer(
'num_classes', 5, 'Number of classes.')

tf.app.flags.DEFINE_string(
'labels_to_names_path', None, 'Label names file path.')

验证时执行以下命令即可:

1
2
3
4
5
6
7
python eval_image_classifier.py \
--checkpoint_path=train_logs \
--eval_dir=eval_logs \
--dataset_dir=../../data/val \
--num_samples=350 \
--num_classes=5 \
--model_name=inception_resnet_v2

可以一边训练一边验证,,注意使用其它的GPU或合理分配显存。

同样也可以可视化log,如果已经在可视化训练的log则建议使用其它端口,如:

1
tensorboard --logdir eval_logs/ --port 6007

测试

参考models/slim/eval_image_classifier.py,可编写批量读取图片用模型进行推导的脚本models/slim/test_image_classifier.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import json
import math
import time
import numpy as np
import tensorflow as tf

from nets import nets_factory
from preprocessing import preprocessing_factory

slim = tf.contrib.slim

tf.app.flags.DEFINE_string(
'master', '', 'The address of the TensorFlow master to use.')

tf.app.flags.DEFINE_string(
'checkpoint_path', '/tmp/tfmodel/',
'The directory where the model was written to or an absolute path to a '
'checkpoint file.')

tf.app.flags.DEFINE_string(
'test_list', '', 'Test image list.')

tf.app.flags.DEFINE_string(
'test_dir', '.', 'Test image directory.')

tf.app.flags.DEFINE_integer(
'batch_size', 16, 'Batch size.')

tf.app.flags.DEFINE_integer(
'num_classes', 5, 'Number of classes.')

tf.app.flags.DEFINE_integer(
'labels_offset', 0,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.')

tf.app.flags.DEFINE_string(
'model_name', 'inception_v3', 'The name of the architecture to evaluate.')

tf.app.flags.DEFINE_string(
'preprocessing_name', None, 'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.')

tf.app.flags.DEFINE_integer(
'test_image_size', None, 'Eval image size')

FLAGS = tf.app.flags.FLAGS


def main(_):
if not FLAGS.test_list:
raise ValueError('You must supply the test list with --test_list')

tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default():
tf_global_step = slim.get_or_create_global_step()

####################
# Select the model #
####################
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(FLAGS.num_classes - FLAGS.labels_offset),
is_training=False)

#####################################
# Select the preprocessing function #
#####################################
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=False)

test_image_size = FLAGS.test_image_size or network_fn.default_image_size

if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
else:
checkpoint_path = FLAGS.checkpoint_path

batch_size = FLAGS.batch_size
tensor_input = tf.placeholder(tf.float32, [None, test_image_size, test_image_size, 3])
logits, _ = network_fn(tensor_input)
logits = tf.nn.top_k(logits, 5)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

test_ids = [line.strip() for line in open(FLAGS.test_list)]
tot = len(test_ids)
results = list()
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
time_start = time.time()
for idx in range(0, tot, batch_size):
images = list()
idx_end = min(tot, idx + batch_size)
print(idx)
for i in range(idx, idx_end):
image_id = test_ids[i]
test_path = os.path.join(FLAGS.test_dir, image_id)
image = open(test_path, 'rb').read()
image = tf.image.decode_jpeg(image, channels=3)
processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
processed_image = sess.run(processed_image)
images.append(processed_image)
images = np.array(images)
predictions = sess.run(logits, feed_dict = {tensor_input : images}).indices
for i in range(idx, idx_end):
print('{} {}'.format(image_id, predictions[i - idx].tolist())
time_total = time.time() - time_start
print('total time: {}, total images: {}, average time: {}'.format(
time_total, len(test_ids), time_total / len(test_ids)))

if __name__ == '__main__':
tf.app.run()

测试时执行以下命令即可:

1
2
3
4
5
6
7
CUDA_VISIBLE_DEVICES="0" python test_image_classifier.py \
--checkpoint_path=train_logs/ \
--test_list=../../data/list_val.txt \
--test_dir=../../data/flower_photos/ \
--batch_size=16 \
--num_classes=5 \
--model_name=inception_resnet_v2