Gemma 3N — FFN Gaussian Top-K Sparsity¶
Gemma 3N’s early FFN layers use Gaussian Top-K Sparsity to zero out 95 % of the gate activations. Because the mechanism replaces a sort with two reductions, it maps cleanly to the pccx v002 SFU.
1. What the Rule Does¶
For layers 0 through 9 (10 layers total), the gate-projection output is thresholded so that only the top 5 % of activations survive:
The constant 1.6448536 is the Z-score for the 95th percentile of a
standard normal distribution; the rule assumes the gate outputs are
approximately Gaussian, which empirically holds for FFN pre-activations
in practice.
Layers 10–34 skip this step and use the standard GELU gate:
2. Why It’s Done This Way¶
Sorting is expensive on parallel hardware. A real “top 5 %” requires a sort over 16 384 elements — not a good shape for an NPU.
Reductions are cheap.
MeanandStdare twoCVO_REDUCE_SUMcalls on the same buffer, plus a pair ofCVO_SCALEcalls. Total cost: a handful of cycles, negligible compared to a16 384 × 2048GEMV.A Gaussian approximation is accurate enough. The 5 % boundary isn’t sharp; small misclassifications at the edge don’t hurt downstream accuracy.
3. Pipeline View¶
flowchart TD
X["Input x_n2"] --> GateOp(("x · W_gate^T"))
GateOp --> GateOut["gate_raw (16384)"]
GateOut --> CalcStats["Reduce-sum × 2<br/>→ Mean and Std"]
CalcStats --> Threshold["cutoff = Mean + 1.6448·Std"]
GateOut --> FilterOp
Threshold --> FilterOp["max(gate_raw - cutoff, 0)"]
FilterOp --> SparseGate["sparse_gate<br/>(~95% zeros)"]
SparseGate --> GELU["GELU"]
GELU --> MultOp(("element-wise ×"))
X --> UpOp(("x · W_up^T"))
UpOp --> UpOut["up_out (16384)"]
UpOut --> MultOp
MultOp --> FFN_Mid["hidden → W_down"]
4. Mapping onto pccx v002¶
Operation |
Instruction |
Notes |
|---|---|---|
\(\mathrm{Mean}(gate\_raw)\) |
|
Scale preloaded via |
\(\mathrm{Var}(gate\_raw)\) |
|
Four CVO calls. Can overlap with the |
|
Scalar compute done in the SFU’s ALU |
Single cycle. |
|
Custom CVO fused with GELU? No — currently realized as
|
The SFU’s GELU path includes a configurable bias-and-clip front end for exactly this reason. |
Elementwise multiply with |
Direct FIFO inside the GEMV / SFU pair |
No L2 round-trip. |
5. Throughput Impact¶
With 95 % of sparse_gate equal to zero, the downstream
W_down GEMV can skip masked rows entirely. The pccx v002 driver
emits a sparse GEMV variant for layers 0–9: the weight streamer
compares each row mask and skips DSP cycles when the mask is zero.
Layer range |
Gate density |
FFN GMAC/s (effective) |
Notes |
|---|---|---|---|
Layers 0–9 |
~5 % |
~40 GMAC/s per token |
Dominated by |
Layers 10–34 |
100 % |
~130 GMAC/s per token |
Full dense FFN. |
Note
The sparsity mask is computed on the fly from the gate output of
the current token — it is not learned. The driver must therefore
issue the MEMCPY for the skip mask from L2 back into the weight
streamer after the cutoff is known; on pccx v002 this is done with
MEMCPY async=1 overlapped with the W_up GEMV.
See also
Full FFN spec: Gemma 3N E4B — Operator-Level Pipeline §4.D.
SFU function codes: Per-Instruction Encoding §4.