-
Notifications
You must be signed in to change notification settings - Fork 18.6k
Tanh #116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Tanh #116
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// Copyright 2014 Aravindh Mahendran | ||
// TanH neuron activation function layer. Adapted from ReLU layer code written by Yangqing Jia | ||
|
||
#include "caffe/layer.hpp" | ||
#include "caffe/vision_layers.hpp" | ||
#include <algorithm> | ||
|
||
namespace caffe { | ||
|
||
template <typename Dtype> | ||
void TanHLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, | ||
vector<Blob<Dtype>*>* top) { | ||
const Dtype* bottom_data = bottom[0]->cpu_data(); | ||
Dtype* top_data = (*top)[0]->mutable_cpu_data(); | ||
Dtype exp2x; | ||
const int count = bottom[0]->count(); | ||
for (int i = 0; i < count; ++i) { | ||
exp2x = exp(2*bottom_data[i]); | ||
top_data[i] = (exp2x - Dtype(1))/(exp2x + Dtype(1)); | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
Dtype TanHLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top, | ||
const bool propagate_down, | ||
vector<Blob<Dtype>*>* bottom) { | ||
if (propagate_down) { | ||
const Dtype* bottom_data = (*bottom)[0]->cpu_data(); | ||
const Dtype* top_diff = top[0]->cpu_diff(); | ||
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); | ||
const int count = (*bottom)[0]->count(); | ||
Dtype exp2x; | ||
Dtype tanhx; | ||
for (int i = 0; i < count; ++i) { | ||
exp2x = exp(2*bottom_data[i]); | ||
tanhx = (exp2x - Dtype(1))/(exp2x + Dtype(1)); | ||
bottom_diff[i] = top_diff[i] * (1 - tanhx*tanhx); | ||
} | ||
} | ||
return Dtype(0); | ||
} | ||
|
||
template <typename Dtype> | ||
__global__ void TanHForward(const int n, const Dtype* in, Dtype* out) { | ||
int index = threadIdx.x + blockIdx.x * blockDim.x; | ||
if (index < n) { | ||
Dtype exp2x = exp(2*in[index]); | ||
out[index] = (exp2x - Dtype(1))/(exp2x + Dtype(1)); | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
void TanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom, | ||
vector<Blob<Dtype>*>* top) { | ||
const Dtype* bottom_data = bottom[0]->gpu_data(); | ||
Dtype* top_data = (*top)[0]->mutable_gpu_data(); | ||
const int count = bottom[0]->count(); | ||
TanHForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>( | ||
count, bottom_data, top_data); | ||
CUDA_POST_KERNEL_CHECK; | ||
// << " count: " << count << " bottom_data: " | ||
// << (unsigned long)bottom_data << " top_data: " << (unsigned long)top_data | ||
// << " blocks: " << CAFFE_GET_BLOCKS(count) | ||
// << " threads: " << CAFFE_CUDA_NUM_THREADS; | ||
} | ||
|
||
template <typename Dtype> | ||
__global__ void TanHBackward(const int n, const Dtype* in_diff, | ||
const Dtype* in_data, Dtype* out_diff) { | ||
int index = threadIdx.x + blockIdx.x * blockDim.x; | ||
if (index < n) { | ||
Dtype exp2x = exp(2*in_data[index]); | ||
Dtype tanhx = (exp2x - Dtype(1))/(exp2x + Dtype(1)); | ||
out_diff[index] = in_diff[index] * (1 - tanhx*tanhx); | ||
} | ||
} | ||
|
||
template <typename Dtype> | ||
Dtype TanHLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top, | ||
const bool propagate_down, | ||
vector<Blob<Dtype>*>* bottom) { | ||
if (propagate_down) { | ||
const Dtype* bottom_data = (*bottom)[0]->gpu_data(); | ||
const Dtype* top_diff = top[0]->gpu_diff(); | ||
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); | ||
const int count = (*bottom)[0]->count(); | ||
TanHBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>( | ||
count, top_diff, bottom_data, bottom_diff); | ||
CUDA_POST_KERNEL_CHECK; | ||
} | ||
return Dtype(0); | ||
} | ||
|
||
INSTANTIATE_CLASS(TanHLayer); | ||
|
||
|
||
} // namespace caffe |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
// Copyright 2014 Aravindh Mahendran | ||
// Adapted from other test files | ||
|
||
#include <cmath> | ||
#include <cstring> | ||
#include <cuda_runtime.h> | ||
|
||
#include "gtest/gtest.h" | ||
#include "caffe/blob.hpp" | ||
#include "caffe/common.hpp" | ||
#include "caffe/filler.hpp" | ||
#include "caffe/vision_layers.hpp" | ||
#include "caffe/test/test_gradient_check_util.hpp" | ||
|
||
#include "caffe/test/test_caffe_main.hpp" | ||
|
||
namespace caffe { | ||
|
||
extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; | ||
|
||
template <typename Dtype> | ||
class TanHLayerTest : public ::testing::Test { | ||
protected: | ||
TanHLayerTest() | ||
: blob_bottom_(new Blob<Dtype>(2, 10, 1, 1)), | ||
blob_top_(new Blob<Dtype>()) { | ||
// fill the values | ||
FillerParameter filler_param; | ||
GaussianFiller<Dtype> filler(filler_param); | ||
filler.Fill(this->blob_bottom_); | ||
blob_bottom_vec_.push_back(blob_bottom_); | ||
blob_top_vec_.push_back(blob_top_); | ||
}; | ||
virtual ~TanHLayerTest() { delete blob_bottom_; delete blob_top_; } | ||
Blob<Dtype>* const blob_bottom_; | ||
Blob<Dtype>* const blob_top_; | ||
vector<Blob<Dtype>*> blob_bottom_vec_; | ||
vector<Blob<Dtype>*> blob_top_vec_; | ||
}; | ||
|
||
typedef ::testing::Types<float, double> Dtypes; | ||
TYPED_TEST_CASE(TanHLayerTest, Dtypes); | ||
|
||
TYPED_TEST(TanHLayerTest, TestForwardCPU) { | ||
LayerParameter layer_param; | ||
Caffe::set_mode(Caffe::CPU); | ||
TanHLayer<TypeParam> layer(layer_param); | ||
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); | ||
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); | ||
// Test exact values | ||
for (int i = 0; i < this->blob_bottom_->num(); ++i) { | ||
for (int j = 0; j < this->blob_bottom_->channels(); ++j) { | ||
for (int k = 0; k < this->blob_bottom_->height(); ++k) { | ||
for (int l = 0; l < this->blob_bottom_->width(); ++l) { | ||
EXPECT_GE(this->blob_top_->data_at(i,j,k,l) + 1e-4, | ||
(exp(2*this->blob_bottom_->data_at(i,j,k,l))-1)/(exp(2*this->blob_bottom_->data_at(i,j,k,l))+1)); | ||
EXPECT_LE(this->blob_top_->data_at(i,j,k,l) - 1e-4, | ||
(exp(2*this->blob_bottom_->data_at(i,j,k,l))-1)/(exp(2*this->blob_bottom_->data_at(i,j,k,l))+1)); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
TYPED_TEST(TanHLayerTest, TestGradientCPU) { | ||
LayerParameter layer_param; | ||
Caffe::set_mode(Caffe::CPU); | ||
TanHLayer<TypeParam> layer(layer_param); | ||
GradientChecker<TypeParam> checker(1e-2, 1e-3); | ||
checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_); | ||
} | ||
|
||
TYPED_TEST(TanHLayerTest, TestForwardGPU) { | ||
LayerParameter layer_param; | ||
Caffe::set_mode(Caffe::GPU); | ||
TanHLayer<TypeParam> layer(layer_param); | ||
layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); | ||
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_)); | ||
// Test exact values | ||
for (int i = 0; i < this->blob_bottom_->num(); ++i) { | ||
for (int j = 0; j < this->blob_bottom_->channels(); ++j) { | ||
for (int k = 0; k < this->blob_bottom_->height(); ++k) { | ||
for (int l = 0; l < this->blob_bottom_->width(); ++l) { | ||
EXPECT_GE(this->blob_top_->data_at(i,j,k,l) + 1e-4, | ||
(exp(2*this->blob_bottom_->data_at(i,j,k,l))-1)/(exp(2*this->blob_bottom_->data_at(i,j,k,l))+1)); | ||
EXPECT_LE(this->blob_top_->data_at(i,j,k,l) - 1e-4, | ||
(exp(2*this->blob_bottom_->data_at(i,j,k,l))-1)/(exp(2*this->blob_bottom_->data_at(i,j,k,l))+1)); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
TYPED_TEST(TanHLayerTest, TestGradientGPU) { | ||
LayerParameter layer_param; | ||
Caffe::set_mode(Caffe::GPU); | ||
TanHLayer<TypeParam> layer(layer_param); | ||
GradientChecker<TypeParam> checker(1e-2, 1e-3); | ||
checker.CheckGradientExhaustive(layer, this->blob_bottom_vec_, this->blob_top_vec_); | ||
} | ||
|
||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should create a SetUp method to check that there is only one bottom blob and only one top blob or if you also want this layer to work in-place you could add this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the back-and-forth, but one does not need to write a custom set up function if no special care needs to be carried out. The NeuronLayer has a virtual SetUp() function that will deal with the default construction and in-place check:
https://github.com/BVLC/caffe/blob/master/src/caffe/layers/neuron_layer.cpp
So maybe simply ignore the setup function...?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a setup for the situation in which tanh is not used in place. It should then call Reshape for (*top)[0].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/BVLC/caffe/blob/master/src/caffe/layers/neuron_layer.cpp#L18 serves this purpose.