1. 使用已训练好的模型进行图像分类
  2. 提取特征并可视化
  3. 提取特征并存储
  4. 使用特征文件进行可视化
    1. 在python中读取mat文件
  5. 使用自己的网络
  6. 使用Model Zoo里的网络
  7. 参考

使用CAFFE( http://caffe.berkeleyvision.org )运行CNN网络,并提取出特征,将其存储成lmdb以供后续使用,亦可以对其可视化。

使用已训练好的模型进行图像分类

其实在 http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/00-classification.ipynb 中已经很详细地介绍了怎么使用已训练好的模型对测试图像进行分类了。由于CAFFE不断更新,这个页面的内容和代码也会更新。以下只记录当前能运行的主要步骤。

  1. 下载CAFFE,并安装相应的dependencies。

  2. caffe_root下运行./scripts/download_model_binary.py models/bvlc_reference_caffenet获得预训练的CaffeNet。

  3. 在ipython里(或python,但需要把部分代码注释掉)运行以下代码来加载网络。

    • ./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    import numpy as np
    import matplotlib.pyplot as plt
    %matplotlib inline

    # Make sure that caffe is on the python path:
    caffe_root = '../' # this file is expected to be in {caffe_root}/examples
    import sys
    sys.path.insert(0, caffe_root + 'python')

    import caffe

    plt.rcParams['figure.figsize'] = (10, 10)
    plt.rcParams['image.interpolation'] = 'nearest'
    plt.rcParams['image.cmap'] = 'gray'

    import os
    if not os.path.isfile(caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'):
    print("Downloading pre-trained CaffeNet model...")
    !../scripts/download_model_binary.py ../models/bvlc_reference_caffenet
  4. 设置网络为测试阶段,并加载网络模型prototxt和数据平均值mean_npy。

    • ./models/bvlc_reference_caffenet/deploy.prototxt

    • ./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel

    • ./python/caffe/imagenet/ilsvrc_2012_mean.npy

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    caffe.set_mode_cpu()
    net = caffe.Net(caffe_root + 'models/bvlc_reference_caffenet/deploy.prototxt',
    caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel',
    caffe.TEST)

    # input preprocessing: 'data' is the name of the input blob == net.inputs[0]
    transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
    transformer.set_transpose('data', (2,0,1))
    transformer.set_mean('data', np.load(caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy').mean(1).mean(1)) # mean pixel
    transformer.set_raw_scale('data', 255) # the reference model operates on images in [0,255] range instead of [0,1]
    transformer.set_channel_swap('data', (2,1,0)) # the reference model has channels in BGR order instead of RGB
  5. 加载测试图片,并预测分类结果。

    • ./examples/images/cat.jpg
    1
    2
    3
    4
    5
    # set net to batch size of 50
    net.blobs['data'].reshape(50,3,227,227)
    net.blobs['data'].data[...] = transformer.preprocess('data', caffe.io.load_image(caffe_root + 'examples/images/cat.jpg'))
    out = net.forward()
    print("Predicted class is #{}.".format(out['prob'].argmax()))
  6. 加载标签,并输出top_k。

    • ./data/ilsvrc12/synset_words.txt
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    # load labels
    imagenet_labels_filename = caffe_root + 'data/ilsvrc12/synset_words.txt'
    try:
    labels = np.loadtxt(imagenet_labels_filename, str, delimiter='\t')
    except:
    !../data/ilsvrc12/get_ilsvrc_aux.sh
    labels = np.loadtxt(imagenet_labels_filename, str, delimiter='\t')

    # sort top k predictions from softmax output
    top_k = net.blobs['prob'].data[0].flatten().argsort()[-1:-6:-1]
    print labels[top_k]

提取特征并可视化

接上一章,如果提取特征之后不作存储直接可视化的话,可按以下步骤。

  1. 网络的特征存储在net.blobs,参数和bias存储在net.params,以下代码输出每一层的名称和大小。这里亦可手动把它们存储下来。

    1
    2
    [(k, v.data.shape) for k, v in net.blobs.items()]
    [(k, v[0].data.shape) for k, v in net.params.items()]
  2. 可视化。以下是辅助函数。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    # take an array of shape (n, height, width) or (n, height, width, channels)
    # and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)
    def vis_square(data, padsize=1, padval=0):
    data -= data.min()
    data /= data.max()

    # force the number of filters to be square
    n = int(np.ceil(np.sqrt(data.shape[0])))
    padding = ((0, n ** 2 - data.shape[0]), (0, padsize), (0, padsize)) + ((0, 0),) * (data.ndim - 3)
    data = np.pad(data, padding, mode='constant', constant_values=(padval, padval))

    # tile the filters into an image
    data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
    data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])

    plt.imshow(data)
    plt.show()
    • 根据每一层的名称,选择需要可视化的层,可以可视化filters(参数)和output(特征)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    # the parameters are a list of [weights, biases]
    filters = net.params['conv1'][0].data
    vis_square(filters.transpose(0, 2, 3, 1))

    feat = net.blobs['conv1'].data[0, :36]
    vis_square(feat, padval=1)

    # There are 256 filters, each of which has dimension 5 x 5 x 48. We show only the first 48 filters, with each channel shown separately, so that each filter is a row.
    filters = net.params['conv2'][0].data
    vis_square(filters[:48].reshape(48**2, 5, 5))

    # rectified, only the first 36 of 256 channels
    feat = net.blobs['conv2'].data[0, :36]
    vis_square(feat, padval=1)

    feat = net.blobs['conv3'].data[0]
    vis_square(feat, padval=0.5)

    feat = net.blobs['conv4'].data[0]
    vis_square(feat, padval=0.5)

    feat = net.blobs['conv5'].data[0]
    vis_square(feat, padval=0.5)

    feat = net.blobs['pool5'].data[0]
    vis_square(feat, padval=1)

    feat = net.blobs['fc6'].data[0]
    plt.subplot(2, 1, 1)
    plt.plot(feat.flat)
    plt.subplot(2, 1, 2)
    _ = plt.hist(feat.flat[feat.flat > 0], bins=100)

    feat = net.blobs['fc7'].data[0]
    plt.subplot(2, 1, 1)
    plt.plot(feat.flat)
    plt.subplot(2, 1, 2)
    _ = plt.hist(feat.flat[feat.flat > 0], bins=100)

    feat = net.blobs['prob'].data[0]
    plt.plot(feat.flat)

提取特征并存储

CAFFE提供了一个提取特征的tool,见 http://caffe.berkeleyvision.org/gathered/examples/feature_extraction.html

  1. 选择需要特征提取的图像。

    • ./examples/_temp
    1
    2
    3
    mkdir examples/_temp
    find `pwd`/examples/images -type f -exec echo {} \; > examples/_temp/temp.txt
    sed "s/$/ 0/" examples/_temp/temp.txt > examples/_temp/file_list.txt
  2. 跟前面一样,下载模型以及定义prototxt。

    • ./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel

    • ./examples/_temp/imagenet_val.prototxt

    1
    2
    ./data/ilsvrc12/get_ilsvrc_aux.sh
    cp examples/feature_extraction/imagenet_val.prototxt examples/_temp
  3. 使用extract_features.bin工具提取特征,并存储为lmdb。运行参数为extract_features.bin $MODEL $PROTOTXT $LAYER $LMDB_OUTPUT_PATH $BATCHSIZE

    • ./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel

    • ./examples/_temp/imagenet_val.prototxt

    • ./examples/_temp/features

    1
    2
    3
    4
    ./build/tools/extract_features.bin \
    models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel \
    examples/_temp/imagenet_val.prototxt \
    fc7 examples/_temp/features 10 lmdb

使用特征文件进行可视化

参考 http://www.cnblogs.com/platero/p/3967208.html 和 lmdb的文档 https://lmdb.readthedocs.org/en/release ,读取lmdb文件,然后转换成mat文件,再用matlab调用mat进行可视化。

  1. 安装CAFFE的python依赖库,并使用以下两个辅助文件把lmdb转换为mat。

    • ./feat_helper_pb2.py

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      # Generated by the protocol buffer compiler.  DO NOT EDIT!

      from google.protobuf import descriptor
      from google.protobuf import message
      from google.protobuf import reflection
      from google.protobuf import descriptor_pb2
      # @@protoc_insertion_point(imports)


      DESCRIPTOR = descriptor.FileDescriptor(
      name='datum.proto',
      package='feat_extract',
      serialized_pb='\n\x0b\x64\x61tum.proto\x12\x0c\x66\x65\x61t_extract\"i\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02')


      _DATUM = descriptor.Descriptor(
      name='Datum',
      full_name='feat_extract.Datum',
      filename=None,
      file=DESCRIPTOR,
      containing_type=None,
      fields=[
      descriptor.FieldDescriptor(
      name='channels', full_name='feat_extract.Datum.channels', index=0,
      number=1, type=5, cpp_type=1, label=1,
      has_default_value=False, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
      descriptor.FieldDescriptor(
      name='height', full_name='feat_extract.Datum.height', index=1,
      number=2, type=5, cpp_type=1, label=1,
      has_default_value=False, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
      descriptor.FieldDescriptor(
      name='width', full_name='feat_extract.Datum.width', index=2,
      number=3, type=5, cpp_type=1, label=1,
      has_default_value=False, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
      descriptor.FieldDescriptor(
      name='data', full_name='feat_extract.Datum.data', index=3,
      number=4, type=12, cpp_type=9, label=1,
      has_default_value=False, default_value="",
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
      descriptor.FieldDescriptor(
      name='label', full_name='feat_extract.Datum.label', index=4,
      number=5, type=5, cpp_type=1, label=1,
      has_default_value=False, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
      descriptor.FieldDescriptor(
      name='float_data', full_name='feat_extract.Datum.float_data', index=5,
      number=6, type=2, cpp_type=6, label=3,
      has_default_value=False, default_value=[],
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      options=None),
      ],
      extensions=[
      ],
      nested_types=[],
      enum_types=[
      ],
      options=None,
      is_extendable=False,
      extension_ranges=[],
      serialized_start=29,
      serialized_end=134,
      )

      DESCRIPTOR.message_types_by_name['Datum'] = _DATUM

      class Datum(message.Message):
      __metaclass__ = reflection.GeneratedProtocolMessageType
      DESCRIPTOR = _DATUM

      # @@protoc_insertion_point(class_scope:feat_extract.Datum)

      # @@protoc_insertion_point(module_scope)
    • ./lmdb2mat.py

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      import lmdb
      import feat_helper_pb2
      import numpy as np
      import scipy.io as sio
      import time

      def main(argv):
      lmdb_name = sys.argv[1]
      print "%s" % sys.argv[1]
      batch_num = int(sys.argv[2]);
      batch_size = int(sys.argv[3]);
      window_num = batch_num*batch_size;

      start = time.time()
      if 'db' not in locals().keys():
      db = lmdb.open(lmdb_name)
      txn= db.begin()
      cursor = txn.cursor()
      cursor.iternext()
      datum = feat_helper_pb2.Datum()

      keys = []
      values = []
      for key, value in enumerate( cursor.iternext_nodup()):
      keys.append(key)
      values.append(cursor.value())

      ft = np.zeros((window_num, int(sys.argv[4])))
      for im_idx in range(window_num):
      datum.ParseFromString(values[im_idx])
      ft[im_idx, :] = datum.float_data

      print 'time 1: %f' %(time.time() - start)
      sio.savemat(sys.argv[5], {'feats':ft})
      print 'time 2: %f' %(time.time() - start)
      print 'done!'

      if __name__ == '__main__':
      import sys
      main(sys.argv)
    • 运行bash

      1
      2
      3
      4
      5
      6
      7
      8
      9
      #!/usr/bin/env sh
      LMDB=./examples/_temp/features_fc7 # lmdb文件路径
      BATCHNUM=1
      BATCHSIZE=10
      # DIM=290400 # feature长度,conv1
      # DIM=43264 # conv5
      DIM=4096
      OUT=./examples/_temp/features_fc7.mat #mat文件保存路径
      python ./lmdb2mat.py $LMDB $BATCHNUM $BATCHSIZE $DIM $OUT
  2. 参考UFLDL里的display_network函数,对mat文件里的特征进行可视化。

    • display_network.m

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      101
      function [h, array] = display_network(A, opt_normalize, opt_graycolor, cols, opt_colmajor)
      % This function visualizes filters in matrix A. Each column of A is a
      % filter. We will reshape each column into a square image and visualizes
      % on each cell of the visualization panel.
      % All other parameters are optional, usually you do not need to worry
      % about it.
      % opt_normalize: whether we need to normalize the filter so that all of
      % them can have similar contrast. Default value is true.
      % opt_graycolor: whether we use gray as the heat map. Default is true.
      % cols: how many columns are there in the display. Default value is the
      % squareroot of the number of columns in A.
      % opt_colmajor: you can switch convention to row major for A. In that
      % case, each row of A is a filter. Default value is false.
      warning off all

      if ~exist('opt_normalize', 'var') || isempty(opt_normalize)
      opt_normalize= true;
      end

      if ~exist('opt_graycolor', 'var') || isempty(opt_graycolor)
      opt_graycolor= true;
      end

      if ~exist('opt_colmajor', 'var') || isempty(opt_colmajor)
      opt_colmajor = false;
      end

      % rescale
      A = A - mean(A(:));

      if opt_graycolor, colormap(gray); end

      % compute rows, cols
      [L M]=size(A);
      sz=sqrt(L);
      buf=1;
      if ~exist('cols', 'var')
      if floor(sqrt(M))^2 ~= M
      n=ceil(sqrt(M));
      while mod(M, n)~=0 && n<1.2*sqrt(M), n=n+1; end
      m=ceil(M/n);
      else
      n=sqrt(M);
      m=n;
      end
      else
      n = cols;
      m = ceil(M/n);
      end

      array=-ones(buf+m*(sz+buf),buf+n*(sz+buf));

      if ~opt_graycolor
      array = 0.1.* array;
      end


      if ~opt_colmajor
      k=1;
      for i=1:m
      for j=1:n
      if k>M,
      continue;
      end
      clim=max(abs(A(:,k)));
      if opt_normalize
      array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/clim;
      else
      array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/max(abs(A(:)));
      end
      k=k+1;
      end
      end
      else
      k=1;
      for j=1:n
      for i=1:m
      if k>M,
      continue;
      end
      clim=max(abs(A(:,k)));
      if opt_normalize
      array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/clim;
      else
      array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)';
      end
      k=k+1;
      end
      end
      end

      if opt_graycolor
      h=imagesc(array);
      else
      h=imagesc(array,'EraseMode','none',[-1 1]);
      end
      axis image off

      drawnow;

      warning on all
    • 在matlab里运行以下代码:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      nsample = 2;
      % num_output = 96; % conv1
      % num_output = 256; % conv5
      num_output = 4096; % fc7

      load features_fc7.mat
      width = size(feats, 2);
      nmap = width / num_output;

      for i = 1 : nsample
      feat = feats(i, :);
      feat = reshape(feat, [nmap num_output]);
      figure('name', sprintf('image #%d', i));
      display_network(feat);
      end

在python中读取mat文件

在python中,使用scipy.io.loadmat()即可读取mat文件,返回一个dict()

1
2
3
import scipy.io
matfile = 'features_fc7.mat'
data = scipy.io.loadmat(matfile)

使用自己的网络

只需把前面列出来的文件与参数修改成自定义的即可。

使用Model Zoo里的网络

根据 https://github.com/BVLC/caffe/wiki/Model-Zoo 的介绍,选择自己所需的网络,并下载到相应位置即可。

如VGG-16:

1
2
3
./scripts/download_model_from_gist.sh 211839e770f7b538e2d8
mv ./models/211839e770f7b538e2d8 ./models/VGG_ILSVRC_16_layers
./scripts/download_model_binary.py ./models/VGG_ILSVRC_16_layers

参考

http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/00-classification.ipynb

http://caffe.berkeleyvision.org/gathered/examples/feature_extraction.html

http://www.cnblogs.com/platero/p/3967208.html

https://lmdb.readthedocs.org/en/release/