CVO 코어 (SFU)¶
Complex Vector Operation 코어는 Transformer 가 요구하는 모든 비선형 활성화 — Softmax, GELU, RMSNorm, RoPE — 를 2 개의 μCVO-core 로 처리합니다. 각 μCVO-core 는 CORDIC 유닛 (회전 기반 sin/cos) 과 SFU (exp / sqrt / GELU / 역수용 LUT + 조각 다항식) 를 동시에 가집니다.
더 보기
- pccx ISA 사양
OP_CVO인코딩과 함수 코드 테이블 (§3.4).
모듈¶
CVO_top.sv— 최상위 오케스트레이터. 디스패처가 보낸cvo_control_uop_t를 받아cvo_func에 따라 CORDIC 이나 SFU 유닛으로 라우팅하고,sub_emax와accm플래그를 처리.CVO_cordic_unit.sv— RoPE 가 필요로 하는 삼각함수 primitive 와 sin/cos 를 위한 CORDIC 회전 파이프라인.CVO_sfu_unit.sv— exp, sqrt, GELU, 역수, reduce-sum 을 위한 LUT + 다항식 엔진.
소스¶
CVO_top.sv
`timescale 1ns / 1ps
`include "GLOBAL_CONST.svh"
import isa_pkg::*;
import bf16_math_pkg::*;
// ===| CVO Top |=================================================================
// Wraps CVO_sfu_unit (EXP/SQRT/GELU/RECIP/SCALE/REDUCE_SUM) and
// CVO_cordic_unit (SIN/COS) behind a unified streaming interface.
//
// Data flow:
// Host issues OP_CVO via AXI-Lite → Global_Scheduler produces cvo_control_uop_t
// → CVO_top latches uop, processes IN_length BF16 elements from L2 stream,
// writes results back via output stream.
//
// FLAG_SUB_EMAX: subtract IN_e_max from each input before the function.
// Implements exp(x - e_max) for numerically stable softmax.
// FLAG_ACCM: accumulate output into dst (add OUT_result to prior value).
// Handled externally by the mem subsystem; CVO_top only signals it via OUT_accm.
//
// FSM states:
// IDLE : waiting for valid uop
// RUNNING : streaming IN_length elements through the chosen unit
// DONE : pulse OUT_done for one cycle, return to IDLE
// ===============================================================================
module CVO_top (
input logic clk,
input logic rst_n,
input logic i_clear,
// ===| Dispatch from Global_Scheduler |=====================================
input cvo_control_uop_t IN_uop,
input logic IN_uop_valid,
output logic OUT_uop_ready,
// ===| BF16 Input Stream (from L2 via mem_dispatcher) |=====================
input logic [15:0] IN_data,
input logic IN_data_valid,
output logic OUT_data_ready,
// ===| BF16 Output Stream (to L2 via mem_dispatcher) |=====================
output logic [15:0] OUT_result,
output logic OUT_result_valid,
input logic IN_result_ready,
// ===| e_max for FLAG_SUB_EMAX |============================================
// Passed in as BF16; CVO subtracts this from each element before the function.
input logic [15:0] IN_e_max,
// ===| Status |=============================================================
output logic OUT_busy,
output logic OUT_done,
output logic OUT_accm // mirrors IN_uop.flags.accm to mem subsystem
);
// ===| FSM |===================================================================
typedef enum logic [1:0] {
ST_IDLE = 2'b00,
ST_RUNNING = 2'b01,
ST_DONE = 2'b10
} cvo_state_e;
cvo_state_e state;
// ===| Latched UOP |===========================================================
cvo_func_e uop_func;
cvo_flags_t uop_flags;
logic [15:0] uop_length;
logic [15:0] elem_count; // elements processed in current operation
// ===| BF16 subtract e_max (combinational) |===================================
// Implements x - e_max in BF16 via bf16_add(x, -e_max).
logic [15:0] sub_emax_result_wire;
always_comb begin : comb_sub_emax
// Negate e_max by flipping sign bit, then add to x
sub_emax_result_wire = bf16_add(IN_data, {~IN_e_max[15], IN_e_max[14:0]});
end
// ===| Input to sub-units (after optional e_max subtraction) |=================
logic [15:0] data_to_unit_wire;
logic data_valid_to_unit_wire;
always_comb begin
data_to_unit_wire = uop_flags.sub_emax ? sub_emax_result_wire : IN_data;
data_valid_to_unit_wire = (state == ST_RUNNING) && IN_data_valid;
end
// ===| SFU Instantiation |=====================================================
logic [15:0] sfu_result;
logic sfu_result_valid;
logic sfu_ready;
CVO_sfu_unit u_CVO_sfu_unit (
.clk (clk),
.rst_n (rst_n),
.i_clear (i_clear),
.IN_func (uop_func),
.IN_length (uop_length),
.IN_flags (uop_flags),
.IN_data (data_to_unit_wire),
.IN_valid (data_valid_to_unit_wire && !is_cordic_op_wire),
.OUT_data_ready (sfu_ready),
.OUT_result (sfu_result),
.OUT_result_valid(sfu_result_valid)
);
// ===| CORDIC Instantiation |==================================================
logic [15:0] cordic_sin;
logic [15:0] cordic_cos;
logic cordic_valid;
CVO_cordic_unit u_CVO_cordic_unit (
.clk (clk),
.rst_n (rst_n),
.IN_angle_bf16(data_to_unit_wire),
.IN_valid (data_valid_to_unit_wire && is_cordic_op_wire),
.OUT_sin_bf16 (cordic_sin),
.OUT_cos_bf16 (cordic_cos),
.OUT_valid (cordic_valid)
);
// ===| Opcode Routing |========================================================
logic is_cordic_op_wire;
always_comb begin
is_cordic_op_wire = (uop_func == CVO_SIN) || (uop_func == CVO_COS);
end
// ===| FSM Logic |=============================================================
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
state <= ST_IDLE;
uop_func <= CVO_EXP;
uop_flags <= '0;
uop_length <= 16'd0;
elem_count <= 16'd0;
OUT_done <= 1'b0;
end else begin
OUT_done <= 1'b0;
case (state)
// ===| IDLE: wait for dispatch |===
ST_IDLE: begin
if (IN_uop_valid) begin
uop_func <= IN_uop.cvo_func;
uop_flags <= IN_uop.flags;
uop_length <= IN_uop.length;
elem_count <= 16'd0;
state <= ST_RUNNING;
end
end
// ===| RUNNING: count consumed elements |===
ST_RUNNING: begin
if (IN_data_valid && OUT_data_ready) begin
elem_count <= elem_count + 16'd1;
if (elem_count == uop_length - 16'd1) begin
state <= ST_DONE;
end
end
end
// ===| DONE: pulse and return |===
ST_DONE: begin
OUT_done <= 1'b1;
state <= ST_IDLE;
end
default: state <= ST_IDLE;
endcase
end
end
// ===| Output Mux |============================================================
// CORDIC outputs two results per input; select sin or cos based on function.
logic [15:0] result_mux_wire;
logic result_valid_mux_wire;
always_comb begin
if (is_cordic_op_wire) begin
result_mux_wire = (uop_func == CVO_SIN) ? cordic_sin : cordic_cos;
result_valid_mux_wire = cordic_valid;
end else begin
result_mux_wire = sfu_result;
result_valid_mux_wire = sfu_result_valid;
end
end
// ===| Output Registers |======================================================
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
OUT_result <= 16'd0;
OUT_result_valid <= 1'b0;
end else begin
OUT_result <= result_mux_wire;
OUT_result_valid <= result_valid_mux_wire && IN_result_ready;
end
end
// ===| Status & Control |======================================================
assign OUT_busy = (state != ST_IDLE);
assign OUT_uop_ready = (state == ST_IDLE);
assign OUT_data_ready = sfu_ready && (state == ST_RUNNING);
assign OUT_accm = uop_flags.accm;
endmodule
CVO_cordic_unit.sv
`timescale 1ns / 1ps
`include "GLOBAL_CONST.svh"
import bf16_math_pkg::*;
// ===| CVO CORDIC Unit |=========================================================
// Computes sin(θ) and cos(θ) for a BF16 input angle (radians) using a 14-stage
// pipelined CORDIC algorithm (rotation mode).
//
// Internal format : Q4.12 signed fixed-point (16-bit)
// 1.0 = 0x1000 = 4096, π ≈ 0x3244 = 12868
//
// CORDIC gain K (14 iterations) ≈ 0.60725, pre-baked into x0 = K * 4096 = 0x09B8.
// Pipeline latency : 16 clocks (1 convert-in + 14 CORDIC + 1 convert-out)
// ===============================================================================
module CVO_cordic_unit (
input logic clk,
input logic rst_n,
// ===| Input |===============================================================
input logic [15:0] IN_angle_bf16, // BF16 angle in radians
input logic IN_valid,
// ===| Output |==============================================================
output logic [15:0] OUT_sin_bf16,
output logic [15:0] OUT_cos_bf16,
output logic OUT_valid
);
// ===| CORDIC Constants (Q4.12) |==============================================
// atan(2^-i) * 4096, i = 0..13
localparam logic signed [15:0] AtanLut [14] = '{
16'sh0C91, // atan(2^0 ) = π/4 ≈ 0.7854
16'sh076B, // atan(2^-1) ≈ 0.4636
16'sh03EB, // atan(2^-2) ≈ 0.2450
16'sh01FD, // atan(2^-3) ≈ 0.1244
16'sh0100, // atan(2^-4) ≈ 0.0624
16'sh0080, // atan(2^-5) ≈ 0.0312
16'sh0040, // atan(2^-6) ≈ 0.0156
16'sh0020, // atan(2^-7) ≈ 0.0078
16'sh0010, // atan(2^-8) ≈ 0.0039
16'sh0008, // atan(2^-9) ≈ 0.0020
16'sh0004, // atan(2^-10) ≈ 0.0010
16'sh0002, // atan(2^-11) ≈ 0.0005
16'sh0001, // atan(2^-12) ≈ 0.0002
16'sh0001 // atan(2^-13) ≈ 0.0001
};
// CORDIC gain K pre-scaled: K * 4096 = 0x09B8
localparam logic signed [15:0] CordicKQ412 = 16'sh09B8;
// ===| Stage 0: BF16 → Q4.12 Conversion |=====================================
logic signed [15:0] s0_angle_fixed;
logic s0_valid;
always_comb begin : bf16_to_q412
automatic logic sign_bit;
automatic logic [ 7:0] exp_raw;
automatic logic [ 6:0] mant_raw;
automatic logic [15:0] magnitude;
automatic int shift_amt;
sign_bit = IN_angle_bf16[15];
exp_raw = IN_angle_bf16[14:7];
mant_raw = IN_angle_bf16[ 6:0];
// (1.mant) * 2^(exp-127) in Q4.12 → multiply by 2^12 / 2^(exp-127)
// = {1, mant} * 2^(exp - 122) ({1,mant} is an 8-bit value representing 1.mant * 128)
shift_amt = int'(exp_raw) - 122;
if (exp_raw == 8'd0) begin
// denormal / zero → treat as 0
magnitude = 16'd0;
end else if (shift_amt >= 15) begin
// overflow → saturate to max Q4.12 (π ≈ 3.14, max safe = 3.9999)
magnitude = 16'h7FFF;
end else if (shift_amt < -7) begin
// underflow → rounds to 0
magnitude = 16'd0;
end else if (shift_amt >= 0) begin
magnitude = 16'({1'b1, mant_raw, 7'b0} << shift_amt);
end else begin
magnitude = 16'({1'b1, mant_raw, 7'b0} >> (-shift_amt));
end
s0_angle_fixed = sign_bit ? -$signed(magnitude) : $signed(magnitude);
end
// Register stage 0
logic signed [15:0] s0_angle_ff;
always_ff @(posedge clk) begin
if (!rst_n) begin
s0_angle_ff <= 16'sh0;
s0_valid <= 1'b0;
end else begin
s0_angle_ff <= s0_angle_fixed;
s0_valid <= IN_valid;
end
end
// ===| Stages 1-14: CORDIC Iterations |========================================
// Each stage i: x_{i} → x_{i+1}, y_{i} → y_{i+1}, z_{i} → z_{i+1}
// d_i = sign(z_i)
// x_{i+1} = x_i - d_i * (y_i >>> i)
// y_{i+1} = y_i + d_i * (x_i >>> i)
// z_{i+1} = z_i - d_i * ATAN_LUT[i]
logic signed [15:0] cx [0:14]; // CORDIC x pipeline
logic signed [15:0] cy [0:14]; // CORDIC y pipeline
logic signed [15:0] cz [0:14]; // CORDIC z pipeline
logic cv [0:14]; // valid pipeline
// Initialize iteration 0 from stage-0 output
assign cx[0] = CordicKQ412;
assign cy[0] = 16'sh0;
assign cz[0] = s0_angle_ff;
assign cv[0] = s0_valid;
genvar gi;
generate
for (gi = 0; gi < 14; gi++) begin : gen_cordic_iter
always_ff @(posedge clk) begin
if (!rst_n) begin
cx[gi+1] <= 16'sh0;
cy[gi+1] <= 16'sh0;
cz[gi+1] <= 16'sh0;
cv[gi+1] <= 1'b0;
end else begin
cv[gi+1] <= cv[gi];
if (cz[gi] >= 0) begin
// d_i = +1: rotate counter-clockwise
cx[gi+1] <= cx[gi] - (cy[gi] >>> gi);
cy[gi+1] <= cy[gi] + (cx[gi] >>> gi);
cz[gi+1] <= cz[gi] - AtanLut[gi];
end else begin
// d_i = -1: rotate clockwise
cx[gi+1] <= cx[gi] + (cy[gi] >>> gi);
cy[gi+1] <= cy[gi] - (cx[gi] >>> gi);
cz[gi+1] <= cz[gi] + AtanLut[gi];
end
end
end
end
endgenerate
// ===| Stage 15: Q4.12 → BF16 Conversion |=====================================
// cos result = cx[14], sin result = cy[14]
function automatic logic [15:0] q412_to_bf16(input logic signed [15:0] val);
automatic logic sign_out;
automatic logic [14:0] mag;
automatic logic [ 3:0] leading;
automatic logic [ 7:0] exp_out;
automatic logic [ 6:0] mant_out;
sign_out = val[15];
mag = sign_out ? 15'(-$signed(val)) : 15'(val);
// Find leading 1 position (bit 14 = highest in 15-bit magnitude)
leading = 4'd0;
for (int b = 14; b >= 0; b--) begin
if (mag[b]) begin
leading = 4'(b);
break;
end
end
if (mag == 0) begin
return 16'd0;
end else begin
// biased exponent: value_in_Q412 = mag * 2^(leading-12)
// BF16 exponent bias = 127; real exp = leading - 12
exp_out = 8'd127 + leading - 8'd12;
// 7 mantissa bits below the leading 1
if (leading >= 7)
mant_out = mag[leading-1 -: 7];
else
mant_out = 7'(mag[leading-1:0] << (7 - leading));
return {sign_out, exp_out, mant_out};
end
endfunction
always_ff @(posedge clk) begin
if (!rst_n) begin
OUT_cos_bf16 <= 16'd0;
OUT_sin_bf16 <= 16'd0;
OUT_valid <= 1'b0;
end else begin
OUT_valid <= cv[14];
OUT_cos_bf16 <= q412_to_bf16(cx[14]);
OUT_sin_bf16 <= q412_to_bf16(cy[14]);
end
end
endmodule
CVO_sfu_unit.sv
`timescale 1ns / 1ps
`include "GLOBAL_CONST.svh"
import isa_pkg::*;
import bf16_math_pkg::*;
// ===| CVO Special Function Unit |===============================================
// Streaming BF16 SFU for: EXP, SQRT, GELU, RECIP, SCALE, REDUCE_SUM.
// One element per cycle throughput once the pipeline is filled.
//
// Pipeline latencies (IN_valid to OUT_result_valid):
// EXP : 4 cycles
// SQRT : 3 cycles
// RECIP : 4 cycles
// SCALE : 3 cycles (first input word = scalar, rest = data)
// REDUCE_SUM : variable (accumulates for IN_length cycles, then emits scalar)
// GELU : 12 cycles (MUL + EXP + ADD + RECIP + MUL chain)
//
// BF16 raw format: {sign[15], exp[14:7], mant[6:0]}
// ===============================================================================
module CVO_sfu_unit (
input logic clk,
input logic rst_n,
input logic i_clear,
// ===| Operation Select |====================================================
input cvo_func_e IN_func,
input logic [15:0] IN_length,
input cvo_flags_t IN_flags,
// ===| Streaming Input |=====================================================
input logic [15:0] IN_data,
input logic IN_valid,
output logic OUT_data_ready,
// ===| Streaming Output |====================================================
output logic [15:0] OUT_result,
output logic OUT_result_valid
);
// ===| BF16 Constants |========================================================
localparam logic [15:0] Bf16One = 16'h3F80; // 1.0
localparam logic [15:0] Bf16Two = 16'h4000; // 2.0
localparam logic [15:0] Bf16Scale1702 = 16'h3FD9; // 1.702 (GELU sigmoid scale)
localparam logic [7:0] Log2EQ17 = 8'hB8; // log2(e) ≈ 1.4427 in Q1.7
// ===| BF16 Arithmetic (combinational) |=======================================
// ===| BF16 Multiply |===
function automatic logic [15:0] bf16_mul(input logic [15:0] a, input logic [15:0] b);
logic s;
logic [ 8:0] esum;
logic [15:0] mp;
logic [ 7:0] er;
logic [ 6:0] mr;
if (a[14:0] == 0 || b[14:0] == 0) return 16'd0;
s = a[15] ^ b[15];
esum = {1'b0, a[14:7]} + {1'b0, b[14:7]};
mp = {1'b1, a[6:0]} * {1'b1, b[6:0]};
if (mp[15]) begin
er = 8'(esum - 9'd127 + 9'd1);
mr = mp[14:8];
end else begin
er = 8'(esum - 9'd127);
mr = mp[13:7];
end
return {s, er, mr};
endfunction
// ===| BF16 Add |===
function automatic logic [15:0] bf16_add(input logic [15:0] a, input logic [15:0] b);
logic [7:0] ea, eb, elarge;
logic [7:0] diff;
logic [8:0] ma, mb;
logic [9:0] sum;
logic [7:0] eout;
logic [6:0] mout;
logic sout;
if (a[14:0] == 0) return b;
if (b[14:0] == 0) return a;
ea = a[14:7];
eb = b[14:7];
ma = {1'b0, 1'b1, a[6:0]};
mb = {1'b0, 1'b1, b[6:0]};
if (ea >= eb) begin
elarge = ea;
diff = ea - eb;
mb = 9'(mb >> diff);
end else begin
elarge = eb;
diff = eb - ea;
ma = 9'(ma >> diff);
end
if (a[15] == b[15]) begin
sout = a[15];
sum = {1'b0, ma} + {1'b0, mb};
end else if (ma >= mb) begin
sout = a[15];
sum = {1'b0, ma} - {1'b0, mb};
end else begin
sout = b[15];
sum = {1'b0, mb} - {1'b0, ma};
end
if (sum == 0) return 16'd0;
if (sum[9]) begin
eout = elarge + 8'd1;
mout = sum[9:3];
end else if (sum[8]) begin
eout = elarge;
mout = sum[8:2];
end else begin
eout = elarge - 8'd1;
mout = sum[7:1];
end
return {sout, eout, mout};
endfunction
// ===| EXP mantissa LUT: mantissa bits of 2^(k/128), k=0..127 |===
function automatic logic [6:0] exp_mant_lut(input logic [6:0] k);
// 2^(k/128) - 1 ≈ k*ln2/128; includes curvature correction
logic [8:0] v;
v = {2'b0, k} + ({2'b0, k} * {2'b0, k} >> 9);
return v[6:0];
endfunction
// ===| SQRT mantissa LUTs |===
function automatic logic [6:0] sqrt_mant_even(input logic [6:0] k);
// sqrt(1 + k/128) - 1; range [0, 0.414] → mant bits [0, 53]
logic [8:0] v;
v = ({2'b0, k} + ({2'b0, k} * {2'b0, 7'(128 - k)} >> 9)) >> 1;
return v[6:0];
endfunction
function automatic logic [6:0] sqrt_mant_odd(input logic [6:0] k);
// sqrt(2*(1+k/128)) - 1; range [0.414, 0.848] → mant bits [53, 108]
logic [8:0] v;
v = 9'd53 + ({2'b0, k} * 9'd91 >> 7);
return v[6:0];
endfunction
// ===| EXP Pipeline (4 stages) |===============================================
// Stage 1 — unpack BF16
logic exp_s1_valid;
logic exp_s1_sign;
logic [7:0] exp_s1_exp;
logic [6:0] exp_s1_mant;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
exp_s1_valid <= 1'b0;
end else begin
exp_s1_valid <= IN_valid && (IN_func == CVO_EXP || IN_func == CVO_GELU);
exp_s1_sign <= IN_data[15];
exp_s1_exp <= IN_data[14:7];
exp_s1_mant <= IN_data[6:0];
end
end
// Stage 2 — BF16 → Q8.7 fixed-point, then multiply by log2(e)
logic exp_s2_valid;
logic [ 8:0] exp_s2_n; // integer part of x*log2e
logic [ 6:0] exp_s2_frac; // fractional 7-bit index for LUT
logic signed [15:0] exp_s1_xfixed_wire; // Q8.7 signed
logic signed [23:0] exp_s1_y_wire; // Q9.14: x*log2e
always_comb begin : comb_exp_convert
logic [15:0] mag;
int sh;
sh = int'(exp_s1_exp) - 127;
mag = 16'd0;
if (exp_s1_exp == 8'd0) mag = 16'd0;
else if (sh >= 8) mag = 16'h7FFF;
else if (sh >= -7) mag = 16'({1'b1, exp_s1_mant, 7'b0} << (sh + 7));
else mag = 16'({1'b1, exp_s1_mant, 7'b0} >> -(sh + 7));
exp_s1_xfixed_wire = exp_s1_sign ? -$signed({1'b0, mag}) : $signed({1'b0, mag});
exp_s1_y_wire = $signed(exp_s1_xfixed_wire) * $signed({1'b0, Log2EQ17});
end
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
exp_s2_valid <= 1'b0;
end else begin
exp_s2_valid <= exp_s1_valid;
exp_s2_n <= 9'(exp_s1_y_wire[23:14]);
exp_s2_frac <= exp_s1_y_wire[13:7];
end
end
// Stage 3 — assemble output BF16
logic exp_s3_valid;
logic [15:0] exp_s3_result;
logic [ 8:0] exp_s2_out_exp_wire;
always_comb begin
exp_s2_out_exp_wire = 9'd127 + exp_s2_n;
end
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
exp_s3_valid <= 1'b0;
end else begin
exp_s3_valid <= exp_s2_valid;
if (exp_s2_out_exp_wire[8] || exp_s2_out_exp_wire == 0)
exp_s3_result <= (exp_s2_n[8] == 0) ? 16'h7F80 : 16'd0; // +inf or 0
else exp_s3_result <= {1'b0, exp_s2_out_exp_wire[7:0], exp_mant_lut(exp_s2_frac)};
end
end
// ===| SQRT Pipeline (3 stages) |===============================================
logic sqrt_s1_valid;
logic [7:0] sqrt_s1_exp;
logic [6:0] sqrt_s1_mant;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
sqrt_s1_valid <= 1'b0;
end else begin
sqrt_s1_valid <= IN_valid && (IN_func == CVO_SQRT);
sqrt_s1_exp <= IN_data[14:7];
sqrt_s1_mant <= IN_data[6:0];
end
end
logic sqrt_s2_valid;
logic [15:0] sqrt_s2_result;
logic [ 7:0] sqrt_s1_unbiased_wire;
logic [ 7:0] sqrt_s1_out_exp_wire;
logic [ 6:0] sqrt_s1_out_mant_wire;
always_comb begin
sqrt_s1_unbiased_wire = sqrt_s1_exp - 8'd127;
sqrt_s1_out_exp_wire = 8'd127 + {1'b0, sqrt_s1_unbiased_wire[7:1]};
sqrt_s1_out_mant_wire = sqrt_s1_unbiased_wire[0] ? sqrt_mant_odd(sqrt_s1_mant) :
sqrt_mant_even(sqrt_s1_mant);
end
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
sqrt_s2_valid <= 1'b0;
end else begin
sqrt_s2_valid <= sqrt_s1_valid;
sqrt_s2_result <= {1'b0, sqrt_s1_out_exp_wire, sqrt_s1_out_mant_wire};
end
end
logic sqrt_s3_valid;
logic [15:0] sqrt_s3_result;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
sqrt_s3_valid <= 1'b0;
end else begin
sqrt_s3_valid <= sqrt_s2_valid;
sqrt_s3_result <= sqrt_s2_result;
end
end
// ===| RECIP Pipeline (4 stages) |=============================================
// 1/x via 1 Newton-Raphson step: r1 = r0 * (2 - x*r0)
// Initial estimate: exp flipped around 254, mantissa roughly inverted.
logic recip_s1_valid;
logic recip_s1_sign;
logic [15:0] recip_s1_r0;
logic [15:0] recip_s1_x;
logic [ 7:0] recip_in_inv_exp_wire;
logic [ 6:0] recip_in_inv_mant_wire;
always_comb begin
recip_in_inv_exp_wire = 8'd254 - IN_data[14:7];
recip_in_inv_mant_wire = 7'd127 - {1'b0, IN_data[6:1]};
end
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
recip_s1_valid <= 1'b0;
end else begin
recip_s1_valid <= IN_valid && (IN_func == CVO_RECIP);
recip_s1_sign <= IN_data[15];
recip_s1_x <= {1'b0, IN_data[14:0]}; // |x|
recip_s1_r0 <= {1'b0, recip_in_inv_exp_wire, recip_in_inv_mant_wire};
end
end
logic recip_s2_valid;
logic [15:0] recip_s2_xr0;
logic [15:0] recip_s2_r0;
logic recip_s2_sign;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
recip_s2_valid <= 1'b0;
end else begin
recip_s2_valid <= recip_s1_valid;
recip_s2_xr0 <= bf16_mul(recip_s1_x, recip_s1_r0);
recip_s2_r0 <= recip_s1_r0;
recip_s2_sign <= recip_s1_sign;
end
end
logic recip_s3_valid;
logic [15:0] recip_s3_corr;
logic [15:0] recip_s3_r0;
logic recip_s3_sign;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
recip_s3_valid <= 1'b0;
end else begin
recip_s3_valid <= recip_s2_valid;
recip_s3_corr <= bf16_add(Bf16Two, {1'b1, recip_s2_xr0[14:0]});
recip_s3_r0 <= recip_s2_r0;
recip_s3_sign <= recip_s2_sign;
end
end
logic recip_s4_valid;
logic [15:0] recip_s4_result;
logic [15:0] recip_s3_mag_wire;
always_comb begin
recip_s3_mag_wire = bf16_mul(recip_s3_r0, recip_s3_corr);
end
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
recip_s4_valid <= 1'b0;
end else begin
recip_s4_valid <= recip_s3_valid;
recip_s4_result <= {recip_s3_sign, recip_s3_mag_wire[14:0]};
end
end
// ===| SCALE Pipeline (3 stages) |=============================================
// First element received = scalar (or 1/scalar if FLAG_RECIP_SCALE).
logic scale_scalar_loaded;
logic [15:0] scale_scalar;
logic scale_s1_valid;
logic [15:0] scale_s1_product;
logic scale_s2_valid;
logic [15:0] scale_s2_result;
logic [15:0] scale_scalar_next_wire;
always_comb begin
// Approximate 1/scalar for recip_scale mode: flip exponent + invert mantissa
scale_scalar_next_wire = IN_flags.recip_scale
? {IN_data[15], 8'(8'd254 - IN_data[14:7]), 7'(7'd127 - {1'b0, IN_data[6:1]})}
: IN_data;
end
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
scale_scalar_loaded <= 1'b0;
scale_scalar <= 16'd0;
scale_s1_valid <= 1'b0;
scale_s2_valid <= 1'b0;
end else begin
if (IN_valid && IN_func == CVO_SCALE) begin
if (!scale_scalar_loaded) begin
scale_scalar <= scale_scalar_next_wire;
scale_scalar_loaded <= 1'b1;
scale_s1_valid <= 1'b0;
end else begin
scale_s1_valid <= 1'b1;
scale_s1_product <= bf16_mul(IN_data, scale_scalar);
end
end else begin
if (IN_func != CVO_SCALE) scale_scalar_loaded <= 1'b0;
scale_s1_valid <= 1'b0;
end
scale_s2_valid <= scale_s1_valid;
scale_s2_result <= scale_s1_product;
end
end
// ===| REDUCE_SUM (sequential BF16 accumulation) |=============================
logic [15:0] rsum_count;
logic [15:0] rsum_accum;
logic rsum_out_valid;
logic [15:0] rsum_out;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
rsum_count <= 16'd0;
rsum_accum <= 16'd0;
rsum_out_valid <= 1'b0;
rsum_out <= 16'd0;
end else begin
rsum_out_valid <= 1'b0;
if (IN_valid && IN_func == CVO_REDUCE_SUM) begin
rsum_count <= rsum_count + 16'd1;
rsum_accum <= bf16_add(rsum_accum, IN_data);
if (rsum_count == IN_length - 16'd1) begin
rsum_out <= bf16_add(rsum_accum, IN_data);
rsum_out_valid <= 1'b1;
rsum_count <= 16'd0;
rsum_accum <= 16'd0;
end
end else if (IN_func != CVO_REDUCE_SUM) begin
rsum_count <= 16'd0;
rsum_accum <= 16'd0;
end
end
end
// ===| GELU Pipeline (12 stages) |=============================================
// GELU(x) ≈ x * sigmoid(1.702*x), sigmoid(y) = 1/(1+exp(-y))
// Chain: MUL(1.702)[1] → NEG[0] → EXP[3] → ADD(1)[1] → RECIP[4] → MUL(x)[1] = 10+delay
// x is preserved in a 10-cycle delay chain.
localparam int GeluDelay = 10;
logic [15:0] gelu_x_pipe [GeluDelay];
// Stage g1: 1.702 * x
logic gelu_g1_valid;
logic [15:0] gelu_g1_y;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
gelu_g1_valid <= 1'b0;
for (int d = 0; d < GeluDelay; d++) gelu_x_pipe[d] <= 16'd0;
end else begin
gelu_x_pipe[0] <= IN_data;
for (int d = 1; d < GeluDelay; d++) gelu_x_pipe[d] <= gelu_x_pipe[d-1];
gelu_g1_valid <= IN_valid && (IN_func == CVO_GELU);
gelu_g1_y <= bf16_mul(IN_data, Bf16Scale1702);
end
end
// Stage g2: negate → -1.702x
logic gelu_g2_valid;
logic [15:0] gelu_g2_neg_y;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
gelu_g2_valid <= 1'b0;
end else begin
gelu_g2_valid <= gelu_g1_valid;
gelu_g2_neg_y <= {~gelu_g1_y[15], gelu_g1_y[14:0]};
end
end
// Stages g3-g5: EXP(-y) — re-use combinational helpers, pipeline 3 stages
logic gelu_e1_valid;
logic [ 8:0] gelu_e1_n;
logic [ 6:0] gelu_e1_frac;
logic signed [15:0] gelu_g2_xf_wire;
logic signed [23:0] gelu_g2_y_wire;
always_comb begin : comb_gelu_exp_convert
logic [15:0] mag;
int sh;
sh = int'(gelu_g2_neg_y[14:7]) - 127;
mag = 16'd0;
if (gelu_g2_neg_y[14:7] == 8'd0) mag = 16'd0;
else if (sh >= 8) mag = 16'h7FFF;
else if (sh >= -7) mag = 16'({1'b1, gelu_g2_neg_y[6:0], 7'b0} << (sh + 7));
else mag = 16'({1'b1, gelu_g2_neg_y[6:0], 7'b0} >> -(sh + 7));
gelu_g2_xf_wire = gelu_g2_neg_y[15] ? -$signed({1'b0, mag}) : $signed({1'b0, mag});
gelu_g2_y_wire = $signed(gelu_g2_xf_wire) * $signed({1'b0, Log2EQ17});
end
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
gelu_e1_valid <= 1'b0;
end else begin
gelu_e1_valid <= gelu_g2_valid;
gelu_e1_n <= 9'(gelu_g2_y_wire[23:14]);
gelu_e1_frac <= gelu_g2_y_wire[13:7];
end
end
logic gelu_e2_valid;
logic [15:0] gelu_e2_expval;
logic [ 8:0] gelu_e1_out_exp_wire;
always_comb begin
gelu_e1_out_exp_wire = 9'd127 + gelu_e1_n;
end
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
gelu_e2_valid <= 1'b0;
end else begin
gelu_e2_valid <= gelu_e1_valid;
if (gelu_e1_out_exp_wire[8] || gelu_e1_out_exp_wire == 0)
gelu_e2_expval <= (gelu_e1_n[8] == 0) ? 16'h7F80 : 16'd0;
else gelu_e2_expval <= {1'b0, gelu_e1_out_exp_wire[7:0], exp_mant_lut(gelu_e1_frac)};
end
end
// Stage g6: 1 + exp(-y)
logic gelu_g6_valid;
logic [15:0] gelu_g6_denom;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
gelu_g6_valid <= 1'b0;
end else begin
gelu_g6_valid <= gelu_e2_valid;
gelu_g6_denom <= bf16_add(Bf16One, gelu_e2_expval);
end
end
// Stages g7-g9: RECIP(1 + exp(-y)) = sigmoid
logic gelu_r1_valid;
logic [15:0] gelu_r1_r0;
logic [15:0] gelu_r1_x;
logic [ 7:0] gelu_recip_inv_exp_wire;
logic [ 6:0] gelu_recip_inv_mant_wire;
always_comb begin
gelu_recip_inv_exp_wire = 8'd254 - gelu_g6_denom[14:7];
gelu_recip_inv_mant_wire = 7'd127 - {1'b0, gelu_g6_denom[6:1]};
end
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
gelu_r1_valid <= 1'b0;
end else begin
gelu_r1_valid <= gelu_g6_valid;
gelu_r1_x <= gelu_g6_denom;
gelu_r1_r0 <= {1'b0, gelu_recip_inv_exp_wire, gelu_recip_inv_mant_wire};
end
end
logic gelu_r2_valid;
logic [15:0] gelu_r2_xr0;
logic [15:0] gelu_r2_r0;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
gelu_r2_valid <= 1'b0;
end else begin
gelu_r2_valid <= gelu_r1_valid;
gelu_r2_xr0 <= bf16_mul(gelu_r1_x, gelu_r1_r0);
gelu_r2_r0 <= gelu_r1_r0;
end
end
logic gelu_r3_valid;
logic [15:0] gelu_r3_sigmoid;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
gelu_r3_valid <= 1'b0;
end else begin
gelu_r3_valid <= gelu_r2_valid;
gelu_r3_sigmoid <= bf16_mul(gelu_r2_r0, bf16_add(Bf16Two, {1'b1, gelu_r2_xr0[14:0]}));
end
end
// Stage g10: x * sigmoid = GELU(x)
logic gelu_out_valid;
logic [15:0] gelu_out;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
gelu_out_valid <= 1'b0;
end else begin
gelu_out_valid <= gelu_r3_valid;
gelu_out <= bf16_mul(gelu_x_pipe[GeluDelay-1], gelu_r3_sigmoid);
end
end
// ===| Output Mux |============================================================
always_comb begin
OUT_data_ready = 1'b1;
OUT_result = 16'd0;
OUT_result_valid = 1'b0;
case (IN_func)
CVO_EXP: begin
OUT_result = exp_s3_result;
OUT_result_valid = exp_s3_valid;
end
CVO_SQRT: begin
OUT_result = sqrt_s3_result;
OUT_result_valid = sqrt_s3_valid;
end
CVO_GELU: begin
OUT_result = gelu_out;
OUT_result_valid = gelu_out_valid;
end
CVO_RECIP: begin
OUT_result = recip_s4_result;
OUT_result_valid = recip_s4_valid;
end
CVO_SCALE: begin
OUT_result = scale_s2_result;
OUT_result_valid = scale_s2_valid;
end
CVO_REDUCE_SUM: begin
OUT_result = rsum_out;
OUT_result_valid = rsum_out_valid;
end
default: ;
endcase
end
endmodule