Skip to content

Flash attention soft capping support #23300

@sparsh35

Description

@sparsh35

In Jax experimental pallas kernels for TPU , there is support for attn logits softcapping for paged attention but not for flash attention.
If support can be added for pallas flash kernels as well, as it can then be used in pytorch xla as well as vllm implementation.
Gemma 2 9b model works even with logit softcapping but 27 b doesn't.

PR for support of soft capping for Paged Attention

Pytorch xla custom kernel integration for paged attention

Need for flash attention support for running Gemma 2 with VLLM on TPUs

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestpallasIssues pertaining to Pallas (GPU or TPU)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions