Skip to content

Fake tensor leakage in aot_autograd by itself #164732

@tugsbayasgalan

Description

@tugsbayasgalan

🐛 Describe the bug

import contextlib

import torch
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors


class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.dict = {}

    def forward(self, x):
        val = x.sum()
        self.add_val("foo", val)
        return val

    def add_val(self, name, val):
        if name not in self.dict:
            self.dict[name] = val
        else:
            print("THERERE", self.dict[name])
            self.dict[name] += val


with contextlib.ExitStack() as stack:
    joint_with_descriptors = aot_export_joint_with_descriptors(
        stack,
        Foo(),
        (torch.randn(4, 4),),
    )
print(joint_with_descriptors.graph_module)

this prints

THERERE FunctionalTensor(_to_functional_tensor(FakeTensor(..., size=())))

This is because we are corrupting the model state when we are doing the first pass over aot_autograd.

Versions

main

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @eellison @Chillee @samdow @kshitij12345 @bdhirsh

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions