[P] I replaced Dot-Product Attention with distance-based RBF-Attention (so you don't have to...)
Swapping dot-product for distance metrics broke everything from memory to RoPE, requiring custom kernels.
A deep dive into transformer architecture reveals a fundamental quirk: in standard dot-product attention, a key vector with a massive magnitude can 'bully' the softmax, overpowering better-aligned but shorter vectors. To fix this, a researcher replaced the dot-product with a Radial Basis Function (RBF) kernel, a distance-based metric where vectors must be genuinely close in high-dimensional space to score highly. This simple conceptual change, however, triggered a cascade of technical failures, exposing how deeply the dot-product is hardcoded into the entire machine learning stack.
The first major hurdle was memory. Naively computing pairwise Euclidean distances creates a full N x N matrix, causing instant Out-Of-Memory (OOM) errors. The solution involved algebraic reformulation, showing RBF attention is mathematically equivalent to dot-product attention with a built-in L2 penalty on key norms. This required bypassing PyTorch's optimized SDPA to write a custom, memory-efficient Triton kernel. Furthermore, the change broke established mechanisms like 'attention sinks,' where models use large-magnitude tokens (e.g., <BOS>) as dumping grounds for attention. In Euclidean space, a huge vector means infinite distance, so the researcher introduced learnable 'register tokens' initialized at the origin to serve as safe sinks.
Finally, the modification made popular positional encoding schemes like Rotary Position Embedding (RoPE) nonsensical, as RoPE's rotations are mathematically elegant for preserving dot-products but not for pure distance metrics. This experiment underscores the immense engineering inertia behind core AI components and highlights the non-trivial trade-offs—like losing magnitude-based signaling—involved in altering foundational operations.
- Replacing dot-product with RBF kernel prevents key vectors from dominating attention via magnitude alone, requiring genuine proximity.
- Implementation required a custom Triton kernel to avoid OOM errors and introduced 'register tokens' to replace attention sinks.
- The change breaks assumptions in modern architectures, making techniques like RoPE irrelevant and revealing deep dependencies on dot-product.
Why It Matters
Challenges core transformer assumptions, showing the trade-offs and engineering lock-in behind foundational AI components like attention.