Skip to content

test_flex_attention unit test failures #155894

@nWEIdia

Description

@nWEIdia

🐛 Describe the bug

<style type="text/css"></style>

TestFlexDecodingCUDA test_non_equal_head_dims_score_mod2_float32_head_dims0_cuda_float32 inductor/test_flex_decoding.py
TestFlexDecodingCUDA test_builtin_score_mods_float32_score_mod2_head_dims2_cuda_float32 inductor/test_flex_decoding.py
TestFlexDecodingCUDA test_builtin_score_mods_float32_score_mod2_head_dims0_cuda_float32 inductor/test_flex_decoding.py
TestFlexAttentionCUDA test_GQA_score_mod5_cuda_float16 inductor/test_flex_attention.py
TestFlexAttentionCUDA test_block_mask_non_divisible_cuda inductor/test_flex_attention.py
TestFlexAttentionCUDA test_num_warps_8_error_cuda inductor/test_flex_attention.py
TestFlexAttentionCUDA test_GQA_score_mod7_cuda_float16 inductor/test_flex_attention.py
TestFlexAttentionCUDA test_GQA_score_mod6_cuda_float16 inductor/test_flex_attention.py
TestFlexAttentionCUDA test_tma_with_customer_kernel_options_cuda inductor/test_flex_attention.py

Error message:

FAILED [5.0394s] test_flex_attention.py::TestFlexAttentionCUDA::test_GQA_score_mod7_cuda_float16 - AssertionError: False is not true : Output/Grad with NaN

Versions

Reproducible with: (e.g. on H100 or B200)
ghcr.io/pytorch/pytorch-nightly:2.8.0.dev20250612-cuda12.8-cudnn9-devel

(make sure nvidia-container-toolkit is installed, if not, please follow official installation script).

example reproducer command:
python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_score_mod7_cuda_float16

cc @ptrblck @msaroufim @eqy @jerryzh168 @mruberry @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng @atalman @malfet @tinglvv

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: flex attentionmodule: higher order operatorstorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,module: testsIssues related to tests (not the torch.testing module)oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions