Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
MultOneHotDiscreteTensorSpec,
BoundedTensorSpec,
NdBoundedTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.tensordict.tensordict import assert_allclose_td, TensorDict
from torchrl.envs import EnvCreator, ObservationNorm, CatTensors, DoubleToFloat
from torchrl.envs.gym_like import default_info_dict_reader
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
from torchrl.envs.libs.gym import _has_gym, GymEnv
from torchrl.envs.libs.gym import GymEnv, GymWrapper, _has_gym
from torchrl.envs.transforms import (
TransformedEnv,
Compose,
Expand Down Expand Up @@ -1121,6 +1123,34 @@ def test_batch_unlocked_with_batch_size(device):
env.step(td_expanded)


@pytest.mark.skipif(not _has_gym, reason="no gym")
def test_info_dict_reader(seed=0):
import gym

env = GymWrapper(gym.make("HalfCheetah-v4"))
env.set_info_dict_reader(default_info_dict_reader(["x_position"]))

assert "x_position" in env.observation_spec.keys()
assert isinstance(env.observation_spec["x_position"], UnboundedContinuousTensorSpec)
Comment on lines +1133 to +1134
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there something else I could be checking here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could do assert env.observation_spec["x_position"].is_in(tensordict["x_position"])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's essentially testing that x_position is what it's supposed to be
Another thing we could do is set up a screwed up tensor spec (e.g. one that has wrong shape or wrong dtype) and check that this assertion raises an error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added some additional tests as you suggested, although it wasn't raising an error at any stage, just returning false for

env.observation_spec["x_position"].is_in(tensordict["x_position"])

Would you have expected an exception at some point?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no but yes if you put an assert if front of it :p


tensordict = env.reset()
tensordict = env.rand_step(tensordict)

assert env.observation_spec["x_position"].is_in(tensordict["x_position"])

env2 = GymWrapper(gym.make("HalfCheetah-v4"))
env2.set_info_dict_reader(
default_info_dict_reader(
["x_position"], spec={"x_position": OneHotDiscreteTensorSpec(5)}
)
)

tensordict2 = env2.reset()
tensordict2 = env2.rand_step(tensordict2)

assert not env2.observation_spec["x_position"].is_in(tensordict2["x_position"])


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
53 changes: 48 additions & 5 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,38 @@

from __future__ import annotations

import abc
import warnings
from typing import Optional, Union, Tuple, Any, Dict
from typing import List, Optional, Sequence, Union, Tuple, Any, Dict

import numpy as np
import torch

from torchrl.data import TensorDict
from torchrl.data.tensor_specs import TensorSpec, UnboundedContinuousTensorSpec
from torchrl.data.tensordict.tensordict import TensorDictBase
from torchrl.envs.common import _EnvWrapper

__all__ = ["GymLikeEnv", "default_info_dict_reader"]


class default_info_dict_reader:
class BaseInfoDictReader(metaclass=abc.ABCMeta):
"""
Base class for info-readers.
"""

@abc.abstractmethod
def __call__(
self, info_dict: Dict[str, Any], tensordict: TensorDictBase
) -> TensorDictBase:
raise NotImplementedError

@abc.abstractproperty
def info_spec(self) -> Dict[str, TensorSpec]:
raise NotImplementedError


class default_info_dict_reader(BaseInfoDictReader):
"""
Default info-key reader.

Expand All @@ -39,11 +57,30 @@ class default_info_dict_reader:

"""

def __init__(self, keys=None):
def __init__(
self,
keys: List[str] = None,
spec: Union[Sequence[TensorSpec], Dict[str, TensorSpec]] = None,
):
if keys is None:
keys = []
self.keys = keys

if isinstance(spec, Sequence):
if len(spec) != len(self.keys):
raise ValueError(
"If specifying specs for info keys with a sequence, the "
"length of the sequence must match the number of keys"
)
self._info_spec = dict(zip(self.keys, spec))
else:
if spec is None:
spec = {}

self._info_spec = {
key: spec.get(key, UnboundedContinuousTensorSpec()) for key in self.keys
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MyPy gives me a warning here because TensorSpec has abstract method index which UnboundedContinuousTensorSpec does not define.

}

def __call__(
self, info_dict: Dict[str, Any], tensordict: TensorDictBase
) -> TensorDictBase:
Expand All @@ -57,9 +94,13 @@ def __call__(
tensordict[key] = info_dict[key]
return tensordict

@property
def info_spec(self) -> Dict[str, TensorSpec]:
return self._info_spec


class GymLikeEnv(_EnvWrapper):
_info_dict_reader: callable
_info_dict_reader: BaseInfoDictReader

"""
A gym-like env is an environment whose behaviour is similar to gym environments in what
Expand Down Expand Up @@ -216,7 +257,7 @@ def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple:
)
return step_outputs_tuple

def set_info_dict_reader(self, info_dict_reader: callable) -> GymLikeEnv:
def set_info_dict_reader(self, info_dict_reader: BaseInfoDictReader) -> GymLikeEnv:
"""
Sets an info_dict_reader function. This function should take as input an
info_dict dictionary and the tensordict returned by the step function, and
Expand All @@ -240,6 +281,8 @@ def set_info_dict_reader(self, info_dict_reader: callable) -> GymLikeEnv:

"""
self.info_dict_reader = info_dict_reader
for info_key, spec in info_dict_reader.info_spec.items():
self.observation_spec[info_key] = spec
return self

def __repr__(self) -> str:
Expand Down