Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,27 @@ class DropoutLayer : public NeuronLayer<Dtype> {
};


template <typename Dtype>
class FlattenLayer : public Layer<Dtype> {
public:
explicit FlattenLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
int channels_out_;
};


template <typename Dtype>
class InnerProductLayer : public Layer<Dtype> {
public:
Expand Down
2 changes: 2 additions & 0 deletions src/caffe/layer_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
return new DropoutLayer<Dtype>(param);
} else if (type == "euclidean_loss") {
return new EuclideanLossLayer<Dtype>(param);
} else if (type == "flatten") {
return new FlattenLayer<Dtype>(param);
} else if (type == "im2col") {
return new Im2colLayer<Dtype>(param);
} else if (type == "infogain_loss") {
Expand Down
57 changes: 57 additions & 0 deletions src/caffe/layers/flatten_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright 2013 Yangqing Jia

#include <vector>

#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/util/math_functions.hpp"

namespace caffe {

template <typename Dtype>
void FlattenLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 1) << "Flatten Layer takes a single blob as input.";
CHECK_EQ(top->size(), 1) << "Flatten Layer takes a single blob as output.";
channels_out_ = bottom[0]->channels() * bottom[0]->height()
* bottom[0]->width();
(*top)[0]->Reshape(bottom[0]->num(), channels_out_, 1, 1);
};

template <typename Dtype>
void FlattenLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {

const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
caffe_copy(channels_out_, bottom_data, top_data);
}

template <typename Dtype>
void FlattenLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
caffe_gpu_copy(channels_out_, bottom_data, top_data);
}

template <typename Dtype>
Dtype FlattenLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
caffe_copy(channels_out_, top_diff, bottom_diff);
}


template <typename Dtype>
Dtype FlattenLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
caffe_gpu_copy(channels_out_, top_diff, bottom_diff);
}

INSTANTIATE_CLASS(FlattenLayer);

} // namespace caffe
93 changes: 93 additions & 0 deletions src/caffe/test/test_flatten_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright 2013 Yangqing Jia

#include <cstring>
#include <cuda_runtime.h>

#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/test/test_gradient_check_util.hpp"

#include "caffe/test/test_caffe_main.hpp"

namespace caffe {

extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;

template <typename Dtype>
class FlattenLayerTest : public ::testing::Test {
protected:
FlattenLayerTest()
: blob_bottom_(new Blob<Dtype>(2, 3, 6, 5)),
blob_top_(new Blob<Dtype>()) {
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
filler.Fill(this->blob_bottom_);
blob_bottom_vec_.push_back(blob_bottom_);
blob_top_vec_.push_back(blob_top_);
};
virtual ~FlattenLayerTest() { delete blob_bottom_; delete blob_top_; }
Blob<Dtype>* const blob_bottom_;
Blob<Dtype>* const blob_top_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};

typedef ::testing::Types<float, double> Dtypes;
TYPED_TEST_CASE(FlattenLayerTest, Dtypes);

TYPED_TEST(FlattenLayerTest, TestSetup) {
LayerParameter layer_param;
FlattenLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
EXPECT_EQ(this->blob_top_->num(), 2);
EXPECT_EQ(this->blob_top_->channels(), 3 * 6 * 5);
EXPECT_EQ(this->blob_top_->height(), 1);
EXPECT_EQ(this->blob_top_->width(), 1);
}

TYPED_TEST(FlattenLayerTest, TestCPU) {
LayerParameter layer_param;
FlattenLayer<TypeParam> layer(layer_param);
Caffe::set_mode(Caffe::CPU);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
for (int c = 0; c < 3 * 6 * 5; ++c) {
EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
this->blob_bottom_->data_at(0, c / (6 * 5), (c / 5) % 6, c % 5));
}
}

TYPED_TEST(FlattenLayerTest, TestGPU) {
LayerParameter layer_param;
FlattenLayer<TypeParam> layer(layer_param);
Caffe::set_mode(Caffe::GPU);
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
for (int c = 0; c < 3 * 6 * 5; ++c) {
EXPECT_EQ(this->blob_top_->data_at(0, c, 0, 0),
this->blob_bottom_->data_at(0, c / (6 * 5), (c / 5) % 6, c % 5));
}
}

TYPED_TEST(FlattenLayerTest, TestCPUGradient) {
LayerParameter layer_param;
Caffe::set_mode(Caffe::CPU);
FlattenLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-2);
checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
}

TYPED_TEST(FlattenLayerTest, TestGPUGradient) {
LayerParameter layer_param;
Caffe::set_mode(Caffe::GPU);
FlattenLayer<TypeParam> layer(layer_param);
GradientChecker<TypeParam> checker(1e-2, 1e-2);
checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_);
}


}