Skip to content

Conversation

@zeenolife
Copy link
Contributor

Description

As described in #390, this PR addresses initialisation of TensorDicts from nested Python dictionaries.

Namely, it includes three new functionalities:

  1. Nested dict initialisation, with inferred batch size and device:
TensorDict({"a": {"b": torch.randn(3, 1)}}, [3])

returns

TensorDict(
    fields={
        a: TensorDict(
            fields={
                b: Tensor(torch.Size([3, 1]), dtype=torch.float32)},
            batch_size=torch.Size([3]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
  1. Setting dict value, with inferred batch size and device:
td = TensorDict({"a": torch.randn(3, 1)}, [3])
td["b"] = {"c": torch.randn(3, 4)}

returns

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 1]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
            batch_size=torch.Size([3]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
  1. Recursive conversion of TensorDict into dict:
td = TensorDict({"a": torch.randn(3, 1)}, [3])
td["b"] = {"c": torch.randn(3, 4)}
td.to_dict()

returns

{'a': tensor([[-2.0345],
        [ 0.8855],
        [-0.6279]]), 'b': {'c': tensor([[ 0.2649, -1.3553, -0.0903,  1.7265],
        [-0.0252,  1.1936, -0.2416,  0.1220],
        [ 0.2263,  0.6542,  0.4279,  0.2826]])}}

Motivation and Context

The motivation and context for the change is described in #390

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • New feature (non-breaking change which adds core functionality)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • ? My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • ? I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 8, 2022
@zeenolife
Copy link
Contributor Author

It seems like the tests are currently failing due to the new version of gym, and it's currently being handled at #403

@zeenolife
Copy link
Contributor Author

PTAL @vmoens

@vmoens vmoens changed the title [DRAFT] Adding support for initialising TensorDicts from nested dicts. Addressing #390 [DRAFT] Adding support for initialising TensorDicts from nested dicts Sep 8, 2022
@vmoens vmoens linked an issue Sep 8, 2022 that may be closed by this pull request
3 tasks
@vmoens vmoens added the enhancement New feature or request label Sep 8, 2022
@zeenolife zeenolife changed the title [DRAFT] Adding support for initialising TensorDicts from nested dicts [Feature] Adding support for initialising TensorDicts from nested dicts Sep 9, 2022
Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there: It seems everything is coded and the CI is happy!
Can we test the setting

tensordict[:, :2] = {"a": torch.randn(3, 4, 5), ...}

One way to test that would be

sub_td = tensordict[:, :2].to_tensordict()  # clone the data to a new tensordict
sub_td.zero_()
sub_dict = sub_td.to_dict()
tensordict[:, :2] = sub_dict
# check that all values in tensordict[:, :2] are zero

)
raise err
else:
indexed_bs = _getitem_batch_size(self.batch_size, index)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've put the TensorDict casting over here, because the batch size wouldn't be computed correctly for broadcasting

- future
- cloudpickle
- gym
- gym==0.25.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be able to remove this now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main branch currently has it fixed too, should I remove it here?

@zeenolife
Copy link
Contributor Author

Note about the PR:

  • Ideally, we would want the _process_tensor() function to be the main and only "pre-processor" of the values. This would allow us to remove all the logic from __setitem__ dunder method, and put it into .set*() methods. However, currently the child classes of the TensorDictBase, particularly, their .set*() methods are not consistent. This results in complicated intercalls, and errors. I propose to create another task, to make all .set*() methods to be consistent, so that _process_tensor() would be a single "pre-processor" of the input

@vmoens
Copy link
Collaborator

vmoens commented Sep 12, 2022

  • Ideally, we would want the _process_tensor() function to be the main and only "pre-processor" of the values. This would allow us to remove all the logic from __setitem__ dunder method, and put it into .set*() methods. However, currently the child classes of the TensorDictBase, particularly, their .set*() methods are not consistent. This results in complicated intercalls, and errors. I propose to create another task, to make all .set*() methods to be consistent, so that _process_tensor() would be a single "pre-processor" of the input

We definitely need some hardcore cleanup in the calls from __setitem__ -> set / set_ / set_at_ -> _process_tensor!

@vmoens vmoens merged commit 59007c3 into pytorch:main Sep 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Creating tensordicts from nested dictionaries (and returning them)

3 participants