Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Changed the ImageDataLayer to have more than 1 image output.
By adding more 'top' params in the data layer in the .prototxt file while
keeping the label last, more images can be loaded.
This can be used to load for example siamese image pairs from a textfile as
created by 'create_imageset' in 'examples/siamese/'.

Since this required changes in the BasePrefetchingDataLayer and also other
classes are inherited from it, for now a new class:
BasePrefetchingMulitDataLayer has been made.
After thorough testing, this code can be moved into the original class as it
still supports loading only a single entity of data.
  • Loading branch information
FlorisGaisser committed Jun 17, 2015
commit f92665dc27a792e20cf0a484b385a253137d4a62
57 changes: 53 additions & 4 deletions include/caffe/data_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,41 @@ class BasePrefetchingDataLayer :
Blob<Dtype> transformed_data_;
};

/**
* @brief Provides pre-fetching base for data layers that feed multiple blobs
* to the Net.
*
* TODO(dox): thorough documentation for Forward and proto params.
*/
template <typename Dtype>
class BasePrefetchingMultiDataLayer :
public BaseDataLayer<Dtype>, public InternalThread {
public:
explicit BasePrefetchingMultiDataLayer(const LayerParameter& param)
: BaseDataLayer<Dtype>(param) {}
// LayerSetUp: implements common data layer setup functionality, and calls
// DataLayerSetUp to do special data layer setup for individual layer types.
// This method may not be overridden.
void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual void CreatePrefetchThread();
virtual void JoinPrefetchThread();
// The thread's function
virtual void InternalThreadEntry() {}

protected:
int input_data_size_;
std::vector<Blob<Dtype>*> prefetch_data_;
Blob<Dtype> prefetch_label_;
Blob<Dtype> transformed_data_;
};

template <typename Dtype>
class DataLayer : public BasePrefetchingDataLayer<Dtype> {
public:
Expand Down Expand Up @@ -217,27 +252,41 @@ class HDF5OutputLayer : public Layer<Dtype> {
/**
* @brief Provides data to the Net from image files.
*
* The file format is of the form:
* [file_path] { ... [file_path]} [label]
* The delimiter can be ' ', ',' or '\t' (checked in this order)
* To enable more than one file input, just add a 'top' in the layer parameters
* while keeping the label last.
*
* TODO(dox): thorough documentation for Forward and proto params.
*/
template <typename Dtype>
class ImageDataLayer : public BasePrefetchingDataLayer<Dtype> {
class ImageDataLayer : public BasePrefetchingMultiDataLayer<Dtype> {
public:
explicit ImageDataLayer(const LayerParameter& param)
: BasePrefetchingDataLayer<Dtype>(param) {}
: BasePrefetchingMultiDataLayer<Dtype>(param) {}
virtual ~ImageDataLayer();
virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "ImageData"; }
virtual inline int ExactNumBottomBlobs() const { return 0; }
virtual inline int ExactNumTopBlobs() const { return 2; }
virtual inline int ExactNumTopBlobs() const {
return this->layer_param_.top_size();
}
// would this work? return this->input_data_size_ + 1;

protected:
shared_ptr<Caffe::RNG> prefetch_rng_;
virtual void ShuffleImages();
virtual void InternalThreadEntry();
int findNumOccurrences(char delim, std::string text);
void split(
char delim,
const std::string &line,
std::vector<std::string> *parts);

vector<std::pair<std::string, int> > lines_;
vector<std::pair<std::vector<std::string>, int> > lines_;
int lines_id_;
};

Expand Down
69 changes: 69 additions & 0 deletions src/caffe/layers/base_data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,76 @@ void BasePrefetchingDataLayer<Dtype>::Forward_cpu(
STUB_GPU_FORWARD(BasePrefetchingDataLayer, Forward);
#endif



template <typename Dtype>
void BasePrefetchingMultiDataLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
// the data size is of the size the top - 1 as top = data + label
input_data_size_ = top.size() - 1;
for (int i = 0; i < input_data_size_; i++)
prefetch_data_.push_back(new Blob<Dtype>());
BaseDataLayer<Dtype>::LayerSetUp(bottom, top);
// Now, start the prefetch thread. Before calling prefetch, we make two
// cpu_data calls so that the prefetch thread does not accidentally make
// simultaneous cudaMalloc calls when the main thread is running. In some
// GPUs this seems to cause failures if we do not so.
for (int data_id = 0; data_id < input_data_size_; data_id++) {
this->prefetch_data_[data_id]->mutable_cpu_data();
}
if (this->output_labels_) {
this->prefetch_label_.mutable_cpu_data();
}
DLOG(INFO) << "Initializing prefetch";
this->CreatePrefetchThread();
DLOG(INFO) << "Prefetch initialized.";
}

template <typename Dtype>
void BasePrefetchingMultiDataLayer<Dtype>::CreatePrefetchThread() {
this->data_transformer_->InitRand();
CHECK(StartInternalThread()) << "Thread execution failed";
}

template <typename Dtype>
void BasePrefetchingMultiDataLayer<Dtype>::JoinPrefetchThread() {
CHECK(WaitForInternalThreadToExit()) << "Thread joining failed";
}

template <typename Dtype>
void BasePrefetchingMultiDataLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
// First, join the thread
JoinPrefetchThread();
DLOG(INFO) << "Thread joined";
// Reshape to loaded data.
for (int data_id = 0; data_id < input_data_size_; data_id++) {
top[data_id]->ReshapeLike(*prefetch_data_[data_id]);
// Copy the data
caffe_copy(
prefetch_data_[data_id]->count(),
prefetch_data_[data_id]->cpu_data(),
top[data_id]->mutable_cpu_data());
DLOG(INFO) << "Prefetch copied";
}
if (this->output_labels_) {
// Reshape to loaded labels.
top[input_data_size_]->ReshapeLike(prefetch_label_);
// Copy the labels.
caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
top[input_data_size_]->mutable_cpu_data());
}
// Start a new prefetch thread
DLOG(INFO) << "CreatePrefetchThread";
CreatePrefetchThread();
}

#ifdef CPU_ONLY
STUB_GPU_FORWARD(BasePrefetchingMultiDataLayer, Forward);
#endif

INSTANTIATE_CLASS(BaseDataLayer);
INSTANTIATE_CLASS(BasePrefetchingDataLayer);
INSTANTIATE_CLASS(BasePrefetchingMultiDataLayer);

} // namespace caffe
29 changes: 29 additions & 0 deletions src/caffe/layers/base_data_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,33 @@ void BasePrefetchingDataLayer<Dtype>::Forward_gpu(

INSTANTIATE_LAYER_GPU_FORWARD(BasePrefetchingDataLayer);

template <typename Dtype>
void BasePrefetchingMultiDataLayer<Dtype>::Forward_gpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
// First, join the thread
JoinPrefetchThread();
// Reshape to loaded data.
for (int data_id = 0; data_id < input_data_size_; data_id++) {
top[data_id]->ReshapeLike(*prefetch_data_[data_id]);
// Copy the data
caffe_copy(
prefetch_data_[data_id]->count(),
prefetch_data_[data_id]->cpu_data(),
top[data_id]->mutable_gpu_data());
DLOG(INFO) << "Prefetch copied";
}
if (this->output_labels_) {
// Reshape to loaded labels.
top[input_data_size_]->ReshapeLike(prefetch_label_);
// Copy the labels.
caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),
top[input_data_size_]->mutable_gpu_data());
}
// Start a new prefetch thread
DLOG(INFO) << "CreatePrefetchThread";
CreatePrefetchThread();
}

INSTANTIATE_LAYER_GPU_FORWARD(BasePrefetchingMultiDataLayer);

} // namespace caffe
Loading