Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions include/caffe/data_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/data_transformer.hpp"
#include "caffe/filler.hpp"
#include "caffe/internal_thread.hpp"
#include "caffe/layer.hpp"
Expand All @@ -24,12 +25,12 @@ namespace caffe {

// TODO: DataLayer, ImageDataLayer, and WindowDataLayer all have the
// same basic structure and a lot of duplicated code.

template <typename Dtype>
class DataLayer : public Layer<Dtype>, public InternalThread {
public:
explicit DataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
: Layer<Dtype>(param),
data_transformer_(param.data_param().transform_param()) {}
virtual ~DataLayer();
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
Expand All @@ -53,11 +54,10 @@ class DataLayer : public Layer<Dtype>, public InternalThread {

virtual void CreatePrefetchThread();
virtual void JoinPrefetchThread();
virtual unsigned int PrefetchRand();
// The thread's function
virtual void InternalThreadEntry();

shared_ptr<Caffe::RNG> prefetch_rng_;
DataTransformer<Dtype> data_transformer_;

// LEVELDB
shared_ptr<leveldb::DB> db_;
Expand Down Expand Up @@ -178,7 +178,8 @@ template <typename Dtype>
class ImageDataLayer : public Layer<Dtype>, public InternalThread {
public:
explicit ImageDataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
: Layer<Dtype>(param),
data_transformer_(param.data_param().transform_param()) {}
virtual ~ImageDataLayer();
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
Expand All @@ -203,10 +204,11 @@ class ImageDataLayer : public Layer<Dtype>, public InternalThread {

virtual void CreatePrefetchThread();
virtual void JoinPrefetchThread();
virtual unsigned int PrefetchRand();
virtual void InternalThreadEntry();

DataTransformer<Dtype> data_transformer_;
shared_ptr<Caffe::RNG> prefetch_rng_;

vector<std::pair<std::string, int> > lines_;
int lines_id_;
int datum_channels_;
Expand Down
55 changes: 55 additions & 0 deletions include/caffe/data_transformer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef CAFFE_DATA_TRANSFORMER_HPP
#define CAFFE_DATA_TRANSFORMER_HPP

#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"

namespace caffe {

/**
* @brief Applies common transformations to the input data, such as
* scaling, mirroring, substracting the image mean...
*/
template <typename Dtype>
class DataTransformer {
public:
explicit DataTransformer(const TransformationParameter& param)
: param_(param) {
phase_ = Caffe::phase();
}
virtual ~DataTransformer() {}

void InitRand();

/**
* @brief Applies the transformation defined in the data layer's
* transform_param block to the data.
*
* @param batch_item_id
* Datum position within the batch. This is used to compute the
* writing position in the top blob's data
* @param datum
* Datum containing the data to be transformed.
* @param mean
* @param top_data
* This is meant to be the top blob's data. The transformed data will be
* written at the appropriate place within the blob's data.
*/
void Transform(const int batch_item_id, const Datum& datum,
const Dtype* mean, Dtype* transformed_data);

protected:
virtual unsigned int Rand();

// Tranformation parameters
TransformationParameter param_;


shared_ptr<Caffe::RNG> rng_;
Caffe::Phase phase_;
};

} // namespace caffe

#endif // CAFFE_DATA_TRANSFORMER_HPP_

113 changes: 113 additions & 0 deletions src/caffe/data_transformer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#include <string>

#include "caffe/data_transformer.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"

namespace caffe {

template<typename Dtype>
void DataTransformer<Dtype>::Transform(const int batch_item_id,
const Datum& datum,
const Dtype* mean,
Dtype* transformed_data) {
const string& data = datum.data();
const int channels = datum.channels();
const int height = datum.height();
const int width = datum.width();
const int size = datum.channels() * datum.height() * datum.width();

const int crop_size = param_.crop_size();
const bool mirror = param_.mirror();
const Dtype scale = param_.scale();



if (mirror && crop_size == 0) {
LOG(FATAL) << "Current implementation requires mirror and crop_size to be "
<< "set at the same time.";
}

if (crop_size) {
CHECK(data.size()) << "Image cropping only support uint8 data";
int h_off, w_off;
// We only do random crop when we do training.
if (phase_ == Caffe::TRAIN) {
h_off = Rand() % (height - crop_size);
w_off = Rand() % (width - crop_size);
} else {
h_off = (height - crop_size) / 2;
w_off = (width - crop_size) / 2;
}
if (mirror && Rand() % 2) {
// Copy mirrored version
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int data_index = (c * height + h + h_off) * width + w + w_off;
int top_index = ((batch_item_id * channels + c) * crop_size + h)
* crop_size + (crop_size - 1 - w);
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
transformed_data[top_index] =
(datum_element - mean[data_index]) * scale;
}
}
}
} else {
// Normal copy
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int top_index = ((batch_item_id * channels + c) * crop_size + h)
* crop_size + w;
int data_index = (c * height + h + h_off) * width + w + w_off;
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
transformed_data[top_index] =
(datum_element - mean[data_index]) * scale;
}
}
}
}
} else {
// we will prefer to use data() first, and then try float_data()
if (data.size()) {
for (int j = 0; j < size; ++j) {
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[j]));
transformed_data[j + batch_item_id * size] =
(datum_element - mean[j]) * scale;
}
} else {
for (int j = 0; j < size; ++j) {
transformed_data[j + batch_item_id * size] =
(datum.float_data(j) - mean[j]) * scale;
}
}
}
}

template <typename Dtype>
void DataTransformer<Dtype>::InitRand() {
const bool needs_rand = (phase_ == Caffe::TRAIN) &&
(param_.mirror() || param_.crop_size());
if (needs_rand) {
const unsigned int rng_seed = caffe_rng_rand();
rng_.reset(new Caffe::RNG(rng_seed));
} else {
rng_.reset();
}
}

template <typename Dtype>
unsigned int DataTransformer<Dtype>::Rand() {
CHECK(rng_);
caffe::rng_t* rng =
static_cast<caffe::rng_t*>(rng_->generator());
return (*rng)();
}

INSTANTIATE_CLASS(DataTransformer);

} // namespace caffe
98 changes: 10 additions & 88 deletions src/caffe/layers/data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,8 @@ void DataLayer<Dtype>::InternalThreadEntry() {
if (output_labels_) {
top_label = prefetch_label_.mutable_cpu_data();
}
const Dtype scale = this->layer_param_.data_param().scale();
const int batch_size = this->layer_param_.data_param().batch_size();
const int crop_size = this->layer_param_.data_param().crop_size();
const bool mirror = this->layer_param_.data_param().mirror();

if (mirror && crop_size == 0) {
LOG(FATAL) << "Current implementation requires mirror and crop_size to be "
<< "set at the same time.";
}
// datum scales
const int channels = datum_channels_;
const int height = datum_height_;
const int width = datum_width_;
const int size = datum_size_;
const Dtype* mean = data_mean_.cpu_data();
for (int item_id = 0; item_id < batch_size; ++item_id) {
// get a blob
Expand All @@ -56,66 +44,13 @@ void DataLayer<Dtype>::InternalThreadEntry() {
LOG(FATAL) << "Unknown database backend";
}

const string& data = datum.data();
if (crop_size) {
CHECK(data.size()) << "Image cropping only support uint8 data";
int h_off, w_off;
// We only do random crop when we do training.
if (phase_ == Caffe::TRAIN) {
h_off = PrefetchRand() % (height - crop_size);
w_off = PrefetchRand() % (width - crop_size);
} else {
h_off = (height - crop_size) / 2;
w_off = (width - crop_size) / 2;
}
if (mirror && PrefetchRand() % 2) {
// Copy mirrored version
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int top_index = ((item_id * channels + c) * crop_size + h)
* crop_size + (crop_size - 1 - w);
int data_index = (c * height + h + h_off) * width + w + w_off;
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
top_data[top_index] = (datum_element - mean[data_index]) * scale;
}
}
}
} else {
// Normal copy
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int top_index = ((item_id * channels + c) * crop_size + h)
* crop_size + w;
int data_index = (c * height + h + h_off) * width + w + w_off;
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
top_data[top_index] = (datum_element - mean[data_index]) * scale;
}
}
}
}
} else {
// we will prefer to use data() first, and then try float_data()
if (data.size()) {
for (int j = 0; j < size; ++j) {
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[j]));
top_data[item_id * size + j] = (datum_element - mean[j]) * scale;
}
} else {
for (int j = 0; j < size; ++j) {
top_data[item_id * size + j] =
(datum.float_data(j) - mean[j]) * scale;
}
}
}
// Apply data transformations (mirror, scale, crop...)
data_transformer_.Transform(item_id, datum, mean, top_data);

if (output_labels_) {
top_label[item_id] = datum.label();
}

// go to the next iter
switch (this->layer_param_.data_param().backend()) {
case DataParameter_DB_LEVELDB:
Expand Down Expand Up @@ -244,7 +179,7 @@ void DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
}

// image
int crop_size = this->layer_param_.data_param().crop_size();
int crop_size = this->layer_param_.data_param().transform_param().crop_size();
if (crop_size > 0) {
(*top)[0]->Reshape(this->layer_param_.data_param().batch_size(),
datum.channels(), crop_size, crop_size);
Expand Down Expand Up @@ -274,8 +209,9 @@ void DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CHECK_GT(datum_height_, crop_size);
CHECK_GT(datum_width_, crop_size);
// check if we want to have mean
if (this->layer_param_.data_param().has_mean_file()) {
const string& mean_file = this->layer_param_.data_param().mean_file();
if (this->layer_param_.data_param().transform_param().has_mean_file()) {
const string& mean_file =
this->layer_param_.data_param().transform_param().mean_file();
LOG(INFO) << "Loading mean file from" << mean_file;
BlobProto blob_proto;
ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
Expand Down Expand Up @@ -305,15 +241,9 @@ void DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
template <typename Dtype>
void DataLayer<Dtype>::CreatePrefetchThread() {
phase_ = Caffe::phase();
const bool prefetch_needs_rand = (phase_ == Caffe::TRAIN) &&
(this->layer_param_.data_param().mirror() ||
this->layer_param_.data_param().crop_size());
if (prefetch_needs_rand) {
const unsigned int prefetch_rng_seed = caffe_rng_rand();
prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));
} else {
prefetch_rng_.reset();
}

data_transformer_.InitRand();

CHECK(!StartInternalThread()) << "Pthread execution failed";
}

Expand All @@ -322,14 +252,6 @@ void DataLayer<Dtype>::JoinPrefetchThread() {
CHECK(!WaitForInternalThreadToExit()) << "Pthread joining failed";
}

template <typename Dtype>
unsigned int DataLayer<Dtype>::PrefetchRand() {
CHECK(prefetch_rng_);
caffe::rng_t* prefetch_rng =
static_cast<caffe::rng_t*>(prefetch_rng_->generator());
return (*prefetch_rng)();
}

template <typename Dtype>
void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Expand Down
Loading