FlashAttention-T: Towards Fully Tensorized Attention by Exploiting Tensor-Vector Parallelism
The attention mechanism is central to modern deep learning, particularly in large language models (LLMs), but suffers from quadratic computational complexity. To accelerate attention computation on GPUs, fused attention techniques (e.g., FlashAttention) consolidate the matrix multiplication (GEMM) and softmax computations into a single kernel. However, these operations remain computationally decoupled: the GEMM leverages high-performance tensor units (Tensor Cores), while the softmax executes on slower vector units (CUDA cores). This imbalance induces severe vector intervals—periods where tensor units sit idle awaiting vector unit completion—significantly underutilizing tensor units. Furthermore, ongoing hardware advancements delivering faster tensor units exacerbate this bottleneck.
To resolve the vector interval bottleneck, in this paper, we propose FlashAttention-T, a fused attention implementation that advances toward fully tensorized fused attention. Our key insight is to offload critical softmax primitives to idle tensor units, maximizing hardware utilization and throughput. Concretely, we first introduce a series of operand value assignment methods to repurpose tensor matrix multiply-add (MMA) instructions for executing softmax primitives (e.g., element-wise scaling). Building on this, we design a tensorized online softmax algorithm with numerical stability guarantees, adhering to the constraints of repurposed tensor MMA instructions. To maximize performance, we parallelize the online softmax computation across tensor and vector units via architecture-aware scheduling techniques, fully leveraging heterogeneous parallelism.
Extensive evaluation of various attention configurations across multiple platforms including NVIDIA Ampere (A100 and AGX Orin) and Hopper (H100) GPUs demonstrates that FlashAttention-T achieves vector interval ratios 1.17–2.18x lower than the baseline on Ampere GPUs and reduces them down to 2.7% on the Hopper H100, with average speedups of up to 1.17x against FlashAttention-2 and FlashAttention-3 while preserving numerical stability and accuracy.