Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2075,6 +2075,81 @@ def test_scaled_dot_product_attention_fp16_overflow(self, device):
y = torch.nn.functional.scaled_dot_product_attention(x, x, x)
self.assertFalse(y.isnan().any().item())

@parametrize("dtype", [torch.float32, torch.bfloat16, torch.half])
@parametrize("n_heads", [[8, 8], [16, 8], [10, 2]]) # [q_heads, kv_heads]
@parametrize("is_causal", [True, False])
@parametrize("allow_reduced_precision", [True, False])
def test_reference_implementation_bitwise_match_math_backend(self, device, dtype, n_heads, is_causal, allow_reduced_precision):
"""Regression test for scaled_dot_product_attention documentation [1] implementation.
Should produces bitwise identical results to the MATH backend.

[1] https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

If you modify the reference implementation, update this test.
"""
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False
) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)

# Default float32 upcast across all backends
origin_dtype = query.dtype
if not torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed():
query, key, value = query.float(), key.float(), value.float()

scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
query, key = query * math.sqrt(scale_factor), key * math.sqrt(scale_factor)

attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool, device=attn_bias.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias

if enable_gqa:
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)

attn_weight = query @ key.transpose(-2, -1)
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
out = attn_weight @ value
return out.to(origin_dtype)

torch.manual_seed(42)
q_heads, kv_heads = n_heads
enable_gqa = q_heads != kv_heads
batch_size, seq_len, head_dim = 2, 8, 32

query = torch.randn(batch_size, q_heads, seq_len, head_dim, device=device, dtype=dtype)
key = torch.randn(batch_size, kv_heads, seq_len, head_dim, device=device, dtype=dtype)
value = torch.randn(batch_size, kv_heads, seq_len, head_dim, device=device, dtype=dtype)


origin_flag = torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed()
try:
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(allow_reduced_precision)
doc_result = scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=is_causal, enable_gqa=enable_gqa
)
with sdpa_kernel(backends=[SDPBackend.MATH]):
math_ref = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=is_causal, enable_gqa=enable_gqa
)
# Must be exact bitwise match - no tolerance allowed
self.assertEqual(doc_result, math_ref, atol=0., rtol=0.)
finally:
# Restore flag
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(origin_flag)

class TestSDPACpuOnly(NNTestCase):
""" Used to test CPU only functionality of scaled_dot_product_attention """

Expand Down
15 changes: 12 additions & 3 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5817,11 +5817,19 @@ def _in_projection(
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)

# Default float32 upcast across all backends
origin_dtype = query.dtype
if not torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed():
query, key, value = query.float(), key.float(), value.float()

scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
query, key = query * math.sqrt(scale_factor), key * math.sqrt(scale_factor)

attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
temp_mask = torch.ones(L, S, dtype=torch.bool, device=attn_bias.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))

if attn_mask is not None:
Expand All @@ -5834,11 +5842,12 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight = query @ key.transpose(-2, -1)
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
out = attn_weight @ value
return out.to(origin_dtype)

.. warning::
This function is beta and subject to change.
Expand Down
Loading