Matt J. Borowski
Homepage

Warp work distribution in Flash Attention

Mar 05, 2026

Flash Attention with 8×32×16 WMMA tiles along with d-axis warp work split enable more warps per block in Flash Attention and hike occupancy 100% (8→16 warps per scheduler) for Br=64. SRAM pressure, however, bars padding, leading to bank conflicts. A detailed analysis with Nsight Compute on parallelism trade-offs.

See the kernel code on my GitHub here and the full analysis here.

Highlights

  • Occupancy Increase: Increased warp occupancy by splitting work along the d-dimension (8×32×16 WMMA tiles vs 16×16×16), doubling active warps for Br=64 but exposing hidden costs.
  • SRAM Pressure: Observed major SRAM pressure: accumulation buffers (~32 KB each) push per-block SRAM near limits, causing overflows and silent failures when padding increases.
  • Bank Conflicts: Doubling warps reduced compute throughput due to shared-memory bank conflicts (≈3‑way) that serialized accesses and reduced issued work.
  • Scheduler Bottlenecks: Scheduler and warp-state stats show more active warps but only ~35% increase in eligible warps and >120% rise in warp cycles per instruction (Stall Barrier, MIO Throttle).
  • Validation Run: Validation with Br=32 (fa_tc_v2a, PAD=16) eliminated bank conflicts and lowered latency (~7.5 ms), confirming conflict-driven slowdown at Br=64.
  • Next: Root cause fixes and next steps: guard race in accumulation (fixed), and prioritize mitigating bank conflicts and reducing per-warp/shared allocations (re-indexing, padding, swizzling, or preserving single-warp work).