diff --git a/examples/imagenet.prototxt b/examples/imagenet.prototxt index 5db585b7fec..e14aaa76f99 100644 --- a/examples/imagenet.prototxt +++ b/examples/imagenet.prototxt @@ -65,15 +65,6 @@ layers { bottom: "pool1" top: "norm1" } -layers { - layer { - name: "pad2" - type: "padding" - pad: 2 - } - bottom: "norm1" - top: "pad2" -} layers { layer { name: "conv2" @@ -81,6 +72,7 @@ layers { num_output: 256 group: 2 kernelsize: 5 + pad: 2 weight_filler { type: "gaussian" std: 0.01 @@ -94,7 +86,7 @@ layers { weight_decay: 1. weight_decay: 0. } - bottom: "pad2" + bottom: "norm1" top: "conv2" } layers { @@ -127,21 +119,13 @@ layers { bottom: "pool2" top: "norm2" } -layers { - layer { - name: "pad3" - type: "padding" - pad: 1 - } - bottom: "norm2" - top: "pad3" -} layers { layer { name: "conv3" type: "conv" num_output: 384 kernelsize: 3 + pad: 1 weight_filler { type: "gaussian" std: 0.01 @@ -155,7 +139,7 @@ layers { weight_decay: 1. weight_decay: 0. } - bottom: "pad3" + bottom: "norm2" top: "conv3" } layers { @@ -166,15 +150,6 @@ layers { bottom: "conv3" top: "conv3" } -layers { - layer { - name: "pad4" - type: "padding" - pad: 1 - } - bottom: "conv3" - top: "pad4" -} layers { layer { name: "conv4" @@ -182,6 +157,7 @@ layers { num_output: 384 group: 2 kernelsize: 3 + pad: 1 weight_filler { type: "gaussian" std: 0.01 @@ -195,7 +171,7 @@ layers { weight_decay: 1. weight_decay: 0. } - bottom: "pad4" + bottom: "conv3" top: "conv4" } layers { @@ -206,15 +182,6 @@ layers { bottom: "conv4" top: "conv4" } -layers { - layer { - name: "pad5" - type: "padding" - pad: 1 - } - bottom: "conv4" - top: "pad5" -} layers { layer { name: "conv5" @@ -222,6 +189,7 @@ layers { num_output: 256 group: 2 kernelsize: 3 + pad: 1 weight_filler { type: "gaussian" std: 0.01 @@ -235,7 +203,7 @@ layers { weight_decay: 1. weight_decay: 0. } - bottom: "pad5" + bottom: "conv4" top: "conv5" } layers { diff --git a/examples/imagenet_deploy.prototxt b/examples/imagenet_deploy.prototxt index 62579140e73..41e1c60af15 100644 --- a/examples/imagenet_deploy.prototxt +++ b/examples/imagenet_deploy.prototxt @@ -56,15 +56,6 @@ layers { bottom: "pool1" top: "norm1" } -layers { - layer { - name: "pad2" - type: "padding" - pad: 2 - } - bottom: "norm1" - top: "pad2" -} layers { layer { name: "conv2" @@ -72,6 +63,7 @@ layers { num_output: 256 group: 2 kernelsize: 5 + pad: 2 weight_filler { type: "gaussian" std: 0.01 @@ -85,7 +77,7 @@ layers { weight_decay: 1. weight_decay: 0. } - bottom: "pad2" + bottom: "norm1" top: "conv2" } layers { @@ -118,21 +110,13 @@ layers { bottom: "pool2" top: "norm2" } -layers { - layer { - name: "pad3" - type: "padding" - pad: 1 - } - bottom: "norm2" - top: "pad3" -} layers { layer { name: "conv3" type: "conv" num_output: 384 kernelsize: 3 + pad: 1 weight_filler { type: "gaussian" std: 0.01 @@ -146,7 +130,7 @@ layers { weight_decay: 1. weight_decay: 0. } - bottom: "pad3" + bottom: "norm2" top: "conv3" } layers { @@ -157,15 +141,6 @@ layers { bottom: "conv3" top: "conv3" } -layers { - layer { - name: "pad4" - type: "padding" - pad: 1 - } - bottom: "conv3" - top: "pad4" -} layers { layer { name: "conv4" @@ -173,6 +148,7 @@ layers { num_output: 384 group: 2 kernelsize: 3 + pad: 1 weight_filler { type: "gaussian" std: 0.01 @@ -186,7 +162,7 @@ layers { weight_decay: 1. weight_decay: 0. } - bottom: "pad4" + bottom: "conv3" top: "conv4" } layers { @@ -197,15 +173,6 @@ layers { bottom: "conv4" top: "conv4" } -layers { - layer { - name: "pad5" - type: "padding" - pad: 1 - } - bottom: "conv4" - top: "pad5" -} layers { layer { name: "conv5" @@ -213,6 +180,7 @@ layers { num_output: 256 group: 2 kernelsize: 3 + pad: 1 weight_filler { type: "gaussian" std: 0.01 @@ -226,7 +194,7 @@ layers { weight_decay: 1. weight_decay: 0. } - bottom: "pad5" + bottom: "conv4" top: "conv5" } layers { diff --git a/examples/imagenet_val.prototxt b/examples/imagenet_val.prototxt index fbc4c32522e..d1e4cd02acd 100644 --- a/examples/imagenet_val.prototxt +++ b/examples/imagenet_val.prototxt @@ -53,15 +53,6 @@ layers { bottom: "pool1" top: "norm1" } -layers { - layer { - name: "pad2" - type: "padding" - pad: 2 - } - bottom: "norm1" - top: "pad2" -} layers { layer { name: "conv2" @@ -69,8 +60,9 @@ layers { num_output: 256 group: 2 kernelsize: 5 + pad: 2 } - bottom: "pad2" + bottom: "norm1" top: "conv2" } layers { @@ -103,23 +95,15 @@ layers { bottom: "pool2" top: "norm2" } -layers { - layer { - name: "pad3" - type: "padding" - pad: 1 - } - bottom: "norm2" - top: "pad3" -} layers { layer { name: "conv3" type: "conv" num_output: 384 kernelsize: 3 + pad: 1 } - bottom: "pad3" + bottom: "norm2" top: "conv3" } layers { @@ -130,15 +114,6 @@ layers { bottom: "conv3" top: "conv3" } -layers { - layer { - name: "pad4" - type: "padding" - pad: 1 - } - bottom: "conv3" - top: "pad4" -} layers { layer { name: "conv4" @@ -146,8 +121,9 @@ layers { num_output: 384 group: 2 kernelsize: 3 + pad: 1 } - bottom: "pad4" + bottom: "conv3" top: "conv4" } layers { @@ -158,15 +134,6 @@ layers { bottom: "conv4" top: "conv4" } -layers { - layer { - name: "pad5" - type: "padding" - pad: 1 - } - bottom: "conv4" - top: "pad5" -} layers { layer { name: "conv5" @@ -174,8 +141,9 @@ layers { num_output: 256 group: 2 kernelsize: 3 + pad: 1 } - bottom: "pad5" + bottom: "conv4" top: "conv5" } layers { diff --git a/include/caffe/util/im2col.hpp b/include/caffe/util/im2col.hpp index 83c01ddab53..521efd31593 100644 --- a/include/caffe/util/im2col.hpp +++ b/include/caffe/util/im2col.hpp @@ -7,22 +7,22 @@ namespace caffe { template void im2col_cpu(const Dtype* data_im, const int channels, - const int height, const int width, const int ksize, const int stride, + const int height, const int width, const int ksize, const int pad, const int stride, Dtype* data_col); template void col2im_cpu(const Dtype* data_col, const int channels, - const int height, const int width, const int psize, const int stride, + const int height, const int width, const int psize, const int pad, const int stride, Dtype* data_im); template void im2col_gpu(const Dtype* data_im, const int channels, - const int height, const int width, const int ksize, const int stride, + const int height, const int width, const int ksize, const int pad, const int stride, Dtype* data_col); template void col2im_gpu(const Dtype* data_col, const int channels, - const int height, const int width, const int psize, const int stride, + const int height, const int width, const int psize, const int pad, const int stride, Dtype* data_im); } // namespace caffe diff --git a/include/caffe/util/insert_splits.hpp b/include/caffe/util/insert_splits.hpp new file mode 100644 index 00000000000..d0df85650c9 --- /dev/null +++ b/include/caffe/util/insert_splits.hpp @@ -0,0 +1,29 @@ +// Copyright 2014 Jeff Donahue + +#ifndef _CAFFE_UTIL_INSERT_SPLITS_HPP_ +#define _CAFFE_UTIL_INSERT_SPLITS_HPP_ + +#include "caffe/proto/caffe.pb.h" + +using std::pair; +using std::string; + +namespace caffe { + +// Copy NetParameters with SplitLayers added to replace any shared bottom +// blobs with unique bottom blobs provided by the SplitLayer. +void insert_splits(const NetParameter& param, NetParameter* param_split); + +void configure_split_layer(const string& layer_name, const string& blob_name, + const int blob_idx, const int split_count, + LayerConnection* split_layer_connection); + +string get_split_layer_name(const string& layer_name, const string& blob_name, + const int blob_idx); + +string get_split_blob_name(const string& layer_name, const string& blob_name, + const int blob_idx, const int split_idx); + +} // namespace caffe + +#endif // CAFFE_UTIL_INSERT_SPLITS_HPP_ diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 82e52cd5bfe..5dd3c9076d6 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -44,6 +44,23 @@ class ReLULayer : public NeuronLayer { const bool propagate_down, vector*>* bottom); }; +template +class TanHLayer : public NeuronLayer { + public: + explicit TanHLayer(const LayerParameter& param) + : NeuronLayer(param) {} + + protected: + virtual void Forward_cpu(const vector*>& bottom, + vector*>* top); + virtual void Forward_gpu(const vector*>& bottom, + vector*>* top); + + virtual Dtype Backward_cpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); + virtual Dtype Backward_gpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); +}; template class SigmoidLayer : public NeuronLayer { @@ -109,9 +126,9 @@ class DropoutLayer : public NeuronLayer { template -class FlattenLayer : public Layer { +class SplitLayer : public Layer { public: - explicit FlattenLayer(const LayerParameter& param) + explicit SplitLayer(const LayerParameter& param) : Layer(param) {} virtual void SetUp(const vector*>& bottom, vector*>* top); @@ -130,9 +147,9 @@ class FlattenLayer : public Layer { template -class InnerProductLayer : public Layer { +class FlattenLayer : public Layer { public: - explicit InnerProductLayer(const LayerParameter& param) + explicit FlattenLayer(const LayerParameter& param) : Layer(param) {} virtual void SetUp(const vector*>& bottom, vector*>* top); @@ -142,23 +159,18 @@ class InnerProductLayer : public Layer { vector*>* top); virtual void Forward_gpu(const vector*>& bottom, vector*>* top); - virtual Dtype Backward_cpu(const vector*>& top, const bool propagate_down, vector*>* bottom); virtual Dtype Backward_gpu(const vector*>& top, const bool propagate_down, vector*>* bottom); - int M_; - int K_; - int N_; - bool biasterm_; - shared_ptr bias_multiplier_; + int count_; }; template -class PaddingLayer : public Layer { +class InnerProductLayer : public Layer { public: - explicit PaddingLayer(const LayerParameter& param) + explicit InnerProductLayer(const LayerParameter& param) : Layer(param) {} virtual void SetUp(const vector*>& bottom, vector*>* top); @@ -168,20 +180,18 @@ class PaddingLayer : public Layer { vector*>* top); virtual void Forward_gpu(const vector*>& bottom, vector*>* top); + virtual Dtype Backward_cpu(const vector*>& top, const bool propagate_down, vector*>* bottom); virtual Dtype Backward_gpu(const vector*>& top, const bool propagate_down, vector*>* bottom); - unsigned int PAD_; - int NUM_; - int CHANNEL_; - int HEIGHT_IN_; - int WIDTH_IN_; - int HEIGHT_OUT_; - int WIDTH_OUT_; + int M_; + int K_; + int N_; + bool biasterm_; + shared_ptr bias_multiplier_; }; - template class LRNLayer : public Layer { public: @@ -234,9 +244,9 @@ class Im2colLayer : public Layer { int CHANNELS_; int HEIGHT_; int WIDTH_; + int PAD_; }; - template class PoolingLayer : public Layer { public: @@ -288,6 +298,7 @@ class ConvolutionLayer : public Layer { int STRIDE_; int NUM_; int CHANNELS_; + int PAD_; int HEIGHT_; int WIDTH_; int NUM_OUTPUT_; @@ -487,4 +498,3 @@ class AccuracyLayer : public Layer { } // namespace caffe #endif // CAFFE_VISION_LAYERS_HPP_ - diff --git a/python/caffe/convlayervisualization.py b/python/caffe/convlayervisualization.py new file mode 100644 index 00000000000..8503a3b56c7 --- /dev/null +++ b/python/caffe/convlayervisualization.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python +""" + Code for visualizing caffe network conv layers. + +""" + +import matplotlib +matplotlib.use('Agg') + +from caffe.convert import blobproto_to_array +from caffe.proto import caffe_pb2 + +import pylab as pl +import numpy as np + +from math import sqrt, ceil +import sys +import os + + +class ConvLayerVisualizer(object): + + def __init__(self, file_name): + self.load_model(file_name) + self.extract_conv_layers() + + def load_model(self, file_name): + """ + Load the snapshot into a NetParameter object. + """ + net = caffe_pb2.NetParameter() + data = open(file_name).read() + net.ParseFromString(data) + self.net = net + + def extract_conv_layers(self): + """ + Extract all the convolutional layers from the network. + """ + conv_layers = [] + for layer in self.net.layers: + if layer.layer.type != "conv": + continue + conv_layers.append(layer.layer) + self.conv_layers = conv_layers + + def visualize_conv_layers(self): + """ + Visualize the conv layer kernels in a subplot each. + + """ + conv_layers = self.conv_layers + + f, axes = pl.subplots(len(conv_layers)) + f.suptitle(self.net.name) + for idx, (conv_layer, ax) in enumerate(zip(conv_layers, axes)): + W = blobproto_to_array(conv_layer.blobs[0]) + # only combine color channels for the first conv layer + combine_chans = idx == 0 + plot_weights(W, title="Layer %d" % idx, + axis=ax, combine_chans=combine_chans) + self._fig = f + + def visualize_conv_layer(self, conv_layer_idx=0): + """ + Visualize a single conv layer of the network. + + """ + conv_layers = self.conv_layers + + W = blobproto_to_array(conv_layers[conv_layer_idx].blobs[0]) + # only combine color channels for the first conv layer + combine_chans = conv_layer_idx == 0 + f, axis = pl.subplots() + plot_weights(W, title="Layer %d" % conv_layer_idx, + axis=axis, combine_chans=combine_chans) + self._fig = f + + def save_fig_to_file(self, file_name): + self._fig.savefig(file_name) + + +def plot_weights(filters, title, axis=None, combine_chans=False): + """ + Takes conv layer kernel numpy ndarray as an input. + Plots all kernels in one big image. + """ + filters = filters - filters.min() + filters = filters / filters.max() + + if axis is None: + f, axis = pl.subplots() + + make_filter_fig(filters, + filter_start=0, + axis=axis, + title=title, + num_filters=filters.shape[0], + combine_chans=combine_chans) + + +def make_filter_fig(filters, + filter_start, + axis, + title, + num_filters, + combine_chans): + """ + Plot the given filters. + + filters: + ndarray with dimensions: + num_examples, num_channels, filter_size, filter_size + + Code adapted from: + https://code.google.com/p/cuda-convnet/source/browse/trunk/shownet.py + """ + FILTERS_PER_ROW = int(ceil(sqrt(filters.shape[0]))) + MAX_ROWS = FILTERS_PER_ROW + MAX_FILTERS = FILTERS_PER_ROW * MAX_ROWS + num_colors = filters.shape[1] + f_per_row = int(ceil(FILTERS_PER_ROW / + float(1 if combine_chans else num_colors))) + filter_end = min(filter_start+MAX_FILTERS, num_filters) + filter_rows = int(ceil(float(filter_end - filter_start) / f_per_row)) + + assert filters.shape[2] == filters.shape[3] + filter_size = int(filters.shape[2]) + axis.set_title('%s %dx%d filters %d-%d' % (title, filter_size, filter_size, + filter_start, filter_end-1), + horizontalalignment='center') + num_filters = filter_end - filter_start + if not combine_chans: + bigpic = np.zeros((filter_size * filter_rows + filter_rows + 1, + filter_size*num_colors * f_per_row + f_per_row + 1), + dtype=np.single) + else: + bigpic = np.zeros((3, filter_size * filter_rows + filter_rows + 1, + filter_size * f_per_row + f_per_row + 1), + dtype=np.single) + + for m in xrange(filter_start, filter_end): + filter = filters[m,:,:,:] + y, x = (m - filter_start) / f_per_row, (m - filter_start) % f_per_row + if not combine_chans: + for c in xrange(num_colors): + filter_pic = filter[c,:].reshape((filter_size,filter_size)) + bigpic[1 + (1 + filter_size) * y:1 + (1 + filter_size) * y + filter_size, + 1 + (1 + filter_size*num_colors) * x + filter_size*c:1 + (1 + filter_size*num_colors) * x + filter_size*(c+1)] = filter_pic + else: + filter_pic = filter.reshape((3, filter_size, filter_size)) + bigpic[:, + 1 + (1 + filter_size) * y:1 + (1 + filter_size) * y + filter_size, + 1 + (1 + filter_size) * x:1 + (1 + filter_size) * x + filter_size] = filter_pic + + axis.set_xticks([]) + axis.set_yticks([]) + if not combine_chans: + axis.imshow(bigpic, cmap=pl.cm.gray, interpolation='nearest') + else: + bigpic = bigpic.swapaxes(0,2).swapaxes(0,1) + axis.imshow(bigpic, interpolation='nearest') + + +if __name__ == '__main__': + if len(sys.argv) != 3: + print 'Usage: %s input_net_proto_file output_image_file' % \ + os.path.basename(sys.argv[0]) + else: + print "Loading %s" % sys.argv[1] + visualizer = ConvLayerVisualizer(sys.argv[1]) + + visualizer.visualize_conv_layer() + + print 'Exporting conv layers to %s' % sys.argv[2] + visualizer.save_fig_to_file(sys.argv[2]) diff --git a/python/caffe/pycaffe.cpp b/python/caffe/pycaffe.cpp index 1beec163c8f..7e17d564938 100644 --- a/python/caffe/pycaffe.cpp +++ b/python/caffe/pycaffe.cpp @@ -32,12 +32,16 @@ using boost::python::vector_indexing_suite; class CaffeBlob { public: + CaffeBlob(const shared_ptr > &blob, const string& name) + : blob_(blob), name_(name) {} + CaffeBlob(const shared_ptr > &blob) : blob_(blob) {} CaffeBlob() {} + string name() const { return name_; } int num() const { return blob_->num(); } int channels() const { return blob_->channels(); } int height() const { return blob_->height(); } @@ -51,6 +55,7 @@ class CaffeBlob { protected: shared_ptr > blob_; + string name_; }; @@ -219,15 +224,27 @@ struct CaffeNet void set_device(int device_id) { Caffe::SetDevice(device_id); } vector blobs() { - return vector(net_->blobs().begin(), net_->blobs().end()); + vector result; + for (int i = 0; i < net_->blobs().size(); ++i) { + result.push_back(CaffeBlob(net_->blobs()[i], net_->blob_names()[i])); + } + return result; } vector params() { - return vector(net_->params().begin(), net_->params().end()); + vector result; + int ix = 0; + for (int i = 0; i < net_->layers().size(); ++i) { + for (int j = 0; j < net_->layers()[i]->blobs().size(); ++j) { + result.push_back(CaffeBlob(net_->params()[ix], net_->layer_names()[i])); + ix++; + } + } + return result; } // The pointer to the internal caffe::Net instant. - shared_ptr > net_; + shared_ptr > net_; }; @@ -251,6 +268,7 @@ BOOST_PYTHON_MODULE(pycaffe) boost::python::class_( "CaffeBlob", boost::python::no_init) + .add_property("name", &CaffeBlob::name) .add_property("num", &CaffeBlob::num) .add_property("channels", &CaffeBlob::channels) .add_property("height", &CaffeBlob::height) diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index b62ba3839c9..8733ff863e6 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -41,18 +41,20 @@ Layer* GetLayer(const LayerParameter& param) { return new InnerProductLayer(param); } else if (type == "lrn") { return new LRNLayer(param); - } else if (type == "padding") { - return new PaddingLayer(param); } else if (type == "pool") { return new PoolingLayer(param); } else if (type == "relu") { return new ReLULayer(param); + } else if (type == "tanh") { + return new TanHLayer(param); } else if (type == "sigmoid") { return new SigmoidLayer(param); } else if (type == "softmax") { return new SoftmaxLayer(param); } else if (type == "softmax_loss") { return new SoftmaxWithLossLayer(param); + } else if (type == "split") { + return new SplitLayer(param); } else if (type == "multinomial_logistic_loss") { return new MultinomialLogisticLossLayer(param); } else { diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp index f2608be2f64..69a860bf285 100644 --- a/src/caffe/layers/conv_layer.cpp +++ b/src/caffe/layers/conv_layer.cpp @@ -18,6 +18,7 @@ void ConvolutionLayer::SetUp(const vector*>& bottom, KSIZE_ = this->layer_param_.kernelsize(); STRIDE_ = this->layer_param_.stride(); GROUP_ = this->layer_param_.group(); + PAD_ = this->layer_param_.pad(); NUM_ = bottom[0]->num(); CHANNELS_ = bottom[0]->channels(); HEIGHT_ = bottom[0]->height(); @@ -27,8 +28,8 @@ void ConvolutionLayer::SetUp(const vector*>& bottom, CHECK_EQ(CHANNELS_ % GROUP_, 0); // The im2col result buffer would only hold one image at a time to avoid // overly large memory usage. - int height_out = (HEIGHT_ - KSIZE_) / STRIDE_ + 1; - int width_out = (WIDTH_ - KSIZE_) / STRIDE_ + 1; + int height_out = (HEIGHT_ + 2 * PAD_ - KSIZE_) / STRIDE_ + 1; + int width_out = (WIDTH_ + 2 * PAD_ - KSIZE_) / STRIDE_ + 1; col_buffer_.Reshape(1, CHANNELS_ * KSIZE_ * KSIZE_, height_out, width_out); // Set the parameters CHECK_EQ(NUM_OUTPUT_ % GROUP_, 0) @@ -88,7 +89,7 @@ void ConvolutionLayer::Forward_cpu(const vector*>& bottom, for (int n = 0; n < NUM_; ++n) { // First, im2col im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_, - WIDTH_, KSIZE_, STRIDE_, col_data); + WIDTH_, KSIZE_, PAD_, STRIDE_, col_data); // Second, innerproduct with groups for (int g = 0; g < GROUP_; ++g) { caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, K_, @@ -118,7 +119,7 @@ void ConvolutionLayer::Forward_gpu(const vector*>& bottom, for (int n = 0; n < NUM_; ++n) { // First, im2col im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_, - WIDTH_, KSIZE_, STRIDE_, col_data); + WIDTH_, KSIZE_, PAD_, STRIDE_, col_data); // Second, innerproduct with groups for (int g = 0; g < GROUP_; ++g) { caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, K_, @@ -167,7 +168,7 @@ Dtype ConvolutionLayer::Backward_cpu(const vector*>& top, // since we saved memory in the forward pass by not storing all col data, // we will need to recompute them. im2col_cpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_, - WIDTH_, KSIZE_, STRIDE_, col_data); + WIDTH_, KSIZE_, PAD_, STRIDE_, col_data); // gradient w.r.t. weight. Note that we will accumulate diffs. for (int g = 0; g < GROUP_; ++g) { caffe_cpu_gemm(CblasNoTrans, CblasTrans, M_, K_, N_, @@ -185,7 +186,7 @@ Dtype ConvolutionLayer::Backward_cpu(const vector*>& top, } // col2im back to the data col2im_cpu(col_diff, CHANNELS_, HEIGHT_, - WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n)); + WIDTH_, KSIZE_, PAD_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n)); } } return Dtype(0.); @@ -225,7 +226,7 @@ Dtype ConvolutionLayer::Backward_gpu(const vector*>& top, // since we saved memory in the forward pass by not storing all col data, // we will need to recompute them. im2col_gpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_, - WIDTH_, KSIZE_, STRIDE_, col_data); + WIDTH_, KSIZE_, PAD_, STRIDE_, col_data); // gradient w.r.t. weight. Note that we will accumulate diffs. for (int g = 0; g < GROUP_; ++g) { caffe_gpu_gemm(CblasNoTrans, CblasTrans, M_, K_, N_, @@ -243,7 +244,7 @@ Dtype ConvolutionLayer::Backward_gpu(const vector*>& top, } // col2im back to the data col2im_gpu(col_diff, CHANNELS_, HEIGHT_, - WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n)); + WIDTH_, KSIZE_, PAD_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n)); } } return Dtype(0.); diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index d1262d03f24..ffb7fd0a9e2 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -129,6 +129,7 @@ void DataLayer::SetUp(const vector*>& bottom, leveldb::DB* db_temp; leveldb::Options options; options.create_if_missing = false; + options.max_open_files = 100; LOG(INFO) << "Opening leveldb " << this->layer_param_.source(); leveldb::Status status = leveldb::DB::Open( options, this->layer_param_.source(), &db_temp); diff --git a/src/caffe/layers/im2col_layer.cpp b/src/caffe/layers/im2col_layer.cpp index 976c8441e69..5f9986a2f86 100644 --- a/src/caffe/layers/im2col_layer.cpp +++ b/src/caffe/layers/im2col_layer.cpp @@ -16,11 +16,12 @@ void Im2colLayer::SetUp(const vector*>& bottom, CHECK_EQ(top->size(), 1) << "Im2col Layer takes a single blob as output."; KSIZE_ = this->layer_param_.kernelsize(); STRIDE_ = this->layer_param_.stride(); + PAD_ = this->layer_param_.pad(); CHANNELS_ = bottom[0]->channels(); HEIGHT_ = bottom[0]->height(); WIDTH_ = bottom[0]->width(); (*top)[0]->Reshape(bottom[0]->num(), CHANNELS_ * KSIZE_ * KSIZE_, - (HEIGHT_ - KSIZE_) / STRIDE_ + 1, (WIDTH_ - KSIZE_) / STRIDE_ + 1); + (HEIGHT_ + 2 * PAD_ - KSIZE_) / STRIDE_ + 1, (WIDTH_ + 2 * PAD_ - KSIZE_) / STRIDE_ + 1); }; template @@ -30,7 +31,7 @@ void Im2colLayer::Forward_cpu(const vector*>& bottom, Dtype* top_data = (*top)[0]->mutable_cpu_data(); for (int n = 0; n < bottom[0]->num(); ++n) { im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_, - WIDTH_, KSIZE_, STRIDE_, top_data + (*top)[0]->offset(n)); + WIDTH_, KSIZE_, PAD_, STRIDE_, top_data + (*top)[0]->offset(n)); } } @@ -41,7 +42,7 @@ void Im2colLayer::Forward_gpu(const vector*>& bottom, Dtype* top_data = (*top)[0]->mutable_gpu_data(); for (int n = 0; n < bottom[0]->num(); ++n) { im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_, - WIDTH_, KSIZE_, STRIDE_, top_data + (*top)[0]->offset(n)); + WIDTH_, KSIZE_, PAD_, STRIDE_, top_data + (*top)[0]->offset(n)); } } @@ -52,7 +53,7 @@ Dtype Im2colLayer::Backward_cpu(const vector*>& top, Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); for (int n = 0; n < top[0]->num(); ++n) { col2im_cpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_, - WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n)); + WIDTH_, KSIZE_, PAD_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n)); } return Dtype(0.); } @@ -65,7 +66,7 @@ Dtype Im2colLayer::Backward_gpu(const vector*>& top, Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); for (int n = 0; n < top[0]->num(); ++n) { col2im_gpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_, - WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n)); + WIDTH_, KSIZE_, PAD_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n)); } return Dtype(0.); } diff --git a/src/caffe/layers/padding_layer.cu b/src/caffe/layers/padding_layer.cu deleted file mode 100644 index 90f5508b434..00000000000 --- a/src/caffe/layers/padding_layer.cu +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2013 Yangqing Jia - -#include "caffe/layer.hpp" -#include "caffe/vision_layers.hpp" - -#include - -namespace caffe { - -template -void PaddingLayer::SetUp(const vector*>& bottom, - vector*>* top) { - PAD_ = this->layer_param_.pad(); - CHECK_EQ(bottom.size(), 1) << "Padding Layer takes a single blob as input."; - CHECK_EQ(top->size(), 1) << "Padding Layer takes a single blob as output."; - NUM_ = bottom[0]->num(); - CHANNEL_ = bottom[0]->channels(); - HEIGHT_IN_ = bottom[0]->height(); - WIDTH_IN_ = bottom[0]->width(); - HEIGHT_OUT_ = HEIGHT_IN_ + PAD_ * 2; - WIDTH_OUT_ = WIDTH_IN_ + PAD_ * 2; - (*top)[0]->Reshape(NUM_, CHANNEL_, HEIGHT_OUT_, WIDTH_OUT_); - -}; - -template -void PaddingLayer::Forward_cpu(const vector*>& bottom, - vector*>* top) { - Dtype* top_data = (*top)[0]->mutable_cpu_data(); - const Dtype* bottom_data = bottom[0]->cpu_data(); - memset(top_data, 0, sizeof(Dtype) * (*top)[0]->count()); - // In short, top[n, c, h, w] = bottom[n, c, h-pad, w-pad] if in range - for (int n = 0; n < NUM_; ++n) { - for (int c = 0; c < CHANNEL_; ++c) { - for (int h = 0; h < HEIGHT_IN_; ++h) { - // copy the width part - memcpy( - top_data + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_) - * WIDTH_OUT_ + PAD_, - bottom_data + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_, - sizeof(Dtype) * WIDTH_IN_); - } - } - } -} - -template -Dtype PaddingLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { - const Dtype* top_diff = top[0]->cpu_diff(); - Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); - //memset(bottom_data, 0, sizeof(Dtype) * (*bottom)[0]->count()); - for (int n = 0; n < NUM_; ++n) { - for (int c = 0; c < CHANNEL_; ++c) { - for (int h = 0; h < HEIGHT_IN_; ++h) { - // copy the width part - memcpy( - bottom_diff + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_, - top_diff + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_) - * WIDTH_OUT_ + PAD_, - sizeof(Dtype) * WIDTH_IN_); - } - } - } - return Dtype(0.); -} - -template -__global__ void PaddingForward(const int count, const Dtype* in, Dtype* out, - const int num, const int channel, const int height_in, const int width_in, - const int pad) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < count) { - int height_out = height_in + pad + pad; - int width_out = width_in + pad + pad; - int w = index % width_in; - index /= width_in; - int h = index % height_in; - index /= height_in; - int c = index % channel; - index /= channel; - out[((index * channel + c) * height_out + h + pad) * width_out + pad + w] = - in[((index * channel + c) * height_in + h) * width_in + w]; - } -} - -template -void PaddingLayer::Forward_gpu(const vector*>& bottom, - vector*>* top) { - const Dtype* bottom_data = bottom[0]->gpu_data(); - Dtype* top_data = (*top)[0]->mutable_gpu_data(); - const int count = bottom[0]->count(); - // First, set all data to be zero for the boundary pixels - CUDA_CHECK(cudaMemset(top_data, 0, sizeof(Dtype) * (*top)[0]->count())); - PaddingForward<<>>( - count, bottom_data, top_data, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_, - PAD_); - CUDA_POST_KERNEL_CHECK; -} - -template -__global__ void PaddingBackward(const int count, const Dtype* in, Dtype* out, - const int num, const int channel, const int height_in, const int width_in, - const int pad) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < count) { - int height_out = height_in + pad + pad; - int width_out = width_in + pad + pad; - int w = index % width_in; - index /= width_in; - int h = index % height_in; - index /= height_in; - int c = index % channel; - index /= channel; - out[((index * channel + c) * height_in + h) * width_in + w] = - in[((index * channel + c) * height_out + h + pad) * width_out + pad + w]; - } -} - -template -Dtype PaddingLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, - vector*>* bottom) { - if (propagate_down) { - const Dtype* top_diff = top[0]->gpu_diff(); - Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); - const int count = (*bottom)[0]->count(); - PaddingBackward<<>>( - count, top_diff, bottom_diff, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_, - PAD_); - CUDA_POST_KERNEL_CHECK; - } - return Dtype(0); -} - -INSTANTIATE_CLASS(PaddingLayer); - - -} // namespace caffe diff --git a/src/caffe/layers/split_layer.cpp b/src/caffe/layers/split_layer.cpp new file mode 100644 index 00000000000..5accdd08e32 --- /dev/null +++ b/src/caffe/layers/split_layer.cpp @@ -0,0 +1,101 @@ +// Copyright 2014 Jeff Donahue + +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template +void SplitLayer::SetUp(const vector*>& bottom, + vector*>* top) { + CHECK_EQ(bottom.size(), 1) << "Split Layer takes a single blob as input."; + CHECK_GE(top->size(), 1) << "Split Layer takes at least one blob as output."; + count_ = bottom[0]->count(); + for (int i = 0; i < top->size(); ++i) { + // Allow the 0th top blob to be 'in-place', but no others. + if (i == 0 && (*top)[i] == bottom[0]) { + continue; + } else { + CHECK_NE((*top)[i], bottom[0]) << "Only 0th top blob may be in place."; + } + (*top)[i]->Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + CHECK_EQ(count_, (*top)[i]->count()); + } +}; + +template +void SplitLayer::Forward_cpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + for (int i = 0; i < top->size(); ++i) { + if (i == 0 && (*top)[i] == bottom[0]) { + continue; + } + Dtype* top_data = (*top)[i]->mutable_cpu_data(); + caffe_copy(count_, bottom_data, top_data); + } +} + +template +void SplitLayer::Forward_gpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + for (int i = 0; i < top->size(); ++i) { + if (i == 0 && (*top)[i] == bottom[0]) { + continue; + } + Dtype* top_data = (*top)[i]->mutable_gpu_data(); + caffe_gpu_copy(count_, bottom_data, top_data); + } +} + +template +Dtype SplitLayer::Backward_cpu(const vector*>& top, + const bool propagate_down, vector*>* bottom) { + if (propagate_down) { + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + // Initialize by copying first top blob diff to our diff, unless we're + // doing in-place computation for the first blob, in which case the diff is + // already initialized. + if (top[0] != (*bottom)[0]) { + caffe_copy(count_, top_diff, bottom_diff); + } + // Add remaining top blob diffs. + for (int i = 1; i < top.size(); ++i) { + top_diff = top[i]->cpu_diff(); + caffe_axpy(count_, Dtype(1.), top_diff, bottom_diff); + } + } + return Dtype(0.); +} + + +template +Dtype SplitLayer::Backward_gpu(const vector*>& top, + const bool propagate_down, vector*>* bottom) { + if (propagate_down) { + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); + // Initialize by copying first top blob diff to our diff, unless we're + // doing in-place computation for the first blob, in which case the diff is + // already initialized. + if (top[0] != (*bottom)[0]) { + caffe_gpu_copy(count_, top_diff, bottom_diff); + } + // Add remaining top blob diffs. + for (int i = 1; i < top.size(); ++i) { + top_diff = top[i]->gpu_diff(); + caffe_gpu_axpy(count_, Dtype(1.), top_diff, bottom_diff); + } + } + return Dtype(0.); +} + +INSTANTIATE_CLASS(SplitLayer); + +} // namespace caffe diff --git a/src/caffe/layers/tanh_layer.cu b/src/caffe/layers/tanh_layer.cu new file mode 100644 index 00000000000..22e0831afb7 --- /dev/null +++ b/src/caffe/layers/tanh_layer.cu @@ -0,0 +1,97 @@ +// Copyright 2014 Aravindh Mahendran +// TanH neuron activation function layer. Adapted from ReLU layer code written by Yangqing Jia + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include + +namespace caffe { + +template +void TanHLayer::Forward_cpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = (*top)[0]->mutable_cpu_data(); + Dtype exp2x; + const int count = bottom[0]->count(); + for (int i = 0; i < count; ++i) { + exp2x = exp(2*bottom_data[i]); + top_data[i] = (exp2x - Dtype(1))/(exp2x + Dtype(1)); + } +} + +template +Dtype TanHLayer::Backward_cpu(const vector*>& top, + const bool propagate_down, + vector*>* bottom) { + if (propagate_down) { + const Dtype* bottom_data = (*bottom)[0]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + const int count = (*bottom)[0]->count(); + Dtype exp2x; + Dtype tanhx; + for (int i = 0; i < count; ++i) { + exp2x = exp(2*bottom_data[i]); + tanhx = (exp2x - Dtype(1))/(exp2x + Dtype(1)); + bottom_diff[i] = top_diff[i] * (1 - tanhx*tanhx); + } + } + return Dtype(0); +} + +template +__global__ void TanHForward(const int n, const Dtype* in, Dtype* out) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < n) { + Dtype exp2x = exp(2*in[index]); + out[index] = (exp2x - Dtype(1))/(exp2x + Dtype(1)); + } +} + +template +void TanHLayer::Forward_gpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = (*top)[0]->mutable_gpu_data(); + const int count = bottom[0]->count(); + TanHForward<<>>( + count, bottom_data, top_data); + CUDA_POST_KERNEL_CHECK; + // << " count: " << count << " bottom_data: " + // << (unsigned long)bottom_data << " top_data: " << (unsigned long)top_data + // << " blocks: " << CAFFE_GET_BLOCKS(count) + // << " threads: " << CAFFE_CUDA_NUM_THREADS; +} + +template +__global__ void TanHBackward(const int n, const Dtype* in_diff, + const Dtype* in_data, Dtype* out_diff) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < n) { + Dtype exp2x = exp(2*in_data[index]); + Dtype tanhx = (exp2x - Dtype(1))/(exp2x + Dtype(1)); + out_diff[index] = in_diff[index] * (1 - tanhx*tanhx); + } +} + +template +Dtype TanHLayer::Backward_gpu(const vector*>& top, + const bool propagate_down, + vector*>* bottom) { + if (propagate_down) { + const Dtype* bottom_data = (*bottom)[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); + const int count = (*bottom)[0]->count(); + TanHBackward<<>>( + count, top_diff, bottom_data, bottom_diff); + CUDA_POST_KERNEL_CHECK; + } + return Dtype(0); +} + +INSTANTIATE_CLASS(TanHLayer); + + +} // namespace caffe diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index f265cd36c55..e976dfd5fd0 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -9,6 +9,7 @@ #include "caffe/layer.hpp" #include "caffe/net.hpp" #include "caffe/util/io.hpp" +#include "caffe/util/insert_splits.hpp" using std::pair; using std::map; @@ -29,7 +30,10 @@ Net::Net(const string& param_file) { } template -void Net::Init(const NetParameter& param) { +void Net::Init(const NetParameter& in_param) { + // Create a copy of in_param with splits added where necessary. + NetParameter param; + insert_splits(in_param, ¶m); // Basically, build all the layers and set up its connections. name_ = param.name(); map blob_name_to_idx; diff --git a/src/caffe/test/test_padding_layer.cpp b/src/caffe/test/test_padding_layer.cpp deleted file mode 100644 index da48111a66d..00000000000 --- a/src/caffe/test/test_padding_layer.cpp +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2013 Yangqing Jia - -#include -#include - -#include "gtest/gtest.h" -#include "caffe/blob.hpp" -#include "caffe/common.hpp" -#include "caffe/filler.hpp" -#include "caffe/vision_layers.hpp" -#include "caffe/test/test_gradient_check_util.hpp" - -#include "caffe/test/test_caffe_main.hpp" - -namespace caffe { - -extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; - -template -class PaddingLayerTest : public ::testing::Test { - protected: - PaddingLayerTest() - : blob_bottom_(new Blob(2, 3, 4, 5)), - blob_top_(new Blob()) { - // fill the values - FillerParameter filler_param; - GaussianFiller filler(filler_param); - filler.Fill(this->blob_bottom_); - blob_bottom_vec_.push_back(blob_bottom_); - blob_top_vec_.push_back(blob_top_); - }; - virtual ~PaddingLayerTest() { delete blob_bottom_; delete blob_top_; } - Blob* const blob_bottom_; - Blob* const blob_top_; - vector*> blob_bottom_vec_; - vector*> blob_top_vec_; -}; - -typedef ::testing::Types Dtypes; -TYPED_TEST_CASE(PaddingLayerTest, Dtypes); - -TYPED_TEST(PaddingLayerTest, TestCPU) { - LayerParameter layer_param; - layer_param.set_pad(1); - Caffe::set_mode(Caffe::CPU); - PaddingLayer layer(layer_param); - layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); - layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); - EXPECT_EQ(this->blob_top_->num(), 2); - EXPECT_EQ(this->blob_top_->channels(), 3); - EXPECT_EQ(this->blob_top_->height(), 6); - EXPECT_EQ(this->blob_top_->width(), 7); - for (int n = 0; n < 2; ++n) { - for (int c = 0; c < 3; ++c) { - for (int h = 0; h < 4; ++h) { - for (int w = 0; w < 5; ++w) { - EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w), - this->blob_top_->data_at(n, c, h + 1, w + 1)); - } - } - } - } -} - -TYPED_TEST(PaddingLayerTest, TestCPUGrad) { - LayerParameter layer_param; - layer_param.set_pad(1); - Caffe::set_mode(Caffe::CPU); - PaddingLayer layer(layer_param); - GradientChecker checker(1e-2, 1e-3); - checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_); -} - -TYPED_TEST(PaddingLayerTest, TestGPU) { - if (CAFFE_TEST_CUDA_PROP.major >= 2) { - LayerParameter layer_param; - layer_param.set_pad(1); - Caffe::set_mode(Caffe::GPU); - PaddingLayer layer(layer_param); - layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); - layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); - EXPECT_EQ(this->blob_top_->num(), 2); - EXPECT_EQ(this->blob_top_->channels(), 3); - EXPECT_EQ(this->blob_top_->height(), 6); - EXPECT_EQ(this->blob_top_->width(), 7); - for (int n = 0; n < 2; ++n) { - for (int c = 0; c < 3; ++c) { - for (int h = 0; h < 4; ++h) { - for (int w = 0; w < 5; ++w) { - EXPECT_EQ(this->blob_bottom_->data_at(n, c, h, w), - this->blob_top_->data_at(n, c, h + 1, w + 1)); - } - } - } - } - } else { - LOG(ERROR) << "Skipping test (gpu version too low)."; - } -} - -TYPED_TEST(PaddingLayerTest, TestGPUGrad) { - if (CAFFE_TEST_CUDA_PROP.major >= 2) { - LayerParameter layer_param; - layer_param.set_pad(1); - Caffe::set_mode(Caffe::GPU); - PaddingLayer layer(layer_param); - GradientChecker checker(1e-2, 1e-3); - checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_); - } else { - LOG(ERROR) << "Skipping test (gpu version too low)."; - } -} - -} diff --git a/src/caffe/test/test_split_layer.cpp b/src/caffe/test/test_split_layer.cpp new file mode 100644 index 00000000000..3311c9ac76c --- /dev/null +++ b/src/caffe/test/test_split_layer.cpp @@ -0,0 +1,1128 @@ +// Copyright 2014 Jeff Donahue + +#include +#include +#include + +#include "gtest/gtest.h" +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/test/test_gradient_check_util.hpp" +#include "caffe/util/insert_splits.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; + +template +class SplitLayerTest : public ::testing::Test { + protected: + SplitLayerTest() + : blob_bottom_(new Blob(2, 3, 6, 5)), + blob_top_a_(new Blob()), + blob_top_b_(new Blob()) { + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_a_); + blob_top_vec_.push_back(blob_top_b_); + }; + virtual ~SplitLayerTest() { + delete blob_bottom_; + delete blob_top_a_; + delete blob_top_b_; + } + Blob* const blob_bottom_; + Blob* const blob_top_a_; + Blob* const blob_top_b_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +typedef ::testing::Types Dtypes; +TYPED_TEST_CASE(SplitLayerTest, Dtypes); + +TYPED_TEST(SplitLayerTest, TestSetup) { + LayerParameter layer_param; + SplitLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + EXPECT_EQ(this->blob_top_a_->num(), 2); + EXPECT_EQ(this->blob_top_a_->channels(), 3); + EXPECT_EQ(this->blob_top_a_->height(), 6); + EXPECT_EQ(this->blob_top_a_->width(), 5); + EXPECT_EQ(this->blob_top_b_->num(), 2); + EXPECT_EQ(this->blob_top_b_->channels(), 3); + EXPECT_EQ(this->blob_top_b_->height(), 6); + EXPECT_EQ(this->blob_top_b_->width(), 5); +} + +TYPED_TEST(SplitLayerTest, TestCPU) { + LayerParameter layer_param; + SplitLayer layer(layer_param); + Caffe::set_mode(Caffe::CPU); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + TypeParam bottom_value = this->blob_bottom_->cpu_data()[i]; + EXPECT_EQ(bottom_value, this->blob_top_a_->cpu_data()[i]); + EXPECT_EQ(bottom_value, this->blob_top_b_->cpu_data()[i]); + } +} + +TYPED_TEST(SplitLayerTest, TestGPU) { + LayerParameter layer_param; + SplitLayer layer(layer_param); + Caffe::set_mode(Caffe::GPU); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + TypeParam bottom_value = this->blob_bottom_->cpu_data()[i]; + EXPECT_EQ(bottom_value, this->blob_top_a_->cpu_data()[i]); + EXPECT_EQ(bottom_value, this->blob_top_b_->cpu_data()[i]); + } +} + +TYPED_TEST(SplitLayerTest, TestCPUInPlace) { + LayerParameter layer_param; + SplitLayer layer(layer_param); + Caffe::set_mode(Caffe::CPU); + this->blob_top_vec_[0] = this->blob_bottom_vec_[0]; + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + TypeParam bottom_value = this->blob_bottom_->cpu_data()[i]; + EXPECT_EQ(bottom_value, this->blob_top_b_->cpu_data()[i]); + } +} + +TYPED_TEST(SplitLayerTest, TestGPUInPlace) { + LayerParameter layer_param; + SplitLayer layer(layer_param); + Caffe::set_mode(Caffe::GPU); + this->blob_top_vec_[0] = this->blob_bottom_vec_[0]; + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + TypeParam bottom_value = this->blob_bottom_->cpu_data()[i]; + EXPECT_EQ(bottom_value, this->blob_top_b_->cpu_data()[i]); + } +} + +TYPED_TEST(SplitLayerTest, TestCPUGradient) { + LayerParameter layer_param; + Caffe::set_mode(Caffe::CPU); + SplitLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(SplitLayerTest, TestGPUGradient) { + LayerParameter layer_param; + Caffe::set_mode(Caffe::GPU); + SplitLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(SplitLayerTest, TestCPUGradientInPlace) { + LayerParameter layer_param; + Caffe::set_mode(Caffe::CPU); + SplitLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + this->blob_top_vec_[0] = this->blob_bottom_vec_[0]; + checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(SplitLayerTest, TestGPUGradientInPlace) { + LayerParameter layer_param; + Caffe::set_mode(Caffe::GPU); + SplitLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + this->blob_top_vec_[0] = this->blob_bottom_vec_[0]; + checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + + +template +class SplitLayerInsertionTest : public ::testing::Test { + protected: + SplitLayerInsertionTest() { }; + void RunInsertionTest( + const string& input_param_string, const string& output_param_string) { + // Test that insert_splits called on the proto specified by + // input_param_string results in the proto specified by + // output_param_string. + NetParameter input_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + input_param_string, &input_param)); + NetParameter expected_output_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + output_param_string, &expected_output_param)); + NetParameter actual_output_param; + insert_splits(input_param, &actual_output_param); + EXPECT_EQ(expected_output_param.DebugString(), + actual_output_param.DebugString()); + // Also test idempotence. + NetParameter double_split_insert_param; + insert_splits(actual_output_param, &double_split_insert_param); + EXPECT_EQ(actual_output_param.DebugString(), + double_split_insert_param.DebugString()); + } +}; + +typedef ::testing::Types InsertionDtypes; +TYPED_TEST_CASE(SplitLayerInsertionTest, InsertionDtypes); + +TYPED_TEST(SplitLayerInsertionTest, TestNoInsertion1) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " layer { " + " name: 'data' " + " type: 'data' " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " layer { " + " name: 'innerprod' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod' " + "} " + "layers: { " + " layer { " + " name: 'loss' " + " type: 'softmax_with_loss' " + " } " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunInsertionTest(input_proto, input_proto); +} + +TYPED_TEST(SplitLayerInsertionTest, TestNoInsertion2) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " layer { " + " name: 'data' " + " type: 'data' " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " layer { " + " name: 'data_split' " + " type: 'split' " + " } " + " bottom: 'data' " + " top: 'data_split_0' " + " top: 'data_split_1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod1' " + " type: 'inner_product' " + " } " + " bottom: 'data_split_0' " + " top: 'innerprod1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod2' " + " type: 'inner_product' " + " } " + " bottom: 'data_split_1' " + " top: 'innerprod2' " + "} " + "layers: { " + " layer { " + " name: 'loss' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod1' " + " bottom: 'innerprod2' " + "} "; + this->RunInsertionTest(input_proto, input_proto); +} + +TYPED_TEST(SplitLayerInsertionTest, TestNoInsertionImageNet) { + const string& input_proto = + "name: 'CaffeNet' " + "layers { " + " layer { " + " name: 'data' " + " type: 'data' " + " source: '/home/jiayq/Data/ILSVRC12/train-leveldb' " + " meanfile: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " batchsize: 256 " + " cropsize: 227 " + " mirror: true " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers { " + " layer { " + " name: 'conv1' " + " type: 'conv' " + " num_output: 96 " + " kernelsize: 11 " + " stride: 4 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'data' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'relu1' " + " type: 'relu' " + " } " + " bottom: 'conv1' " + " top: 'conv1' " + "} " + "layers { " + " layer { " + " name: 'pool1' " + " type: 'pool' " + " pool: MAX " + " kernelsize: 3 " + " stride: 2 " + " } " + " bottom: 'conv1' " + " top: 'pool1' " + "} " + "layers { " + " layer { " + " name: 'norm1' " + " type: 'lrn' " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool1' " + " top: 'norm1' " + "} " + "layers { " + " layer { " + " name: 'pad2' " + " type: 'padding' " + " pad: 2 " + " } " + " bottom: 'norm1' " + " top: 'pad2' " + "} " + "layers { " + " layer { " + " name: 'conv2' " + " type: 'conv' " + " num_output: 256 " + " group: 2 " + " kernelsize: 5 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad2' " + " top: 'conv2' " + "} " + "layers { " + " layer { " + " name: 'relu2' " + " type: 'relu' " + " } " + " bottom: 'conv2' " + " top: 'conv2' " + "} " + "layers { " + " layer { " + " name: 'pool2' " + " type: 'pool' " + " pool: MAX " + " kernelsize: 3 " + " stride: 2 " + " } " + " bottom: 'conv2' " + " top: 'pool2' " + "} " + "layers { " + " layer { " + " name: 'norm2' " + " type: 'lrn' " + " local_size: 5 " + " alpha: 0.0001 " + " beta: 0.75 " + " } " + " bottom: 'pool2' " + " top: 'norm2' " + "} " + "layers { " + " layer { " + " name: 'pad3' " + " type: 'padding' " + " pad: 1 " + " } " + " bottom: 'norm2' " + " top: 'pad3' " + "} " + "layers { " + " layer { " + " name: 'conv3' " + " type: 'conv' " + " num_output: 384 " + " kernelsize: 3 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad3' " + " top: 'conv3' " + "} " + "layers { " + " layer { " + " name: 'relu3' " + " type: 'relu' " + " } " + " bottom: 'conv3' " + " top: 'conv3' " + "} " + "layers { " + " layer { " + " name: 'pad4' " + " type: 'padding' " + " pad: 1 " + " } " + " bottom: 'conv3' " + " top: 'pad4' " + "} " + "layers { " + " layer { " + " name: 'conv4' " + " type: 'conv' " + " num_output: 384 " + " group: 2 " + " kernelsize: 3 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad4' " + " top: 'conv4' " + "} " + "layers { " + " layer { " + " name: 'relu4' " + " type: 'relu' " + " } " + " bottom: 'conv4' " + " top: 'conv4' " + "} " + "layers { " + " layer { " + " name: 'pad5' " + " type: 'padding' " + " pad: 1 " + " } " + " bottom: 'conv4' " + " top: 'pad5' " + "} " + "layers { " + " layer { " + " name: 'conv5' " + " type: 'conv' " + " num_output: 256 " + " group: 2 " + " kernelsize: 3 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pad5' " + " top: 'conv5' " + "} " + "layers { " + " layer { " + " name: 'relu5' " + " type: 'relu' " + " } " + " bottom: 'conv5' " + " top: 'conv5' " + "} " + "layers { " + " layer { " + " name: 'pool5' " + " type: 'pool' " + " kernelsize: 3 " + " pool: MAX " + " stride: 2 " + " } " + " bottom: 'conv5' " + " top: 'pool5' " + "} " + "layers { " + " layer { " + " name: 'fc6' " + " type: 'innerproduct' " + " num_output: 4096 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'pool5' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'relu6' " + " type: 'relu' " + " } " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'drop6' " + " type: 'dropout' " + " dropout_ratio: 0.5 " + " } " + " bottom: 'fc6' " + " top: 'fc6' " + "} " + "layers { " + " layer { " + " name: 'fc7' " + " type: 'innerproduct' " + " num_output: 4096 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.005 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 1. " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'fc6' " + " top: 'fc7' " + "} " + "layers { " + " layer { " + " name: 'relu7' " + " type: 'relu' " + " } " + " bottom: 'fc7' " + " top: 'fc7' " + "} " + "layers { " + " layer { " + " name: 'drop7' " + " type: 'dropout' " + " dropout_ratio: 0.5 " + " } " + " bottom: 'fc7' " + " top: 'fc7' " + "} " + "layers { " + " layer { " + " name: 'fc8' " + " type: 'innerproduct' " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " } " + " bottom: 'fc7' " + " top: 'fc8' " + "} " + "layers { " + " layer { " + " name: 'loss' " + " type: 'softmax_loss' " + " } " + " bottom: 'fc8' " + " bottom: 'label' " + "} "; + this->RunInsertionTest(input_proto, input_proto); +} + +TYPED_TEST(SplitLayerInsertionTest, TestInsertionWithInPlace) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " layer { " + " name: 'data' " + " type: 'data' " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " layer { " + " name: 'innerprod' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod' " + "} " + "layers: { " + " layer { " + " name: 'relu' " + " type: 'relu' " + " } " + " bottom: 'innerprod' " + " top: 'innerprod' " + "} " + "layers: { " + " layer { " + " name: 'loss' " + " type: 'softmax_with_loss' " + " } " + " bottom: 'innerprod' " + " bottom: 'label' " + "} "; + this->RunInsertionTest(input_proto, input_proto); +} + +TYPED_TEST(SplitLayerInsertionTest, TestInsertion) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " layer { " + " name: 'data' " + " type: 'data' " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " layer { " + " name: 'innerprod1' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod2' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod2' " + "} " + "layers: { " + " layer { " + " name: 'innerprod3' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod3' " + "} " + "layers: { " + " layer { " + " name: 'loss1' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod1' " + " bottom: 'innerprod2' " + "} " + "layers: { " + " layer { " + " name: 'loss2' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod2' " + " bottom: 'innerprod3' " + "} "; + const string& expected_output_proto = + "name: 'TestNetwork' " + "layers: { " + " layer { " + " name: 'data' " + " type: 'data' " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " layer { " + " name: 'data_data_0_split' " + " type: 'split' " + " } " + " bottom: 'data' " + " top: 'data' " + " top: 'data_data_0_split_1' " + " top: 'data_data_0_split_2' " + "} " + "layers: { " + " layer { " + " name: 'innerprod1' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod2' " + " type: 'inner_product' " + " } " + " bottom: 'data_data_0_split_1' " + " top: 'innerprod2' " + "} " + "layers: { " + " layer { " + " name: 'innerprod2_innerprod2_0_split' " + " type: 'split' " + " } " + " bottom: 'innerprod2' " + " top: 'innerprod2' " + " top: 'innerprod2_innerprod2_0_split_1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod3' " + " type: 'inner_product' " + " } " + " bottom: 'data_data_0_split_2' " + " top: 'innerprod3' " + "} " + "layers: { " + " layer { " + " name: 'loss1' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod1' " + " bottom: 'innerprod2' " + "} " + "layers: { " + " layer { " + " name: 'loss2' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod2_innerprod2_0_split_1' " + " bottom: 'innerprod3' " + "} "; + this->RunInsertionTest(input_proto, expected_output_proto); +} + +TYPED_TEST(SplitLayerInsertionTest, TestInsertionTwoTop) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " layer { " + " name: 'data' " + " type: 'data' " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " layer { " + " name: 'innerprod1' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod2' " + " type: 'inner_product' " + " } " + " bottom: 'label' " + " top: 'innerprod2' " + "} " + "layers: { " + " layer { " + " name: 'innerprod3' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod3' " + "} " + "layers: { " + " layer { " + " name: 'innerprod4' " + " type: 'inner_product' " + " } " + " bottom: 'label' " + " top: 'innerprod4' " + "} " + "layers: { " + " layer { " + " name: 'loss1' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod1' " + " bottom: 'innerprod3' " + "} " + "layers: { " + " layer { " + " name: 'loss2' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod2' " + " bottom: 'innerprod4' " + "} "; + const string& expected_output_proto = + "name: 'TestNetwork' " + "layers: { " + " layer { " + " name: 'data' " + " type: 'data' " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " layer { " + " name: 'data_data_0_split' " + " type: 'split' " + " } " + " bottom: 'data' " + " top: 'data' " + " top: 'data_data_0_split_1' " + "} " + "layers: { " + " layer { " + " name: 'label_data_1_split' " + " type: 'split' " + " } " + " bottom: 'label' " + " top: 'label' " + " top: 'label_data_1_split_1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod1' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod2' " + " type: 'inner_product' " + " } " + " bottom: 'label' " + " top: 'innerprod2' " + "} " + "layers: { " + " layer { " + " name: 'innerprod3' " + " type: 'inner_product' " + " } " + " bottom: 'data_data_0_split_1' " + " top: 'innerprod3' " + "} " + "layers: { " + " layer { " + " name: 'innerprod4' " + " type: 'inner_product' " + " } " + " bottom: 'label_data_1_split_1' " + " top: 'innerprod4' " + "} " + "layers: { " + " layer { " + " name: 'loss1' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod1' " + " bottom: 'innerprod3' " + "} " + "layers: { " + " layer { " + " name: 'loss2' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod2' " + " bottom: 'innerprod4' " + "} "; + this->RunInsertionTest(input_proto, expected_output_proto); +} + +TYPED_TEST(SplitLayerInsertionTest, TestInputInsertion) { + const string& input_proto = + "name: 'TestNetwork' " + "input: 'data' " + "input_dim: 10 " + "input_dim: 3 " + "input_dim: 227 " + "input_dim: 227 " + "layers: { " + " layer { " + " name: 'innerprod1' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod2' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod2' " + "} " + "layers: { " + " layer { " + " name: 'loss' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod1' " + " bottom: 'innerprod2' " + "} "; + const string& expected_output_proto = + "name: 'TestNetwork' " + "input: 'data' " + "input_dim: 10 " + "input_dim: 3 " + "input_dim: 227 " + "input_dim: 227 " + "layers: { " + " layer { " + " name: 'data_input_0_split' " + " type: 'split' " + " } " + " bottom: 'data' " + " top: 'data' " + " top: 'data_input_0_split_1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod1' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod2' " + " type: 'inner_product' " + " } " + " bottom: 'data_input_0_split_1' " + " top: 'innerprod2' " + "} " + "layers: { " + " layer { " + " name: 'loss' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod1' " + " bottom: 'innerprod2' " + "} "; + this->RunInsertionTest(input_proto, expected_output_proto); +} + +TYPED_TEST(SplitLayerInsertionTest, TestWithInPlace) { + const string& input_proto = + "name: 'TestNetwork' " + "layers: { " + " layer { " + " name: 'data' " + " type: 'data' " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " layer { " + " name: 'innerprod1' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod1' " + "} " + "layers: { " + " layer { " + " name: 'relu1' " + " type: 'relu' " + " } " + " bottom: 'innerprod1' " + " top: 'innerprod1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod2' " + " type: 'inner_product' " + " } " + " bottom: 'innerprod1' " + " top: 'innerprod2' " + "} " + "layers: { " + " layer { " + " name: 'loss1' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod1' " + " bottom: 'label' " + "} " + "layers: { " + " layer { " + " name: 'loss2' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod2' " + " bottom: 'data' " + "} "; + const string& expected_output_proto = + "name: 'TestNetwork' " + "layers: { " + " layer { " + " name: 'data' " + " type: 'data' " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " layer { " + " name: 'data_data_0_split' " + " type: 'split' " + " } " + " bottom: 'data' " + " top: 'data' " + " top: 'data_data_0_split_1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod1' " + " type: 'inner_product' " + " } " + " bottom: 'data' " + " top: 'innerprod1' " + "} " + "layers: { " + " layer { " + " name: 'relu1' " + " type: 'relu' " + " } " + " bottom: 'innerprod1' " + " top: 'innerprod1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod1_relu1_0_split' " + " type: 'split' " + " } " + " bottom: 'innerprod1' " + " top: 'innerprod1' " + " top: 'innerprod1_relu1_0_split_1' " + "} " + "layers: { " + " layer { " + " name: 'innerprod2' " + " type: 'inner_product' " + " } " + " bottom: 'innerprod1' " + " top: 'innerprod2' " + "} " + "layers: { " + " layer { " + " name: 'loss1' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod1_relu1_0_split_1' " + " bottom: 'label' " + "} " + "layers: { " + " layer { " + " name: 'loss2' " + " type: 'euclidean_loss' " + " } " + " bottom: 'innerprod2' " + " bottom: 'data_data_0_split_1' " + "} "; + this->RunInsertionTest(input_proto, expected_output_proto); +} + +} diff --git a/src/caffe/test/test_tanh_layer.cpp b/src/caffe/test/test_tanh_layer.cpp new file mode 100644 index 00000000000..a4226a28b22 --- /dev/null +++ b/src/caffe/test/test_tanh_layer.cpp @@ -0,0 +1,102 @@ +// Copyright 2014 Aravindh Mahendran +// Adapted from other test files + +#include +#include +#include + +#include "gtest/gtest.h" +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; + +template +class TanHLayerTest : public ::testing::Test { + protected: + TanHLayerTest() + : blob_bottom_(new Blob(2, 10, 1, 1)), + blob_top_(new Blob()) { + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + }; + virtual ~TanHLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +typedef ::testing::Types Dtypes; +TYPED_TEST_CASE(TanHLayerTest, Dtypes); + +TYPED_TEST(TanHLayerTest, TestForwardCPU) { + LayerParameter layer_param; + Caffe::set_mode(Caffe::CPU); + TanHLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + // Test exact values + for (int i = 0; i < this->blob_bottom_->num(); ++i) { + for (int j = 0; j < this->blob_bottom_->channels(); ++j) { + for (int k = 0; k < this->blob_bottom_->height(); ++k) { + for (int l = 0; l < this->blob_bottom_->width(); ++l) { + EXPECT_GE(this->blob_top_->data_at(i,j,k,l) + 1e-4, + (exp(2*this->blob_bottom_->data_at(i,j,k,l))-1)/(exp(2*this->blob_bottom_->data_at(i,j,k,l))+1)); + EXPECT_LE(this->blob_top_->data_at(i,j,k,l) - 1e-4, + (exp(2*this->blob_bottom_->data_at(i,j,k,l))-1)/(exp(2*this->blob_bottom_->data_at(i,j,k,l))+1)); + } + } + } + } +} + +TYPED_TEST(TanHLayerTest, TestGradientCPU) { + LayerParameter layer_param; + Caffe::set_mode(Caffe::CPU); + TanHLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_); +} + +TYPED_TEST(TanHLayerTest, TestForwardGPU) { + LayerParameter layer_param; + Caffe::set_mode(Caffe::GPU); + TanHLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); + layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); + // Test exact values + for (int i = 0; i < this->blob_bottom_->num(); ++i) { + for (int j = 0; j < this->blob_bottom_->channels(); ++j) { + for (int k = 0; k < this->blob_bottom_->height(); ++k) { + for (int l = 0; l < this->blob_bottom_->width(); ++l) { + EXPECT_GE(this->blob_top_->data_at(i,j,k,l) + 1e-4, + (exp(2*this->blob_bottom_->data_at(i,j,k,l))-1)/(exp(2*this->blob_bottom_->data_at(i,j,k,l))+1)); + EXPECT_LE(this->blob_top_->data_at(i,j,k,l) - 1e-4, + (exp(2*this->blob_bottom_->data_at(i,j,k,l))-1)/(exp(2*this->blob_bottom_->data_at(i,j,k,l))+1)); + } + } + } + } +} + +TYPED_TEST(TanHLayerTest, TestGradientGPU) { + LayerParameter layer_param; + Caffe::set_mode(Caffe::GPU); + TanHLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-3); + checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_); +} + +} diff --git a/src/caffe/util/im2col.cpp b/src/caffe/util/im2col.cpp index db79bb2cdc9..b32f6eeecc8 100644 --- a/src/caffe/util/im2col.cpp +++ b/src/caffe/util/im2col.cpp @@ -10,10 +10,10 @@ namespace caffe { template void im2col_cpu(const Dtype* data_im, const int channels, - const int height, const int width, const int ksize, const int stride, + const int height, const int width, const int ksize, const int pad, const int stride, Dtype* data_col) { - int height_col = (height - ksize) / stride + 1; - int width_col = (width - ksize) / stride + 1; + int height_col = (height + 2 * pad - ksize) / stride + 1; + int width_col = (width + 2 * pad - ksize) / stride + 1; int channels_col = channels * ksize * ksize; for (int c = 0; c < channels_col; ++c) { int w_offset = c % ksize; @@ -21,9 +21,13 @@ void im2col_cpu(const Dtype* data_im, const int channels, int c_im = c / ksize / ksize; for (int h = 0; h < height_col; ++h) { for (int w = 0; w < width_col; ++w) { - data_col[(c * height_col + h) * width_col + w] = - data_im[(c_im * height + h * stride + h_offset) * width - + w * stride + w_offset]; + int h_pad = h * stride - pad + h_offset; + int w_pad = w * stride - pad + w_offset; + if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width) + data_col[(c * height_col + h) * width_col + w] = + data_im[(c_im * height + h_pad) * width + w_pad]; + else + data_col[(c * height_col + h) * width_col + w] = 0; } } } @@ -31,19 +35,19 @@ void im2col_cpu(const Dtype* data_im, const int channels, // Explicit instantiation template void im2col_cpu(const float* data_im, const int channels, - const int height, const int width, const int ksize, const int stride, + const int height, const int width, const int ksize, const int pad, const int stride, float* data_col); template void im2col_cpu(const double* data_im, const int channels, - const int height, const int width, const int ksize, const int stride, + const int height, const int width, const int ksize, const int pad, const int stride, double* data_col); template void col2im_cpu(const Dtype* data_col, const int channels, - const int height, const int width, const int ksize, const int stride, + const int height, const int width, const int ksize, const int pad, const int stride, Dtype* data_im) { memset(data_im, 0, sizeof(Dtype) * height * width * channels); - int height_col = (height - ksize) / stride + 1; - int width_col = (width - ksize) / stride + 1; + int height_col = (height + 2 * pad - ksize) / stride + 1; + int width_col = (width + 2 * pad - ksize) / stride + 1; int channels_col = channels * ksize * ksize; for (int c = 0; c < channels_col; ++c) { int w_offset = c % ksize; @@ -51,8 +55,10 @@ void col2im_cpu(const Dtype* data_col, const int channels, int c_im = c / ksize / ksize; for (int h = 0; h < height_col; ++h) { for (int w = 0; w < width_col; ++w) { - data_im[(c_im * height + h * stride + h_offset) * width + w * stride - + w_offset] += data_col[(c * height_col + h) * width_col + w]; + int h_pad = h * stride - pad + h_offset; + int w_pad = w * stride - pad + w_offset; + if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width) + data_im[(c_im * height + h_pad) * width + w_pad] += data_col[(c * height_col + h) * width_col + w]; } } } @@ -60,10 +66,10 @@ void col2im_cpu(const Dtype* data_col, const int channels, // Explicit instantiation template void col2im_cpu(const float* data_col, const int channels, - const int height, const int width, const int psize, const int stride, + const int height, const int width, const int psize, const int pad, const int stride, float* data_im); template void col2im_cpu(const double* data_col, const int channels, - const int height, const int width, const int psize, const int stride, + const int height, const int width, const int psize, const int pad, const int stride, double* data_im); } // namespace caffe diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu index 0b0c8b8354f..7f1376d6b16 100644 --- a/src/caffe/util/im2col.cu +++ b/src/caffe/util/im2col.cu @@ -9,10 +9,9 @@ namespace caffe { - template __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, - const int height, const int width, const int ksize, + const int height, const int width, const int ksize, const int pad, const int stride, const int height_col, const int width_col, Dtype* data_col) { int index = threadIdx.x + blockIdx.x * blockDim.x; if (index < n) { @@ -21,14 +20,16 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, int h_out = index % height_col; int channel_in = index / height_col; int channel_out = channel_in * ksize * ksize; - int h_in = h_out * stride; - int w_in = w_out * stride; + int h_in = h_out * stride - pad; + int w_in = w_out * stride - pad; data_col += (channel_out * height_col + h_out) * width_col + w_out; data_im += (channel_in * height + h_in) * width + w_in; for (int i = 0; i < ksize; ++i) { for (int j = 0; j < ksize; ++j) { - *data_col = data_im[i * width + j]; - data_col += height_col * width_col; + int h = h_in + i; + int w = w_in + j; + *data_col = (h >= 0 && w >= 0 && h < width && w < height) ? data_im[i * width + j] : 0; + data_col += height_col * width_col; } } } @@ -36,15 +37,15 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im, template void im2col_gpu(const Dtype* data_im, const int channels, - const int height, const int width, const int ksize, const int stride, + const int height, const int width, const int ksize, const int pad, const int stride, Dtype* data_col) { // We are going to launch channels * height_col * width_col kernels, each // kernel responsible for copying a single-channel grid. - int height_col = (height - ksize) / stride + 1; - int width_col = (width - ksize) / stride + 1; + int height_col = (height + 2 * pad - ksize) / stride + 1; + int width_col = (width + 2 * pad - ksize) / stride + 1; int num_kernels = channels * height_col * width_col; im2col_gpu_kernel<<>>( - num_kernels, data_im, height, width, ksize, stride, height_col, width_col, + num_kernels, data_im, height, width, ksize, pad, stride, height_col, width_col, data_col); CUDA_POST_KERNEL_CHECK; } @@ -52,21 +53,21 @@ void im2col_gpu(const Dtype* data_im, const int channels, // Explicit instantiation template void im2col_gpu(const float* data_im, const int channels, - const int height, const int width, const int ksize, const int stride, + const int height, const int width, const int ksize, const int pad, const int stride, float* data_col); template void im2col_gpu(const double* data_im, const int channels, - const int height, const int width, const int ksize, const int stride, + const int height, const int width, const int ksize, const int pad, const int stride, double* data_col); template __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, - const int height, const int width, const int channels, const int ksize, + const int height, const int width, const int channels, const int ksize, const int pad, const int stride, const int height_col, const int width_col, Dtype* data_im) { int index = threadIdx.x + blockIdx.x * blockDim.x; if (index < n) { Dtype val = 0; - int w = index % width; - int h = (index / width) % height; + int w = index % width + pad; + int h = (index / width) % height + pad; int c = index / (width * height); // compute the start and end of the output int w_col_start = (w < ksize) ? 0 : (w - ksize) / stride + 1; @@ -97,16 +98,16 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, template void col2im_gpu(const Dtype* data_col, const int channels, - const int height, const int width, const int ksize, const int stride, + const int height, const int width, const int ksize, const int pad, const int stride, Dtype* data_im) { //CUDA_CHECK(cudaMemset(data_im, 0, sizeof(Dtype) * height * width * channels)); - int height_col = (height - ksize) / stride + 1; - int width_col = (width - ksize) / stride + 1; + int height_col = (height + 2 * pad - ksize) / stride + 1; + int width_col = (width + 2 * pad - ksize) / stride + 1; int num_kernels = channels * height * width; // To avoid involving atomic operations, we will launch one kernel per // bottom dimension, and then in the kernel add up the top dimensions. col2im_gpu_kernel<<>>( - num_kernels, data_col, height, width, channels, ksize, stride, + num_kernels, data_col, height, width, channels, ksize, pad, stride, height_col, width_col, data_im); CUDA_POST_KERNEL_CHECK; } @@ -114,10 +115,10 @@ void col2im_gpu(const Dtype* data_col, const int channels, // Explicit instantiation template void col2im_gpu(const float* data_col, const int channels, - const int height, const int width, const int psize, const int stride, + const int height, const int width, const int psize, const int pad, const int stride, float* data_im); template void col2im_gpu(const double* data_col, const int channels, - const int height, const int width, const int psize, const int stride, + const int height, const int width, const int psize, const int pad, const int stride, double* data_im); diff --git a/src/caffe/util/insert_splits.cpp b/src/caffe/util/insert_splits.cpp new file mode 100644 index 00000000000..6db6458c4af --- /dev/null +++ b/src/caffe/util/insert_splits.cpp @@ -0,0 +1,129 @@ +// Copyright 2014 Jeff Donahue + +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/util/insert_splits.hpp" + +using std::map; +using std::ostringstream; +using std::pair; +using std::make_pair; + +namespace caffe { + +void insert_splits(const NetParameter& param, NetParameter* param_split) { + // Initialize by copying from the input NetParameter. + param_split->CopyFrom(param); + param_split->clear_layers(); + map > blob_name_to_last_top_idx; + map, pair > bottom_idx_to_source_top_idx; + map, int> top_idx_to_bottom_count; + map, int> top_idx_to_bottom_split_idx; + map layer_idx_to_layer_name; + layer_idx_to_layer_name[-1] = "input"; + // Determine the number of times each blob is used as an input (bottom) blob. + for (int i = 0; i < param.input_size(); ++i) { + const string& blob_name = param.input(i); + blob_name_to_last_top_idx[blob_name] = make_pair(-1, i); + } + for (int i = 0; i < param.layers_size(); ++i) { + const LayerConnection& layer_connection = param.layers(i); + layer_idx_to_layer_name[i] = layer_connection.layer().name(); + for (int j = 0; j < layer_connection.bottom_size(); ++j) { + const string& blob_name = layer_connection.bottom(j); + if (blob_name_to_last_top_idx.find(blob_name) == + blob_name_to_last_top_idx.end()) { + LOG(FATAL) << "Unknown blob input " << blob_name << " to layer " << j; + } + const pair& bottom_idx = make_pair(i, j); + const pair& top_idx = blob_name_to_last_top_idx[blob_name]; + bottom_idx_to_source_top_idx[bottom_idx] = top_idx; + ++top_idx_to_bottom_count[top_idx]; + } + for (int j = 0; j < layer_connection.top_size(); ++j) { + const string& blob_name = layer_connection.top(j); + blob_name_to_last_top_idx[blob_name] = make_pair(i, j); + } + } + // Create split layer for any input blobs used by other layers as bottom + // blobs more than once. + for (int i = 0; i < param.input_size(); ++i) { + const int split_count = top_idx_to_bottom_count[make_pair(-1, i)]; + if (split_count > 1) { + const string& layer_name = layer_idx_to_layer_name[-1]; + const string& blob_name = param.input(i); + LayerConnection* split_layer_connection = param_split->add_layers(); + configure_split_layer(layer_name, blob_name, i, split_count, + split_layer_connection); + } + } + for (int i = 0; i < param.layers_size(); ++i) { + LayerConnection* layer_connection = param_split->add_layers(); + layer_connection->CopyFrom(param.layers(i)); + // Replace any shared bottom blobs with split layer outputs. + for (int j = 0; j < layer_connection->bottom_size(); ++j) { + const pair& top_idx = + bottom_idx_to_source_top_idx[make_pair(i, j)]; + const int split_count = top_idx_to_bottom_count[top_idx]; + if (split_count > 1) { + const string& layer_name = layer_idx_to_layer_name[top_idx.first]; + const string& blob_name = layer_connection->bottom(j); + layer_connection->set_bottom(j, get_split_blob_name(layer_name, + blob_name, top_idx.second, top_idx_to_bottom_split_idx[top_idx]++)); + } + } + // Create split layer for any top blobs used by other layers as bottom + // blobs more than once. + for (int j = 0; j < layer_connection->top_size(); ++j) { + const int split_count = top_idx_to_bottom_count[make_pair(i, j)]; + if (split_count > 1) { + const string& layer_name = layer_idx_to_layer_name[i]; + const string& blob_name = layer_connection->top(j); + LayerConnection* split_layer_connection = param_split->add_layers(); + configure_split_layer(layer_name, blob_name, j, split_count, + split_layer_connection); + } + } + } +} + +void configure_split_layer(const string& layer_name, const string& blob_name, + const int blob_idx, const int split_count, + LayerConnection* split_layer_connection) { + split_layer_connection->Clear(); + split_layer_connection->add_bottom(blob_name); + LayerParameter* split_layer_param = split_layer_connection->mutable_layer(); + split_layer_param->set_name( + get_split_layer_name(layer_name, blob_name, blob_idx)); + split_layer_param->set_type("split"); + for (int k = 0; k < split_count; ++k) { + split_layer_connection->add_top( + get_split_blob_name(layer_name, blob_name, blob_idx, k)); + } +} + +string get_split_layer_name(const string& layer_name, const string& blob_name, + const int blob_idx) { + ostringstream split_layer_name; + split_layer_name << blob_name << "_" << layer_name << "_" << blob_idx + << "_split"; + return split_layer_name.str(); +} + +string get_split_blob_name(const string& layer_name, const string& blob_name, + const int blob_idx, const int split_idx) { + // 0th split top blob is given the same name as the bottom blob so that + // computation is done 'in-place', saving a bit of time and memory. + if (split_idx == 0) { + return blob_name; + } + ostringstream split_blob_name; + split_blob_name << blob_name << "_" << layer_name << "_" << blob_idx + << "_split_" << split_idx; + return split_blob_name.str(); +} + +} // namespace caffe