Skip to content

Conversation

benjaminglass1
Copy link
Collaborator

@benjaminglass1 benjaminglass1 commented Aug 18, 2025

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Aug 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160897

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b8dc3d8 with merge base 1c16c18 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Comment on lines +3024 to +3025
replica._buffers = dict(replica._buffers)
replica._modules = dict(replica._modules)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MutableMapping doesn't define .copy. This was the only code change that was required to make this change work.

[ghstack-poisoned]
[ghstack-poisoned]
@benjaminglass1 benjaminglass1 marked this pull request as ready for review August 19, 2025 21:44
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Aug 20, 2025
…arify torch.nn.Module typing

ghstack-source-id: bd07986
Pull Request resolved: #160897
@eellison
Copy link
Contributor

eellison commented Sep 3, 2025

cc @albanD mind taking a look?

@amjames amjames added module: cpp-extensions Related to torch.utils.cpp_extension module: nn Related to torch.nn labels Sep 9, 2025
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Sep 10, 2025
…arify torch.nn.Module typing

ghstack-source-id: fd1955f
Pull Request resolved: #160897
training: bool
_parameters: dict[str, Optional[Parameter]]
_buffers: dict[str, Optional[Tensor]]
_parameters: MutableMapping[str, Optional[Parameter]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this? All of these are actually dict right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD Only in some cases! For instance, they could be OrderedDictWrapper, which wraps a pybind11 C++ dict-like class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is really used in practice but ok.
I'll let @lolpack decide on this one since it's pure typing.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard to say what the downstream effect of this change is, but if the underlying value is not always a dict than this seems like a legitimate increase in strictness.

Sorry to kick the question back to you, but it's hard to say what the downstream impact on users would be without more context. If this is not user facing, then IMO you might as well switch. If it is user facing, then I think you need to trade off between backwards-compat, convenience, and safety.

Maybe you know that in all user-facing cases, the value is actually a dict, and users can safely use dict-only methods? In that case, I say continue using dict.

If the MutableMapping interface is sufficient for common use cases, maybe it's worth changing it for safety, even at the risk of some churn for users.

I think it's a judgment call.

@rec
Copy link
Collaborator

rec commented Sep 23, 2025

I ran a "delta" on the pyright type checking (using this utility) before and after this commit: the change is not huge, but universally positive, adding +0.12% to "completenessScore".

[
    {
        "absolute": {
            "exportedSymbolCounts": {
                "withKnownType": 8,
                "withUnknownType": -6
            },
            "completenessScore": 0.0004445407249853428,
        },
        "percent": {
            "exportedSymbolCounts": {
                "withKnownType": 0.13054830287206268,
                "withUnknownType": -0.06430178973314757
            },
            "completenessScore": 0.11826611356782056,
        },
        "symbols": {
            "added": [
                "torch.nn.cpp.OrderedDictWrapper.__delitem__",
                "torch.nn.cpp.OrderedDictWrapper.__setitem__"
            ],
            "common": {
                "torch.nn.cpp.OrderedDictWrapper.__contains__": {
                    "isTypeKnown": true,
                    "diagnostics": {
                        "removed": [
                            "Return type annotation is missing",
                            "Type annotation for parameter \"key\" is missing"
                        ]
                    }
                },
                "torch.nn.cpp.OrderedDictWrapper.__getitem__": {
                    "isTypeKnown": true,
                    "diagnostics": {
                        "removed": [
                            "Return type annotation is missing",
                            "Type annotation for parameter \"key\" is missing"
                        ]
                    }
                },
                "torch.nn.cpp.OrderedDictWrapper.__init__": {
                    "diagnostics": {
                        "removed": [
                            "Type annotation for parameter \"attr\" is missing"
                        ]
                    }
                },
                "torch.nn.cpp.OrderedDictWrapper.__iter__": {
                    "isTypeKnown": true,
                    "diagnostics": {
                        "removed": [
                            "Return type annotation is missing"
                        ]
                    }
                },
                "torch.nn.cpp.OrderedDictWrapper.__len__": {
                    "isTypeKnown": true,
                    "diagnostics": {
                        "removed": [
                            "Return type annotation is missing"
                        ]
                    }
                },
                "torch.nn.cpp.OrderedDictWrapper.cpp_module": {
                    "isTypeKnown": false,
                    "diagnostics": {
                        "added": [
                            "Type unknown for variable \"torch.nn.cpp.OrderedDictWrapper.cpp_module\""
                        ]
                    }
                },
                "torch.nn.cpp.OrderedDictWrapper.items": {
                    "isTypeKnown": true,
                    "diagnostics": {
                        "removed": [
                            "Return type annotation is missing"
                        ]
                    }
                },
                "torch.nn.cpp.OrderedDictWrapper.keys": {
                    "isTypeKnown": true,
                    "diagnostics": {
                        "removed": [
                            "Return type annotation is missing"
                        ]
                    }
                },
                "torch.nn.cpp.OrderedDictWrapper.values": {
                    "isTypeKnown": true,
                    "diagnostics": {
                        "removed": [
                            "Return type annotation is missing"
                        ]
                    }
                }
            },
            "removed": []
        },
        "file_names": [
            "/home/rec/git/pytorch/outputs/pyright.1c16c18a534.json",
            "/home/rec/git/pytorch/outputs/pyright.11e97bc7bd4.json"
        ]
    }
]

namespace torch::python {
namespace {
template <typename T>
void bind_ordered_dict(py::module module, const char* dict_name) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is going to break cpp20 because module is a keyword, fun...

# must manually override them.

def items(self):
def items(self) -> list[tuple[str, _T]]: # type: ignore[override]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def items(self) -> list[tuple[str, _T]]: # type: ignore[override]
def items(self) -> collections.abc.ItemsView[str, _T]]:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Skylion007 So this would actually be incorrect. What is returned from this method is not an ItemsView, but a list of tuple. We'd have to rewrite parts of the underlying C++ class to make an ItemsView possible, is my read after tackling this code. Same applies to the two comments below.

return self.cpp_dict.items()

def keys(self):
def keys(self) -> list[str]: # type: ignore[override]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def keys(self) -> list[str]: # type: ignore[override]
def keys(self) -> collections.abc.KeysView[str]:

return self.cpp_dict.keys()

def values(self):
def values(self) -> list[_T]: # type: ignore[override]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def values(self) -> list[_T]: # type: ignore[override]
def values(self) -> collections.abc.ValuesView[_T]:

def __iter__(self):
# This should return an Iterator[str], but OrderedDict::item is not currently
# designed to let us iterate over only the keys.
def __iter__(self) -> Iterator[tuple[str, _T]]: # type: ignore[override]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we just call iter keys? or is it an API change...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No API change; this wraps an underlying C++ custom class we've implemented, and we don't have methods available for simply looking at the keys. One could possibly be implemented, but I was more interested in documenting the existing behavior in this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cpp-extensions Related to torch.utils.cpp_extension module: nn Related to torch.nn open source release notes: cpp release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants