Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `IterableDatasetDict`, a version of `DatasetDict` for streaming-like datasets.
- Added a [PyTorch Lightning](https://www.pytorchlightning.ai) integration with `LightningTrainStep`.

### Fixed

- Fixed bug with `FromParams` and `Lazy` where extra arguments would sometimes be passed down through
to a `Lazy` class when they shouldn't.

## [v0.2.4](https://github.com/allenai/tango/releases/tag/v0.2.4) - 2021-10-22

### Added
Expand Down
3 changes: 1 addition & 2 deletions tango/common/from_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,7 @@ def construct_arg(
return default

value_cls = args[0]
subextras = create_extras(value_cls, extras)
return Lazy(value_cls, params=deepcopy(popped_params), constructor_extras=subextras) # type: ignore
return Lazy(value_cls, params=deepcopy(popped_params)) # type: ignore

# For any other kind of iterable, we will just assume that a list is good enough, and treat
# it the same as List. This condition needs to be at the end, so we don't catch other kinds
Expand Down
2 changes: 2 additions & 0 deletions tango/integrations/pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def run(self) -> DatasetDict:
"LightningLogger",
"LightningModule",
"LightningProfiler",
"LightningPlugin",
"LightningTrainStep",
"LightningTrainer",
]
Expand All @@ -140,5 +141,6 @@ def run(self) -> DatasetDict:
from .callbacks import LightningCallback
from .loggers import LightningLogger
from .model import LightningModule
from .plugins import LightningPlugin
from .profilers import LightningProfiler
from .train import LightningTrainer, LightningTrainStep
20 changes: 20 additions & 0 deletions tango/integrations/pytorch_lightning/plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytorch_lightning as pl

from tango.common.registrable import Registrable


class LightningPlugin(pl.plugins.Plugin, Registrable):
"""
This is simply a :class:`~tango.common.registrable.Registrable`
version of the PyTorch Lightning :class:`~pytorch_lightning.plugins.Plugin` class.
"""


# Register all callbacks.
for name, cls in pl.plugins.__dict__.items():
if (
isinstance(cls, type)
and issubclass(cls, pl.plugins.Plugin)
and not cls == pl.plugins.Plugin
):
LightningPlugin.register("pytorch_lightning::" + name)(cls)
68 changes: 36 additions & 32 deletions tango/integrations/pytorch_lightning/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import List, Optional, Union

import pytorch_lightning as pl
Expand All @@ -15,6 +16,7 @@
from .callbacks import LightningCallback
from .loggers import LightningLogger
from .model import LightningModule
from .plugins import LightningPlugin
from .profilers import LightningProfiler


Expand All @@ -24,6 +26,38 @@ class LightningTrainer(pl.Trainer, Registrable): # type: ignore
the PyTorch Lightning :class:`~pytorch_lightning.trainer.trainer.Trainer`.
"""

def __init__(
self,
work_dir: Path,
logger: Optional[Union[List[Lazy[LightningLogger]], Lazy[LightningLogger]]] = None,
callbacks: Optional[List[LightningCallback]] = None,
profiler: Optional[Union[str, Lazy[LightningProfiler]]] = None,
accelerator: Optional[Union[str, LightningAccelerator]] = None,
plugins: Optional[List[LightningPlugin]] = None,
**kwargs,
):
loggers: List[LightningLogger] = (
[]
if not logger
else [
logger_.construct(save_dir=work_dir)
for logger_ in (logger if isinstance(logger, list) else [logger])
]
)

profiler: Optional[Union[str, LightningProfiler]] = (
profiler.construct(dirpath=work_dir) if isinstance(profiler, Lazy) else profiler
)

super().__init__(
logger=loggers,
callbacks=callbacks,
profiler=profiler,
accelerator=accelerator,
plugins=plugins,
**kwargs,
)

def _to_params(self):
return {}

Expand Down Expand Up @@ -55,10 +89,6 @@ def run( # type: ignore[override]
*,
validation_dataloader: Lazy[DataLoader] = None,
validation_split: str = "validation",
loggers: Optional[List[Lazy[LightningLogger]]] = None,
callbacks: Optional[List[Lazy[LightningCallback]]] = None,
profiler: Optional[Union[str, Lazy[LightningProfiler]]] = None,
accelerator: Optional[Union[str, Lazy[LightningAccelerator]]] = None,
) -> torch.nn.Module:

"""
Expand Down Expand Up @@ -87,43 +117,17 @@ def run( # type: ignore[override]
:class:`dict` objects. If not specified, but ``validation_split`` is given,
the validation ``DataLoader`` will be constructed from the same parameters
as the train ``DataLoader``.
loggers: List[:class:`LightningLogger`]
A list of :class:`LightningLogger`.
callbacks: List[:class:`LightningCallback`]
A list of :class:`LightningCallback`.
profiler: Union[:class:`LightningProfiler`, :class:`str`], optional
:class:`LightningProfiler` object.
accelerator: Union[:class:`LightningAccelerator`, :class:`str`], optional
:class:`LightningAccelerator` object.

Returns
-------
:class:`LightningModule`
The trained model on CPU with the weights from the best checkpoint loaded.

"""
loggers: List[LightningLogger] = [
logger.construct(save_dir=self.work_dir) for logger in (loggers or [])
]

callbacks: List[LightningCallback] = [
callback.construct() for callback in (callbacks or [])
]

profiler: Optional[Union[str, LightningProfiler]] = (
profiler.construct(dirpath=self.work_dir) if isinstance(profiler, Lazy) else profiler
)

accelerator: Optional[Union[str, LightningAccelerator]] = (
accelerator.construct() if isinstance(accelerator, Lazy) else accelerator
)

trainer: LightningTrainer = trainer.construct(
logger=loggers, callbacks=callbacks, profiler=profiler, accelerator=accelerator
)
trainer: LightningTrainer = trainer.construct(work_dir=self.work_dir)

# Find the checkpoint callback and make sure it uses the right directory.
checkpoint_callback: pl.callbacks.model_checkpoint.ModelCheckpoint

for callback in trainer.callbacks:
if isinstance(callback, pl.callbacks.model_checkpoint.ModelCheckpoint):
callback.dirpath = self.work_dir
Expand Down
15 changes: 9 additions & 6 deletions test_fixtures/integrations/pytorch_lightning/train.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
"trainer": {
"type": "default",
"max_epochs": 5,
"log_every_n_steps": 3
},
"loggers": ["pytorch_lightning::TensorBoardLogger", "pytorch_lightning::CSVLogger"],
"accelerator": "cpu",
"profiler": {
"type": "pytorch_lightning::SimpleProfiler",
"log_every_n_steps": 3,
"logger": [
{"type": "pytorch_lightning::TensorBoardLogger"},
{"type": "pytorch_lightning::CSVLogger"},
],
"accelerator": "cpu",
"profiler": {
"type": "pytorch_lightning::SimpleProfiler",
},
},
"dataset_dict": {
"type": "ref",
Expand Down