Skip to content

torch.compile() within TorchDispatchMode always causes an unknown guard failure. #144787

@galv

Description

@galv

🐛 Describe the bug

When I run torch.compile() under an "infra" TorchDispatchMode, it seems that a recompile always happens, but I don't know what guard is failing:

import torch
from torch.overrides import TorchFunctionMode
from torch.utils._python_dispatch import TorchDispatchMode

from torch._dynamo import config

class MyFunctionMode(TorchFunctionMode):
    def __torch_function__(self, func, types, args, kwargs=None):
        return func(*args, **(kwargs or {}))

class MyDispatchMode(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args, kwargs=None):
        return func(*args, **(kwargs or {}))

    @classmethod
    def is_infra_mode(cls):
        return True

def f(x, y):
    return x @ y

x = torch.ones(10, device="cuda")

mode = MyFunctionMode()

f_compiled = torch.compile(f, backend="eager")

for i in range(2):
    if i == 0:
        config.error_on_recompile = False
    if i == 1:
        config.error_on_recompile = True
    with mode:
        f_compiled(x, x)

mode = MyDispatchMode()

for i in range(2):
    if i == 0:
        config.error_on_recompile = False
    if i == 1:
        config.error_on_recompile = True
    with mode:
        f_compiled(x, x)

Running the above script on top-of-tree pytorch gives the following error message:

I0114 18:25:17.922947 2151712 torch/_dynamo/utils.py:1521] [0/0] ChromiumEventLogger initialized with id eeb788f2-8d2b-4de5-adf7-22df55d8491d
I0114 18:25:17.924832 2151712 torch/_dynamo/symbolic_convert.py:2744] [0/0] Step 1: torchdynamo start tracing f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:17.925518 2151712 torch/fx/experimental/symbolic_shapes.py:3243] [0/0] create_env
I0114 18:25:17.946520 2151712 torch/_dynamo/symbolic_convert.py:3066] [0/0] Step 1: torchdynamo done tracing f (RETURN_VALUE)
I0114 18:25:17.950973 2151712 torch/_dynamo/output_graph.py:1460] [0/0] Step 2: calling compiler function eager
I0114 18:25:17.951271 2151712 torch/_dynamo/output_graph.py:1465] [0/0] Step 2: done compiler function eager
I0114 18:25:17.954654 2151712 torch/fx/experimental/symbolic_shapes.py:4623] [0/0] produce_guards
I0114 18:25:17.956163 2151712 torch/_dynamo/pgo.py:647] [0/0] put_code_state: no cache key, skipping
I0114 18:25:17.956523 2151712 torch/_dynamo/convert_frame.py:1078] [0/0] run_gc_after_compile: running gc
I0114 18:25:17.984054 2151712 torch/_dynamo/symbolic_convert.py:2744] [0/1] Step 1: torchdynamo start tracing f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:17.984482 2151712 torch/fx/experimental/symbolic_shapes.py:3243] [0/1] create_env
I0114 18:25:17.988030 2151712 torch/_dynamo/symbolic_convert.py:3066] [0/1] Step 1: torchdynamo done tracing f (RETURN_VALUE)
I0114 18:25:17.989872 2151712 torch/_dynamo/output_graph.py:1460] [0/1] Step 2: calling compiler function eager
I0114 18:25:17.990141 2151712 torch/_dynamo/output_graph.py:1465] [0/1] Step 2: done compiler function eager
I0114 18:25:17.992269 2151712 torch/fx/experimental/symbolic_shapes.py:4623] [0/1] produce_guards
I0114 18:25:17.993348 2151712 torch/_dynamo/pgo.py:647] [0/1] put_code_state: no cache key, skipping
I0114 18:25:17.993675 2151712 torch/_dynamo/convert_frame.py:1078] [0/1] run_gc_after_compile: running gc
Traceback (most recent call last):
  File "/home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py", line 44, in <module>
    f_compiled(x, x)
  File "/home/dgalvez/code/asr/pytorch-4/torch/_dynamo/eval_frame.py", line 576, in _fn
    return fn(*args, **kwargs)
  File "/home/dgalvez/code/asr/pytorch-4/torch/_dynamo/convert_frame.py", line 1422, in __call__
    return self._torchdynamo_orig_callable(
  File "/home/dgalvez/code/asr/pytorch-4/torch/_dynamo/convert_frame.py", line 1203, in __call__
    result = self._inner_convert(
  File "/home/dgalvez/code/asr/pytorch-4/torch/_dynamo/convert_frame.py", line 569, in __call__
    return _compile(
  File "/home/dgalvez/code/asr/pytorch-4/torch/_dynamo/convert_frame.py", line 920, in _compile
    recompile_reasons = get_and_maybe_log_recompilation_reason(
  File "/home/dgalvez/code/asr/pytorch-4/torch/_dynamo/guards.py", line 2780, in get_and_maybe_log_recompilation_reason
    raise exc.RecompileError(message)
torch._dynamo.exc.RecompileError: Recompiling function f in /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
    triggered by the following guard failure(s):
    - 0/1:
    - 0/0: ___check_torch_function_mode_stack()
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] TorchDynamo attempted to trace the following frames: [
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398]   * f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398]   * f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398]   * f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] ]
I0114 18:25:18.005912 2151712 torch/_dynamo/utils.py:751] TorchDynamo compilation metrics:
I0114 18:25:18.005912 2151712 torch/_dynamo/utils.py:751] Function, Runtimes (s)
I0114 18:25:18.005912 2151712 torch/_dynamo/utils.py:751] _compile.compile_inner, 0.0418
I0114 18:25:18.005912 2151712 torch/_dynamo/utils.py:751] OutputGraph.call_user_compiler, 0.0016
I0114 18:25:18.005912 2151712 torch/_dynamo/utils.py:751] gc, 0.0016

You can see from this section that three compiles happen: the first compile under MyTorchFunctionMode, the first compile under MyTorchDispatchMode, and the second compile under MyTorchDispatchMode:

I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] TorchDynamo attempted to trace the following frames: [
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398]   * f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398]   * f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398]   * f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] ]

@mlazos since you worked on #131828, do you know if this is expected? For reasons related to #140979: https://github.com/pytorch/pytorch/pull/140979/files#r1877221096

I realize just after having linke dto that that there is a brief answer to my question, but I will make this issue nonetheless for documentation purposes.

Versions


Collecting environment information...
PyTorch version: 2.7.0a0+gitcd1b9e4
Is debug build: True
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.2
Libc version: glibc-2.35

Python version: 3.9.21 (main, Dec 11 2024, 16:24:11)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-46-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.6.77
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA L40S
Nvidia driver version: 560.35.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   45 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          25
On-line CPU(s) list:             0-24
Vendor ID:                       AuthenticAMD
Model name:                      AMD EPYC 9454 48-Core Processor
CPU family:                      25
Model:                           17
Thread(s) per core:              1
Core(s) per socket:              1
Socket(s):                       25
Stepping:                        1
BogoMIPS:                        5491.74
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext invpcid_single ibpb vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor fsrm flush_l1d
Hypervisor vendor:               VMware
Virtualization type:             full
L1d cache:                       800 KiB (25 instances)
L1i cache:                       800 KiB (25 instances)
L2 cache:                        25 MiB (25 instances)
L3 cache:                        800 MiB (25 instances)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-24
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, STIBP disabled, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] mypy==1.13.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.22.4
[pip3] onnx==1.17.0
[pip3] onnxscript==0.1.0.dev20240817
[pip3] optree==0.13.0
[pip3] torch==2.7.0a0+gitcd1b9e4
[pip3] triton==3.2.0+git35c6c7c6
[conda] numpy                     1.22.4                   pypi_0    pypi
[conda] optree                    0.13.0                   pypi_0    pypi
[conda] torch                     2.7.0a0+gitcd1b9e4           dev_0    <develop>
[conda] triton                    3.2.0+git35c6c7c6          pypi_0    pypi

cc @Chillee @ezyang @zou3519 @albanD @samdow @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions