diff --git a/test/test_distributions.py b/test/test_distributions.py index 2dde18cf213..43433b95fd3 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -8,7 +8,7 @@ import pytest import torch from _utils_internal import get_available_devices -from torch import nn +from torch import nn, autograd from torchrl.data.tensordict.tensordict import _TensorDict from torchrl.modules import ( TanhNormal, @@ -211,6 +211,18 @@ def test_tanhtrsf(dtype): assert (some_big_number.sign() == ones.sign()).all() +@pytest.mark.parametrize("dtype", [torch.float, torch.double]) +def test_tanhtrsf_grad(dtype): + torch.manual_seed(0) + trsf = SafeTanhTransform() + x = torch.randn(100, requires_grad=True) + y1 = trsf(x) + y2 = x.tanh() + g1 = autograd.grad(y1.sum(), x, retain_graph=True)[0] + g2 = autograd.grad(y2.sum(), x, retain_graph=True)[0] + torch.testing.assert_close(g1, g2) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/csrc/pybind.cpp b/torchrl/csrc/pybind.cpp index 2e6fa604e6f..e3af8c3128d 100644 --- a/torchrl/csrc/pybind.cpp +++ b/torchrl/csrc/pybind.cpp @@ -11,6 +11,7 @@ #include #include "segment_tree.h" +#include "utils.h" namespace py = pybind11; @@ -20,4 +21,6 @@ PYBIND11_MODULE(_torchrl, m) { torchrl::DefineMinSegmentTree("Fp32", m); torchrl::DefineMinSegmentTree("Fp64", m); + + m.def("safetanh", &safetanh, "Safe Tanh"); } diff --git a/torchrl/csrc/utils.h b/torchrl/csrc/utils.h new file mode 100644 index 00000000000..3f2944e2897 --- /dev/null +++ b/torchrl/csrc/utils.h @@ -0,0 +1,13 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include + +torch::Tensor safetanh(torch::Tensor input) { + return torch::clamp(torch::tanh(input), -0.999999, 0.999999); +} diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 4c88c7e967d..5ca1ec4c26a 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -11,6 +11,7 @@ from torch import distributions as D, nn from torch.distributions import constraints +from torchrl._torchrl import safetanh from torchrl.modules.distributions.truncated_normal import ( TruncatedNormal as _TruncatedNormal, ) @@ -85,9 +86,7 @@ class SafeTanhTransform(D.TanhTransform): """ def _call(self, x: torch.Tensor) -> torch.Tensor: - eps = torch.finfo(x.dtype).eps - y = super()._call(x) - y.data.clamp_(-1 + eps, 1 - eps) + y = safetanh(x) return y def _inverse(self, y: torch.Tensor) -> torch.Tensor: