Skip to content

torch.compile with Custom tensor subclass doesn't inline the tensor subclass methods  #128149

@tugsbayasgalan

Description

@tugsbayasgalan

🐛 Describe the bug

import torch
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import return_and_correct_aliasing


# A simple tensor subclass that holds a tensor with custom metadata and custom method
class CustomTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, elem):
        shape = elem.shape
        kwargs = {}
        kwargs["strides"] = elem.stride()
        kwargs["storage_offset"] = elem.storage_offset()
        kwargs["device"] = elem.device
        kwargs["layout"] = elem.layout
        kwargs["requires_grad"] = elem.requires_grad
        kwargs["dtype"] = elem.dtype
        return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)

    def __init__(self, elem):
        self.elem = elem
        self.constant_attribute = 4

    def __repr__(self):
        inner_repr = repr(self.elem)
        return f"CustomTensor({inner_repr})"

    def __tensor_flatten__(self):
        return ["elem"], self.constant_attribute

    def add_constant(self, a):
        self.constant_attribute += a

    @staticmethod
    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
        assert meta is not None
        elem = inner_tensors["elem"]
        out = CustomTensor(elem)
        out.constant_attribute = meta
        return out

    @classmethod
    def __torch_dispatch__(cls, func, types, args, kwargs):
        if kwargs is None:
            kwargs = {}
        args_inner = pytree.tree_map_only(CustomTensor, lambda x: x.elem, args)

        kwargs_inner = pytree.tree_map_only(CustomTensor, lambda x: x.elem, kwargs)

        out_inner = func(*args_inner, **kwargs_inner)
        out_inner_flat, spec = pytree.tree_flatten(out_inner)
        # for aten ops that return non-tensors, just assume that
        # our cust inner tensors return the same value
        out_flat = [
            CustomTensor(o_inner) if isinstance(o_inner, torch.Tensor) else o_inner
            for o_inner in out_inner_flat
        ]
        out = pytree.tree_unflatten(out_flat, spec)
        return return_and_correct_aliasing(func, args, kwargs, out)


def f(x, y):
    x.add_constant(4)
    return x.cos() + y.cos()


a = torch.ones(4, requires_grad=True)
b = torch.ones(4, requires_grad=True)
custom_a = CustomTensor(a)
custom_a_for_compile = custom_a.detach().clone().requires_grad_()

print(custom_a.constant_attribute)  # prints 4
f(custom_a, a)
print(custom_a.constant_attribute)  # print 8

compiled_f = torch.compile(f, backend="aot_eager")
compiled_f(custom_a_for_compile, b)
print(custom_a_for_compile.constant_attribute)  # prints 4 but should be 8

Versions

main

cc @ezyang @msaroufim @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @bdhirsh @anijain2305

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions