Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,12 @@ jobs:
run: |
. .venv/bin/activate
which python
echo "========="
python --version
echo "========="
pip freeze
echo "========="
tango info

- name: ${{ matrix.task.name }}
run: |
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Fixed bug with `TorchEvalStep` when constructing callbacks.
- Fixed some import error issues caused when an integration is not installed.

### Changed

Expand Down
4 changes: 2 additions & 2 deletions tango/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
import click
from click_help_colors import HelpColorsCommand, HelpColorsGroup

from tango.common.exceptions import CliRunError
from tango.common.exceptions import CliRunError, IntegrationMissingError
from tango.common.logging import (
cli_logger,
initialize_logging,
Expand Down Expand Up @@ -479,7 +479,7 @@ def info(settings: TangoGlobalSettings):
is_installed = True
try:
import_module_and_submodules(integration)
except (ModuleNotFoundError, ImportError):
except (IntegrationMissingError, ModuleNotFoundError, ImportError):
is_installed = False
if is_installed:
cli_logger.info(" [green]\N{check mark} %s[/]", name)
Expand Down
10 changes: 5 additions & 5 deletions tango/common/registrable.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ def search_modules(cls: Type[_RegistrableT], name: str):
):
return None

def try_import(module):
def try_import(module, recursive: bool = True):
try:
import_module_and_submodules(module)
import_module_and_submodules(module, recursive=recursive)
except IntegrationMissingError:
pass
except ImportError as e:
Expand All @@ -200,7 +200,7 @@ def try_import(module):
integrations = {m.split(".")[-1]: m for m in find_integrations()}
integrations_imported: Set[str] = set()
if name in integrations:
try_import(integrations[name])
try_import(integrations[name], recursive=False)
integrations_imported.add(name)
if name in Registrable._registry[cls]:
return None
Expand All @@ -209,7 +209,7 @@ def try_import(module):
# Try to guess the integration that it comes from.
maybe_integration = name.split("::")[0]
if maybe_integration in integrations:
try_import(integrations[maybe_integration])
try_import(integrations[maybe_integration], recursive=False)
integrations_imported.add(maybe_integration)
if name in Registrable._registry[cls]:
return None
Expand Down Expand Up @@ -248,7 +248,7 @@ def try_import(module):
# Try importing all other integrations.
for integration_name, module in integrations.items():
if integration_name not in integrations_imported:
try_import(module)
try_import(module, recursive=False)
integrations_imported.add(integration_name)
if name in Registrable._registry[cls]:
return None
Expand Down
21 changes: 12 additions & 9 deletions tango/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def resolve_module_name(package_name: str) -> Tuple[str, Path]:
return package_name, base_path


def import_module_and_submodules(package_name: str, exclude: Optional[Set[str]] = None) -> None:
def import_module_and_submodules(
package_name: str, exclude: Optional[Set[str]] = None, recursive: bool = True
) -> None:
"""
Import all submodules under the given package.

Expand Down Expand Up @@ -105,14 +107,15 @@ def import_module_and_submodules(package_name: str, exclude: Optional[Set[str]]
path = getattr(module, "__path__", [])
path_string = "" if not path else path[0]

# walk_packages only finds immediate children, so need to recurse.
for module_finder, name, _ in pkgutil.walk_packages(path):
# Sometimes when you import third-party libraries that are on your path,
# `pkgutil.walk_packages` returns those too, so we need to skip them.
if path_string and module_finder.path != path_string: # type: ignore[union-attr]
continue
subpackage = f"{package_name}.{name}"
import_module_and_submodules(subpackage, exclude=exclude)
if recursive:
# walk_packages only finds immediate children, so need to recurse.
for module_finder, name, _ in pkgutil.walk_packages(path):
# Sometimes when you import third-party libraries that are on your path,
# `pkgutil.walk_packages` returns those too, so we need to skip them.
if path_string and module_finder.path != path_string: # type: ignore[union-attr]
continue
subpackage = f"{package_name}.{name}"
import_module_and_submodules(subpackage, exclude=exclude)


def _parse_bool(value: Union[bool, str]) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions tango/integrations/beaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from tango.common.exceptions import IntegrationMissingError

try:
import beaker
except ModuleNotFoundError:
from beaker import Beaker
except (ModuleNotFoundError, ImportError):
raise IntegrationMissingError("beaker", dependencies={"beaker-py"})

from .executor import (
Expand Down
13 changes: 10 additions & 3 deletions tango/integrations/wandb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,22 @@
from .step_cache import WandbStepCache
from .workspace import WandbWorkspace

try:
import torch
except ModuleNotFoundError:
pass
else:
from .torch_train_callback import WandbTrainCallback

__all__.append("WandbTrainCallback")

try:
import flax
import jax
import tensorflow # flax has a tensorflow dependency
import torch
except ModuleNotFoundError:
pass
else:
from .flax_train_callback import WandbFlaxTrainCallback
from .torch_train_callback import WandbTrainCallback

__all__.append("WandbTrainCallback")
__all__.append("WandbFlaxTrainCallback")
3 changes: 3 additions & 0 deletions tests/common/registrable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ def test_registering_step_by_reserved_name(self):
@Step.register("ref")
class BadStep(Step):
pass

def test_search_modules(self):
Step.search_modules("foo-bar-baz-non-existent")