Flash Attention is a high-performance mechanism designed to accelerate the attention mechanism in transformers by reducing memory overhead and computation costs. This repository contains my attempts at implementing flash attention in raw cuda.
run flash attention kernel to benchmark with pytorch's implementation and normal attention
python bench.py
run standalone version in standalone folder
nvcc fa_1.cu
These benchmarks were run on a GTX 1650ti with parameters - Batch Size=8, Heads=16, Head Dimensions=32, Sequence Length=1024
- naive flash attention v1 implementation - 11.204ms
- naive flash attention v2 implementation - 10.934ms
- optimized flash attention implementation - 6.172ms
- pytorch flash attention implementation - 15.842ms
- naive implmentation of forward pass of v1
- naive implmentation of forward pass of v2
- naive implmentation of backward pass of v1
- naive implmentation of backward pass of v2
- optimized implementation of forward pass of v2 (coalesced memory access using chunk-based reads, optimized loop unrolling etc.)
- use warp level primitives in optimized implementation to further speed up
- add mixed precision support.
- optimized implementation of backward pass of v2
- attention masking (supports causual masking)