-
Notifications
You must be signed in to change notification settings - Fork 25.5k
[cpp_extension] Add abstract base class to OrderedDictWrapper, and clarify torch.nn.Module typing #160897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/benjaminglass1/102/base
Are you sure you want to change the base?
[cpp_extension] Add abstract base class to OrderedDictWrapper, and clarify torch.nn.Module typing #160897
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160897
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b8dc3d8 with merge base 1c16c18 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
replica._buffers = dict(replica._buffers) | ||
replica._modules = dict(replica._modules) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MutableMapping
doesn't define .copy
. This was the only code change that was required to make this change work.
cc @albanD mind taking a look? |
training: bool | ||
_parameters: dict[str, Optional[Parameter]] | ||
_buffers: dict[str, Optional[Tensor]] | ||
_parameters: MutableMapping[str, Optional[Parameter]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this? All of these are actually dict right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD Only in some cases! For instance, they could be OrderedDictWrapper
, which wraps a pybind11
C++ dict-like class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is really used in practice but ok.
I'll let @lolpack decide on this one since it's pure typing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's hard to say what the downstream effect of this change is, but if the underlying value is not always a dict
than this seems like a legitimate increase in strictness.
Sorry to kick the question back to you, but it's hard to say what the downstream impact on users would be without more context. If this is not user facing, then IMO you might as well switch. If it is user facing, then I think you need to trade off between backwards-compat, convenience, and safety.
Maybe you know that in all user-facing cases, the value is actually a dict, and users can safely use dict-only methods? In that case, I say continue using dict.
If the MutableMapping interface is sufficient for common use cases, maybe it's worth changing it for safety, even at the risk of some churn for users.
I think it's a judgment call.
I ran a "delta" on the pyright type checking (using this utility) before and after this commit: the change is not huge, but universally positive, adding +0.12% to "completenessScore".
|
namespace torch::python { | ||
namespace { | ||
template <typename T> | ||
void bind_ordered_dict(py::module module, const char* dict_name) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, this is going to break cpp20 because module is a keyword, fun...
# must manually override them. | ||
|
||
def items(self): | ||
def items(self) -> list[tuple[str, _T]]: # type: ignore[override] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def items(self) -> list[tuple[str, _T]]: # type: ignore[override] | |
def items(self) -> collections.abc.ItemsView[str, _T]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Skylion007 So this would actually be incorrect. What is returned from this method is not an ItemsView
, but a list
of tuple
. We'd have to rewrite parts of the underlying C++ class to make an ItemsView
possible, is my read after tackling this code. Same applies to the two comments below.
return self.cpp_dict.items() | ||
|
||
def keys(self): | ||
def keys(self) -> list[str]: # type: ignore[override] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def keys(self) -> list[str]: # type: ignore[override] | |
def keys(self) -> collections.abc.KeysView[str]: |
return self.cpp_dict.keys() | ||
|
||
def values(self): | ||
def values(self) -> list[_T]: # type: ignore[override] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def values(self) -> list[_T]: # type: ignore[override] | |
def values(self) -> collections.abc.ValuesView[_T]: |
def __iter__(self): | ||
# This should return an Iterator[str], but OrderedDict::item is not currently | ||
# designed to let us iterate over only the keys. | ||
def __iter__(self) -> Iterator[tuple[str, _T]]: # type: ignore[override] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we just call iter keys? or is it an API change...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No API change; this wraps an underlying C++ custom class we've implemented, and we don't have methods available for simply looking at the keys. One could possibly be implemented, but I was more interested in documenting the existing behavior in this PR.
Stack from ghstack (oldest at bottom):
cc @malfet @zou3519 @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki