Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6676,6 +6676,36 @@ def test_single_trans_env_check(self):
assert "mykey" in env.reset().keys()
assert ("next", "mykey") in env.rollout(3).keys(True)

def test_nested_key_env(self):
env = MultiKeyCountingEnv()
env_obs_spec_prior_primer = env.observation_spec.clone()
env = TransformedEnv(
env,
TensorDictPrimer(
CompositeSpec(
{
"nested_1": CompositeSpec(
{
"mykey": UnboundedContinuousTensorSpec(
(env.nested_dim_1, 4)
)
},
shape=(env.nested_dim_1,),
)
}
),
reset_key="_reset",
),
)
check_env_specs(env)
env_obs_spec_post_primer = env.observation_spec.clone()
assert ("nested_1", "mykey") in env_obs_spec_post_primer.keys(True, True)
del env_obs_spec_post_primer[("nested_1", "mykey")]
assert env_obs_spec_post_primer == env_obs_spec_prior_primer

assert ("nested_1", "mykey") in env.reset().keys(True, True)
assert ("next", "nested_1", "mykey") in env.rollout(3).keys(True, True)

def test_transform_no_env(self):
t = TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3]))
td = TensorDict({"a": torch.zeros(())}, [])
Expand Down
26 changes: 12 additions & 14 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4646,7 +4646,7 @@ def __init__(
self.reset_key = reset_key

# sanity check
for spec in self.primers.values():
for spec in self.primers.values(True, True):
if not isinstance(spec, TensorSpec):
raise ValueError(
"The values of the primers must be a subtype of the TensorSpec class. "
Expand Down Expand Up @@ -4705,15 +4705,16 @@ def transform_observation_spec(
raise ValueError(
f"observation_spec was expected to be of type CompositeSpec. Got {type(observation_spec)} instead."
)
for key, spec in self.primers.items():
if spec.shape[: len(observation_spec.shape)] != observation_spec.shape:
expanded_spec = self._expand_shape(spec)
spec = expanded_spec

if self.primers.shape != observation_spec.shape:
try:
device = observation_spec.device
except RuntimeError:
device = self.device
observation_spec[key] = self.primers[key] = spec.to(device)
# We try to set the primer shape to the observation spec shape
self.primers.shape = observation_spec.shape
except ValueError:
# If we fail, we expnad them to that shape
self.primers = self._expand_shape(self.primers)
device = observation_spec.device
observation_spec.update(self.primers.clone().to(device))
return observation_spec

def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
Expand Down Expand Up @@ -4763,8 +4764,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
for key in self.primers.keys():
if key not in next_tensordict.keys(True):
for key in self.primers.keys(True, True):
if key not in next_tensordict.keys(True, True):
prev_val = tensordict.get(key)
next_tensordict.set(key, prev_val)
return next_tensordict
Expand All @@ -4782,9 +4783,6 @@ def _reset(
_reset = _get_reset(self.reset_key, tensordict)
if _reset.any():
for key, spec in self.primers.items(True, True):
if spec.shape[: len(tensordict.batch_size)] != tensordict.batch_size:
expanded_spec = self._expand_shape(spec)
self.primers[key] = spec = expanded_spec
if self.random:
shape = (
()
Expand Down