Skip to content

Conversation

@vmoens
Copy link
Collaborator

@vmoens vmoens commented Jul 25, 2022

R3M integration in torchrl

As soon as I'm granted access to the domain's s3 bucket i'll upload the resnet weights.

Here's a code snippet to get started:

import torch
from torchrl.envs.transforms import R3MTransform
from torchrl.data import TensorDict

for shape in ([], [2], [2, 3]):
    td = TensorDict({"next_observation": torch.randint(255, (*shape, 224, 224, 3))}, shape)
    transform = R3MTransform(model_name="resnet50", keys_in=["next_observation"])
    transform(td)
    print(td)

Downloading weights

Weights can be downloaded easily via
transform = R3MTransform(model_name="resnet50", keys_in=["next_observation"], download=True)

Benefits:

  • Faster execution
  • Possibility of executing R3M on a large number of workers (current implementation collides above 8 workers or so)
  • Possibility of downloading specific versions of the R3M
  • Makes it easy to execute R3M on images stored in the replay buffer (if needed)
  • Fully customisable (e.g. can return both the R3M embedding and the orignal image, can work with any environment that return images, not only gym)

Efficiency

When comparing with the old implementation, the following code returns a speed of 1.4 sec / batch of 4 rollout on CUDA for the torchrl version, compared with 5.8 sec for the mj_envs implementation

from torchrl.trainers.helpers.envs import LIBS
from utils import MJEnv
from torchrl.envs import R3MTransform, TransformedEnv, ParallelEnv
import time
import torch

LIBS["mjenv"] = MJEnv

device = torch.device("cuda:0") if torch.has_cuda and torch.cuda.device_count() else torch.device("cpu")
if __name__ == "__main__":
    transform = R3MTransform(model_name="resnet18", keys_in=["next_pixels"])
    env1 = TransformedEnv(
        ParallelEnv(16, lambda: MJEnv("kitchen_micro_open-v3", from_pixels=True,
                                     device=device)
                    ),
        transform=transform.eval(),
    )
    assert env1.device == device
    # assert not env1.transform.training
    # assert not env1.transform[-1].convnet.training
    del transform

    env1.reset()
    t0 = time.time()
    print(env1.rollout(max_steps=20))
    print(time.time() - t0)
    t0 = time.time()
    print(env1.rollout(max_steps=20))
    print(time.time() - t0)
    env1.close()
    del env1

    env2 = ParallelEnv(
        16,
        lambda: MJEnv("kitchen_micro_open-v4", device=device)
    )
    assert env2.device == device
    env2.reset()
    t0 = time.time()
    print(env2.rollout(max_steps=20))
    print(time.time() - t0)
    t0 = time.time()
    print(env2.rollout(max_steps=20))
    print(time.time() - t0)
    env2.close()
    del env2

(MJEnv is the TorchRL wrapper for mj_env environments)

TODO:

  • store weights on AWS
  • write tests

cc @vikashplus @suraj-nair-1

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 25, 2022
@vmoens vmoens added the enhancement New feature or request label Jul 25, 2022
@vmoens vmoens requested a review from vikashplus August 8, 2022 17:03
@vmoens vmoens changed the title [WIP]: R3M integration [Feature]: R3M integration Aug 31, 2022
@vmoens vmoens marked this pull request as ready for review August 31, 2022 18:22
@vmoens vmoens merged commit a61c8a5 into main Aug 31, 2022
@vmoens vmoens deleted the r3m_integration branch August 31, 2022 18:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants