-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Open
Labels
module: fakeTensormodule: functorchPertaining to torch.func or pytorch/functorchPertaining to torch.func or pytorch/functorchmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triage review
Description
🐛 Describe the bug
import contextlib
import torch
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.dict = {}
def forward(self, x):
val = x.sum()
self.add_val("foo", val)
return val
def add_val(self, name, val):
if name not in self.dict:
self.dict[name] = val
else:
print("THERERE", self.dict[name])
self.dict[name] += val
with contextlib.ExitStack() as stack:
joint_with_descriptors = aot_export_joint_with_descriptors(
stack,
Foo(),
(torch.randn(4, 4),),
)
print(joint_with_descriptors.graph_module)
this prints
THERERE FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=())))
This is because we are corrupting the model state when we are doing the first pass over aot_autograd.
Versions
main
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @eellison @Chillee @samdow @kshitij12345 @bdhirsh
Metadata
Metadata
Assignees
Labels
module: fakeTensormodule: functorchPertaining to torch.func or pytorch/functorchPertaining to torch.func or pytorch/functorchmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triage review