-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Description
🐛 Describe the bug
I get failures when running python test/inductor/test_fp8.py TestFP8Lowering.test_tensorwise_scaling_float32_shape_16,32,32_has_bias_False_use_fast_accum_True_persistent_matmul_True
Mismatched elements: 510 / 512 (99.6%)
Greatest absolute difference: 21.93155288696289 at index (10, 30) (up to 0.05 allowed)
Greatest relative difference: 14.029411315917969 at index (10, 30) (up to 0.01 allowed)
Printing the elements tensor(20.3683, device='cuda:0') tensor(-1.5633, device='cuda:0')
so the "compiled" variant of the function is wildly off.
However this doesn't always happen. When the auto-tuner chooses the base version _scaled_mm
over the Triton template the test succeeds, so it is the Triton template causing the issue.
I.e. when I see this it fails:
AUTOTUNE scaled_mm(16x32, 32x32, , )
triton_scaled_mm_device_tma_1 0.0070 ms 100.0% ACC_TYPE='tl.float32', BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, NUM_SMS=132, SCALING_ROWWISE=False, TMA_SIZE=128, USE_FAST_ACCUM=True, num_stages=4, num_warps=2
triton_scaled_mm_device_tma_0 0.0071 ms 98.6% ACC_TYPE='tl.float32', BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, NUM_SMS=132, SCALING_ROWWISE=False, TMA_SIZE=128, USE_FAST_ACCUM=True, num_stages=3, num_warps=2
triton_scaled_mm_device_tma_3 0.0072 ms 97.3% ACC_TYPE='tl.float32', BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, NUM_SMS=132, SCALING_ROWWISE=False, TMA_SIZE=128, USE_FAST_ACCUM=True, num_stages=6, num_warps=2
triton_scaled_mm_device_tma_2 0.0072 ms 96.9% ACC_TYPE='tl.float32', BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, NUM_SMS=132, SCALING_ROWWISE=False, TMA_SIZE=128, USE_FAST_ACCUM=True, num_stages=5, num_warps=2
_scaled_mm 0.0080 ms 87.6%
While it succeeds with those statistics:
AUTOTUNE scaled_mm(16x32, 32x32, , )
_scaled_mm 0.0062 ms 100.0%
triton_scaled_mm_device_tma_0 0.0068 ms 91.1% ACC_TYPE='tl.float32', BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, NUM_SMS=132, SCALING_ROWWISE=False, TMA_SIZE=128, USE_FAST_ACCUM=True, num_stages=3, num_warps=2
triton_scaled_mm_device_tma_1 0.0070 ms 89.0% ACC_TYPE='tl.float32', BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, NUM_SMS=132, SCALING_ROWWISE=False, TMA_SIZE=128, USE_FAST_ACCUM=True, num_stages=4, num_warps=2
triton_scaled_mm_device_tma_2 0.0070 ms 89.0% ACC_TYPE='tl.float32', BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, NUM_SMS=132, SCALING_ROWWISE=False, TMA_SIZE=128, USE_FAST_ACCUM=True, num_stages=5, num_warps=2
triton_scaled_mm_device_tma_3 0.0072 ms 85.8% ACC_TYPE='tl.float32', BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, NUM_SMS=132, SCALING_ROWWISE=False, TMA_SIZE=128, USE_FAST_ACCUM=True, num_stages=6, num_warps=2
I can force the failure by modifying the choices
list to a single element at
pytorch/torch/_inductor/kernel/mm.py
Line 1165 in aeb5321
return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) |
All triton_scaled_mm_device_tma_*
variants fail with the same result.
In PyTorch 2.8 & 2.9 there are new kernels triton_mm_4
to triton_mm_8
which work, but the "old" ones are still present and fail in the same way
Versions
PyTorch version: 2.8.0.dev20250627+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: AlmaLinux 9.4 (Seafoam Ocelot) (x86_64)
GCC version: (GCC) 13.3.0
Clang version: Could not collect
CMake version: version 3.29.3
Libc version: glibc-2.34
Python version: 3.12.3 (main, Mar 19 2025, 10:36:13) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.14.0-427.33.1.el9_4.x86_64-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.6.20
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100
GPU 1: NVIDIA H100
GPU 2: NVIDIA H100
GPU 3: NVIDIA H100
Nvidia driver version: 560.35.03
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architektur: x86_64
CPU Operationsmodus: 32-bit, 64-bit
Adressgrößen: 52 bits physical, 57 bits virtual
Byte-Reihenfolge: Little Endian
CPU(s): 64
Liste der Online-CPU(s): 0-63
Anbieterkennung: AuthenticAMD
Modellname: AMD EPYC 9334 32-Core Processor
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] optree==0.14.1
[pip3] pytorch-triton==3.3.1+gitc8757738
[pip3] torch==2.8.0.dev20250627+cu126
[pip3] triton==3.3.1
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov @coconutruben @bertmaher @int3 @davidberard98 @nmacchioni @embg @peterbell10