-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Closed
Description
In PyTorch 2.0, torch.nn.functional.scaled_dot_product_attention
takes the normalization factor from Q.size(-1)
: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
However, in our model implementation, this value is different from the head size because a rotary percentage of 0.25 is used by default, meaning that we cannot use it in that case
if self.rotary_percentage != 1.0:
self.register_buffer(
"bias",
torch.tril(torch.ones(config.block_size, config.block_size)).view(
1, 1, config.block_size, config.block_size
),
)
...
if hasattr(self, "bias"):
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_size))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
else:
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
PyTorch nightly (to be released with 2.1) conveniently added a scale
argument to scaled_dot_product_attention
: https://pytorch.org/docs/main/generated/torch.nn.functional.scaled_dot_product_attention.html
My proposal would be to install a nightly version in our requirements
Metadata
Metadata
Assignees
Labels
No labels