[P] Fused MoE Dispatch in Pure Triton: Beating CUDA-Optimized Megablocks at Inference Batch Sizes
A pure Triton kernel achieves 131% speed of Stanford's CUDA Megablocks at small batch sizes, cutting memory traffic by 35%.
Developer Subhadip Mitra has open-sourced a novel, high-performance kernel for Mixture-of-Experts (MoE) models, written entirely in Triton. This specialized code, which executes the critical 'dispatch' step in an MoE forward pass, directly challenges the dominance of hand-optimized CUDA. In head-to-head benchmarks on an A100 GPU with the popular Mixtral-8x7B model, the Triton kernel achieves 131% the speed of Stanford's Megablocks library at a batch size of 32 tokens, and 124% at 128 tokens. This performance lead is significant because these smaller batch sizes are precisely what's relevant for real-time inference and chatbot applications, where users expect low latency.
The kernel's speed and efficiency stem from two architectural innovations. First, it fuses the gate and up-projection GEMM (general matrix multiply) operations. By having both computations share the same loaded input data and performing the SiLU activation function within GPU registers, it eliminates approximately 470MB of intermediate buffer data per forward pass. This translates to a 35% reduction in memory traffic, a major bottleneck in GPU computation. Second, it employs a 'block-scheduled grouped GEMM' strategy. This clever scheduling uses a precomputed map to handle the variable-sized batches sent to different experts within the model, all within a single kernel launch and without wasteful padding.
Remarkably, the kernel demonstrates true hardware portability. While benchmarked on NVIDIA's A100, the same unmodified Triton code successfully passes a full test suite on AMD's competing MI300X accelerator. It has also been validated on other major MoE architectures like DeepSeek-V3 (with 256 experts) and Qwen2-MoE. This work proves that high-level languages like Triton can not only match but surpass meticulously crafted low-level CUDA in specific, performance-critical domains, potentially lowering the barrier to efficient AI inference across hardware platforms.
- Outperforms Stanford's CUDA Megablocks by 31% (131% speed) at a 32-token batch size for Mixtral-8x7B inference.
- Fuses operations to cut 470MB of buffer data per pass, reducing memory traffic by 35%.
- Runs on both NVIDIA and AMD GPUs (A100 & MI300X) with zero code changes, ensuring hardware portability.
Why It Matters
Enables faster, more efficient inference for MoE models like Mixtral, reducing latency for end-users and lowering hardware dependency.