Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ dependencies:
- mujoco_py
- hydra-core
- tensorboard
- wandb
- dm_control
1 change: 1 addition & 0 deletions .circleci/unittest/linux_stable/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ dependencies:
- mujoco_py
- hydra-core
- tensorboard
- wandb
- dm_control
18 changes: 14 additions & 4 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
transformed_env_constructor,
EnvConfig,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.losses import make_ddpg_loss, LossConfig
from torchrl.trainers.helpers.models import (
make_ddpg_actor,
DDPGModelConfig,
)
from torchrl.trainers.helpers.recorder import RecorderConfig
from torchrl.trainers.helpers.replay_buffer import (
make_replay_buffer,
ReplayArgsConfig,
Expand All @@ -46,7 +46,7 @@
EnvConfig,
LossConfig,
DDPGModelConfig,
RecorderConfig,
LoggerConfig,
ReplayArgsConfig,
)
for config_field in dataclasses.fields(config_cls)
Expand All @@ -68,7 +68,6 @@

@hydra.main(version_base=None, config_path=".", config_name="config")
def main(cfg: "DictConfig"):
from torchrl.trainers.loggers.tensorboard import TensorboardLogger

cfg = correct_for_frame_skip(cfg)

Expand All @@ -89,7 +88,18 @@ def main(cfg: "DictConfig"):
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)
logger = TensorboardLogger(f"ddpg_logging/{exp_name}")
if cfg.logger == "tensorboard":
from torchrl.trainers.loggers.tensorboard import TensorboardLogger

logger = TensorboardLogger(log_dir="ddpg_logging", exp_name=exp_name)
elif cfg.logger == "csv":
from torchrl.trainers.loggers.csv import CSVLogger

logger = CSVLogger(log_dir="ddpg_logging", exp_name=exp_name)
elif cfg.logger == "wandb":
from torchrl.trainers.loggers.wandb import WandbLogger

logger = WandbLogger(log_dir="ddpg_logging", exp_name=exp_name)
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down
20 changes: 15 additions & 5 deletions examples/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
transformed_env_constructor,
EnvConfig,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.losses import make_dqn_loss, LossConfig
from torchrl.trainers.helpers.models import (
make_dqn_actor,
DiscreteModelConfig,
)
from torchrl.trainers.helpers.recorder import RecorderConfig
from torchrl.trainers.helpers.replay_buffer import (
make_replay_buffer,
ReplayArgsConfig,
Expand All @@ -46,7 +46,7 @@
EnvConfig,
LossConfig,
DiscreteModelConfig,
RecorderConfig,
LoggerConfig,
ReplayArgsConfig,
)
for config_field in dataclasses.fields(config_cls)
Expand All @@ -59,8 +59,6 @@
@hydra.main(version_base=None, config_path=".", config_name="config")
def main(cfg: "DictConfig"):

from torchrl.trainers.loggers.tensorboard import TensorboardLogger

cfg = correct_for_frame_skip(cfg)

if not isinstance(cfg.reward_scaling, float):
Expand All @@ -80,7 +78,19 @@ def main(cfg: "DictConfig"):
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)
logger = TensorboardLogger(f"dqn_logging/{exp_name}")
if cfg.logger == "tensorboard":
from torchrl.trainers.loggers.tensorboard import TensorboardLogger

logger = TensorboardLogger(log_dir="dqn_logging", exp_name=exp_name)
elif cfg.logger == "csv":
from torchrl.trainers.loggers.csv import CSVLogger

logger = CSVLogger(log_dir="dqn_logging", exp_name=exp_name)
elif cfg.logger == "wandb":
from torchrl.trainers.loggers.wandb import WandbLogger

logger = WandbLogger(log_dir="dqn_logging", exp_name=exp_name)

video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down
18 changes: 14 additions & 4 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
transformed_env_constructor,
EnvConfig,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.losses import make_ppo_loss, PPOLossConfig
from torchrl.trainers.helpers.models import (
make_ppo_model,
PPOModelConfig,
)
from torchrl.trainers.helpers.recorder import RecorderConfig
from torchrl.trainers.helpers.trainers import make_trainer, TrainerConfig

config_fields = [
Expand All @@ -42,7 +42,7 @@
EnvConfig,
PPOLossConfig,
PPOModelConfig,
RecorderConfig,
LoggerConfig,
)
for config_field in dataclasses.fields(config_cls)
]
Expand All @@ -54,7 +54,6 @@

@hydra.main(version_base=None, config_path=".", config_name="config")
def main(cfg: "DictConfig"):
from torchrl.trainers.loggers.tensorboard import TensorboardLogger

cfg = correct_for_frame_skip(cfg)

Expand All @@ -75,7 +74,18 @@ def main(cfg: "DictConfig"):
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)
logger = TensorboardLogger(f"ppo_logging/{exp_name}")
if cfg.logger == "tensorboard":
from torchrl.trainers.loggers.tensorboard import TensorboardLogger

logger = TensorboardLogger(log_dir="ppo_logging", exp_name=exp_name)
elif cfg.logger == "csv":
from torchrl.trainers.loggers.csv import CSVLogger

logger = CSVLogger(log_dir="ppo_logging", exp_name=exp_name)
elif cfg.logger == "wandb":
from torchrl.trainers.loggers.wandb import WandbLogger

logger = WandbLogger(log_dir="ppo_logging", exp_name=exp_name)
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down
18 changes: 14 additions & 4 deletions examples/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
transformed_env_constructor,
EnvConfig,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.losses import make_redq_loss, LossConfig
from torchrl.trainers.helpers.models import (
make_redq_model,
REDQModelConfig,
)
from torchrl.trainers.helpers.recorder import RecorderConfig
from torchrl.trainers.helpers.replay_buffer import (
make_replay_buffer,
ReplayArgsConfig,
Expand All @@ -46,7 +46,7 @@
EnvConfig,
LossConfig,
REDQModelConfig,
RecorderConfig,
LoggerConfig,
ReplayArgsConfig,
)
for config_field in dataclasses.fields(config_cls)
Expand All @@ -69,7 +69,6 @@

@hydra.main(version_base=None, config_path=".", config_name="config")
def main(cfg: "DictConfig"):
from torchrl.trainers.loggers.tensorboard import TensorboardLogger

cfg = correct_for_frame_skip(cfg)

Expand All @@ -90,7 +89,18 @@ def main(cfg: "DictConfig"):
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)
logger = TensorboardLogger(f"redq_logging/{exp_name}")
if cfg.logger == "tensorboard":
from torchrl.trainers.loggers.tensorboard import TensorboardLogger

logger = TensorboardLogger(log_dir="redq_logging", exp_name=exp_name)
elif cfg.logger == "csv":
from torchrl.trainers.loggers.csv import CSVLogger

logger = CSVLogger(log_dir="redq_logging", exp_name=exp_name)
elif cfg.logger == "wandb":
from torchrl.trainers.loggers.wandb import WandbLogger

logger = WandbLogger(log_dir="redq_logging", exp_name=exp_name)
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down
18 changes: 14 additions & 4 deletions examples/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
transformed_env_constructor,
EnvConfig,
)
from torchrl.trainers.helpers.logger import LoggerConfig
from torchrl.trainers.helpers.losses import make_sac_loss, LossConfig
from torchrl.trainers.helpers.models import (
make_sac_model,
SACModelConfig,
)
from torchrl.trainers.helpers.recorder import RecorderConfig
from torchrl.trainers.helpers.replay_buffer import (
make_replay_buffer,
ReplayArgsConfig,
Expand All @@ -46,7 +46,7 @@
EnvConfig,
LossConfig,
SACModelConfig,
RecorderConfig,
LoggerConfig,
ReplayArgsConfig,
)
for config_field in dataclasses.fields(config_cls)
Expand All @@ -69,7 +69,6 @@

@hydra.main(version_base=None, config_path=".", config_name="config")
def main(cfg: "DictConfig"):
from torchrl.trainers.loggers.tensorboard import TensorboardLogger

cfg = correct_for_frame_skip(cfg)

Expand All @@ -90,7 +89,18 @@ def main(cfg: "DictConfig"):
datetime.now().strftime("%y_%m_%d-%H_%M_%S"),
]
)
logger = TensorboardLogger(f"sac_logging/{exp_name}")
if cfg.logger == "tensorboard":
from torchrl.trainers.loggers.tensorboard import TensorboardLogger

logger = TensorboardLogger(log_dir="sac_logging", exp_name=exp_name)
elif cfg.logger == "csv":
from torchrl.trainers.loggers.csv import CSVLogger

logger = CSVLogger(log_dir="sac_logging", exp_name=exp_name)
elif cfg.logger == "wandb":
from torchrl.trainers.loggers.wandb import WandbLogger

logger = WandbLogger(log_dir="sac_logging", exp_name=exp_name)
video_tag = exp_name if cfg.record_video else ""

stats = None
Expand Down
100 changes: 100 additions & 0 deletions test/test_loggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import argparse
import os.path
import tempfile
from time import sleep

import pytest
import torch
from torchrl.trainers.loggers.csv import CSVLogger
from torchrl.trainers.loggers.tensorboard import TensorboardLogger, _has_tb
from torchrl.trainers.loggers.wandb import WandbLogger, _has_wandb


@pytest.mark.skipif(not _has_tb, reason="TensorBoard not installed")
class TestTensorboard:
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
def test_log_scalar(self, steps):
torch.manual_seed(0)
with tempfile.TemporaryDirectory() as log_dir:
exp_name = "ramala"
logger = TensorboardLogger(log_dir=log_dir, exp_name=exp_name)

values = torch.rand(3)
for i in range(3):
scalar_name = "foo"
scalar_value = values[i].item()
logger.log_scalar(
value=scalar_value,
name=scalar_name,
step=steps[i] if steps else None,
)

sleep(0.01) # wait until events are registered
from tensorboard.backend.event_processing.event_accumulator import (
EventAccumulator,
)

event_acc = EventAccumulator(logger.experiment.get_logdir())
event_acc.Reload()
assert len(event_acc.Scalars("foo")) == 3, str(event_acc.Scalars("foo"))
for i in range(3):
assert event_acc.Scalars("foo")[i].value == values[i]
if steps:
assert event_acc.Scalars("foo")[i].step == steps[i]


class TestCSVLogger:
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
def test_log_scalar(self, steps):
torch.manual_seed(0)
with tempfile.TemporaryDirectory() as log_dir:
exp_name = "ramala"
logger = CSVLogger(log_dir=log_dir, exp_name=exp_name)

values = torch.rand(3)
for i in range(3):
scalar_name = "foo"
scalar_value = values[i].item()
logger.log_scalar(
value=scalar_value,
name=scalar_name,
step=steps[i] if steps else None,
)

with open(
os.path.join(log_dir, exp_name, "scalars", "foo.csv"), "r"
) as file:
for i, row in enumerate(file.readlines()):
step = steps[i] if steps else i
assert row == f"{step},{values[i].item()}\n"


@pytest.mark.skipif(not _has_wandb, reason="Wandb not installed")
class TestWandbLogger:
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
def test_log_scalar(self, steps):
torch.manual_seed(0)
with tempfile.TemporaryDirectory() as log_dir:
exp_name = "ramala"
logger = WandbLogger(log_dir=log_dir, exp_name=exp_name, offline=True)

values = torch.rand(3)
for i in range(3):
scalar_name = "foo"
scalar_value = values[i].item()
logger.log_scalar(
value=scalar_value,
name=scalar_name,
step=steps[i] if steps else None,
)

assert logger.experiment.summary["foo"] == values[-1].item()
assert logger.experiment.summary["_step"] == i if not steps else steps[i]

logger.experiment.finish()
del logger


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
2 changes: 1 addition & 1 deletion torchrl/trainers/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
from .envs import *
from .losses import *
from .models import *
from .recorder import *
from .logger import *
from .replay_buffer import *
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@


@dataclass
class RecorderConfig:
class LoggerConfig:
logger: str = "csv"
# recorder type to be used. One of 'tensorboard', 'wandb' or 'csv'
record_video: bool = False
# whether a video of the task should be rendered during logging.
no_video: bool = True
Expand Down
3 changes: 2 additions & 1 deletion torchrl/trainers/loggers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ class Logger:

"""

def __init__(self, exp_name: str) -> None:
def __init__(self, exp_name: str, log_dir: str) -> None:
self.exp_name = exp_name
self.log_dir = log_dir
self.experiment = self._create_experiment()

@abc.abstractmethod
Expand Down
Loading