diff --git a/CHANGELOG.md b/CHANGELOG.md index b68756ccd..54af7bbb4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for global settings file, `tango.yml`. - Added 'include_package' (array of string) param to config spec. +- Added a custom error `StopEarly` that a `TrainCallback` can raise within the `TorchTrainStep` + to stop training early without crashing. ### Fixed diff --git a/docs/source/api/integrations/torch.rst b/docs/source/api/integrations/torch.rst index 75dca3e6f..f9d9ac2f3 100644 --- a/docs/source/api/integrations/torch.rst +++ b/docs/source/api/integrations/torch.rst @@ -53,3 +53,6 @@ Callbacks .. autoclass:: tango.integrations.torch.TrainCallback :members: + +.. autoclass:: tango.integrations.torch.StopEarly + :members: diff --git a/tango/integrations/torch/__init__.py b/tango/integrations/torch/__init__.py index fd3a95056..4da90f587 100644 --- a/tango/integrations/torch/__init__.py +++ b/tango/integrations/torch/__init__.py @@ -106,6 +106,24 @@ def run(self) -> DatasetDict: Loading best weights from state_worker0_step100.pt ✓ Finished run for "train" +Tips +---- + +Debugging +~~~~~~~~~ + +When debugging a training loop that's causing errors on a GPU, you should set the environment variable +``CUDA_LAUNCH_BLOCKING=1``. This will ensure that the stack traces shows where the error actually happened. + +You could also use a custom :class:`TrainCallback` to log each batch before they are passed into the model +so that you can see the exact inputs that are causing the issue. + +Stopping early +~~~~~~~~~~~~~~ + +You can stop the "torch::train" step early using a custom :class:`TrainCallback`. Your callback just +needs to raise the :class:`StopEarly` exception. + """ __all__ = [ @@ -119,6 +137,7 @@ def run(self) -> DatasetDict: "Sampler", "ConcatTensorDictsCollator", "TrainCallback", + "StopEarly", ] from .data import DataLoader, Sampler, DataCollator, ConcatTensorDictsCollator @@ -126,4 +145,4 @@ def run(self) -> DatasetDict: from .model import Model from .optim import Optimizer, LRScheduler from .train import TorchTrainStep -from .train_callback import TrainCallback +from .train_callback import TrainCallback, StopEarly diff --git a/tango/integrations/torch/train.py b/tango/integrations/torch/train.py index f99373441..026eef220 100644 --- a/tango/integrations/torch/train.py +++ b/tango/integrations/torch/train.py @@ -17,7 +17,7 @@ from .format import TorchFormat from .model import Model from .optim import Optimizer, LRScheduler -from .train_callback import TrainCallback +from .train_callback import TrainCallback, StopEarly from tango.common.dataset_dict import DatasetDict from tango.common.exceptions import ConfigurationError from tango.common.lazy import Lazy @@ -471,13 +471,14 @@ def _train( for callback in callbacks: callback.pre_train_loop() - with Tqdm.tqdm( + train_batch_iterator = Tqdm.tqdm( training_batches, desc="Training", initial=start_step, total=train_steps, disable=not is_local_main_process, - ) as train_batch_iterator: + ) + try: for step, batch in train_batch_iterator: def is_best_checkpoint() -> bool: @@ -708,6 +709,11 @@ def save_state(): # Checkpoint. if should_checkpoint_this_step: save_state() + except StopEarly: + if is_local_main_process: + print("Stopping earlyy!") + finally: + train_batch_iterator.close() if is_distributed: dist.barrier() diff --git a/tango/integrations/torch/train_callback.py b/tango/integrations/torch/train_callback.py index 6d5a2f094..f59009631 100644 --- a/tango/integrations/torch/train_callback.py +++ b/tango/integrations/torch/train_callback.py @@ -6,6 +6,7 @@ from .data import DataLoader from .model import Model from .optim import Optimizer, LRScheduler +from tango.common.exceptions import TangoError from tango.common.registrable import Registrable @@ -142,3 +143,13 @@ def post_val_loop(self, step: int, val_metric_name: str, val_metric: float) -> N Called right after the validation loop finishes. """ pass + + +class StopEarly(TangoError): + """ + Callbacks can raise this exception to stop training early without crashing. + + .. important:: + During distributed training all workers must raise this exception at the same point + in the training loop, otherwise there will be a deadlock. + """