Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Options:
- [ ] [PEP 673][PEP673]: `typing.Self` => `typing_extensions.Self`
- [ ] [PEP 655][PEP655]: `typing.[Not]Required` => `typing_extensions.[Not]Required`
- [ ] [PEP 646][PEP646]: `*Ts` => `typing_extensions.Unpack[Ts]`
- [ ] Remove `typing.Any` when used as base class
- Generated `TypeVar`s
- [ ] Prefix extracted `TypeVar`s names with `_`
- [ ] De-duplicate extracted `TypeVar`s
Expand Down
21 changes: 21 additions & 0 deletions tests/test_py311.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import textwrap

import pytest
from unpy.convert import PythonVersion, convert


Expand Down Expand Up @@ -353,3 +354,23 @@ class HasArrayStruct(Protocol):
""")
pyi_out = convert(pyi_in, python=PythonVersion.PY311)
assert pyi_out == pyi_expect


def test_subclass_path():
pyi_in = _src("""
from pathlib import Path

class MyPath(Path): ...
""")
with pytest.raises(NotImplementedError):
convert(pyi_in, python=PythonVersion.PY311)


def test_subclass_pathlib_path():
pyi_in = _src("""
import pathlib

class MyPath(pathlib.Path): ...
""")
with pytest.raises(NotImplementedError):
convert(pyi_in, python=PythonVersion.PY311)
28 changes: 17 additions & 11 deletions unpy/_py311.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ def _get_typing_baseclass(
node: cst.ClassDef,
base_name: LiteralString,
/,
modules: set[str] | None = None,
) -> cst.Name | cst.Attribute | None:
if modules is None:
modules = {"typing", "typing_extensions"}

base_expr_matches: list[cst.Name | cst.Attribute] = []
for base_arg in node.bases:
if base_arg.keyword or base_arg.star:
Expand All @@ -184,20 +188,19 @@ def _get_typing_baseclass(
match base_expr := base_arg.value:
case cst.Name(_name) if _name == base_name:
return base_expr
case cst.Attribute(
cst.Name("typing" | "typing_extensions"),
cst.Name(_name),
) if _name == base_name:
case cst.Attribute(cst.Name(_module), cst.Name(_name)) if (
_name == base_name and _module in modules
):
base_expr_matches.append(base_expr)
case cst.Subscript(
cst.Name(_name)
| cst.Attribute(
cst.Name("typing" | "typing_extensions"),
cst.Name(_name),
),
) if _name == base_name:
case cst.Subscript(cst.Name(_name)) if _name == base_name:
raise NotImplementedError(f"{base_name!r} base class with type params")
case cst.Subscript(cst.Attribute(cst.Name(_module), cst.Name(_name))) if (
_name == base_name and _module in modules
):
base_qname = f"{_module}.{_name}"
raise NotImplementedError(f"{base_qname!r} base class with type params")
case _:
# maybe raise here?
pass

match base_expr_matches:
Expand Down Expand Up @@ -447,6 +450,9 @@ def visit_ClassDef(self, /, node: cst.ClassDef) -> bool | None:
stack = self._stack
stack.append(node.name.value)

if _get_typing_baseclass(node, "Path", modules={"pathlib"}):
raise NotImplementedError("subclassing 'pathlib.Path` is not supported")

if not (tpars := node.type_parameters):
return

Expand Down