Skip to content

Flash attention support #2

@carmocca

Description

@carmocca

In PyTorch 2.0, torch.nn.functional.scaled_dot_product_attention takes the normalization factor from Q.size(-1): https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

However, in our model implementation, this value is different from the head size because a rotary percentage of 0.25 is used by default, meaning that we cannot use it in that case

if self.rotary_percentage != 1.0:
    self.register_buffer(
        "bias",
        torch.tril(torch.ones(config.block_size, config.block_size)).view(
            1, 1, config.block_size, config.block_size
        ),
    )

...

if hasattr(self, "bias"):
    # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
    att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_size))
    att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
    att = F.softmax(att, dim=-1)
    y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
else:
    # efficient attention using Flash Attention CUDA kernels
    y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)

PyTorch nightly (to be released with 2.1) conveniently added a scale argument to scaled_dot_product_attention: https://pytorch.org/docs/main/generated/torch.nn.functional.scaled_dot_product_attention.html

My proposal would be to install a nightly version in our requirements

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions