mstar.utils.adarms_norm#
Fused AdaRMS normalisation + scale/shift/gate Triton kernel.
Replaces the three-step sequence in Pi05AdaRMSNorm.forward():
_rms_normalize(x) – two passes over x (cast, square, mean, rsqrt, mul)
modulation.chunk(3, dim=-1) – slices already in registers, free
normed * (1+scale) + shift – two more passes over x-sized tensors
with a single pass:
Load x row → compute RMS in float32 → normalise
Load (scale, shift, gate) row from modulation → apply conditioning
Store normed output and gate
Falls back to the original eager path on CPU or when Triton is not available.
Functions
|
Fused AdaRMS norm: RMS-normalise x then apply scale/shift conditioning. |
- mstar.utils.adarms_norm.adarms_norm_fused(x, scale, shift, gate_mod, eps=1e-6)[source]#
Fused AdaRMS norm: RMS-normalise x then apply scale/shift conditioning.
- Parameters:
x (Tensor) – float tensor, shape [BS * AH, H], any dtype.
scale (Tensor) – shape [BS, H] — the (1 + scale) multiplier.
shift (Tensor) – shape [BS, H] — additive shift after norm.
gate_mod (Tensor) – shape [BS, H] — gate vector returned unchanged.
eps (float) – variance epsilon for numerical stability.
- Returns:
(normed, gate) both shape [BS * AH, H], same dtype as x.
- Return type:
Falls back to an eager implementation on CPU or without Triton.