Skip to content

Can't call torch.compile inside of a custom op #151328

@zou3519

Description

@zou3519
import torch

lib = torch.library.Library("mylib", "FRAGMENT")
lib.define("foo(Tensor x) -> Tensor")


def inner(x):
    return x.sin().cos()

def foo_impl(x):
    return torch.compile(inner, fullgraph=True)(x)

lib.impl("foo", foo_impl, "CompositeExplicitAutograd")

@torch.compile(fullgraph=True)
def f(x):
    return torch.ops.mylib.foo.default(x)

x = torch.randn(3)
f(x)
"""
File ~/dev/misc_cpu11/pt-misc_cpu11/torch/_subclasses/meta_utils.py:894, in MetaConverter.meta_tensor(self, t, shape_env, callback_, source, symbolic_context)
    886     source = ConstantSource(
    887         f"__meta_utils_unknown_tensor{len(self.tensor_memo)}"
    888     )
    890 # This indicates you set no_dispatch() before calling into this
    891 # function.  This is an error: we may be creating fake tensors and
    892 # will perform operations on them which need fake tensor mode to
    893 # be active.  You will segfault if you are in a no_dispatch() block.
--> 894 assert not torch._C._dispatch_tls_local_exclude_set().has(
    895     torch._C.DispatchKey.Python
    896 )
    897 self.arg_cnt += 1
    899 # When we make as_strided calls, we end up generating a guard
    900 # that the new as_strided tensor is in bounds for the old storage
    901 # for the base (since as_strided calls can "bust" out of their
   (...)
    921 # as we allocate variables, and we do need to register guards for
    922 # these cases.

TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function mylib.foo.default(*(FakeTensor(..., size=(3,)),), **{}): got AssertionError('\n\nfrom user c
ode:\n   File "<ipython-input-2-9e7ce20b02c0>", line 8, in inner\n    return x.sin().cos()\n\nSet TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especial
ly if you\'re reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"\n')

from user code:
   File "<ipython-input-2-9e7ce20b02c0>", line 17, in f
    return torch.ops.mylib.foo.default(x)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dy
namo"
"""

motivation is that we want the custom op to be backed by a torch.compile implemetation?

cc @ezyang @gchanan @kadeng @msaroufim @chauhang @penguinwu @eellison @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @amjames @bdhirsh

Metadata

Metadata

Labels

dynamo-triage-jan2025featureA request for a proper, new feature.high prioritymodule: dynamomodule: fakeTensormodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions