Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,26 @@ def test_self_mutating1(self):
else:
self.assertExpectedInline(cnt.frame_count, """1""")

def test_nn_module_setattr(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.var = 0

@torch.compile(backend="eager", dynamic=False)
def f(x, m):
return x + m.var

inp = torch.ones(3)
m = Mod()

self.assertEqual(f(inp, m), inp)
# In 3.13.0, setattr will not fire a __dict__'s watchers,
# so guards may not be invalidated.
m.var = 1
# should trigger a recompile
self.assertEqual(f(inp, m), inp + 1)

@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
def test_generation_tag(self):
cnt = torch._dynamo.testing.CompileCounter()
Expand Down
35 changes: 27 additions & 8 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import textwrap
import time
import types
import warnings
import weakref
from contextlib import contextmanager
from copy import deepcopy
Expand Down Expand Up @@ -651,6 +652,20 @@ def guard_on_dict_keys_and_order(self, value, guard):
key, get_verbose_code_parts(f"{key_source} == {key!r}", guard)
)

@staticmethod
def _get_generic_dict_manager_example_value(example_value):
# due to a bug in 3.13.0 (introduced by https://github.com/python/cpython/pull/116115,
# reported in https://github.com/python/cpython/issues/125608,
# fixed by https://github.com/python/cpython/pull/125611), we cannot take
# advantage of __dict__ versions to speed up guard checks.
if sys.version_info >= (3, 13) and sys.version_info < (3, 13, 1):
warnings.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warn once?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warnings module should warn once by default.

Copy link
Collaborator

@Skylion007 Skylion007 Oct 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it should be a RuntimeWarning type warning instead of a UserWarning. Unless there is a better type?

"Guards may run slower on Python 3.13.0. Consider upgrading to Python 3.13.1+.",
RuntimeWarning,
)
return None
return example_value

def getattr_on_nn_module(
self,
source,
Expand Down Expand Up @@ -776,7 +791,7 @@ def getitem_on_dict_mgr(
# Guard Manager
mod_generic_dict_manager = base_guard_manager.get_generic_dict_manager(
source=mod_dict_source,
example_value=mod_dict,
example_value=self._get_generic_dict_manager_example_value(mod_dict),
guard_manager_enum=GuardManagerType.GUARD_MANAGER,
)

Expand Down Expand Up @@ -1271,7 +1286,7 @@ def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None:
mod_dict_source = f"{guard.name}.__dict__"
mod_generic_dict_manager = base_manager.get_generic_dict_manager(
source=mod_dict_source,
example_value=val.__dict__,
example_value=self._get_generic_dict_manager_example_value(val.__dict__),
guard_manager_enum=GuardManagerType.GUARD_MANAGER,
)

Expand Down Expand Up @@ -2261,12 +2276,16 @@ def add_code_part(code_part, guard, log_only=False):
structured_guard_fns.append(
lambda: {
"code": code_part,
"stack": structured.from_traceback(guard.stack.summary())
if guard.stack
else None,
"user_stack": structured.from_traceback(guard.user_stack)
if guard.user_stack
else None,
"stack": (
structured.from_traceback(guard.stack.summary())
if guard.stack
else None
),
"user_stack": (
structured.from_traceback(guard.user_stack)
if guard.user_stack
else None
),
}
)

Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/dynamo/guards.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <torch/extension.h>

#include <torch/csrc/dynamo/debug_macros.h>

#ifdef USE_CUDA
#include <ATen/cuda/EmptyTensor.h>
#endif
Expand Down Expand Up @@ -655,7 +657,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {

static std::unordered_map<PyObject*, uint64_t> dict_version_map;
static int dict_version_watcher_id;
static uint64_t global_dict_version_id = 0;
static uint64_t global_dict_version_id = 1;
static int dict_version_watch_callback(
PyDict_WatchEvent event,
PyObject* dict,
Expand Down