Skip to content

akshat-sj/flashattention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FlashAttention

Flash Attention is a high-performance mechanism designed to accelerate the attention mechanism in transformers by reducing memory overhead and computation costs. This repository contains my attempts at implementing flash attention in raw cuda.

Executing program

run flash attention kernel to benchmark with pytorch's implementation and normal attention

python bench.py

run standalone version in standalone folder

nvcc fa_1.cu

Benchmarks

These benchmarks were run on a GTX 1650ti with parameters - Batch Size=8, Heads=16, Head Dimensions=32, Sequence Length=1024

  • naive flash attention v1 implementation - 11.204ms
  • naive flash attention v2 implementation - 10.934ms
  • optimized flash attention implementation - 6.172ms
  • pytorch flash attention implementation - 15.842ms

To-do

  • naive implmentation of forward pass of v1
  • naive implmentation of forward pass of v2
  • naive implmentation of backward pass of v1
  • naive implmentation of backward pass of v2
  • optimized implementation of forward pass of v2 (coalesced memory access using chunk-based reads, optimized loop unrolling etc.)
  • use warp level primitives in optimized implementation to further speed up
  • add mixed precision support.
  • optimized implementation of backward pass of v2
  • attention masking (supports causual masking)

About

flash attention in raw cuda

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published