Tensor Cores + Multi-Head Attention
Mar 04, 2026
I profiled the standard fused Flash Attention kernel against a Tensor Core–optimized implementation and measured a 30.7% runtime speedup (8.33ms → 5.77ms). The implementation uses WMMA for Q@K^T and P@V; each warp owns a 16×d chunk of Q and processes 16×16×16 tiles serially.
See the kernel code on my GitHub here and the full analysis here.
Highlights
- Speedup: 30.7% runtime improvement (8.33ms → 5.77ms).
- Approach:
WMMAforQ@K^TandP@V; one warp processes16×16×16tiles for a16×dQ chunk. - Efficiency: Throughput (memory/compute) dropped from 67.83% to 46.05% while runtime improved — indicating fewer idle cycles and higher work-per-clock.
- Challenge: Low occupancy (~16.7%) due to
Br=64row SRAM pressure on the L4 GPU; this forces serial tile processing per warp, trading parallelism for memory fit. - Next steps: Try
8×8×32tiles to improve occupancy and add multi-warp d-splits to increase theoretical active warps per scheduler.