diff --git a/README.md b/README.md index 224a150..6678163 100644 --- a/README.md +++ b/README.md @@ -124,14 +124,13 @@ Options: - [ ] Remove `typing.Any` when used as base class - Generated `TypeVar`s - [ ] Prefix extracted `TypeVar`s names with `_` - - [ ] De-duplicate extracted `TypeVar`s - - [ ] Prevent `TypeVar` name clashes (rename or merge) + - [x] De-duplicate extracted typevar-likes with same name if equivalent + - [ ] Rename extracted typevar-likes with same name if not equivalent - [ ] Infer variance of `typing_extensions.TypeVar(..., infer_variance=True)` whose name does not end with `_contra`/`_in` (`contravariant=True`) or `_co`/`_out` (`covariant=True`) - [x] Convert `default=Any` to `default={bound}` or `default=object` - [x] Remove `bound=Any` and `bound=object` - - [ ] Importing `TypeVar`'s (not recommended) - Imports - [x] Reuse existing `from typing[_extensions] import {name}` imports instead of adding new ones diff --git a/tests/test_py311.py b/tests/test_py311.py index 902f433..a1ce72f 100644 --- a/tests/test_py311.py +++ b/tests/test_py311.py @@ -41,6 +41,7 @@ def test_type_alias_param(): """) pyi_expect = _src(""" from typing import TypeAlias, TypeVar + T = TypeVar("T") Pair: TypeAlias = tuple[T, T] """) @@ -54,6 +55,7 @@ def test_type_alias_param_bound(): """) pyi_expect = _src(""" from typing import TypeAlias, TypeVar + N = TypeVar("N", bound=int) Shape2D: TypeAlias = tuple[N, N] """) @@ -72,6 +74,7 @@ def test_type_alias_param_constraints(): from typing import TypeAlias, TypeVar S = TypeVar("S", bytes, str) + PathLike: TypeAlias = S | os.PathLike[S] """) pyi_out = convert(pyi_in, python=PythonVersion.PY311) @@ -85,6 +88,7 @@ def test_type_alias_param_default(): pyi_expect = _src(""" from typing import TypeAlias from typing_extensions import TypeVar + T = TypeVar("T", default=object) OneOrMany: TypeAlias = T | tuple[T, ...] """) @@ -99,6 +103,7 @@ def test_type_alias_params_order_mismatch(): pyi_expect = _src(""" from typing import TypeVar from typing_extensions import TypeAliasType + T1 = TypeVar("T1") T0 = TypeVar("T0") RPair = TypeAliasType("RPair", tuple[T0, T1], type_params=(T1, T0)) @@ -107,12 +112,38 @@ def test_type_alias_params_order_mismatch(): assert pyi_out == pyi_expect +def test_type_alias_dupe_same(): + pyi_in = _src(""" + type Solo[T] = tuple[T] + type Pair[T] = tuple[T, T] + """) + pyi_expect = _src(""" + from typing import TypeAlias, TypeVar + + T = TypeVar("T") + Solo: TypeAlias = tuple[T] + Pair: TypeAlias = tuple[T, T] + """) + pyi_out = convert(pyi_in, python=PythonVersion.PY311) + assert pyi_out == pyi_expect + + +def test_type_alias_dupe_clash(): + pyi_in = _src(""" + type Solo[T] = tuple[T] + type SoloName[T: str] = tuple[T] + """) + with pytest.raises(NotImplementedError): + convert(pyi_in, python=PythonVersion.PY311) + + def test_generic_function(): pyi_in = _src(""" def spam[T](x: T) -> T: ... """) pyi_expect = _src(""" from typing import TypeVar + T = TypeVar("T") def spam(x: T) -> T: ... """) @@ -126,6 +157,7 @@ def f[Z: complex](z: Z) -> Z: ... """) pyi_expect = _src(""" from typing import TypeVar + Z = TypeVar("Z", bound=complex) def f(z: Z) -> Z: ... """) @@ -139,6 +171,7 @@ def f[Z: (int, float, complex)](z: Z) -> Z: ... """) pyi_expect = _src(""" from typing import TypeVar + Z = TypeVar("Z", int, float, complex) def f(z: Z) -> Z: ... """) @@ -152,6 +185,7 @@ def f[Z: complex = complex](z: Z = ...) -> Z: ... """) pyi_expect = _src(""" from typing_extensions import TypeVar + Z = TypeVar("Z", bound=complex, default=complex) def f(z: Z = ...) -> Z: ... """) @@ -177,6 +211,40 @@ def f(z: Z = ...) -> Z: ... assert pyi_out == pyi_expect +def test_generic_function_dupe_same(): + pyi_in = _src(""" + def f[T](x: T, /) -> T: ... + def g[T](y: T, /) -> T: ... + """) + pyi_expect = _src(""" + from typing import TypeVar + + T = TypeVar("T") + def f(x: T, /) -> T: ... + def g(y: T, /) -> T: ... + """) + pyi_out = convert(pyi_in, python=PythonVersion.PY311) + assert pyi_out == pyi_expect + + +def test_generic_function_dupe_clash_bound(): + pyi_in = _src(""" + def f[T](x: T, /) -> T: ... + def g[T: str](v: T, /) -> T: ... + """) + with pytest.raises(NotImplementedError): + convert(pyi_in, python=PythonVersion.PY311) + + +def test_generic_function_dupe_clash_type(): + pyi_in = _src(""" + def f[T](x: T, /) -> T: ... + def g[*T](*xs: *T) -> T: ... + """) + with pytest.raises(NotImplementedError): + convert(pyi_in, python=PythonVersion.PY311) + + def test_generic_class(): pyi_in = _src(""" class C[T_contra, T, T_co]: ... @@ -184,6 +252,7 @@ class C[T_contra, T, T_co]: ... pyi_expect = _src(""" from typing import Generic from typing_extensions import TypeVar + T_contra = TypeVar("T_contra", contravariant=True) T = TypeVar("T", infer_variance=True) T_co = TypeVar("T_co", covariant=True) diff --git a/unpy/_py311.py b/unpy/_py311.py index 20f2556..cb581d5 100644 --- a/unpy/_py311.py +++ b/unpy/_py311.py @@ -1,12 +1,22 @@ import collections import functools from collections.abc import Callable -from typing import ClassVar, Final, Literal, LiteralString, Self, cast, override +from typing import Final, Literal, LiteralString, Self, cast, override import libcst as cst import libcst.matchers as m from libcst.metadata import Scope, ScopeProvider +from ._utils import ( + node_hash, + parse_assign, + parse_bool, + parse_call, + parse_str, + parse_subscript, + parse_tuple, +) + type _BuiltinModule = Literal[ "collections.abc", "inspect", @@ -19,92 +29,69 @@ type _NodeFlat[N: cst.CSTNode, FN: cst.CSTNode] = N | cst.FlattenSentinel[FN] type _NodeOptional[N: cst.CSTNode] = N | cst.RemovalSentinel -__all__ = "PY311Collector", "PY311Transformer" _PY313_TYPING_NAMES: Final = frozenset({"NoDefault", "ReadOnly", "TypeIs"}) _PY312_TYPING_NAMES: Final = frozenset({"TypeAliasType", "override"}) -def bool_expr(value: bool, /) -> cst.Name: - return cst.Name("True" if value else "False") - - -def str_expr(value: str, /) -> cst.SimpleString: - return cst.SimpleString(f'"{value}"') - - -def kwarg_expr(key: str, value: cst.BaseExpression, /) -> cst.Arg: - return cst.Arg( - keyword=cst.Name(key), - value=value, - equal=cst.AssignEqual(cst.SimpleWhitespace(""), cst.SimpleWhitespace("")), - ) - - -def _backport_type_alias( - node: cst.SimpleStatementLine, -) -> _NodeFlat[cst.SimpleStatementLine, cst.SimpleStatementLine]: +def _backport_type_alias(node: cst.SimpleStatementLine) -> cst.SimpleStatementLine: assert len(node.body) == 1 - type_alias_original = cast(cst.TypeAlias, node.body[0]) - name = type_alias_original.name + alias_original = cst.ensure_type(node.body[0], cst.TypeAlias) + name = alias_original.name - type_parameters = type_alias_original.type_parameters + type_parameters = alias_original.type_parameters tpars = type_parameters.params if type_parameters else () if len(tpars) > 1: # TODO: only do this if the order differs between the LHS and RHS. - type_alias_updated = cst.Assign( - [cst.AssignTarget(name)], - cst.Call( - cst.Name("TypeAliasType"), - [ - cst.Arg(str_expr(name.value)), - cst.Arg(type_alias_original.value), - kwarg_expr( - "type_params", - cst.Tuple([cst.Element(tpar.param.name) for tpar in tpars]), - ), - ], + alias_updated = parse_assign( + name, + parse_call( + "TypeAliasType", + parse_str(name.value), + alias_original.value, + type_params=parse_tuple(p.param.name for p in tpars), ), ) else: - type_alias_updated = cst.AnnAssign( - target=name, - annotation=cst.Annotation(cst.Name("TypeAlias")), - value=type_alias_original.value, + alias_updated = cst.AnnAssign( + name, + cst.Annotation(cst.Name("TypeAlias")), + alias_original.value, ) - if not tpars: - return cst.SimpleStatementLine( - [type_alias_updated], - leading_lines=node.leading_lines, - trailing_whitespace=node.trailing_whitespace, - ) + # if not tpars: + return cst.SimpleStatementLine( + [alias_updated], + leading_lines=node.leading_lines, + trailing_whitespace=node.trailing_whitespace, + ) - statements = [ - *(_backport_tpar(param) for param in tpars), - type_alias_updated, - ] - - lines: list[cst.SimpleStatementLine] = [] - for i, statement in enumerate(statements): - line = cst.SimpleStatementLine([statement]) - if i == 0: - line = line.with_changes(leading_lines=node.leading_lines) - elif i == len(statements) - 1: - line = line.with_changes( - trailing_whitespace=node.trailing_whitespace, - ) - lines.append(line) + # statements = [ + # *(_backport_tpar(param) for param in tpars), + # alias_updated, + # ] + + # lines: list[cst.SimpleStatementLine] = [] + # for i, statement in enumerate(statements): + # line = cst.SimpleStatementLine([statement]) + # if i == 0: + # line = line.with_changes(leading_lines=node.leading_lines) + # elif i == len(statements) - 1: + # line = line.with_changes( + # trailing_whitespace=node.trailing_whitespace, + # ) + # lines.append(line) - return cst.FlattenSentinel(lines) + # return cst.FlattenSentinel(lines) def _backport_tpar(tpar: cst.TypeParam, /, *, variant: bool = False) -> cst.Assign: param = tpar.param name = param.name.value - args = [cst.Arg(str_expr(name))] + args: list[cst.BaseExpression] = [parse_str(name)] + kwargs: dict[str, cst.BaseExpression] = {} match param: case cst.TypeVar(_, bound): @@ -115,16 +102,11 @@ def _backport_tpar(tpar: cst.TypeParam, /, *, variant: bool = False) -> cst.Assi variance = "covariant" else: variance = "infer_variance" - else: - variance = None - if variance: - args.append(kwarg_expr(variance, bool_expr(True))) + kwargs[variance] = parse_bool(True) match bound: case ( - None - | cst.Name("object") - | cst.Name("Any") + cst.Name("object" | "Any") | cst.Attribute(cst.Name("typing"), cst.Name("Any")) ): bound = None @@ -132,15 +114,14 @@ def _backport_tpar(tpar: cst.TypeParam, /, *, variant: bool = False) -> cst.Assi for el in elements: con = cst.ensure_type(el, cst.Element).value if isinstance(con, cst.Name) and con.value == "Any": - # `Any` is literally Hitler con = cst.Name("object") - args.append(cst.Arg(con)) + args.append(con) bound = None - case cst.BaseExpression(): + case cst.BaseExpression() | None: pass if bound: - args.append(kwarg_expr("bound", bound)) + kwargs["bound"] = bound case cst.TypeVarTuple(_) | cst.ParamSpec(_): bound = cst.Name("object") @@ -148,21 +129,17 @@ def _backport_tpar(tpar: cst.TypeParam, /, *, variant: bool = False) -> cst.Assi match default := tpar.default: case None: pass - case cst.Name("Any") as _b: - default = bound or cst.Name("object", lpar=_b.lpar, rpar=_b.rpar) + case cst.Name("Any"): + default = bound or cst.Name("object") case cst.BaseExpression(): pass if default: - args.append(kwarg_expr("default", default)) + kwargs["default"] = default # TODO: deal with existing `import {tname} as {tname_alias}` tname = type(param).__name__ - - return cst.Assign( - targets=[cst.AssignTarget(target=cst.Name(name))], - value=cst.Call(func=cst.Name(tname), args=args), - ) + return parse_assign(name, parse_call(tname, *args, **kwargs)) def _remove_tpars[N: _AnyDef](node: N, /) -> N: @@ -253,7 +230,9 @@ class PY311Collector(cst.CSTVisitor): current_imports: dict[_BuiltinModule, str] current_imports_from: dict[_BuiltinModule, dict[str, str]] missing_imports_from: dict[_BuiltinModule, set[str]] - missing_tvars: dict[str, list[cst.Assign]] + + missing_tvars: dict[str, list[cst.Assign]] # {root_node_name: [cst.Assign]} + visited_tvars: dict[str, int] # {typevar_name: typevar_hash} def __init__(self, /) -> None: self._stack = collections.deque() @@ -261,7 +240,9 @@ def __init__(self, /) -> None: self.current_imports = {} self.current_imports_from = collections.defaultdict(dict) self.missing_imports_from = collections.defaultdict(set) + self.missing_tvars = collections.defaultdict(list) + self.visited_tvars = {} super().__init__() @@ -315,6 +296,37 @@ def _require_typing_import( imports[module].add(name) return name + def _register_type_params( + self, + name: str, + tpars: cst.TypeParameters, + /, + *, + variant: bool = False, + ) -> None: + variant_suffixes = "_contra", "_in", "_co", "_out" + + visited_tvars = self.visited_tvars + missing_tvars = self.missing_tvars[name] + for tpar in tpars.params: + tname = tpar.param.name.value + thash = node_hash(tpar) + + if tname in visited_tvars: + if visited_tvars[tname] != thash: + raise NotImplementedError(f"Duplicate type param {tname!r}") + continue + + visited_tvars[tname] = thash + + missing_tvars.append(_backport_tpar(tpar, variant=variant)) + + if tpar.default or (variant and tname.endswith(variant_suffixes)): + module = "typing_extensions" + else: + module = "typing" + self._require_typing_import(module, type(tpar.param).__name__) + @override def visit_Module(self, node: cst.Module) -> None: self._scope = cst.ensure_type(self.get_metadata(ScopeProvider, node), Scope) @@ -429,15 +441,7 @@ def visit_TypeAlias(self, /, node: cst.TypeAlias) -> None: name = node.name.value assert name not in self.missing_tvars - self.missing_tvars[name].extend( - map(_backport_tpar, tpars.params), - ) - - for tpar in tpars.params: - self._require_typing_import( - "typing_extensions" if tpar.default else "typing", - type(tpar.param).__name__, - ) + self._register_type_params(name, tpars) # TODO: additionally require the LHS/RHS order to mismatch here if len(tpars.params) > 1: @@ -445,10 +449,24 @@ def visit_TypeAlias(self, /, node: cst.TypeAlias) -> None: self._require_typing_import(import_from, import_name) + @override + def visit_FunctionDef(self, /, node: cst.FunctionDef) -> bool | None: + stack = self._stack + stack.append(node.name.value) + + if tpars := node.type_parameters: + self._register_type_params(stack[0], tpars) + + @override + def leave_FunctionDef(self, /, original_node: cst.FunctionDef) -> None: + name = self._stack.pop() + assert name == original_node.name.value + @override def visit_ClassDef(self, /, node: cst.ClassDef) -> bool | None: stack = self._stack stack.append(node.name.value) + assert len(stack) > 1 or stack[0] not in self.missing_tvars if _get_typing_baseclass(node, "Path", modules={"pathlib"}): raise NotImplementedError("subclassing 'pathlib.Path` is not supported") @@ -463,59 +481,58 @@ def visit_ClassDef(self, /, node: cst.ClassDef) -> bool | None: # this will require an additional `typing.Generic` base class self._require_typing_import("typing", "Generic") - assert len(stack) > 1 or stack[0] not in self.missing_tvars - self.missing_tvars[stack[0]].extend( - _backport_tpar(tpar, variant=True) for tpar in tpars.params - ) - - for tpar in tpars.params: - tname = type(tpar.param).__name__ - - # `default=...` and `TypeVar(..., infer_variance=True)` require importing - # from `typing_extensions` - name = tpar.param.name.value - if tpar.default or ( - tname == "TypeVar" - and not name.endswith(("_contra", "_in", "_co", "_out")) - ): - tmodule = "typing_extensions" - else: - tmodule = "typing" - - self._require_typing_import(tmodule, tname) + self._register_type_params(stack[0], tpars, variant=True) @override def leave_ClassDef(self, /, original_node: cst.ClassDef) -> None: name = self._stack.pop() assert name == original_node.name.value - @override - def visit_FunctionDef(self, /, node: cst.FunctionDef) -> bool | None: - stack = self._stack - stack.append(node.name.value) - if not (tpars := node.type_parameters): - return +def _new_import_statement_index(module_node: cst.Module) -> int: + # find the first import statement in the module body + i_insert = 0 + illegal_direct_imports = frozenset({"typing", "typing_extensions"}) + for i, statement in enumerate(module_node.body): + if not isinstance(statement, cst.SimpleStatementLine): + continue - self.missing_tvars[stack[0]].extend(map(_backport_tpar, tpars.params)) + for node in statement.body: + if not isinstance(node, cst.Import | cst.ImportFrom): + continue - for tpar in tpars.params: - self._require_typing_import( - "typing_extensions" if tpar.default else "typing", - type(tpar.param).__name__, - ) + _done = False + if isinstance(node, cst.Import): + if any(a.name.value in illegal_direct_imports for a in node.names): + raise NotImplementedError("import typing[_extensions]") + # insert after all `import ...` statements + i_insert = i + 1 + continue - @override - def leave_FunctionDef(self, /, original_node: cst.FunctionDef) -> None: - name = self._stack.pop() - assert name == original_node.name.value + if node.relative: + # insert before the first relative import + return i + match node.module: + case cst.Name("typing"): + # insert the (`typing_extensions`) import after `typing` + return i + 1 + case cst.Name("typing_extensions"): + # insert the (`typing`) import before `typing_extensions` + return i + case cst.Name(name): + # otherwise, assume alphabetically sorted on module, and + # and insert to maintain the order + if name > "typing_extensions": + return i + + i_insert = i + 1 + case _: + continue + return i_insert -class PY311Transformer(m.MatcherDecoratableTransformer): - _MATCH_TYPING_IMPORT: ClassVar = m.ImportFrom( - m.Name("typing") | m.Name("typing_extensions"), - ) +class PY311Transformer(m.MatcherDecoratableTransformer): _stack: collections.deque[str] current_imports: dict[_BuiltinModule, str] @@ -523,6 +540,15 @@ class PY311Transformer(m.MatcherDecoratableTransformer): missing_imports_from: dict[_BuiltinModule, set[str]] missing_tvars: dict[str, list[cst.Assign]] + @classmethod + def from_collector(cls, collector: PY311Collector, /) -> Self: + return cls( + current_imports=collector.current_imports, + current_imports_from=collector.current_imports_from, + missing_imports_from=collector.missing_imports_from, + missing_tvars=collector.missing_tvars, + ) + def __init__( self, /, @@ -535,8 +561,8 @@ def __init__( self._stack = collections.deque() self.current_imports = current_imports self.current_imports_from = current_imports_from - self.missing_tvars = missing_tvars self.missing_imports_from = missing_imports_from + self.missing_tvars = missing_tvars super().__init__() def _del_imports_from(self, module: _BuiltinModule, /) -> set[str]: @@ -555,27 +581,93 @@ def _add_imports_from(self, module: _BuiltinModule, /) -> set[str]: self.current_imports_from[module], ) + @override + def leave_ImportFrom( + self, + /, + original_node: cst.ImportFrom, + updated_node: cst.ImportFrom, + ) -> _NodeOptional[cst.ImportFrom]: + if updated_node.relative or not updated_node.module: + return updated_node + + module = cast(_BuiltinModule, _as_name(updated_node.module)) + + names_del = self._del_imports_from(module) + names_add = self._add_imports_from(module) + + if not (names_del or names_add): + return updated_node + + aliases = updated_node.names + assert not isinstance(aliases, cst.ImportStar) + + aliases_new = [a for a in aliases if a.name.value not in names_del] + aliases_new.extend(cst.ImportAlias(cst.Name(name)) for name in names_add) + aliases_new.sort(key=lambda a: cst.ensure_type(a.name, cst.Name).value) + + if not aliases_new: + return cst.RemoveFromParent() + + # remove trailing comma + if isinstance(aliases_new[-1].comma, cst.Comma): + aliases_new[-1] = aliases_new[-1].with_changes(comma=None) + + return updated_node.with_changes(names=aliases_new) + + def _prepend_tvars[N: _AnyDef]( + self, + /, + node: N, + ) -> _NodeFlat[N, N | cst.SimpleStatementLine]: + if not (tvars := self.missing_tvars.get(node.name.value, [])): + return node + + leading_lines = node.leading_lines or [cst.EmptyLine()] + lines = ( + cst.SimpleStatementLine([tvar], () if i else leading_lines) + for i, tvar in enumerate(tvars) + ) + return cst.FlattenSentinel([*lines, node]) + @m.call_if_inside(m.Module([m.ZeroOrMore(m.SimpleStatementLine())])) @m.leave(m.SimpleStatementLine([m.TypeAlias()])) @_workaround_libcst_runtime_typecheck_bug - def desugar_type_alias( # noqa: PLR6301 + def leave_type_alias_statement( self, /, - _: cst.SimpleStatementLine, + original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine, ) -> _NodeFlat[cst.SimpleStatementLine, cst.SimpleStatementLine]: - return _backport_type_alias(updated_node) + node = _backport_type_alias(updated_node) - def _prepend_tvars[N: _AnyDef](self, /, node: N) -> _NodeFlat[N, cst.BaseStatement]: - if not (tvars := self.missing_tvars.get(node.name.value, [])): + alias_original = cst.ensure_type(original_node.body[0], cst.TypeAlias) + if not (tvars := self.missing_tvars.get(alias_original.name.value, [])): return node + leading_lines = node.leading_lines or [cst.EmptyLine()] lines = ( - cst.SimpleStatementLine([tvar], node.leading_lines if i == 0 else ()) + cst.SimpleStatementLine([tvar], () if i else leading_lines) for i, tvar in enumerate(tvars) ) return cst.FlattenSentinel([*lines, node]) + @override + def visit_FunctionDef(self, /, node: cst.FunctionDef) -> None: + self._stack.append(node.name.value) + + @override + def leave_FunctionDef( + self, + /, + original_node: cst.FunctionDef, + updated_node: cst.FunctionDef, + ) -> cst.FunctionDef | cst.FlattenSentinel[cst.BaseStatement]: + self._stack.pop() + + updated_node = _remove_tpars(updated_node) + return updated_node if self._stack else self._prepend_tvars(updated_node) + @override def visit_ClassDef(self, /, node: cst.ClassDef) -> None: self._stack.append(node.name.value) @@ -587,20 +679,16 @@ def leave_ClassDef( original_node: cst.ClassDef, updated_node: cst.ClassDef, ) -> _NodeFlat[cst.ClassDef, cst.BaseStatement]: - name = self._stack.pop() - assert name == updated_node.name.value + stack = self._stack + stack.pop() if not (tpars := original_node.type_parameters): return updated_node - subscripts = [ - cst.SubscriptElement(cst.Index(type_param.param.name)) - for type_param in tpars.params - ] - + tpar_names = (tpar.param.name for tpar in tpars.params) if base_protocol := _get_typing_baseclass(original_node, "Protocol"): new_bases = [ - cst.Arg(cst.Subscript(base_protocol, subscripts)) + cst.Arg(parse_subscript(base_protocol, *tpar_names)) if base_arg.value is base_protocol else base_arg for base_arg in original_node.bases @@ -608,67 +696,11 @@ def leave_ClassDef( else: new_bases = [ *updated_node.bases, - cst.Arg(cst.Subscript(cst.Name("Generic"), subscripts)), + cst.Arg(parse_subscript("Generic", *tpar_names)), ] updated_node = updated_node.with_changes(type_parameters=None, bases=new_bases) - - return self._prepend_tvars(updated_node) if not self._stack else updated_node - - @override - def visit_FunctionDef(self, /, node: cst.FunctionDef) -> None: - self._stack.append(node.name.value) - - @override - def leave_FunctionDef( - self, - /, - original_node: cst.FunctionDef, - updated_node: cst.FunctionDef, - ) -> cst.FunctionDef | cst.FlattenSentinel[cst.BaseStatement]: - name = self._stack.pop() - assert name == updated_node.name.value - - updated_node = _remove_tpars(updated_node) - - if self._stack: - return updated_node - - return self._prepend_tvars(updated_node) - - @override - def leave_ImportFrom( - self, - /, - original_node: cst.ImportFrom, - updated_node: cst.ImportFrom, - ) -> _NodeOptional[cst.ImportFrom]: - if updated_node.relative or not updated_node.module: - return updated_node - - module = cast(_BuiltinModule, _as_name(updated_node.module)) - - names_del = self._del_imports_from(module) - names_add = self._add_imports_from(module) - - if not (names_del or names_add): - return updated_node - - aliases = updated_node.names - assert not isinstance(aliases, cst.ImportStar) - - aliases_new = [a for a in aliases if a.name.value not in names_del] - aliases_new.extend(cst.ImportAlias(cst.Name(name)) for name in names_add) - aliases_new.sort(key=lambda a: cst.ensure_type(a.name, cst.Name).value) - - if not aliases_new: - return cst.RemoveFromParent() - - # remove trailing comma - if isinstance(aliases_new[-1].comma, cst.Comma): - aliases_new[-1] = aliases_new[-1].with_changes(comma=None) - - return updated_node.with_changes(names=aliases_new) + return self._prepend_tvars(updated_node) if not stack else updated_node @override def leave_Module( @@ -677,9 +709,6 @@ def leave_Module( original_node: cst.Module, updated_node: cst.Module, ) -> cst.Module: - return self._update_imports(updated_node) - - def _update_imports(self, /, module_node: cst.Module) -> cst.Module: new_statements = [ cst.SimpleStatementLine([ cst.ImportFrom( @@ -693,71 +722,17 @@ def _update_imports(self, /, module_node: cst.Module) -> cst.Module: ] if not new_statements: - return module_node - - i_insert = self._new_import_statement_index(module_node) + return updated_node - # NOTE: newlines and other formatting won't be done here; use e.g. ruff instead - return module_node.with_changes( + i_insert = _new_import_statement_index(updated_node) + return updated_node.with_changes( body=[ - *module_node.body[:i_insert], + *updated_node.body[:i_insert], *new_statements, - *module_node.body[i_insert:], + *updated_node.body[i_insert:], ], ) - @staticmethod - def _new_import_statement_index(module_node: cst.Module) -> int: - # find the first import statement in the module body - i_insert = 0 - illegal_direct_imports = frozenset({"typing", "typing_extensions"}) - for i, statement in enumerate(module_node.body): - if not isinstance(statement, cst.SimpleStatementLine): - continue - - for node in statement.body: - if not isinstance(node, cst.Import | cst.ImportFrom): - continue - - _done = False - if isinstance(node, cst.Import): - if any(a.name.value in illegal_direct_imports for a in node.names): - raise NotImplementedError("import typing[_extensions]") - # insert after all `import ...` statements - i_insert = i + 1 - continue - - if node.relative: - # insert before the first relative import - return i - - match node.module: - case cst.Name("typing"): - # insert the (`typing_extensions`) import after `typing` - return i + 1 - case cst.Name("typing_extensions"): - # insert the (`typing`) import before `typing_extensions` - return i - case cst.Name(name): - # otherwise, assume alphabetically sorted on module, and - # and insert to maintain the order - if name > "typing_extensions": - return i - - i_insert = i + 1 - case _: - continue - return i_insert - - @classmethod - def from_collector(cls, collector: PY311Collector, /) -> Self: - return cls( - current_imports=collector.current_imports, - current_imports_from=collector.current_imports_from, - missing_imports_from=collector.missing_imports_from, - missing_tvars=collector.missing_tvars, - ) - def transform(original: cst.Module, /) -> cst.Module: wrapper = cst.MetadataWrapper(original) diff --git a/unpy/_utils.py b/unpy/_utils.py new file mode 100644 index 0000000..c821736 --- /dev/null +++ b/unpy/_utils.py @@ -0,0 +1,216 @@ +import functools +from collections.abc import Iterable +from itertools import starmap +from typing import Literal, cast + +import libcst as cst +from libcst.helpers import filter_node_fields, get_full_name_for_node + +__all__ = [ + "as_dict", + "as_module", + "node_code", + "node_hash", + "parse_assign", + "parse_bool", + "parse_call", + "parse_kwarg", + "parse_str", + "parse_tuple", +] + +type StringPrefix = Literal["", "r", "u", "b", "br", "rb"] +type StringQuote = Literal["'", '"', "'''", '"""'] + + +def as_module( + node: cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel[cst.CSTNode], + /, +) -> cst.Module: + match node: + case cst.Module() as module: + return module + case cst.SimpleStatementLine() | cst.BaseCompoundStatement() as stmt: + return cst.Module([stmt]) + case cst.BaseSmallStatement() as stmt: + return cst.Module([cst.SimpleStatementLine([stmt])]) + case cst.RemovalSentinel(): + return cst.Module([]) + case cst.FlattenSentinel() as nodes: + body: list[cst.SimpleStatementLine | cst.BaseCompoundStatement] = [] + for n in nodes.nodes: + body.extend(as_module(n).body) + return cst.Module(body) + case _: + raise TypeError(type(node)) + + +def as_dict( + node: cst.CSTNode, + /, + *, + syntax: bool = False, + defaults: bool = False, + whitespace: bool = False, +) -> dict[str, object]: + kwargs = { + "syntax": syntax, + "defaults": defaults, + "whitespace": whitespace, + } + + out: dict[str, object] = {} + for field in filter_node_fields( + node, + show_syntax=syntax, + show_defaults=defaults, + show_whitespace=whitespace, + ): + key = field.name + match getattr(node, key): + case cst.CSTNode() as child: + value = as_dict(child, **kwargs) + case [*children] if children and isinstance(children[0], cst.CSTNode): + value = [ + as_dict(child, **kwargs) + for child in cast(list[cst.CSTNode], children) + ] + case _ as value: # pyright: ignore[reportAny] + pass + + out[key] = value + + return out + + +def as_tuple( + node: cst.CSTNode, + /, + *, + syntax: bool = False, + defaults: bool = False, + whitespace: bool = False, +) -> tuple[str, tuple[object, ...]]: + kwargs = { + "syntax": syntax, + "defaults": defaults, + "whitespace": whitespace, + } + + out: list[object] = [] + for field in filter_node_fields( + node, + show_syntax=syntax, + show_defaults=defaults, + show_whitespace=whitespace, + ): + key = field.name + match getattr(node, key): + case cst.CSTNode() as child: + value = as_tuple(child, **kwargs) + case [*children] if children and isinstance(children[0], cst.CSTNode): + value = tuple( + as_tuple(child, **kwargs) + for child in cast(list[cst.CSTNode], children) + ) + case [*values]: + value = tuple(values) + case _ as value: # pyright: ignore[reportAny] + pass + out.append(value) + + return type(node).__name__, tuple(out) + + +def node_hash(node: cst.CSTNode, /, **kwargs: bool) -> int: + return hash((type(node).__name__, as_tuple(node, **kwargs))) + + +def node_code( + node: cst.CSTNode | cst.RemovalSentinel | cst.FlattenSentinel[cst.CSTNode], + /, +) -> str: + if isinstance(node, cst.CSTNode): + return get_full_name_for_node(node) or as_module(node).code + return as_module(node).code + + +@functools.cache +def parse_bool(value: bool | Literal[0, 1], /) -> cst.Name: + return cst.Name("True" if value else "False") + + +def parse_str( + value: str, + /, + *, + quote: StringQuote = '"', + prefix: StringPrefix = "", +) -> cst.SimpleString: + return cst.SimpleString(f"{prefix}{quote}{value}{quote}") + + +def parse_kwarg(key: str, value: cst.BaseExpression, /) -> cst.Arg: + return cst.Arg( + keyword=cst.Name(key), + value=value, + equal=cst.AssignEqual(cst.SimpleWhitespace(""), cst.SimpleWhitespace("")), + ) + + +def parse_tuple( + exprs: Iterable[cst.BaseExpression], + /, + *, + star: cst.BaseExpression | None = None, + parens: bool = True, +) -> cst.Tuple: + elems: list[cst.BaseElement] = [cst.Element(el) for el in exprs] + if star is not None: + elems.append(cst.StarredElement(star)) + + return cst.Tuple(elems) if parens else cst.Tuple(elems, [], []) + + +def _name_or_expr[T: cst.BaseExpression](value: T | str, /) -> T | cst.Name: + return value if isinstance(value, cst.BaseExpression) else cst.Name(value) + + +def parse_call( + func: cst.BaseExpression | str, + /, + *args: cst.BaseExpression, + **kwargs: cst.BaseExpression, +) -> cst.Call: + return cst.Call( + _name_or_expr(func), + [*map(cst.Arg, args), *starmap(parse_kwarg, kwargs.items())], + ) + + +def parse_subscript( + base: cst.BaseExpression | str, + /, + *ixs: cst.BaseSlice | cst.BaseExpression, +) -> cst.Subscript: + elems = [ + cst.SubscriptElement(ix if isinstance(ix, cst.BaseSlice) else cst.Index(ix)) + for ix in ixs + ] + return cst.Subscript(_name_or_expr(base), elems) + + +def parse_assign( + target: ( + cst.BaseAssignTargetExpression + | str + | tuple[cst.BaseAssignTargetExpression | str, ...] + ), + value: cst.BaseExpression, + /, +) -> cst.Assign: + if isinstance(target, cst.BaseAssignTargetExpression | str): + targets = [_name_or_expr(target)] + else: + targets = map(_name_or_expr, target) + return cst.Assign(list(map(cst.AssignTarget, targets)), value)