Gemma 3N E4B — Operator-Level Pipeline

Complete operation specification of Gemma 3N E4B. Every tensor has its shape and dtype annotated so that each step can be mapped one-to-one to a pccx v002 instruction (see Gemma 3N E4B on pccx v002 — Execution and Scheduling).

Base dimensions

\(D=2048\), \(D_{ffn}=16384\), \(D_{patch}=256\), \(D_{router}=4\), \(Vocab=262400\).

Attention

8 Q heads / 2 KV heads, \(d_{head}=256\), \(N_{layers}=35\).


0. Core Function Definitions

Embedding — row lookup by integer id.

\[Output = W_{embed}[token\_id,\ :]\]

RMSNorm

\[RMS = \sqrt{\frac{1}{N}\sum_{i=1}^{N}x_i^2 + 10^{-6}}, \qquad Output = \frac{x}{RMS} \times \gamma\]

GELU

\[Output = 0.5\,x\,\bigl(1 + \tanh(\sqrt{2/\pi}\,(x + 0.044715\,x^3))\bigr)\]

RoPE (Rotary Position Embedding)

\[\begin{split}Output_{2i} &= x_{2i}\cos\theta - x_{2i+1}\sin\theta \\ Output_{2i+1} &= x_{2i}\sin\theta + x_{2i+1}\cos\theta\end{split}\]

INT4 Dequantization — unpack two W4 values from a byte:

\[w_0 = W_{packed} \land \mathtt{0x0F}, \quad w_0 \mathrel{-}= 16 \text{ if } w_0 > 7\]
\[w_1 = W_{packed} \gg 4, \quad w_1 \mathrel{-}= 16 \text{ if } w_1 > 7\]
\[Output = [w_0,\ w_1] \times Scale\]

GQA (Grouped-Query Attention) — groups the 8 Q heads with the 2 KV heads in a 4:1 ratio. For group \(g \in \{0, 1\}\):

\[\begin{split}Q_g &= Q[4g:4g+4,\ :] \in \mathbb{R}^{4 \times d_{head}} \\ \mathrm{scores}_g &= \frac{Q_g \cdot K_g^{T}}{\sqrt{d_{head}}} \in \mathbb{R}^{4 \times L} \\ \mathrm{out}_g &= \mathrm{softmax}(\mathrm{scores}_g)\cdot V_g \in \mathbb{R}^{4 \times d_{head}}\end{split}\]
\[GQA_{out} = \mathrm{concat}(\mathrm{out}_0, \mathrm{out}_1) \in \mathbb{R}^{8 \times d_{head}} \xrightarrow{\text{flatten}} \mathbb{R}^{1 \times D}\]

\(L\) is the current sequence length; \(K_g\), \(V_g\) come from the KV cache.


1. Token Embedding

\[x_0 = Embedding(token\_id, W_{embed}) \times \sqrt{2048}\]

Note

token_id is clipped to min(token_id, 262143) so it stays within the PLE vocab range.

W_embed  [262400 × 2048]  INT4
─────────────────────────────────
x₀       [1 × 2048]       float32

2. AltUp Initial Projections

Inserts \(x_0\) into row 0 of the stream matrix xs; rows 1–3 are filled by dot-producting \(x_0\) with the AltUp projection matrices.

\[\begin{split}xs[0] &= x_0 \\ xs[k+1] &= x_0 \cdot altup\_projs[k], \quad k=0,1,2\end{split}\]
altup_projs[k]  [2048 × 2048]  float32  (k = 0, 1, 2)
─────────────────────────────────────────────────────
xs              [4 × 2048]     float32

3. PLE Setup (Pre-Compute)

Pre-compute the auxiliary pli_all vector once so it can be shared across all 35 layers.

Step 1 — Linear projection and normalization.

\[x_{proj} = \frac{x_0 \cdot W_{ple\_proj}^{T}}{\sqrt{2048}} \xrightarrow{\text{reshape}} [35 \times 256]\]
\[x_{proj\_normed} = \frac{x_{proj}}{RMS(x_{proj})} \times norm_{ple}\]
W_ple_proj  [8960 × 2048]  INT4    →  x_proj         [35 × 256]  float32
norm_ple    [256]          float32 →  x_proj_normed  [35 × 256]  float32

Step 2 — Patch embedding extraction.

\[y = Embedding(token\_id, W_{ple}) \xrightarrow{\text{reshape}} [35 \times 256] \times \sqrt{256}\]
W_ple  [262144 × 8960]  INT4   →  y  [35 × 256]  float32

Step 3 — Composition.

\[pli\_all = (x_{proj\_normed} + y) \times \frac{1}{\sqrt{2}}\]
pli_all  [35 × 256]  float32

4. Transformer Layer (×35)

Loop index \(i = 0, 1, \ldots, 34\).

4.A AltUp Router & Prediction

Mix the four modality vectors to generate \(xs_{pred}\).

\[\begin{split}x_n &= \frac{RMSNorm(xs[0], W_{altup\_rn})}{2048} \\ modalities &= \tanh(x_n \cdot W_{altup\_router}) \\ coef\_mat &= (W_{altup\_pred} \cdot modalities) \xrightarrow{\text{reshape}} [4 \times 4] \\ xs_{pred} &= xs + coef\_mat \cdot xs\end{split}\]
altup_rn     [2048]      float32  →  x_n         [1 × 2048]  float32
altup_router [2048 × 4]  float32  →  modalities  [4]         float32
altup_pred   [16 × 4]    float32  →  coef_mat    [4 × 4]     float32
──────────────────────────────────────────────────────────────────────
xs_pred      [4 × 2048]  float32

4.B Attention — Q / K / V Projection

\[\begin{split}x_{input} &= xs_{pred}[0], \quad x_{norm} = RMSNorm(x_{input}, W_{input\_ln}) \\ Q &= x_{norm} \cdot W_q^{T}, \quad K = x_{norm} \cdot W_k^{T}, \quad V = x_{norm} \cdot W_v^{T}\end{split}\]
input_ln  [2048]           float32
W_q       [2048 × 2048]    INT4    →  Q  [1 × 2048]  →  [8 × 256]  float32
W_k       [512  × 2048]    INT4    →  K  [1 × 512]   →  [2 × 256]  float32
W_v       [512  × 2048]    INT4    →  V  [1 × 512]   →  [2 × 256]  float32

4.B-2 Head-wise QK-Norm

RMSNorm per head (256 dims).

\[\begin{split}Q^{head}_i &= \frac{Q^{head}_i}{RMS(Q^{head}_i)} \times \gamma_q, \quad i = 0, \ldots, 7 \\ K^{head}_j &= \frac{K^{head}_j}{RMS(K^{head}_j)} \times \gamma_k, \quad j = 0, 1\end{split}\]
gamma_q  [256]  float32  (shared across the 8 Q heads)
gamma_k  [256]  float32  (shared across the 2 K heads)
──────────────────────────────────────────────────────
Q_norm   [8 × 256]  float32
K_norm   [2 × 256]  float32

4.B-3 Dynamic-θ RoPE

\[\begin{split}\theta = \begin{cases} 1{,}000{,}000 & i \bmod 5 = 4 \\ 10{,}000 & \text{otherwise} \end{cases}\end{split}\]
\[Q_{rope} = RoPE(Q_{norm}, pos, \theta), \quad K_{rope} = RoPE(K_{norm}, pos, \theta)\]
Q_rope  [8 × 256]  float32
K_rope  [2 × 256]  float32

See Gemma 3N — Attention and RoPE Constraints for the theta cycle visualization and why scaling / softcap are removed.

4.B-4 KV Cache Storage and Cross-Layer Sharing

KV cache stores float16 (downcast from float32).

Shape: K_cache[20, max_seq, 512], V_cache[20, max_seq, 512].

For layers that own a cache entry:

\[i < 20:\quad K\_cache[i, pos, :] = K_{rope}, \quad V\_cache[i, pos, :] = V\]

For layers that reuse an upstream cache:

\[\begin{split}i \ge 20:\quad \begin{cases} target\_K = K\_cache[19, :pos+1, :] & i \bmod 5 = 4 \\ target\_K = K\_cache[18, :pos+1, :] & \text{otherwise} \end{cases}\end{split}\]
K_cache   [20 × max_seq × 512]  float16
V_cache   [20 × max_seq × 512]  float16
target_K  [L × 512]  →  [L × 2 × 256]  float16   (L = pos + 1)
target_V  [L × 512]  →  [L × 2 × 256]  float16

4.B-5 GQA and Output Projection

\[\begin{split}attn\_raw &= GQA(Q_{rope}, target\_K, target\_V) \\ attn\_output &= attn\_raw \cdot W_o^{T}\end{split}\]
scores_g     [4 × L]      float32   (g = 0, 1)
out_g        [4 × 256]    float32
attn_raw     [1 × 2048]   float32
W_o          [2048 × 2048]  INT4
──────────────────────────────────
attn_output  [1 × 2048]   float32

4.C LAuReL + Attention Residual Composition

LAuReL’s residual adds \(x_{norm}\), not the raw input:

\[\begin{split}laurel\_x &= (x_{norm} \cdot W_{laurel\_left}^{T}) \cdot W_{laurel\_right}^{T} \\ laurel\_out &= x_{norm} + RMSNorm(laurel\_x, W_{laurel\_norm})\end{split}\]
\[attn\_output &= RMSNorm(attn\_output, W_{post\_attn\_ln}) + x_{input}\]
\[x_{attn} = (attn\_output + laurel\_out) \times \frac{1}{\sqrt{2}}\]
laurel_left   [64   × 2048]  INT4    →  laurel_x (intermediate) [1 × 64]    float32
laurel_right  [2048 × 64  ]  INT4    →  laurel_x                [1 × 2048]  float32
laurel_norm   [2048]         float32
post_attn_ln  [2048]         float32
──────────────────────────────────────────────────────────────────────
x_attn        [1 × 2048]     float32

See Gemma 3N — LAuReL and PLE Calibration Modules for the rationale behind the \(1/\sqrt{2}\) scale.

4.D FFN (Gate-Up-Down)

\[\begin{split}x_{n2} &= RMSNorm(x_{attn}, W_{pre\_ffn\_ln}) \\ gate\_raw &= x_{n2} \cdot W_{gate}^{T}, \quad up\_out = x_{n2} \cdot W_{up}^{T}\end{split}\]
pre_ffn_ln  [2048]           float32
W_gate      [16384 × 2048]   INT4    →  gate_raw  [1 × 16384]  float32
W_up        [16384 × 2048]   INT4    →  up_out    [1 × 16384]  float32

For \(i \ge 10\) — standard GELU gate.

\[gate\_out = GELU(gate\_raw), \quad hidden = gate\_out \times up\_out\]

For \(i < 10\) — sparse gate (Gaussian Top-K, see Gemma 3N — FFN Gaussian Top-K Sparsity ).

\[cutoff = Mean(gate\_raw) + Std(gate\_raw) \times 1.6448536\]
\[sparse\_gate = \max(gate\_raw - cutoff, 0), \quad hidden = GELU(sparse\_gate) \times up\_out\]
hidden  [1 × 16384]  float32

FFN output.

\[mlp\_out = hidden \cdot W_{down}^{T}\]
\[outputs = RMSNorm(mlp\_out, W_{post\_ffn\_ln}) + x_{attn}\]
W_down       [2048 × 16384]  INT4    →  mlp_out  [1 × 2048]  float32
post_ffn_ln  [2048]          float32
───────────────────────────────────────────────────────────────────
outputs      [1 × 2048]      float32

4.E AltUp Correction + PLE Mixing

Step 1 — Scale and innovation.

\[\begin{split}activated &= outputs \times W_{altup\_scale} \\ innovation &= activated - xs_{pred}[0]\end{split}\]
altup_scale  [2048]       float32
activated    [1 × 2048]   float32
innovation   [1 × 2048]   float32

Step 2 — Correction coefficients.

\[\begin{split}x_{n3} &= \frac{RMSNorm(activated, W_{altup\_rn})}{2048} \\ mod\_corr &= \tanh(x_{n3} \cdot W_{altup\_router}) \\ corr\_coefs &= W_{altup\_corr} \cdot mod\_corr + 1.0\end{split}\]
altup_rn      [2048]      float32  →  x_n3        [1 × 2048]  float32
altup_router  [2048 × 4]  float32  →  mod_corr    [4]         float32
altup_corr    [4 × 4]     float32  →  corr_coefs  [4]         float32

Step 3 — Modality update (broadcasting).

\[xs_{new} = xs_{pred} + corr\_coefs_{[:,1]} \times innovation_{[1,:]}\]
corr_coefs[:, np.newaxis]  [4 × 1]    × innovation  [1 × 2048]  →  [4 × 2048]
xs_new                     [4 × 2048]  float32

Step 4 — PLE mixing (shadow-stream only).

\[\begin{split}gate\_ple &= GELU(activated \cdot W_{ple\_gate}^{T}) \times pli \\ mapped &= RMSNorm(gate\_ple \cdot W_{ple\_proj}, W_{ple\_post\_ln}) \\ xs_{new}[1:] &\mathrel{+}= mapped\end{split}\]
W_ple_gate   [256  × 2048]  INT4    →  gate_ple (INT4 matmul)  [1 × 256]   float32
pli          [256]          float32 →  gate_ple (× pli)        [1 × 256]   float32
W_ple_proj   [256  × 2048]  float32 →  gate_ple · W_ple_proj   [1 × 2048]  float32
ple_post_ln  [2048]         float32 →  mapped                  [1 × 2048]  float32
\[xs \leftarrow xs_{new} \quad [4 \times 2048]\]

5. Decode Logits

Step 1 — Magnitude matching and unprojection.

\[target\_mag = \sqrt{Mean(xs[0]^2)}\]
\[proj\_x_k = xs[k+1] \cdot altup\_unprojs[k], \quad k = 0, 1, 2\]
\[new\_mag_k = \sqrt{Mean(proj\_x_k^2)}, \quad proj\_x_k \mathrel{\ast}= \frac{target\_mag}{\max(new\_mag_k,\ 10^{-12})}\]
altup_unprojs[k]  [2048 × 2048]  float32  (k = 0, 1, 2)
proj_x_k          [1 × 2048]     float32

Step 2 — Average and final projection.

\[x_{final} = Mean([xs[0], proj\_x_0, proj\_x_1, proj\_x_2])\]
\[x_{final\_norm} = RMSNorm(x_{final}, W_{final\_norm})\]
\[logits = x_{final\_norm} \cdot W_{lm\_head}^{T}\]
W_final_norm  [2048]           float32
W_lm_head     [262400 × 2048]  INT4
─────────────────────────────────────
x_final_norm  [1 × 2048]       float32
logits        [1 × 262400]     float32

Step 3 — Logit soft-capping.

\[Logits = 30.0 \times \tanh(logits / 30.0)\]

6. Generation and Sampling

Repetition penalty (\(\rho = 1.15\)):

\[\begin{split}Logits_t = \begin{cases} Logits_t \times \rho & Logits_t < 0 \\ Logits_t / \rho & Logits_t \ge 0 \end{cases}\end{split}\]

Temperature softmax. For \(T = 0\): \(next\_token = \arg\max(Logits)\). For \(T > 0\):

\[Logits_{safe} = \frac{Logits}{T} - \max\!\left(\frac{Logits}{T}\right)\]
\[probs_i = \frac{\exp(Logits_{safe, i})}{\sum_j \exp(Logits_{safe, j})}\]

Top-P sampling (\(p = 0.9\)):

\[sorted\_idx = \mathrm{argsort}(probs)[::-1]\]
\[\mathrm{keep} = \{i : \mathrm{cumsum}(probs[sorted\_idx])_i - probs[sorted\_idx_i] < p\}\]
\[probs = \frac{probs_{filtered}}{\sum probs_{filtered}}, \qquad next\_token \sim \mathrm{Categorical}(probs)\]

Sampling (steps 5 step 3 onwards) is done on the host CPU in the pccx reference driver. The NPU returns logits after the soft-cap.

See also