trunk/459f2f93092bce2b7fe05020ccb689c24cc8cda8: Vectorized scatter_add with TMA bulk reduce on sm_90+ (#182675)
New kernel cuts index_add latency from 84ms to 18ms on GB200.
PyTorch's new PR (#182675) optimizes scatter_add on dim=0 for contiguous tensors with expanded indices, targeting NVIDIA sm_90+ (Hopper/Blackwell) hardware. The existing implementation suffered from serial atomicAdd stalls under high contention—e.g., 3.89M indices targeting a single row. The new kernel leverages NVIDIA's TMA (Tensor Memory Accelerator) via cp.reduce.async.bulk to offload accumulation directly to hardware, bypassing the GPU's warp-serial atomic bottleneck. It uses warp-per-slice scheduling, TMA shared-memory loads, double-buffering, and mbarrier-based completion. Benchmarks on GB200 (sm_100) show up to 4.7x speedup over index_add at D=128 bf16 under high contention, and 2x–3.4x on uniform workloads. Supported dtypes include float32, float64, float16, and bfloat16.
For pre-sm_90 GPUs, the PR provides a vectorized fallback using ld_vec<16> and per-element atomicAdd with warp-per-slice packing, maintaining performance without TMA. ROCm is excluded from the fast path to preserve its existing optimized warp-level atomic coalescing on gfx942/gfx950. This contribution, authored with Claude, is especially valuable for training and inference in recommendation systems and graph neural networks where scatter-reduce operations are frequent bottlenecks. The fix also corrects a precedence bug in the existing eligible-check logic. The clean separation between TMA and fallback paths ensures broad deployment across NVIDIA GPU generations.
- 4.5x–4.7x speedup on GB200 under high contention (D=128 bf16/f32).
- Uses NVIDIA TMA cp.reduce.async.bulk to eliminate serial atomicAdd stalls.
- Supports float32, float64, float16, and bfloat16; pre-sm_90 fallback included.
Why It Matters
Faster scatter operations accelerate recommendation models and graph networks, reducing training time on Hopper/Blackwell GPUs.