Gemma 3N E4B — 연산자 수준 파이프라인¶
Gemma 3N E4B 의 전체 연산 사양. 각 텐서에 shape / dtype 을 표기하여, 각 단계가 pccx v002 의 명령어와 1:1 로 매핑되도록 했습니다 (Gemma 3N E4B 를 pccx v002 에서 실행 — Execution / Scheduling 참고).
- 기본 차원
\(D=2048\), \(D_{ffn}=16384\), \(D_{patch}=256\), \(D_{router}=4\), \(Vocab=262400\).
- 어텐션
Q 8 / KV 2, \(d_{head}=256\), \(N_{layers}=35\).
0. 핵심 함수 정의¶
Embedding — 정수 id 로 행 조회.
RMSNorm
GELU
RoPE (Rotary Position Embedding)
INT4 역양자화 — 한 바이트에서 W4 두 개 풀어내기:
GQA (Grouped-Query Attention) — Q 8 헤드와 KV 2 헤드를 4:1 로 묶음. \(g \in \{0, 1\}\) 에 대해:
\(L\) 은 현재 시퀀스 길이, \(K_g\), \(V_g\) 는 KV 캐시에서 가져옵니다.
1. 토큰 임베딩¶
참고
token_id 는 PLE vocab 범위를 초과하지 않도록
min(token_id, 262143) 로 clip 합니다.
W_embed [262400 × 2048] INT4
─────────────────────────────────
x₀ [1 × 2048] float32
2. AltUp 초기 Projection¶
xs 의 0 번 행에 \(x_0\) 를 넣고, 1~3 행은 \(x_0\) 와
AltUp projection 행렬의 내적으로 채웁니다.
altup_projs[k] [2048 × 2048] float32 (k = 0, 1, 2)
─────────────────────────────────────────────────────
xs [4 × 2048] float32
3. PLE 사전 계산¶
보조 벡터 pli_all 을 토큰 진입 시 1 회만 계산해 35 레이어가 공유합니다.
Step 1 — Linear projection & normalization.
W_ple_proj [8960 × 2048] INT4 → x_proj [35 × 256] float32
norm_ple [256] float32 → x_proj_normed [35 × 256] float32
Step 2 — Patch 임베딩 추출.
W_ple [262144 × 8960] INT4 → y [35 × 256] float32
Step 3 — 합성.
pli_all [35 × 256] float32
4. Transformer 레이어 (×35)¶
루프 인덱스 \(i = 0, 1, \ldots, 34\).
4.A AltUp Router & Prediction¶
4 개의 modality 벡터를 혼합해 \(xs_{pred}\) 를 만듭니다.
altup_rn [2048] float32 → x_n [1 × 2048] float32
altup_router [2048 × 4] float32 → modalities [4] float32
altup_pred [16 × 4] float32 → coef_mat [4 × 4] float32
──────────────────────────────────────────────────────────────────────
xs_pred [4 × 2048] float32
4.B Attention — Q / K / V Projection¶
input_ln [2048] float32
W_q [2048 × 2048] INT4 → Q [1 × 2048] → [8 × 256] float32
W_k [512 × 2048] INT4 → K [1 × 512] → [2 × 256] float32
W_v [512 × 2048] INT4 → V [1 × 512] → [2 × 256] float32
4.B-2 Head-wise QK-Norm¶
헤드 단위 RMSNorm (256 차원).
gamma_q [256] float32 (Q 8 헤드에 공유)
gamma_k [256] float32 (K 2 헤드에 공유)
──────────────────────────────────────────
Q_norm [8 × 256] float32
K_norm [2 × 256] float32
4.B-3 동적 θ RoPE¶
Q_rope [8 × 256] float32
K_rope [2 × 256] float32
자세한 θ 주기 시각화 및 scaling / softcap 제거 근거는 Gemma 3N — Attention 및 RoPE 제약 참고.
4.B-4 KV 캐시 저장과 레이어 간 공유¶
KV 캐시는 float32 를 float16 으로 downcast 해서 저장합니다.
Shape: K_cache[20, max_seq, 512], V_cache[20, max_seq, 512].
고유 엔트리를 가지는 레이어:
상위 캐시를 재사용하는 레이어:
K_cache [20 × max_seq × 512] float16
V_cache [20 × max_seq × 512] float16
target_K [L × 512] → [L × 2 × 256] float16 (L = pos + 1)
target_V [L × 512] → [L × 2 × 256] float16
4.B-5 GQA & Output Projection¶
scores_g [4 × L] float32 (g = 0, 1)
out_g [4 × 256] float32
attn_raw [1 × 2048] float32
W_o [2048 × 2048] INT4
──────────────────────────────────
attn_output [1 × 2048] float32
4.C LAuReL + Attention Residual 합성¶
LAuReL 의 residual 에는 원본이 아니라 정규화된 입력 (\(x_{norm}\)) 이 더해집니다:
laurel_left [64 × 2048] INT4 → laurel_x (중간) [1 × 64] float32
laurel_right [2048 × 64 ] INT4 → laurel_x [1 × 2048] float32
laurel_norm [2048] float32
post_attn_ln [2048] float32
──────────────────────────────────────────────────────────────────
x_attn [1 × 2048] float32
\(1/\sqrt{2}\) 스케일의 근거는 Gemma 3N — LAuReL 과 PLE Calibration 모듈 참고.
4.D FFN (Gate-Up-Down)¶
pre_ffn_ln [2048] float32
W_gate [16384 × 2048] INT4 → gate_raw [1 × 16384] float32
W_up [16384 × 2048] INT4 → up_out [1 × 16384] float32
:math:`i ge 10` — 표준 GELU 게이트.
:math:`i < 10` — 희소 게이트 (Gaussian Top-K, :doc:`gemma3n_ffn_sparsity` 참고).
hidden [1 × 16384] float32
FFN 출력.
W_down [2048 × 16384] INT4 → mlp_out [1 × 2048] float32
post_ffn_ln [2048] float32
───────────────────────────────────────────────────────────────────
outputs [1 × 2048] float32
4.E AltUp Correction + PLE Mixing¶
Step 1 — Scale 과 innovation.
altup_scale [2048] float32
activated [1 × 2048] float32
innovation [1 × 2048] float32
Step 2 — 보정 계수.
altup_rn [2048] float32 → x_n3 [1 × 2048] float32
altup_router [2048 × 4] float32 → mod_corr [4] float32
altup_corr [4 × 4] float32 → corr_coefs [4] float32
Step 3 — Modality 업데이트 (broadcasting).
corr_coefs[:, np.newaxis] [4 × 1] × innovation [1 × 2048] → [4 × 2048]
xs_new [4 × 2048] float32
Step 4 — PLE mixing (shadow stream 전용).
W_ple_gate [256 × 2048] INT4 → gate_ple (INT4 matmul) [1 × 256] float32
pli [256] float32 → gate_ple (× pli) [1 × 256] float32
W_ple_proj [256 × 2048] float32 → gate_ple · W_ple_proj [1 × 2048] float32
ple_post_ln [2048] float32 → mapped [1 × 2048] float32
5. Decode Logits¶
Step 1 — Magnitude 매칭과 unprojection.
altup_unprojs[k] [2048 × 2048] float32 (k = 0, 1, 2)
proj_x_k [1 × 2048] float32
Step 2 — 평균과 최종 projection.
W_final_norm [2048] float32
W_lm_head [262400 × 2048] INT4
─────────────────────────────────────
x_final_norm [1 × 2048] float32
logits [1 × 262400] float32
Step 3 — Logit soft-capping.
6. 생성 및 샘플링¶
반복 페널티 (\(\rho = 1.15\)):
Temperature softmax. \(T = 0\) 이면 \(next\_token = \arg\max(Logits)\). \(T > 0\) 이면:
Top-P 샘플링 (\(p = 0.9\)):
샘플링 (§5 Step 3 이후) 은 pccx 레퍼런스 드라이버에서 호스트 CPU 가 처리합니다. NPU 는 soft-cap 후 logits 를 반환합니다.
더 보기
명령어 수준 매핑: Gemma 3N E4B 를 pccx v002 에서 실행 — Execution / Scheduling.
KV 캐시 전략: KV 캐시 최적화 전략.