-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Description
🐛 Describe the bug
When I run torch.compile() under an "infra" TorchDispatchMode, it seems that a recompile always happens, but I don't know what guard is failing:
import torch
from torch.overrides import TorchFunctionMode
from torch.utils._python_dispatch import TorchDispatchMode
from torch._dynamo import config
class MyFunctionMode(TorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
return func(*args, **(kwargs or {}))
class MyDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
return func(*args, **(kwargs or {}))
@classmethod
def is_infra_mode(cls):
return True
def f(x, y):
return x @ y
x = torch.ones(10, device="cuda")
mode = MyFunctionMode()
f_compiled = torch.compile(f, backend="eager")
for i in range(2):
if i == 0:
config.error_on_recompile = False
if i == 1:
config.error_on_recompile = True
with mode:
f_compiled(x, x)
mode = MyDispatchMode()
for i in range(2):
if i == 0:
config.error_on_recompile = False
if i == 1:
config.error_on_recompile = True
with mode:
f_compiled(x, x)
Running the above script on top-of-tree pytorch gives the following error message:
I0114 18:25:17.922947 2151712 torch/_dynamo/utils.py:1521] [0/0] ChromiumEventLogger initialized with id eeb788f2-8d2b-4de5-adf7-22df55d8491d
I0114 18:25:17.924832 2151712 torch/_dynamo/symbolic_convert.py:2744] [0/0] Step 1: torchdynamo start tracing f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:17.925518 2151712 torch/fx/experimental/symbolic_shapes.py:3243] [0/0] create_env
I0114 18:25:17.946520 2151712 torch/_dynamo/symbolic_convert.py:3066] [0/0] Step 1: torchdynamo done tracing f (RETURN_VALUE)
I0114 18:25:17.950973 2151712 torch/_dynamo/output_graph.py:1460] [0/0] Step 2: calling compiler function eager
I0114 18:25:17.951271 2151712 torch/_dynamo/output_graph.py:1465] [0/0] Step 2: done compiler function eager
I0114 18:25:17.954654 2151712 torch/fx/experimental/symbolic_shapes.py:4623] [0/0] produce_guards
I0114 18:25:17.956163 2151712 torch/_dynamo/pgo.py:647] [0/0] put_code_state: no cache key, skipping
I0114 18:25:17.956523 2151712 torch/_dynamo/convert_frame.py:1078] [0/0] run_gc_after_compile: running gc
I0114 18:25:17.984054 2151712 torch/_dynamo/symbolic_convert.py:2744] [0/1] Step 1: torchdynamo start tracing f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:17.984482 2151712 torch/fx/experimental/symbolic_shapes.py:3243] [0/1] create_env
I0114 18:25:17.988030 2151712 torch/_dynamo/symbolic_convert.py:3066] [0/1] Step 1: torchdynamo done tracing f (RETURN_VALUE)
I0114 18:25:17.989872 2151712 torch/_dynamo/output_graph.py:1460] [0/1] Step 2: calling compiler function eager
I0114 18:25:17.990141 2151712 torch/_dynamo/output_graph.py:1465] [0/1] Step 2: done compiler function eager
I0114 18:25:17.992269 2151712 torch/fx/experimental/symbolic_shapes.py:4623] [0/1] produce_guards
I0114 18:25:17.993348 2151712 torch/_dynamo/pgo.py:647] [0/1] put_code_state: no cache key, skipping
I0114 18:25:17.993675 2151712 torch/_dynamo/convert_frame.py:1078] [0/1] run_gc_after_compile: running gc
Traceback (most recent call last):
File "/home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py", line 44, in <module>
f_compiled(x, x)
File "/home/dgalvez/code/asr/pytorch-4/torch/_dynamo/eval_frame.py", line 576, in _fn
return fn(*args, **kwargs)
File "/home/dgalvez/code/asr/pytorch-4/torch/_dynamo/convert_frame.py", line 1422, in __call__
return self._torchdynamo_orig_callable(
File "/home/dgalvez/code/asr/pytorch-4/torch/_dynamo/convert_frame.py", line 1203, in __call__
result = self._inner_convert(
File "/home/dgalvez/code/asr/pytorch-4/torch/_dynamo/convert_frame.py", line 569, in __call__
return _compile(
File "/home/dgalvez/code/asr/pytorch-4/torch/_dynamo/convert_frame.py", line 920, in _compile
recompile_reasons = get_and_maybe_log_recompilation_reason(
File "/home/dgalvez/code/asr/pytorch-4/torch/_dynamo/guards.py", line 2780, in get_and_maybe_log_recompilation_reason
raise exc.RecompileError(message)
torch._dynamo.exc.RecompileError: Recompiling function f in /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
triggered by the following guard failure(s):
- 0/1:
- 0/0: ___check_torch_function_mode_stack()
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] TorchDynamo attempted to trace the following frames: [
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] * f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] * f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] * f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] ]
I0114 18:25:18.005912 2151712 torch/_dynamo/utils.py:751] TorchDynamo compilation metrics:
I0114 18:25:18.005912 2151712 torch/_dynamo/utils.py:751] Function, Runtimes (s)
I0114 18:25:18.005912 2151712 torch/_dynamo/utils.py:751] _compile.compile_inner, 0.0418
I0114 18:25:18.005912 2151712 torch/_dynamo/utils.py:751] OutputGraph.call_user_compiler, 0.0016
I0114 18:25:18.005912 2151712 torch/_dynamo/utils.py:751] gc, 0.0016
You can see from this section that three compiles happen: the first compile under MyTorchFunctionMode, the first compile under MyTorchDispatchMode, and the second compile under MyTorchDispatchMode:
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] TorchDynamo attempted to trace the following frames: [
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] * f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] * f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] * f /home/dgalvez/code/asr/pytorch-4/repros/dispatch_mode_recompile.py:19
I0114 18:25:18.005497 2151712 torch/_dynamo/eval_frame.py:398] ]
@mlazos since you worked on #131828, do you know if this is expected? For reasons related to #140979: https://github.com/pytorch/pytorch/pull/140979/files#r1877221096
I realize just after having linke dto that that there is a brief answer to my question, but I will make this issue nonetheless for documentation purposes.
Versions
Collecting environment information...
PyTorch version: 2.7.0a0+gitcd1b9e4
Is debug build: True
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.2
Libc version: glibc-2.35
Python version: 3.9.21 (main, Dec 11 2024, 16:24:11) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-46-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.6.77
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA L40S
Nvidia driver version: 560.35.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 45 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 25
On-line CPU(s) list: 0-24
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9454 48-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 1
Core(s) per socket: 1
Socket(s): 25
Stepping: 1
BogoMIPS: 5491.74
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext invpcid_single ibpb vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor fsrm flush_l1d
Hypervisor vendor: VMware
Virtualization type: full
L1d cache: 800 KiB (25 instances)
L1i cache: 800 KiB (25 instances)
L2 cache: 25 MiB (25 instances)
L3 cache: 800 MiB (25 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-24
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP disabled, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] mypy==1.13.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.22.4
[pip3] onnx==1.17.0
[pip3] onnxscript==0.1.0.dev20240817
[pip3] optree==0.13.0
[pip3] torch==2.7.0a0+gitcd1b9e4
[pip3] triton==3.2.0+git35c6c7c6
[conda] numpy 1.22.4 pypi_0 pypi
[conda] optree 0.13.0 pypi_0 pypi
[conda] torch 2.7.0a0+gitcd1b9e4 dev_0 <develop>
[conda] triton 3.2.0+git35c6c7c6 pypi_0 pypi
cc @Chillee @ezyang @zou3519 @albanD @samdow @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames