Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for global settings file, `tango.yml`.
- Added 'include_package' (array of string) param to config spec.
- Added a custom error `StopEarly` that a `TrainCallback` can raise within the `TorchTrainStep`
to stop training early without crashing.

### Fixed

Expand Down
3 changes: 3 additions & 0 deletions docs/source/api/integrations/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,6 @@ Callbacks

.. autoclass:: tango.integrations.torch.TrainCallback
:members:

.. autoclass:: tango.integrations.torch.StopEarly
:members:
21 changes: 20 additions & 1 deletion tango/integrations/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,24 @@ def run(self) -> DatasetDict:
Loading best weights from state_worker0_step100.pt
✓ Finished run for "train"

Tips
----

Debugging
~~~~~~~~~

When debugging a training loop that's causing errors on a GPU, you should set the environment variable
``CUDA_LAUNCH_BLOCKING=1``. This will ensure that the stack traces shows where the error actually happened.

You could also use a custom :class:`TrainCallback` to log each batch before they are passed into the model
so that you can see the exact inputs that are causing the issue.

Stopping early
~~~~~~~~~~~~~~

You can stop the "torch::train" step early using a custom :class:`TrainCallback`. Your callback just
needs to raise the :class:`StopEarly` exception.

"""

__all__ = [
Expand All @@ -119,11 +137,12 @@ def run(self) -> DatasetDict:
"Sampler",
"ConcatTensorDictsCollator",
"TrainCallback",
"StopEarly",
]

from .data import DataLoader, Sampler, DataCollator, ConcatTensorDictsCollator
from .format import TorchFormat
from .model import Model
from .optim import Optimizer, LRScheduler
from .train import TorchTrainStep
from .train_callback import TrainCallback
from .train_callback import TrainCallback, StopEarly
12 changes: 9 additions & 3 deletions tango/integrations/torch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .format import TorchFormat
from .model import Model
from .optim import Optimizer, LRScheduler
from .train_callback import TrainCallback
from .train_callback import TrainCallback, StopEarly
from tango.common.dataset_dict import DatasetDict
from tango.common.exceptions import ConfigurationError
from tango.common.lazy import Lazy
Expand Down Expand Up @@ -471,13 +471,14 @@ def _train(
for callback in callbacks:
callback.pre_train_loop()

with Tqdm.tqdm(
train_batch_iterator = Tqdm.tqdm(
training_batches,
desc="Training",
initial=start_step,
total=train_steps,
disable=not is_local_main_process,
) as train_batch_iterator:
)
try:
for step, batch in train_batch_iterator:

def is_best_checkpoint() -> bool:
Expand Down Expand Up @@ -708,6 +709,11 @@ def save_state():
# Checkpoint.
if should_checkpoint_this_step:
save_state()
except StopEarly:
if is_local_main_process:
print("Stopping earlyy!")
finally:
train_batch_iterator.close()

if is_distributed:
dist.barrier()
Expand Down
11 changes: 11 additions & 0 deletions tango/integrations/torch/train_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .data import DataLoader
from .model import Model
from .optim import Optimizer, LRScheduler
from tango.common.exceptions import TangoError
from tango.common.registrable import Registrable


Expand Down Expand Up @@ -142,3 +143,13 @@ def post_val_loop(self, step: int, val_metric_name: str, val_metric: float) -> N
Called right after the validation loop finishes.
"""
pass


class StopEarly(TangoError):
"""
Callbacks can raise this exception to stop training early without crashing.

.. important::
During distributed training all workers must raise this exception at the same point
in the training loop, otherwise there will be a deadlock.
"""