Skip to content

test_dummy_mha_with_nt_cuda fails on sm70, sm75 #129523

@eqy

Description

@eqy

🐛 Describe the bug

Looks like it's dispatching to efficient attention backward and failing one of the shape checks (

TORCH_CHECK(
        max_seqlen_k <= key.size(1), "Invalid max_seqlen_k:", max_seqlen_k);

)

failing call:

    buf0 = aten._efficient_attention_backward.default(reinterpret_tensor(tangents_1, (1, s1, 2, 3), (6*s1, 6, 3, 1), 0), unsqueeze, unsqueeze_1, unsqueeze_2, None, getitem, convert_element_type, convert_element_type_1, s2, s5, getitem_1, 0.0, getitem_2, getitem_3, 0, False)

Printing k.sizes() here shows: [1, 6, 2, 3] when max_seqlen_k is 10.

Doesn't seem to happen on sm80+ as they seem to be able to dispatch to FA instead?
Interestingly fixing the backend on sm80+ with a decorator to run on efficient-attention only gives:

W0625 22:21:47.487000 140137837507200 torch/nested/_internal/sdpa.py:293] Memory efficient kernel not used because:
W0625 22:21:47.488000 140137837507200 torch/nested/_internal/sdpa.py:296] Flash attention kernel not used because:
W0625 22:21:47.488000 140137837507200 torch/nested/_internal/sdpa.py:101] For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256. Got Query.size(-1): 3, Key.size(-1): 3, Value.size(-1): 3 instead.
W0625 22:21:47.488000 140137837507200 torch/nested/_internal/sdpa.py:299] Math attention kernel not used because:
/workspace/pytorch/test/test_nestedtensor.py:5317: UserWarning: Mem efficient attention requires last dimension of inputs to be divisible by 4. Got Query.size(-1): 3, Key.size(-1): 3, Value.size(-1): 3 instead. (Triggered internally at /workspace/pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:164.)
  attn_output = torch.nn.functional.scaled_dot_product_attention(
/workspace/pytorch/test/test_nestedtensor.py:5317: UserWarning: Flash attention has been runtime disabled. (Triggered internally at /workspace/pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.h:494.)
  attn_output = torch.nn.functional.scaled_dot_product_attention(
E

Simply removing the max_seqlen_k <= k.size(1) shape check allows for test to pass but I'm not sure that's correct---is there some special inductor/symbolic tracing accounting for shapes that needs to be done here?

CC @drisspg

Versions

Current 2024/06/25 source build

cc @ptrblck @msaroufim @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: nestedtensorNestedTensor tag see issue #25032module: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    Cold Storage

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions