-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Open
Labels
module: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032module: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiionAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Looks like it's dispatching to efficient attention backward and failing one of the shape checks (
TORCH_CHECK(
max_seqlen_k <= key.size(1), "Invalid max_seqlen_k:", max_seqlen_k);
)
failing call:
buf0 = aten._efficient_attention_backward.default(reinterpret_tensor(tangents_1, (1, s1, 2, 3), (6*s1, 6, 3, 1), 0), unsqueeze, unsqueeze_1, unsqueeze_2, None, getitem, convert_element_type, convert_element_type_1, s2, s5, getitem_1, 0.0, getitem_2, getitem_3, 0, False)
Printing k.sizes()
here shows: [1, 6, 2, 3]
when max_seqlen_k
is 10
.
Doesn't seem to happen on sm80+ as they seem to be able to dispatch to FA instead?
Interestingly fixing the backend on sm80+ with a decorator to run on efficient-attention only gives:
W0625 22:21:47.487000 140137837507200 torch/nested/_internal/sdpa.py:293] Memory efficient kernel not used because:
W0625 22:21:47.488000 140137837507200 torch/nested/_internal/sdpa.py:296] Flash attention kernel not used because:
W0625 22:21:47.488000 140137837507200 torch/nested/_internal/sdpa.py:101] For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256. Got Query.size(-1): 3, Key.size(-1): 3, Value.size(-1): 3 instead.
W0625 22:21:47.488000 140137837507200 torch/nested/_internal/sdpa.py:299] Math attention kernel not used because:
/workspace/pytorch/test/test_nestedtensor.py:5317: UserWarning: Mem efficient attention requires last dimension of inputs to be divisible by 4. Got Query.size(-1): 3, Key.size(-1): 3, Value.size(-1): 3 instead. (Triggered internally at /workspace/pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:164.)
attn_output = torch.nn.functional.scaled_dot_product_attention(
/workspace/pytorch/test/test_nestedtensor.py:5317: UserWarning: Flash attention has been runtime disabled. (Triggered internally at /workspace/pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.h:494.)
attn_output = torch.nn.functional.scaled_dot_product_attention(
E
Simply removing the max_seqlen_k <= k.size(1)
shape check allows for test to pass but I'm not sure that's correct---is there some special inductor/symbolic tracing accounting for shapes that needs to be done here?
CC @drisspg
Versions
Current 2024/06/25 source build
cc @ptrblck @msaroufim @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ
Metadata
Metadata
Assignees
Labels
module: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: nestedtensorNestedTensor tag see issue #25032NestedTensor tag see issue #25032module: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiionAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Type
Projects
Status
Cold Storage