CVO 코어 (SFU)

Complex Vector Operation 코어는 Transformer 가 요구하는 모든 비선형 활성화 — Softmax, GELU, RMSNorm, RoPE — 를 2 개의 μCVO-core 로 처리합니다. 각 μCVO-core 는 CORDIC 유닛 (회전 기반 sin/cos) 과 SFU (exp / sqrt / GELU / 역수용 LUT + 조각 다항식) 를 동시에 가집니다.

더 보기

pccx ISA 사양

OP_CVO 인코딩과 함수 코드 테이블 (§3.4).

모듈

  • CVO_top.sv — 최상위 오케스트레이터. 디스패처가 보낸 cvo_control_uop_t 를 받아 cvo_func 에 따라 CORDIC 이나 SFU 유닛으로 라우팅하고, sub_emaxaccm 플래그를 처리.

  • CVO_cordic_unit.sv — RoPE 가 필요로 하는 삼각함수 primitive 와 sin/cos 를 위한 CORDIC 회전 파이프라인.

  • CVO_sfu_unit.sv — exp, sqrt, GELU, 역수, reduce-sum 을 위한 LUT + 다항식 엔진.

소스

CVO_top.sv
`timescale 1ns / 1ps
`include "GLOBAL_CONST.svh"

import isa_pkg::*;
import bf16_math_pkg::*;

// ===| CVO Top |=================================================================
// Wraps CVO_sfu_unit (EXP/SQRT/GELU/RECIP/SCALE/REDUCE_SUM) and
// CVO_cordic_unit (SIN/COS) behind a unified streaming interface.
//
// Data flow:
//   Host issues OP_CVO via AXI-Lite → Global_Scheduler produces cvo_control_uop_t
//   → CVO_top latches uop, processes IN_length BF16 elements from L2 stream,
//     writes results back via output stream.
//
// FLAG_SUB_EMAX: subtract IN_e_max from each input before the function.
//   Implements exp(x - e_max) for numerically stable softmax.
// FLAG_ACCM: accumulate output into dst (add OUT_result to prior value).
//   Handled externally by the mem subsystem; CVO_top only signals it via OUT_accm.
//
// FSM states:
//   IDLE    : waiting for valid uop
//   RUNNING : streaming IN_length elements through the chosen unit
//   DONE    : pulse OUT_done for one cycle, return to IDLE
// ===============================================================================

module CVO_top (
    input  logic        clk,
    input  logic        rst_n,
    input  logic        i_clear,

    // ===| Dispatch from Global_Scheduler |=====================================
    input  cvo_control_uop_t IN_uop,
    input  logic             IN_uop_valid,
    output logic             OUT_uop_ready,

    // ===| BF16 Input Stream (from L2 via mem_dispatcher) |=====================
    input  logic [15:0]  IN_data,
    input  logic         IN_data_valid,
    output logic         OUT_data_ready,

    // ===| BF16 Output Stream (to L2 via mem_dispatcher) |=====================
    output logic [15:0]  OUT_result,
    output logic         OUT_result_valid,
    input  logic         IN_result_ready,

    // ===| e_max for FLAG_SUB_EMAX |============================================
    // Passed in as BF16; CVO subtracts this from each element before the function.
    input  logic [15:0]  IN_e_max,

    // ===| Status |=============================================================
    output logic         OUT_busy,
    output logic         OUT_done,
    output logic         OUT_accm   // mirrors IN_uop.flags.accm to mem subsystem
);

  // ===| FSM |===================================================================
  typedef enum logic [1:0] {
    ST_IDLE    = 2'b00,
    ST_RUNNING = 2'b01,
    ST_DONE    = 2'b10
  } cvo_state_e;

  cvo_state_e state;

  // ===| Latched UOP |===========================================================
  cvo_func_e   uop_func;
  cvo_flags_t  uop_flags;
  logic [15:0] uop_length;
  logic [15:0] elem_count;   // elements processed in current operation

  // ===| BF16 subtract e_max (combinational) |===================================
  // Implements x - e_max in BF16 via bf16_add(x, -e_max).

  logic [15:0] sub_emax_result_wire;

  always_comb begin : comb_sub_emax
    // Negate e_max by flipping sign bit, then add to x
    sub_emax_result_wire = bf16_add(IN_data, {~IN_e_max[15], IN_e_max[14:0]});
  end

  // ===| Input to sub-units (after optional e_max subtraction) |=================
  logic [15:0] data_to_unit_wire;
  logic        data_valid_to_unit_wire;

  always_comb begin
    data_to_unit_wire    = uop_flags.sub_emax ? sub_emax_result_wire : IN_data;
    data_valid_to_unit_wire = (state == ST_RUNNING) && IN_data_valid;
  end

  // ===| SFU Instantiation |=====================================================
  logic [15:0] sfu_result;
  logic        sfu_result_valid;
  logic        sfu_ready;

  CVO_sfu_unit u_CVO_sfu_unit (
      .clk             (clk),
      .rst_n           (rst_n),
      .i_clear         (i_clear),

      .IN_func         (uop_func),
      .IN_length       (uop_length),
      .IN_flags        (uop_flags),

      .IN_data         (data_to_unit_wire),
      .IN_valid        (data_valid_to_unit_wire && !is_cordic_op_wire),
      .OUT_data_ready  (sfu_ready),

      .OUT_result      (sfu_result),
      .OUT_result_valid(sfu_result_valid)
  );

  // ===| CORDIC Instantiation |==================================================
  logic [15:0] cordic_sin;
  logic [15:0] cordic_cos;
  logic        cordic_valid;

  CVO_cordic_unit u_CVO_cordic_unit (
      .clk          (clk),
      .rst_n        (rst_n),

      .IN_angle_bf16(data_to_unit_wire),
      .IN_valid     (data_valid_to_unit_wire && is_cordic_op_wire),

      .OUT_sin_bf16 (cordic_sin),
      .OUT_cos_bf16 (cordic_cos),
      .OUT_valid    (cordic_valid)
  );

  // ===| Opcode Routing |========================================================
  logic is_cordic_op_wire;
  always_comb begin
    is_cordic_op_wire = (uop_func == CVO_SIN) || (uop_func == CVO_COS);
  end

  // ===| FSM Logic |=============================================================
  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      state      <= ST_IDLE;
      uop_func   <= CVO_EXP;
      uop_flags  <= '0;
      uop_length <= 16'd0;
      elem_count <= 16'd0;
      OUT_done   <= 1'b0;
    end else begin
      OUT_done <= 1'b0;

      case (state)
        // ===| IDLE: wait for dispatch |===
        ST_IDLE: begin
          if (IN_uop_valid) begin
            uop_func   <= IN_uop.cvo_func;
            uop_flags  <= IN_uop.flags;
            uop_length <= IN_uop.length;
            elem_count <= 16'd0;
            state      <= ST_RUNNING;
          end
        end

        // ===| RUNNING: count consumed elements |===
        ST_RUNNING: begin
          if (IN_data_valid && OUT_data_ready) begin
            elem_count <= elem_count + 16'd1;
            if (elem_count == uop_length - 16'd1) begin
              state    <= ST_DONE;
            end
          end
        end

        // ===| DONE: pulse and return |===
        ST_DONE: begin
          OUT_done <= 1'b1;
          state    <= ST_IDLE;
        end

        default: state <= ST_IDLE;
      endcase
    end
  end

  // ===| Output Mux |============================================================
  // CORDIC outputs two results per input; select sin or cos based on function.
  logic [15:0] result_mux_wire;
  logic        result_valid_mux_wire;

  always_comb begin
    if (is_cordic_op_wire) begin
      result_mux_wire       = (uop_func == CVO_SIN) ? cordic_sin : cordic_cos;
      result_valid_mux_wire = cordic_valid;
    end else begin
      result_mux_wire       = sfu_result;
      result_valid_mux_wire = sfu_result_valid;
    end
  end

  // ===| Output Registers |======================================================
  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      OUT_result       <= 16'd0;
      OUT_result_valid <= 1'b0;
    end else begin
      OUT_result       <= result_mux_wire;
      OUT_result_valid <= result_valid_mux_wire && IN_result_ready;
    end
  end

  // ===| Status & Control |======================================================
  assign OUT_busy      = (state != ST_IDLE);
  assign OUT_uop_ready = (state == ST_IDLE);
  assign OUT_data_ready = sfu_ready && (state == ST_RUNNING);
  assign OUT_accm      = uop_flags.accm;

endmodule
CVO_cordic_unit.sv
`timescale 1ns / 1ps
`include "GLOBAL_CONST.svh"

import bf16_math_pkg::*;

// ===| CVO CORDIC Unit |=========================================================
// Computes sin(θ) and cos(θ) for a BF16 input angle (radians) using a 14-stage
// pipelined CORDIC algorithm (rotation mode).
//
// Internal format : Q4.12 signed fixed-point (16-bit)
//   1.0 = 0x1000 = 4096,  π ≈ 0x3244 = 12868
//
// CORDIC gain K (14 iterations) ≈ 0.60725, pre-baked into x0 = K * 4096 = 0x09B8.
// Pipeline latency : 16 clocks  (1 convert-in + 14 CORDIC + 1 convert-out)
// ===============================================================================

module CVO_cordic_unit (
    input  logic        clk,
    input  logic        rst_n,

    // ===| Input |===============================================================
    input  logic [15:0] IN_angle_bf16,    // BF16 angle in radians
    input  logic        IN_valid,

    // ===| Output |==============================================================
    output logic [15:0] OUT_sin_bf16,
    output logic [15:0] OUT_cos_bf16,
    output logic        OUT_valid
);

  // ===| CORDIC Constants (Q4.12) |==============================================
  // atan(2^-i) * 4096, i = 0..13
  localparam logic signed [15:0] AtanLut [14] = '{
    16'sh0C91,  // atan(2^0 ) = π/4     ≈ 0.7854
    16'sh076B,  // atan(2^-1)           ≈ 0.4636
    16'sh03EB,  // atan(2^-2)           ≈ 0.2450
    16'sh01FD,  // atan(2^-3)           ≈ 0.1244
    16'sh0100,  // atan(2^-4)           ≈ 0.0624
    16'sh0080,  // atan(2^-5)           ≈ 0.0312
    16'sh0040,  // atan(2^-6)           ≈ 0.0156
    16'sh0020,  // atan(2^-7)           ≈ 0.0078
    16'sh0010,  // atan(2^-8)           ≈ 0.0039
    16'sh0008,  // atan(2^-9)           ≈ 0.0020
    16'sh0004,  // atan(2^-10)          ≈ 0.0010
    16'sh0002,  // atan(2^-11)          ≈ 0.0005
    16'sh0001,  // atan(2^-12)          ≈ 0.0002
    16'sh0001   // atan(2^-13)          ≈ 0.0001
  };

  // CORDIC gain K pre-scaled: K * 4096 = 0x09B8
  localparam logic signed [15:0] CordicKQ412 = 16'sh09B8;

  // ===| Stage 0: BF16 → Q4.12 Conversion |=====================================
  logic signed [15:0] s0_angle_fixed;
  logic               s0_valid;

  always_comb begin : bf16_to_q412
    automatic logic        sign_bit;
    automatic logic [ 7:0] exp_raw;
    automatic logic [ 6:0] mant_raw;
    automatic logic [15:0] magnitude;
    automatic int          shift_amt;

    sign_bit = IN_angle_bf16[15];
    exp_raw  = IN_angle_bf16[14:7];
    mant_raw = IN_angle_bf16[ 6:0];

    // (1.mant) * 2^(exp-127) in Q4.12 → multiply by 2^12 / 2^(exp-127)
    // = {1, mant} * 2^(exp - 122)   ({1,mant} is an 8-bit value representing 1.mant * 128)
    shift_amt = int'(exp_raw) - 122;

    if (exp_raw == 8'd0) begin
      // denormal / zero → treat as 0
      magnitude = 16'd0;
    end else if (shift_amt >= 15) begin
      // overflow → saturate to max Q4.12 (π ≈ 3.14, max safe = 3.9999)
      magnitude = 16'h7FFF;
    end else if (shift_amt < -7) begin
      // underflow → rounds to 0
      magnitude = 16'd0;
    end else if (shift_amt >= 0) begin
      magnitude = 16'({1'b1, mant_raw, 7'b0} << shift_amt);
    end else begin
      magnitude = 16'({1'b1, mant_raw, 7'b0} >> (-shift_amt));
    end

    s0_angle_fixed = sign_bit ? -$signed(magnitude) : $signed(magnitude);
  end

  // Register stage 0
  logic signed [15:0] s0_angle_ff;
  always_ff @(posedge clk) begin
    if (!rst_n) begin
      s0_angle_ff <= 16'sh0;
      s0_valid    <= 1'b0;
    end else begin
      s0_angle_ff <= s0_angle_fixed;
      s0_valid    <= IN_valid;
    end
  end

  // ===| Stages 1-14: CORDIC Iterations |========================================
  // Each stage i: x_{i} → x_{i+1}, y_{i} → y_{i+1}, z_{i} → z_{i+1}
  // d_i = sign(z_i)
  // x_{i+1} = x_i - d_i * (y_i >>> i)
  // y_{i+1} = y_i + d_i * (x_i >>> i)
  // z_{i+1} = z_i - d_i * ATAN_LUT[i]

  logic signed [15:0] cx [0:14];  // CORDIC x pipeline
  logic signed [15:0] cy [0:14];  // CORDIC y pipeline
  logic signed [15:0] cz [0:14];  // CORDIC z pipeline
  logic               cv [0:14];  // valid pipeline

  // Initialize iteration 0 from stage-0 output
  assign cx[0] = CordicKQ412;
  assign cy[0] = 16'sh0;
  assign cz[0] = s0_angle_ff;
  assign cv[0] = s0_valid;

  genvar gi;
  generate
    for (gi = 0; gi < 14; gi++) begin : gen_cordic_iter
      always_ff @(posedge clk) begin
        if (!rst_n) begin
          cx[gi+1] <= 16'sh0;
          cy[gi+1] <= 16'sh0;
          cz[gi+1] <= 16'sh0;
          cv[gi+1] <= 1'b0;
        end else begin
          cv[gi+1] <= cv[gi];
          if (cz[gi] >= 0) begin
            // d_i = +1: rotate counter-clockwise
            cx[gi+1] <= cx[gi] - (cy[gi] >>> gi);
            cy[gi+1] <= cy[gi] + (cx[gi] >>> gi);
            cz[gi+1] <= cz[gi] - AtanLut[gi];
          end else begin
            // d_i = -1: rotate clockwise
            cx[gi+1] <= cx[gi] + (cy[gi] >>> gi);
            cy[gi+1] <= cy[gi] - (cx[gi] >>> gi);
            cz[gi+1] <= cz[gi] + AtanLut[gi];
          end
        end
      end
    end
  endgenerate

  // ===| Stage 15: Q4.12 → BF16 Conversion |=====================================
  // cos result = cx[14], sin result = cy[14]

  function automatic logic [15:0] q412_to_bf16(input logic signed [15:0] val);
    automatic logic        sign_out;
    automatic logic [14:0] mag;
    automatic logic [ 3:0] leading;
    automatic logic [ 7:0] exp_out;
    automatic logic [ 6:0] mant_out;

    sign_out = val[15];
    mag      = sign_out ? 15'(-$signed(val)) : 15'(val);

    // Find leading 1 position (bit 14 = highest in 15-bit magnitude)
    leading = 4'd0;
    for (int b = 14; b >= 0; b--) begin
      if (mag[b]) begin
        leading = 4'(b);
        break;
      end
    end

    if (mag == 0) begin
      return 16'd0;
    end else begin
      // biased exponent: value_in_Q412 = mag * 2^(leading-12)
      // BF16 exponent bias = 127; real exp = leading - 12
      exp_out = 8'd127 + leading - 8'd12;

      // 7 mantissa bits below the leading 1
      if (leading >= 7)
        mant_out = mag[leading-1 -: 7];
      else
        mant_out = 7'(mag[leading-1:0] << (7 - leading));

      return {sign_out, exp_out, mant_out};
    end
  endfunction

  always_ff @(posedge clk) begin
    if (!rst_n) begin
      OUT_cos_bf16 <= 16'd0;
      OUT_sin_bf16 <= 16'd0;
      OUT_valid    <= 1'b0;
    end else begin
      OUT_valid    <= cv[14];
      OUT_cos_bf16 <= q412_to_bf16(cx[14]);
      OUT_sin_bf16 <= q412_to_bf16(cy[14]);
    end
  end

endmodule
CVO_sfu_unit.sv
`timescale 1ns / 1ps
`include "GLOBAL_CONST.svh"

import isa_pkg::*;
import bf16_math_pkg::*;

// ===| CVO Special Function Unit |===============================================
// Streaming BF16 SFU for: EXP, SQRT, GELU, RECIP, SCALE, REDUCE_SUM.
// One element per cycle throughput once the pipeline is filled.
//
// Pipeline latencies (IN_valid to OUT_result_valid):
//   EXP        :  4 cycles
//   SQRT       :  3 cycles
//   RECIP      :  4 cycles
//   SCALE      :  3 cycles  (first input word = scalar, rest = data)
//   REDUCE_SUM :  variable  (accumulates for IN_length cycles, then emits scalar)
//   GELU       : 12 cycles  (MUL + EXP + ADD + RECIP + MUL chain)
//
// BF16 raw format: {sign[15], exp[14:7], mant[6:0]}
// ===============================================================================

module CVO_sfu_unit (
    input logic clk,
    input logic rst_n,
    input logic i_clear,

    // ===| Operation Select |====================================================
    input cvo_func_e IN_func,
    input logic [15:0] IN_length,
    input cvo_flags_t IN_flags,

    // ===| Streaming Input |=====================================================
    input  logic [15:0] IN_data,
    input  logic        IN_valid,
    output logic        OUT_data_ready,

    // ===| Streaming Output |====================================================
    output logic [15:0] OUT_result,
    output logic        OUT_result_valid
);

  // ===| BF16 Constants |========================================================
  localparam logic [15:0] Bf16One = 16'h3F80;  // 1.0
  localparam logic [15:0] Bf16Two = 16'h4000;  // 2.0
  localparam logic [15:0] Bf16Scale1702 = 16'h3FD9;  // 1.702 (GELU sigmoid scale)
  localparam logic [7:0] Log2EQ17 = 8'hB8;  // log2(e) ≈ 1.4427 in Q1.7

  // ===| BF16 Arithmetic (combinational) |=======================================

  // ===| BF16 Multiply |===
  function automatic logic [15:0] bf16_mul(input logic [15:0] a, input logic [15:0] b);
    logic        s;
    logic [ 8:0] esum;
    logic [15:0] mp;
    logic [ 7:0] er;
    logic [ 6:0] mr;
    if (a[14:0] == 0 || b[14:0] == 0) return 16'd0;
    s    = a[15] ^ b[15];
    esum = {1'b0, a[14:7]} + {1'b0, b[14:7]};
    mp   = {1'b1, a[6:0]} * {1'b1, b[6:0]};
    if (mp[15]) begin
      er = 8'(esum - 9'd127 + 9'd1);
      mr = mp[14:8];
    end else begin
      er = 8'(esum - 9'd127);
      mr = mp[13:7];
    end
    return {s, er, mr};
  endfunction

  // ===| BF16 Add |===
  function automatic logic [15:0] bf16_add(input logic [15:0] a, input logic [15:0] b);
    logic [7:0] ea, eb, elarge;
    logic [7:0] diff;
    logic [8:0] ma, mb;
    logic [9:0] sum;
    logic [7:0] eout;
    logic [6:0] mout;
    logic       sout;
    if (a[14:0] == 0) return b;
    if (b[14:0] == 0) return a;
    ea = a[14:7];
    eb = b[14:7];
    ma = {1'b0, 1'b1, a[6:0]};
    mb = {1'b0, 1'b1, b[6:0]};
    if (ea >= eb) begin
      elarge = ea;
      diff = ea - eb;
      mb = 9'(mb >> diff);
    end else begin
      elarge = eb;
      diff = eb - ea;
      ma = 9'(ma >> diff);
    end
    if (a[15] == b[15]) begin
      sout = a[15];
      sum  = {1'b0, ma} + {1'b0, mb};
    end else if (ma >= mb) begin
      sout = a[15];
      sum  = {1'b0, ma} - {1'b0, mb};
    end else begin
      sout = b[15];
      sum  = {1'b0, mb} - {1'b0, ma};
    end
    if (sum == 0) return 16'd0;
    if (sum[9]) begin
      eout = elarge + 8'd1;
      mout = sum[9:3];
    end else if (sum[8]) begin
      eout = elarge;
      mout = sum[8:2];
    end else begin
      eout = elarge - 8'd1;
      mout = sum[7:1];
    end
    return {sout, eout, mout};
  endfunction

  // ===| EXP mantissa LUT: mantissa bits of 2^(k/128), k=0..127 |===
  function automatic logic [6:0] exp_mant_lut(input logic [6:0] k);
    // 2^(k/128) - 1 ≈ k*ln2/128; includes curvature correction
    logic [8:0] v;
    v = {2'b0, k} + ({2'b0, k} * {2'b0, k} >> 9);
    return v[6:0];
  endfunction

  // ===| SQRT mantissa LUTs |===
  function automatic logic [6:0] sqrt_mant_even(input logic [6:0] k);
    // sqrt(1 + k/128) - 1; range [0, 0.414] → mant bits [0, 53]
    logic [8:0] v;
    v = ({2'b0, k} + ({2'b0, k} * {2'b0, 7'(128 - k)} >> 9)) >> 1;
    return v[6:0];
  endfunction

  function automatic logic [6:0] sqrt_mant_odd(input logic [6:0] k);
    // sqrt(2*(1+k/128)) - 1; range [0.414, 0.848] → mant bits [53, 108]
    logic [8:0] v;
    v = 9'd53 + ({2'b0, k} * 9'd91 >> 7);
    return v[6:0];
  endfunction

  // ===| EXP Pipeline (4 stages) |===============================================

  // Stage 1 — unpack BF16
  logic       exp_s1_valid;
  logic       exp_s1_sign;
  logic [7:0] exp_s1_exp;
  logic [6:0] exp_s1_mant;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      exp_s1_valid <= 1'b0;
    end else begin
      exp_s1_valid <= IN_valid && (IN_func == CVO_EXP || IN_func == CVO_GELU);
      exp_s1_sign  <= IN_data[15];
      exp_s1_exp   <= IN_data[14:7];
      exp_s1_mant  <= IN_data[6:0];
    end
  end

  // Stage 2 — BF16 → Q8.7 fixed-point, then multiply by log2(e)
  logic               exp_s2_valid;
  logic        [ 8:0] exp_s2_n;  // integer part of x*log2e
  logic        [ 6:0] exp_s2_frac;  // fractional 7-bit index for LUT

  logic signed [15:0] exp_s1_xfixed_wire;  // Q8.7 signed
  logic signed [23:0] exp_s1_y_wire;  // Q9.14: x*log2e

  always_comb begin : comb_exp_convert
    logic [15:0] mag;
    int          sh;
    sh  = int'(exp_s1_exp) - 127;
    mag = 16'd0;
    if (exp_s1_exp == 8'd0) mag = 16'd0;
    else if (sh >= 8) mag = 16'h7FFF;
    else if (sh >= -7) mag = 16'({1'b1, exp_s1_mant, 7'b0} << (sh + 7));
    else mag = 16'({1'b1, exp_s1_mant, 7'b0} >> -(sh + 7));
    exp_s1_xfixed_wire = exp_s1_sign ? -$signed({1'b0, mag}) : $signed({1'b0, mag});
    exp_s1_y_wire      = $signed(exp_s1_xfixed_wire) * $signed({1'b0, Log2EQ17});
  end

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      exp_s2_valid <= 1'b0;
    end else begin
      exp_s2_valid <= exp_s1_valid;
      exp_s2_n     <= 9'(exp_s1_y_wire[23:14]);
      exp_s2_frac  <= exp_s1_y_wire[13:7];
    end
  end

  // Stage 3 — assemble output BF16
  logic        exp_s3_valid;
  logic [15:0] exp_s3_result;

  logic [ 8:0] exp_s2_out_exp_wire;

  always_comb begin
    exp_s2_out_exp_wire = 9'd127 + exp_s2_n;
  end

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      exp_s3_valid <= 1'b0;
    end else begin
      exp_s3_valid <= exp_s2_valid;
      if (exp_s2_out_exp_wire[8] || exp_s2_out_exp_wire == 0)
        exp_s3_result <= (exp_s2_n[8] == 0) ? 16'h7F80 : 16'd0;  // +inf or 0
      else exp_s3_result <= {1'b0, exp_s2_out_exp_wire[7:0], exp_mant_lut(exp_s2_frac)};
    end
  end

  // ===| SQRT Pipeline (3 stages) |===============================================

  logic       sqrt_s1_valid;
  logic [7:0] sqrt_s1_exp;
  logic [6:0] sqrt_s1_mant;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      sqrt_s1_valid <= 1'b0;
    end else begin
      sqrt_s1_valid <= IN_valid && (IN_func == CVO_SQRT);
      sqrt_s1_exp   <= IN_data[14:7];
      sqrt_s1_mant  <= IN_data[6:0];
    end
  end

  logic        sqrt_s2_valid;
  logic [15:0] sqrt_s2_result;

  logic [ 7:0] sqrt_s1_unbiased_wire;
  logic [ 7:0] sqrt_s1_out_exp_wire;
  logic [ 6:0] sqrt_s1_out_mant_wire;

  always_comb begin
    sqrt_s1_unbiased_wire = sqrt_s1_exp - 8'd127;
    sqrt_s1_out_exp_wire = 8'd127 + {1'b0, sqrt_s1_unbiased_wire[7:1]};
    sqrt_s1_out_mant_wire = sqrt_s1_unbiased_wire[0] ? sqrt_mant_odd(sqrt_s1_mant) :
        sqrt_mant_even(sqrt_s1_mant);
  end

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      sqrt_s2_valid <= 1'b0;
    end else begin
      sqrt_s2_valid  <= sqrt_s1_valid;
      sqrt_s2_result <= {1'b0, sqrt_s1_out_exp_wire, sqrt_s1_out_mant_wire};
    end
  end

  logic        sqrt_s3_valid;
  logic [15:0] sqrt_s3_result;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      sqrt_s3_valid <= 1'b0;
    end else begin
      sqrt_s3_valid  <= sqrt_s2_valid;
      sqrt_s3_result <= sqrt_s2_result;
    end
  end

  // ===| RECIP Pipeline (4 stages) |=============================================
  // 1/x via 1 Newton-Raphson step: r1 = r0 * (2 - x*r0)
  // Initial estimate: exp flipped around 254, mantissa roughly inverted.

  logic        recip_s1_valid;
  logic        recip_s1_sign;
  logic [15:0] recip_s1_r0;
  logic [15:0] recip_s1_x;

  logic [ 7:0] recip_in_inv_exp_wire;
  logic [ 6:0] recip_in_inv_mant_wire;

  always_comb begin
    recip_in_inv_exp_wire  = 8'd254 - IN_data[14:7];
    recip_in_inv_mant_wire = 7'd127 - {1'b0, IN_data[6:1]};
  end

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      recip_s1_valid <= 1'b0;
    end else begin
      recip_s1_valid <= IN_valid && (IN_func == CVO_RECIP);
      recip_s1_sign  <= IN_data[15];
      recip_s1_x     <= {1'b0, IN_data[14:0]};  // |x|
      recip_s1_r0    <= {1'b0, recip_in_inv_exp_wire, recip_in_inv_mant_wire};
    end
  end

  logic        recip_s2_valid;
  logic [15:0] recip_s2_xr0;
  logic [15:0] recip_s2_r0;
  logic        recip_s2_sign;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      recip_s2_valid <= 1'b0;
    end else begin
      recip_s2_valid <= recip_s1_valid;
      recip_s2_xr0   <= bf16_mul(recip_s1_x, recip_s1_r0);
      recip_s2_r0    <= recip_s1_r0;
      recip_s2_sign  <= recip_s1_sign;
    end
  end

  logic        recip_s3_valid;
  logic [15:0] recip_s3_corr;
  logic [15:0] recip_s3_r0;
  logic        recip_s3_sign;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      recip_s3_valid <= 1'b0;
    end else begin
      recip_s3_valid <= recip_s2_valid;
      recip_s3_corr  <= bf16_add(Bf16Two, {1'b1, recip_s2_xr0[14:0]});
      recip_s3_r0    <= recip_s2_r0;
      recip_s3_sign  <= recip_s2_sign;
    end
  end

  logic        recip_s4_valid;
  logic [15:0] recip_s4_result;

  logic [15:0] recip_s3_mag_wire;
  always_comb begin
    recip_s3_mag_wire = bf16_mul(recip_s3_r0, recip_s3_corr);
  end

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      recip_s4_valid <= 1'b0;
    end else begin
      recip_s4_valid  <= recip_s3_valid;
      recip_s4_result <= {recip_s3_sign, recip_s3_mag_wire[14:0]};
    end
  end

  // ===| SCALE Pipeline (3 stages) |=============================================
  // First element received = scalar (or 1/scalar if FLAG_RECIP_SCALE).

  logic        scale_scalar_loaded;
  logic [15:0] scale_scalar;

  logic        scale_s1_valid;
  logic [15:0] scale_s1_product;

  logic        scale_s2_valid;
  logic [15:0] scale_s2_result;

  logic [15:0] scale_scalar_next_wire;

  always_comb begin
    // Approximate 1/scalar for recip_scale mode: flip exponent + invert mantissa
    scale_scalar_next_wire = IN_flags.recip_scale
      ? {IN_data[15], 8'(8'd254 - IN_data[14:7]), 7'(7'd127 - {1'b0, IN_data[6:1]})}
      : IN_data;
  end

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      scale_scalar_loaded <= 1'b0;
      scale_scalar        <= 16'd0;
      scale_s1_valid      <= 1'b0;
      scale_s2_valid      <= 1'b0;
    end else begin
      if (IN_valid && IN_func == CVO_SCALE) begin
        if (!scale_scalar_loaded) begin
          scale_scalar        <= scale_scalar_next_wire;
          scale_scalar_loaded <= 1'b1;
          scale_s1_valid      <= 1'b0;
        end else begin
          scale_s1_valid   <= 1'b1;
          scale_s1_product <= bf16_mul(IN_data, scale_scalar);
        end
      end else begin
        if (IN_func != CVO_SCALE) scale_scalar_loaded <= 1'b0;
        scale_s1_valid <= 1'b0;
      end

      scale_s2_valid  <= scale_s1_valid;
      scale_s2_result <= scale_s1_product;
    end
  end

  // ===| REDUCE_SUM (sequential BF16 accumulation) |=============================

  logic [15:0] rsum_count;
  logic [15:0] rsum_accum;
  logic        rsum_out_valid;
  logic [15:0] rsum_out;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      rsum_count     <= 16'd0;
      rsum_accum     <= 16'd0;
      rsum_out_valid <= 1'b0;
      rsum_out       <= 16'd0;
    end else begin
      rsum_out_valid <= 1'b0;
      if (IN_valid && IN_func == CVO_REDUCE_SUM) begin
        rsum_count <= rsum_count + 16'd1;
        rsum_accum <= bf16_add(rsum_accum, IN_data);
        if (rsum_count == IN_length - 16'd1) begin
          rsum_out       <= bf16_add(rsum_accum, IN_data);
          rsum_out_valid <= 1'b1;
          rsum_count     <= 16'd0;
          rsum_accum     <= 16'd0;
        end
      end else if (IN_func != CVO_REDUCE_SUM) begin
        rsum_count <= 16'd0;
        rsum_accum <= 16'd0;
      end
    end
  end

  // ===| GELU Pipeline (12 stages) |=============================================
  // GELU(x) ≈ x * sigmoid(1.702*x),  sigmoid(y) = 1/(1+exp(-y))
  // Chain: MUL(1.702)[1] → NEG[0] → EXP[3] → ADD(1)[1] → RECIP[4] → MUL(x)[1] = 10+delay
  // x is preserved in a 10-cycle delay chain.

  localparam int GeluDelay = 10;

  logic [15:0] gelu_x_pipe   [GeluDelay];

  // Stage g1: 1.702 * x
  logic        gelu_g1_valid;
  logic [15:0] gelu_g1_y;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      gelu_g1_valid <= 1'b0;
      for (int d = 0; d < GeluDelay; d++) gelu_x_pipe[d] <= 16'd0;
    end else begin
      gelu_x_pipe[0] <= IN_data;
      for (int d = 1; d < GeluDelay; d++) gelu_x_pipe[d] <= gelu_x_pipe[d-1];
      gelu_g1_valid <= IN_valid && (IN_func == CVO_GELU);
      gelu_g1_y     <= bf16_mul(IN_data, Bf16Scale1702);
    end
  end

  // Stage g2: negate → -1.702x
  logic        gelu_g2_valid;
  logic [15:0] gelu_g2_neg_y;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      gelu_g2_valid <= 1'b0;
    end else begin
      gelu_g2_valid <= gelu_g1_valid;
      gelu_g2_neg_y <= {~gelu_g1_y[15], gelu_g1_y[14:0]};
    end
  end

  // Stages g3-g5: EXP(-y) — re-use combinational helpers, pipeline 3 stages
  logic               gelu_e1_valid;
  logic        [ 8:0] gelu_e1_n;
  logic        [ 6:0] gelu_e1_frac;

  logic signed [15:0] gelu_g2_xf_wire;
  logic signed [23:0] gelu_g2_y_wire;

  always_comb begin : comb_gelu_exp_convert
    logic [15:0] mag;
    int          sh;
    sh  = int'(gelu_g2_neg_y[14:7]) - 127;
    mag = 16'd0;
    if (gelu_g2_neg_y[14:7] == 8'd0) mag = 16'd0;
    else if (sh >= 8) mag = 16'h7FFF;
    else if (sh >= -7) mag = 16'({1'b1, gelu_g2_neg_y[6:0], 7'b0} << (sh + 7));
    else mag = 16'({1'b1, gelu_g2_neg_y[6:0], 7'b0} >> -(sh + 7));
    gelu_g2_xf_wire = gelu_g2_neg_y[15] ? -$signed({1'b0, mag}) : $signed({1'b0, mag});
    gelu_g2_y_wire  = $signed(gelu_g2_xf_wire) * $signed({1'b0, Log2EQ17});
  end

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      gelu_e1_valid <= 1'b0;
    end else begin
      gelu_e1_valid <= gelu_g2_valid;
      gelu_e1_n     <= 9'(gelu_g2_y_wire[23:14]);
      gelu_e1_frac  <= gelu_g2_y_wire[13:7];
    end
  end

  logic        gelu_e2_valid;
  logic [15:0] gelu_e2_expval;

  logic [ 8:0] gelu_e1_out_exp_wire;
  always_comb begin
    gelu_e1_out_exp_wire = 9'd127 + gelu_e1_n;
  end

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      gelu_e2_valid <= 1'b0;
    end else begin
      gelu_e2_valid <= gelu_e1_valid;
      if (gelu_e1_out_exp_wire[8] || gelu_e1_out_exp_wire == 0)
        gelu_e2_expval <= (gelu_e1_n[8] == 0) ? 16'h7F80 : 16'd0;
      else gelu_e2_expval <= {1'b0, gelu_e1_out_exp_wire[7:0], exp_mant_lut(gelu_e1_frac)};
    end
  end

  // Stage g6: 1 + exp(-y)
  logic        gelu_g6_valid;
  logic [15:0] gelu_g6_denom;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      gelu_g6_valid <= 1'b0;
    end else begin
      gelu_g6_valid <= gelu_e2_valid;
      gelu_g6_denom <= bf16_add(Bf16One, gelu_e2_expval);
    end
  end

  // Stages g7-g9: RECIP(1 + exp(-y)) = sigmoid
  logic        gelu_r1_valid;
  logic [15:0] gelu_r1_r0;
  logic [15:0] gelu_r1_x;

  logic [ 7:0] gelu_recip_inv_exp_wire;
  logic [ 6:0] gelu_recip_inv_mant_wire;
  always_comb begin
    gelu_recip_inv_exp_wire  = 8'd254 - gelu_g6_denom[14:7];
    gelu_recip_inv_mant_wire = 7'd127 - {1'b0, gelu_g6_denom[6:1]};
  end

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      gelu_r1_valid <= 1'b0;
    end else begin
      gelu_r1_valid <= gelu_g6_valid;
      gelu_r1_x     <= gelu_g6_denom;
      gelu_r1_r0    <= {1'b0, gelu_recip_inv_exp_wire, gelu_recip_inv_mant_wire};
    end
  end

  logic        gelu_r2_valid;
  logic [15:0] gelu_r2_xr0;
  logic [15:0] gelu_r2_r0;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      gelu_r2_valid <= 1'b0;
    end else begin
      gelu_r2_valid <= gelu_r1_valid;
      gelu_r2_xr0   <= bf16_mul(gelu_r1_x, gelu_r1_r0);
      gelu_r2_r0    <= gelu_r1_r0;
    end
  end

  logic        gelu_r3_valid;
  logic [15:0] gelu_r3_sigmoid;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      gelu_r3_valid <= 1'b0;
    end else begin
      gelu_r3_valid   <= gelu_r2_valid;
      gelu_r3_sigmoid <= bf16_mul(gelu_r2_r0, bf16_add(Bf16Two, {1'b1, gelu_r2_xr0[14:0]}));
    end
  end

  // Stage g10: x * sigmoid = GELU(x)
  logic        gelu_out_valid;
  logic [15:0] gelu_out;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      gelu_out_valid <= 1'b0;
    end else begin
      gelu_out_valid <= gelu_r3_valid;
      gelu_out       <= bf16_mul(gelu_x_pipe[GeluDelay-1], gelu_r3_sigmoid);
    end
  end

  // ===| Output Mux |============================================================
  always_comb begin
    OUT_data_ready   = 1'b1;
    OUT_result       = 16'd0;
    OUT_result_valid = 1'b0;
    case (IN_func)
      CVO_EXP: begin
        OUT_result = exp_s3_result;
        OUT_result_valid = exp_s3_valid;
      end
      CVO_SQRT: begin
        OUT_result = sqrt_s3_result;
        OUT_result_valid = sqrt_s3_valid;
      end
      CVO_GELU: begin
        OUT_result = gelu_out;
        OUT_result_valid = gelu_out_valid;
      end
      CVO_RECIP: begin
        OUT_result = recip_s4_result;
        OUT_result_valid = recip_s4_valid;
      end
      CVO_SCALE: begin
        OUT_result = scale_s2_result;
        OUT_result_valid = scale_s2_valid;
      end
      CVO_REDUCE_SUM: begin
        OUT_result = rsum_out;
        OUT_result_valid = rsum_out_valid;
      end
      default: ;
    endcase
  end

endmodule