Gemma 3N E4B — Operator-Level Pipeline¶
Complete operation specification of Gemma 3N E4B. Every tensor has its shape and dtype annotated so that each step can be mapped one-to-one to a pccx v002 instruction (see Gemma 3N E4B on pccx v002 — Execution and Scheduling).
- Base dimensions
\(D=2048\), \(D_{ffn}=16384\), \(D_{patch}=256\), \(D_{router}=4\), \(Vocab=262400\).
- Attention
8 Q heads / 2 KV heads, \(d_{head}=256\), \(N_{layers}=35\).
0. Core Function Definitions¶
Embedding — row lookup by integer id.
RMSNorm
GELU
RoPE (Rotary Position Embedding)
INT4 Dequantization — unpack two W4 values from a byte:
GQA (Grouped-Query Attention) — groups the 8 Q heads with the 2 KV heads in a 4:1 ratio. For group \(g \in \{0, 1\}\):
\(L\) is the current sequence length; \(K_g\), \(V_g\) come from the KV cache.
1. Token Embedding¶
Note
token_id is clipped to min(token_id, 262143) so it stays
within the PLE vocab range.
W_embed [262400 × 2048] INT4
─────────────────────────────────
x₀ [1 × 2048] float32
2. AltUp Initial Projections¶
Inserts \(x_0\) into row 0 of the stream matrix xs; rows 1–3 are
filled by dot-producting \(x_0\) with the AltUp projection matrices.
altup_projs[k] [2048 × 2048] float32 (k = 0, 1, 2)
─────────────────────────────────────────────────────
xs [4 × 2048] float32
3. PLE Setup (Pre-Compute)¶
Pre-compute the auxiliary pli_all vector once so it can be shared
across all 35 layers.
Step 1 — Linear projection and normalization.
W_ple_proj [8960 × 2048] INT4 → x_proj [35 × 256] float32
norm_ple [256] float32 → x_proj_normed [35 × 256] float32
Step 2 — Patch embedding extraction.
W_ple [262144 × 8960] INT4 → y [35 × 256] float32
Step 3 — Composition.
pli_all [35 × 256] float32
4. Transformer Layer (×35)¶
Loop index \(i = 0, 1, \ldots, 34\).
4.A AltUp Router & Prediction¶
Mix the four modality vectors to generate \(xs_{pred}\).
altup_rn [2048] float32 → x_n [1 × 2048] float32
altup_router [2048 × 4] float32 → modalities [4] float32
altup_pred [16 × 4] float32 → coef_mat [4 × 4] float32
──────────────────────────────────────────────────────────────────────
xs_pred [4 × 2048] float32
4.B Attention — Q / K / V Projection¶
input_ln [2048] float32
W_q [2048 × 2048] INT4 → Q [1 × 2048] → [8 × 256] float32
W_k [512 × 2048] INT4 → K [1 × 512] → [2 × 256] float32
W_v [512 × 2048] INT4 → V [1 × 512] → [2 × 256] float32
4.B-2 Head-wise QK-Norm¶
RMSNorm per head (256 dims).
gamma_q [256] float32 (shared across the 8 Q heads)
gamma_k [256] float32 (shared across the 2 K heads)
──────────────────────────────────────────────────────
Q_norm [8 × 256] float32
K_norm [2 × 256] float32
4.B-3 Dynamic-θ RoPE¶
Q_rope [8 × 256] float32
K_rope [2 × 256] float32
See Gemma 3N — Attention and RoPE Constraints for the theta cycle visualization and why scaling / softcap are removed.
4.B-4 KV Cache Storage and Cross-Layer Sharing¶
KV cache stores float16 (downcast from float32).
Shape: K_cache[20, max_seq, 512], V_cache[20, max_seq, 512].
For layers that own a cache entry:
For layers that reuse an upstream cache:
K_cache [20 × max_seq × 512] float16
V_cache [20 × max_seq × 512] float16
target_K [L × 512] → [L × 2 × 256] float16 (L = pos + 1)
target_V [L × 512] → [L × 2 × 256] float16
4.B-5 GQA and Output Projection¶
scores_g [4 × L] float32 (g = 0, 1)
out_g [4 × 256] float32
attn_raw [1 × 2048] float32
W_o [2048 × 2048] INT4
──────────────────────────────────
attn_output [1 × 2048] float32
4.C LAuReL + Attention Residual Composition¶
LAuReL’s residual adds \(x_{norm}\), not the raw input:
laurel_left [64 × 2048] INT4 → laurel_x (intermediate) [1 × 64] float32
laurel_right [2048 × 64 ] INT4 → laurel_x [1 × 2048] float32
laurel_norm [2048] float32
post_attn_ln [2048] float32
──────────────────────────────────────────────────────────────────────
x_attn [1 × 2048] float32
See Gemma 3N — LAuReL and PLE Calibration Modules for the rationale behind the \(1/\sqrt{2}\) scale.
4.D FFN (Gate-Up-Down)¶
pre_ffn_ln [2048] float32
W_gate [16384 × 2048] INT4 → gate_raw [1 × 16384] float32
W_up [16384 × 2048] INT4 → up_out [1 × 16384] float32
For \(i \ge 10\) — standard GELU gate.
For \(i < 10\) — sparse gate (Gaussian Top-K, see Gemma 3N — FFN Gaussian Top-K Sparsity ).
hidden [1 × 16384] float32
FFN output.
W_down [2048 × 16384] INT4 → mlp_out [1 × 2048] float32
post_ffn_ln [2048] float32
───────────────────────────────────────────────────────────────────
outputs [1 × 2048] float32
4.E AltUp Correction + PLE Mixing¶
Step 1 — Scale and innovation.
altup_scale [2048] float32
activated [1 × 2048] float32
innovation [1 × 2048] float32
Step 2 — Correction coefficients.
altup_rn [2048] float32 → x_n3 [1 × 2048] float32
altup_router [2048 × 4] float32 → mod_corr [4] float32
altup_corr [4 × 4] float32 → corr_coefs [4] float32
Step 3 — Modality update (broadcasting).
corr_coefs[:, np.newaxis] [4 × 1] × innovation [1 × 2048] → [4 × 2048]
xs_new [4 × 2048] float32
Step 4 — PLE mixing (shadow-stream only).
W_ple_gate [256 × 2048] INT4 → gate_ple (INT4 matmul) [1 × 256] float32
pli [256] float32 → gate_ple (× pli) [1 × 256] float32
W_ple_proj [256 × 2048] float32 → gate_ple · W_ple_proj [1 × 2048] float32
ple_post_ln [2048] float32 → mapped [1 × 2048] float32
5. Decode Logits¶
Step 1 — Magnitude matching and unprojection.
altup_unprojs[k] [2048 × 2048] float32 (k = 0, 1, 2)
proj_x_k [1 × 2048] float32
Step 2 — Average and final projection.
W_final_norm [2048] float32
W_lm_head [262400 × 2048] INT4
─────────────────────────────────────
x_final_norm [1 × 2048] float32
logits [1 × 262400] float32
Step 3 — Logit soft-capping.
6. Generation and Sampling¶
Repetition penalty (\(\rho = 1.15\)):
Temperature softmax. For \(T = 0\): \(next\_token = \arg\max(Logits)\). For \(T > 0\):
Top-P sampling (\(p = 0.9\)):
Sampling (steps 5 step 3 onwards) is done on the host CPU in the pccx reference driver. The NPU returns logits after the soft-cap.
See also
Instruction-level mapping: Gemma 3N E4B on pccx v002 — Execution and Scheduling.
KV cache strategy: KV Cache Optimization Strategy.