Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(
device: DEVICE_TYPING = "cpu",
dtype: Optional[Union[torch.dtype, np.dtype]] = None,
batch_size: Optional[torch.Size] = None,
run_type_checks: bool = True,
):
super().__init__()
if device is not None:
Expand All @@ -224,6 +225,7 @@ def __init__(
"batch_size" not in self.__class__.__dict__
):
self.batch_size = torch.Size([])
self.run_type_checks = run_type_checks

@classmethod
def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs):
Expand Down Expand Up @@ -312,21 +314,21 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
"tensordict.select()) inside _step before writing new tensors onto this new instance."
)
self.is_done = tensordict_out.get("done")
if self.run_type_checks:
for key in self._select_observation_keys(tensordict_out):
obs = tensordict_out.get(key)
self.observation_spec.type_check(obs, key)

if tensordict_out._get_meta("reward").dtype is not self.reward_spec.dtype:
raise TypeError(
f"expected reward.dtype to be {self.reward_spec.dtype} "
f"but got {tensordict_out.get('reward').dtype}"
)

for key in self._select_observation_keys(tensordict_out):
obs = tensordict_out.get(key)
self.observation_spec.type_check(obs, key)

if tensordict_out._get_meta("reward").dtype is not self.reward_spec.dtype:
raise TypeError(
f"expected reward.dtype to be {self.reward_spec.dtype} "
f"but got {tensordict_out.get('reward').dtype}"
)

if tensordict_out._get_meta("done").dtype is not torch.bool:
raise TypeError(
f"expected done.dtype to be torch.bool but got {tensordict_out.get('done').dtype}"
)
if tensordict_out._get_meta("done").dtype is not torch.bool:
raise TypeError(
f"expected done.dtype to be torch.bool but got {tensordict_out.get('done').dtype}"
)
tensordict.update(tensordict_out, inplace=self._inplace_update)

del tensordict_out
Expand Down
7 changes: 6 additions & 1 deletion torchrl/envs/model_based/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class ModelBasedEnvBase(EnvBase, metaclass=abc.ABCMeta):
device (torch.device, optional): device where the env input and output are expected to live
dtype (torch.dtype, optional): dtype of the env input and output
batch_size (torch.Size, optional): number of environments contained in the instance
run_type_check (bool, optional): whether to run type checks on the step of the env

Methods:
step (TensorDict -> TensorDict): step in the environment
Expand All @@ -116,9 +117,13 @@ def __init__(
device: DEVICE_TYPING = "cpu",
dtype: Optional[Union[torch.dtype, np.dtype]] = None,
batch_size: Optional[torch.Size] = None,
run_type_checks: bool = False,
):
super(ModelBasedEnvBase, self).__init__(
device=device, dtype=dtype, batch_size=batch_size
device=device,
dtype=dtype,
batch_size=batch_size,
run_type_checks=run_type_checks,
)
self.world_model = world_model.to(self.device)
self.world_model_params = params
Expand Down