FlashAttention-T: Towards Fully Tensorized Attention by Exploiting Tensor-Vector Parallelism
The attention mechanism is central to modern deep learning, particularly in large language models (LLMs), but suffers from quadratic computational complexity. To accelerate attention computation on GPUs, fused attention techniques (e.g., FlashAttention) consolidate the matrix multiplication (GEMM) and softmax computations into a single kernel. However, these operations remain computationally decoupled: the GEMM leverages high-performance tensor units (Tensor Cores), while the softmax executes on slower vector units (CUDA cores). This imbalance induces severe vector intervals—periods where tensor units sit idle awaiting vector unit completion—significantly underutilizing tensor units. Furthermore, ongoing hardware advancements delivering faster tensor units exacerbate this bottleneck.
To resolve the vector interval bottleneck, in this paper, we propose FlashAttention-T, a fused attention implementation that advances toward fully tensorized fused attention. Our key insight is to offload critical softmax primitives to idle tensor units, maximizing hardware utilization and throughput. Concretely, we first introduce a series of operand value assignment methods to repurpose tensor matrix multiply-add (MMA) instructions for executing softmax primitives (e.g., element-wise scaling). Building on this, we design a tensorized online softmax algorithm with numerical stability guarantees, adhering to the constraints of repurposed tensor MMA instructions. To maximize performance, we parallelize the online softmax computation across tensor and vector units via architecture-aware scheduling techniques, fully leveraging heterogeneous parallelism.
Extensive evaluation of various attention configurations across multiple platforms including NVIDIA Ampere (A100 and AGX Orin) and Hopper (H100) GPUs demonstrates that FlashAttention-T achieves vector interval ratios 1.17–2.18× lower than the baseline on Ampere GPUs and reduces them down to 2.7% on the Hopper H100, with average speedups of up to 1.17× against FlashAttention-2 and FlashAttention-3 while preserving numerical stability and accuracy.
Tue 3 FebDisplayed time zone: Hobart change
17:15 - 18:15 | Optimizing TransformersMain Conference at Pyrmont Chair(s): Shaoshuai Zhang University of Electronic Science and Technology of China | ||
17:15 20mTalk | FlashAttention-T: Towards Fully Tensorized Attention by Exploiting Tensor-Vector Parallelism Main Conference Jianxing Xu University of Science and Technology of China, Yuanbo Wen , Jun Bi Chinese Academy of Sciences, Ruibai Xu University of Science and Technology of China, Guanglin Xu Chinese Academy of Sciences, Rui Zhang Chinese Academy of Sciences, Wei Li Chinese Academy of Sciences, Ling Li Institute of Software, Chinese Academy of Sciences, Tianshi Chen Cambricon Technologies, Qi Guo Chinese Academy of Sciences, Yunji Chen Chinese Academy of Sciences DOI | ||
17:35 20mTalk | Accelerating Sparse Transformer Inference on GPU Main Conference Wenhao Dai China University of Petroleum-Beijing, Haodong Deng China University of Petroleum, Mengfei Rong China University of Petroleum, Xinyu Yang Beihang University, Hongyu Liu Baidu Inc., Fangxin Liu Shanghai Jiao Tong University, Hailong Yang Beihang University, Qianwen Cao China University of Petroleum, Qingxiao Sun Beihang University DOI | ||
17:55 20mTalk | MetaAttention: A Unified and Performant Attention Framework Across Hardware Backends Main Conference Feiyang Chen Shanghai Jiao Tong University, Yu Cheng Peking University, Lei Wang Peking University, Yuqing Xia Microsoft Research, Ziming Miao Microsoft Research, Lingxiao Ma Microsoft Research, Fan Yang Microsoft Research Asia, Jilong Xue Microsoft Research, Zhi Yang Peking University, Mao Yang Microsoft Research, Xingda Wei Shanghai Jiao Tong University, Haibo Chen Shanghai Jiao Tong University DOI | ||