-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Open
Open
Copy link
Labels
dynamo-tensor-subclassesdynamo-triage-june2024module: dynamotensor subclassRelated to tensor subclassesRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
import torch
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import return_and_correct_aliasing
# A simple tensor subclass that holds a tensor with custom metadata and custom method
class CustomTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem):
shape = elem.shape
kwargs = {}
kwargs["strides"] = elem.stride()
kwargs["storage_offset"] = elem.storage_offset()
kwargs["device"] = elem.device
kwargs["layout"] = elem.layout
kwargs["requires_grad"] = elem.requires_grad
kwargs["dtype"] = elem.dtype
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
def __init__(self, elem):
self.elem = elem
self.constant_attribute = 4
def __repr__(self):
inner_repr = repr(self.elem)
return f"CustomTensor({inner_repr})"
def __tensor_flatten__(self):
return ["elem"], self.constant_attribute
def add_constant(self, a):
self.constant_attribute += a
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is not None
elem = inner_tensors["elem"]
out = CustomTensor(elem)
out.constant_attribute = meta
return out
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
args_inner = pytree.tree_map_only(CustomTensor, lambda x: x.elem, args)
kwargs_inner = pytree.tree_map_only(CustomTensor, lambda x: x.elem, kwargs)
out_inner = func(*args_inner, **kwargs_inner)
out_inner_flat, spec = pytree.tree_flatten(out_inner)
# for aten ops that return non-tensors, just assume that
# our cust inner tensors return the same value
out_flat = [
CustomTensor(o_inner) if isinstance(o_inner, torch.Tensor) else o_inner
for o_inner in out_inner_flat
]
out = pytree.tree_unflatten(out_flat, spec)
return return_and_correct_aliasing(func, args, kwargs, out)
def f(x, y):
x.add_constant(4)
return x.cos() + y.cos()
a = torch.ones(4, requires_grad=True)
b = torch.ones(4, requires_grad=True)
custom_a = CustomTensor(a)
custom_a_for_compile = custom_a.detach().clone().requires_grad_()
print(custom_a.constant_attribute) # prints 4
f(custom_a, a)
print(custom_a.constant_attribute) # print 8
compiled_f = torch.compile(f, backend="aot_eager")
compiled_f(custom_a_for_compile, b)
print(custom_a_for_compile.constant_attribute) # prints 4 but should be 8
Versions
main
cc @ezyang @msaroufim @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @bdhirsh @anijain2305
Metadata
Metadata
Assignees
Labels
dynamo-tensor-subclassesdynamo-triage-june2024module: dynamotensor subclassRelated to tensor subclassesRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module