In Jax experimental pallas kernels for TPU , there is support for attn logits softcapping for paged attention but not for flash attention. If support can be added for pallas flash kernels as well, as it can then be used in pytorch xla as well as vllm implementation. Gemma 2 9b model works even with logit softcapping but 27 b doesn't. [PR for support of soft capping for Paged Attention](https://github.com/google/jax/commit/f0e36d5083457278339387285b06f170d28d87ac) [Pytorch xla custom kernel integration for paged attention](https://github.com/pytorch/xla/pull/7704) [Need for flash attention support for running Gemma 2 with VLLM on TPUs](https://github.com/vllm-project/vllm/issues/7950)