Skip to content

Conversation

shelhamer
Copy link
Member

This is the same as #2737 except for

  • Restoring the pooling layer fallback for max + argmax output when the layer is configured to have two tops. The padding is dropped since it is now supported by cuDNN pooling.
  • Clearing the warnings for unused variables in cuDNNConvolutionLayer::Forward_gpu() now that algo and workspace are determined in Reshape().

Thanks @slayton58 for the integration.

@slayton58
Copy link
Contributor

@shelhamer Regarding the test failure for Groups -- it seems like this->weight_offset_ for the cuDNN routines is getting set incorrectly / not set (either way, it's wrong!) This seems to have been introduced in 9d8206e

Setting it back explicitly to:
this->weight_offset_ = (this->num_output_ / this->group_) * (this->channels_ / this->group_) * kernel_h * kernel_w;

in CuDNNConvolutionLayer::Setup seems to fix the issue, not sure if there's a better way - let me know and I'll update the PR

shelhamer added a commit that referenced this pull request Oct 16, 2015
@shelhamer shelhamer merged commit 321720d into BVLC:master Oct 16, 2015
@shelhamer shelhamer deleted the cudnnV3 branch October 16, 2015 03:17
@ronghanghu
Copy link
Member

Great 👍

@shelhamer
Copy link
Member Author

cuDNN v3 is not itself backward compatible with v2, so adopting v3 in this PR does deprecate v2. We plan to follow the latest cuDNN version in master but keep compatability as the cuDNN interface itself allows.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shelhamer @slayton58 I am a bit confused here... Why do we need to zero out the diff? It confuses me as parameter gradients should be accumulated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely a bug -- thanks for the fix in #3254.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants