Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13784,7 +13784,7 @@ def policy(td):
assert r1["before_count"].max() == 18
assert r1["after_count"].max() == 6
finally:
env.close()
env.close(raise_if_closed=False)

@pytest.mark.parametrize("bwad", [False, True])
def test_serial_trans_env_check(self, bwad):
Expand Down
84 changes: 71 additions & 13 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
clear_mpi_env_vars,
)

_CONSOLIDATE_ERR_CAPTURE = (
"TensorDict.consolidate failed. You can deactivate the tensordict consolidation via the "
"`consolidate` keyword argument of the ParallelEnv constructor."
)


def _check_start(fun):
def decorated_fun(self: BatchedEnvBase, *args, **kwargs):
Expand Down Expand Up @@ -307,6 +312,7 @@ def __init__(
non_blocking: bool = False,
mp_start_method: str = None,
use_buffers: bool = None,
consolidate: bool = True,
):
super().__init__(device=device)
self.serial_for_single = serial_for_single
Expand All @@ -315,6 +321,7 @@ def __init__(
self.num_threads = num_threads
self._cache_in_keys = None
self._use_buffers = use_buffers
self.consolidate = consolidate

self._single_task = callable(create_env_fn) or (len(set(create_env_fn)) == 1)
if callable(create_env_fn):
Expand Down Expand Up @@ -841,9 +848,12 @@ def __repr__(self) -> str:
f"\n\tbatch_size={self.batch_size})"
)

def close(self) -> None:
def close(self, *, raise_if_closed: bool = True) -> None:
if self.is_closed:
raise RuntimeError("trying to close a closed environment")
if raise_if_closed:
raise RuntimeError("trying to close a closed environment")
else:
return
if self._verbose:
torchrl_logger.info(f"closing {self.__class__.__name__}")

Expand Down Expand Up @@ -1470,6 +1480,12 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
"_non_tensor_keys": self._non_tensor_keys,
}
)
else:
kwargs[idx].update(
{
"consolidate": self.consolidate,
}
)
process = proc_fun(target=func, kwargs=kwargs[idx])
process.daemon = True
process.start()
Expand Down Expand Up @@ -1526,7 +1542,16 @@ def _step_and_maybe_reset_no_buffers(
else:
workers_range = range(self.num_workers)

td = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1)
if self.consolidate:
try:
td = tensordict.consolidate(
share_memory=True, inplace=True, num_threads=1
)
except Exception as err:
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
else:
td = tensordict

for i in workers_range:
# We send the same td multiple times as it is in shared mem and we just need to index it
# in each process.
Expand Down Expand Up @@ -1804,7 +1829,16 @@ def _step_no_buffers(
else:
workers_range = range(self.num_workers)

data = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1)
if self.consolidate:
try:
data = tensordict.consolidate(
share_memory=True, inplace=True, num_threads=1
)
except Exception as err:
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
else:
data = tensordict

for i, local_data in zip(workers_range, data.unbind(0)):
self.parent_channels[i].send(("step", local_data))
# for i in range(data.shape[0]):
Expand Down Expand Up @@ -2026,9 +2060,14 @@ def _reset_no_buffers(
) -> Tuple[TensorDictBase, TensorDictBase]:
if is_tensor_collection(tensordict):
# tensordict = tensordict.consolidate(share_memory=True, num_threads=1)
tensordict = tensordict.consolidate(
share_memory=True, num_threads=1
).unbind(0)
if self.consolidate:
try:
tensordict = tensordict.consolidate(
share_memory=True, num_threads=1
)
except Exception as err:
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
tensordict = tensordict.unbind(0)
else:
tensordict = [None] * self.num_workers
out_tds = [None] * self.num_workers
Expand Down Expand Up @@ -2545,6 +2584,7 @@ def _run_worker_pipe_direct(
has_lazy_inputs: bool = False,
verbose: bool = False,
num_threads: int | None = None, # for fork start method
consolidate: bool = True,
) -> None:
if num_threads is not None:
torch.set_num_threads(num_threads)
Expand Down Expand Up @@ -2634,9 +2674,18 @@ def _run_worker_pipe_direct(
event.record()
event.synchronize()
mp_event.set()
child_pipe.send(
cur_td.consolidate(share_memory=True, inplace=True, num_threads=1)
)
if consolidate:
try:
child_pipe.send(
cur_td.consolidate(
share_memory=True, inplace=True, num_threads=1
)
)
except Exception as err:
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
else:
child_pipe.send(cur_td)

del cur_td

elif cmd == "step":
Expand All @@ -2650,9 +2699,18 @@ def _run_worker_pipe_direct(
event.record()
event.synchronize()
mp_event.set()
child_pipe.send(
next_td.consolidate(share_memory=True, inplace=True, num_threads=1)
)
if consolidate:
try:
child_pipe.send(
next_td.consolidate(
share_memory=True, inplace=True, num_threads=1
)
)
except Exception as err:
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
else:
child_pipe.send(next_td)

del next_td

elif cmd == "step_and_maybe_reset":
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3651,7 +3651,7 @@ def _select_observation_keys(self, tensordict: TensorDictBase) -> Iterator[str]:
if key.rfind("observation") >= 0:
yield key

def close(self):
def close(self, *, raise_if_closed: bool = True):
self.is_closed = True

def __del__(self):
Expand Down Expand Up @@ -3843,7 +3843,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
raise NotImplementedError

def close(self) -> None:
def close(self, *, raise_if_closed: bool = True) -> None:
"""Closes the contained environment if possible."""
self.is_closed = True
try:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/libs/pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def _update_agent_mask(self, td):
if agent not in agents_acting:
group_mask[index] = False

def close(self) -> None:
def close(self, *, raise_if_closed: bool = True) -> None:
self._env.close()


Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/libs/smacv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def _update_action_mask(self):
self.action_spec.update_mask(mask)
return mask

def close(self):
def close(self, *, raise_if_closed: bool = True):
# Closes StarCraft II
self._env.close()

Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/libs/unity_mlagents.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def _reset(
self._env.reset()
return self._make_td_out(tensordict, is_reset=True)

def close(self):
def close(self, *, raise_if_closed: bool = True):
self._env.close()

@_classproperty
Expand Down
6 changes: 3 additions & 3 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
tensordict = tensordict.select(
*self.reset_keys, *self.state_spec.keys(True, True), strict=False
)
tensordict = self.transform._reset_env_preprocess(tensordict)
tensordict = self.transform._reset_env_preprocess(tensordict)
tensordict_reset = self.base_env._reset(tensordict, **kwargs)
if tensordict is None:
# make sure all transforms see a source tensordict
Expand Down Expand Up @@ -1083,8 +1083,8 @@ def is_closed(self) -> bool:
def is_closed(self, value: bool):
self.base_env.is_closed = value

def close(self):
self.base_env.close()
def close(self, *, raise_if_closed: bool = True):
self.base_env.close(raise_if_closed=raise_if_closed)
self.is_closed = True

def empty_cache(self):
Expand Down
Loading