컴퓨트 코어 모듈

GitHub RTL 원본

이 페이지가 참조하는 SystemVerilog 원본 파일:

1. 행렬 코어 — Systolic Top

GEMM_systolic_top.sv는 32 × 32 시스톨릭 어레이 (cascade 가 16 행 에서 끊겨 32 × 16 서브체인 2 개로 나뉨) 를 감싸는 래퍼입니다. HP0/HP1 에서 가중치 타일을, L2 캐시에서 활성화 행을 수신하고 누산 결과를 후처리기 (post-processor) 로 스트리밍합니다.

리스팅 3 hw/rtl/MAT_CORE/GEMM_systolic_top.sv
`timescale 1ns / 1ps

`include "GLOBAL_CONST.svh"
`include "GEMM_Array.svh"

/**
 * Module: GEMM_systolic_top
 * Target: Kria KV260 @ 400MHz
 *
 * Architecture V2:
 * - Weight Dispatcher (Unpacker)
 * - Staggered Delay Lines for FMap & Instructions
 * - 32x32 Systolic Array Core
 * - e_max Pipe for Synchronization with Result Output
 */

module GEMM_systolic_top #(
    parameter weight_lane_cnt = `HP_PORT_CNT,
    parameter weight_width_per_lane = `HP_PORT_SINGLE_WIDTH,
    parameter weight_size     = `INT4,

    // 32 = 128bit / int4(4bit)
    parameter weight_cnt      = `HP_WEIGHT_CNT(`HP_PORT_SINGLE_WIDTH, `INT4),

    parameter array_horizontal = `ARRAY_SIZE_H,
    parameter array_vertical   = `ARRAY_SIZE_V,

    parameter dsp_A_port       = `ACIN,

    parameter IN_fmap_brodcast = `FIXED_MANT_WIDTH

)(
    input logic clk,
    input logic rst_n,
    input logic i_clear,

    // Control & Inst
    input logic global_weight_valid,
    input logic [2:0] global_inst,
    input logic global_inst_valid,

    // Feature Map Broadcast (from SRAM Cache)
    input logic [IN_fmap_brodcast-1:0] IN_fmap_broadcast      [0:`ARRAY_SIZE_H-1],
    input logic                        IN_fmap_broadcast_valid,

    // e_max (from Cache for Normalization alignment)
    input logic [`BF16_EXP_WIDTH-1:0]  IN_cached_emax_out[0:`ARRAY_SIZE_H-1],

    // Weight Input from FIFO (Direct)
    input  logic [`HP_PORT_MAX_WIDTH-1:0] IN_weight_fifo_data,
    input  logic                          IN_weight_fifo_valid,
    output logic                          weight_fifo_ready,

    // Output Results (Raw)
    output logic [`DSP48E2_POUT_SIZE-1:0] raw_res_sum      [0:`ARRAY_SIZE_H-1],
    output logic                          raw_res_sum_valid[0:`ARRAY_SIZE_H-1],

    // Delayed e_max for Normalizers
    output logic [`BF16_EXP_WIDTH-1:0] delayed_emax_32[0:`ARRAY_SIZE_H-1]
);

  // ===| Weight Dispatcher (The Unpacker) |=======
  logic [weight_size-1:0] unpacked_weights [0:weight_cnt-1];
  logic             weights_ready_for_array;

  GEMM_weight_dispatcher #(
    .weight_lane_cnt(weight_lane_cnt),
    .weight_width_per_lane(weight_width_per_lane),
    .weight_size(weight_size),
    .weight_cnt(weight_cnt)
    .array_horizontal(array_horizontal),
    .array_vertical(array_vertical)
  ) u_weight_unpacker (
      .clk(clk),
      .rst_n(rst_n),
      .fifo_data(IN_weight_fifo_data),
      .fifo_valid(IN_weight_fifo_valid),
      .fifo_ready(weight_fifo_ready),
      .weight_out(unpacked_weights),
      .weight_valid(weights_ready_for_array)
  );

  // ===| Staggered Delay Line for FMap & Instructions |=======
  logic [dsp_A_port-1:0] staggered_fmap      [0:`ARRAY_SIZE_H-1];
  logic                  staggered_fmap_valid[0:`ARRAY_SIZE_H-1];
  logic [           2:0] staggered_inst      [0:`ARRAY_SIZE_H-1];
  logic                  staggered_inst_valid[0:`ARRAY_SIZE_H-1];

  GEMM_fmap_staggered_dispatch #(
      .fmap_width(IN_fmap_brodcast),
      .array_size(array_vertical),
      .fmap_out_width(dsp_A_port)
  ) u_delay_line (
      .clk(clk),
      .rst_n(rst_n),
      .fmap_in(IN_fmap_broadcast),
      .fmap_valid(IN_fmap_broadcast_valid),
      .global_inst(global_inst),
      .global_inst_valid(global_inst_valid),
      .row_data(staggered_fmap),
      .row_valid(staggered_fmap_valid),
      .row_inst(staggered_inst),
      .row_inst_valid(staggered_inst_valid)
  );

  // ===| Systolic Array Core (The Engine) |=======
  logic [`DSP48E2_POUT_SIZE-1:0] raw_res_seq[0:`ARRAY_SIZE_H-1];

  GEMM_systolic_array #(
      .ARRAY_HORIZONTAL(`ARRAY_SIZE_H),
      .array_vertical  (`ARRAY_SIZE_V),
      .h_in_size(`GEMM_MAC_UNIT_IN_H),
      .v_in_size(`GEMM_MAC_UNIT_IN_V)
  ) u_compute_core (
      .clk(clk),
      .rst_n(rst_n),
      .i_clear(i_clear),
      .i_weight_valid(global_weight_valid),

      // Horizontal: Weights
      .H_in(unpacked_weights),

      // Vertical: Feature Map Broadcast & Instructions (Staggered)
      .V_in(staggered_fmap),
      .in_valid(staggered_fmap_valid),
      .inst_in(staggered_inst),
      .inst_valid_in(staggered_inst_valid),

      .V_out(raw_res_seq),
      .V_ACC_out(raw_res_sum),
      .V_ACC_valid(raw_res_sum_valid)
  );

  // ===| e_max Delay Pipe for Normalization alignment |=======
  localparam TOTAL_LATENCY = `SYSTOLIC_TOTAL_LATENCY;
  logic [`BF16_EXP_WIDTH-1:0] emax_pipe[0:`ARRAY_SIZE_H-1][0:TOTAL_LATENCY-1];

  always_ff @(posedge clk) begin
    if (!rst_n) begin
      for (int c = 0; c < `ARRAY_SIZE_H; c++) begin
        for (int d = 0; d < TOTAL_LATENCY; d++) begin
          emax_pipe[c][d] <= 0;
        end
      end
    end else begin
      for (int c = 0; c < `ARRAY_SIZE_H; c++) begin
        emax_pipe[c][0] <= IN_cached_emax_out[c];
        for (int d = 1; d < TOTAL_LATENCY; d++) begin
          emax_pipe[c][d] <= emax_pipe[c][d-1];
        end
      end
    end
  end

  always_comb begin
    for (int c = 0; c < `ARRAY_SIZE_H; c++) begin
      delayed_emax_32[c] = emax_pipe[c][TOTAL_LATENCY-1];
    end
  end

endmodule

2. 벡터 코어 — GEMV Top

GEMV_top.sv는 4 개의 병렬 GEMV 코어를 인스턴스화합니다. 각 코어는 32-wide LUT 기반 MAC 과 5 단 reduction tree (Stage 1: DSP48E2 16 슬라이스, Stage 2–5: LUT 가산기) 를 가지며, HP2/HP3 에서 가중치를 스트리밍합니다.

리스팅 4 hw/rtl/VEC_CORE/GEMV_top.sv
`timescale 1ns / 1ps

`include "GEMV_Vec_Matrix_MUL.svh"
`include "GLOBAL_CONST.svh"

// weight size = 4bit
// feature_map size =  bf16
module GEMV_top
  import vec_core_pkg::*;
#(
    parameter gemv_cfg_t param = VecCoreDefaultCfg,
    parameter A = 0,
    parameter B = 1,
    parameter C = 2,
    parameter D = 3
) (
    input logic clk,
    input logic rst_n,

    input logic IN_weight_valid_A,
    input logic IN_weight_valid_B,
    input logic IN_weight_valid_C,
    input logic IN_weight_valid_D,

    input logic [param.weight_width - 1:0] IN_weight_A[0:param.weight_cnt -1],
    input logic [param.weight_width - 1:0] IN_weight_B[0:param.weight_cnt -1],
    input logic [param.weight_width - 1:0] IN_weight_C[0:param.weight_cnt -1],
    input logic [param.weight_width - 1:0] IN_weight_D[0:param.weight_cnt -1],

    output logic OUT_weight_ready_A,
    output logic OUT_weight_ready_B,
    output logic OUT_weight_ready_C,
    output logic OUT_weight_ready_D,

    input logic [param.fixed_mant_width-1:0] IN_fmap_broadcast      [0:param.fmap_cache_out_cnt-1],
    input logic                              IN_fmap_broadcast_valid,

    input logic [16:0] IN_num_recur,
    // e_max (from Cache for Normalization alignment)
    input logic [dtype_pkg::Bf16ExpWidth-1:0] IN_cached_emax_out[0:param.fmap_cache_out_cnt-1],

    input logic IN_activated_lane[0:param.num_gemv_pipeline-1],

    output logic [param.fmap_type_mixed_precision - 1:0] OUT_final_fmap_A,
    output logic [param.fmap_type_mixed_precision - 1:0] OUT_final_fmap_B,
    output logic [param.fmap_type_mixed_precision - 1:0] OUT_final_fmap_C,
    output logic [param.fmap_type_mixed_precision - 1:0] OUT_final_fmap_D,

    output logic OUT_result_valid_A,
    output logic OUT_result_valid_B,
    output logic OUT_result_valid_C,
    output logic OUT_result_valid_D
);

  logic [param.fixed_mant_width+2:0] fmap_LUT_wire[0:param.fmap_cache_out_cnt-1][0:param.weight_width-1];

  logic fmap_ready_wire;

  GEMV_generate_lut #(
      .param(VecCoreDefaultCfg)
  ) u_GEMV_generate_lut (
      .IN_fmap_broadcast(IN_fmap_broadcast),
      .IN_fmap_broadcast_valid(IN_fmap_broadcast_valid),
      .IN_cached_emax_out(IN_cached_emax_out),

      .OUT_fmap_LUT  (fmap_LUT_wire),
      .OUT_fmap_ready(fmap_ready_wire)
  );


  GEMV_reduction_branch #(
      .param(VecCoreDefaultCfg)
  ) u_GEMV_reduction_branch_A (
      .clk  (clk),
      .rst_n(rst_n),

      .IN_weight_valid(IN_weight_valid_A),
      .IN_weight(IN_weight_A),

      .fmap_ready(fmap_ready_wire),
      .IN_num_recur(IN_num_recur),  // shape x * y * z

      .IN_activated_lane(IN_activated_lane[A]),
      .IN_fmap_LUT(fmap_LUT_wire),

      .OUT_GEMV_result_vector(OUT_final_fmap_A),
      .OUT_valid(OUT_result_valid_A)
  );


  GEMV_reduction_branch #(
      .param(VecCoreDefaultCfg)
  ) u_GEMV_reduction_branch_B (
      .clk  (clk),
      .rst_n(rst_n),

      .IN_weight_valid(IN_weight_valid_B),
      .IN_weight(IN_weight_B),

      .fmap_ready(fmap_ready_wire),
      .IN_num_recur(IN_num_recur),  // shape x * y * z

      .IN_activated_lane(IN_activated_lane[B]),
      .IN_fmap_LUT(fmap_LUT_wire),

      .OUT_GEMV_result_vector(OUT_final_fmap_B),
      .OUT_valid(OUT_result_valid_B)
  );

  GEMV_reduction_branch #(
      .param(VecCoreDefaultCfg)
  ) u_GEMV_reduction_branch_C (
      .clk  (clk),
      .rst_n(rst_n),

      .IN_weight_valid(IN_weight_valid_C),
      .IN_weight(IN_weight_C),

      .fmap_ready(fmap_ready_wire),
      .IN_num_recur(IN_num_recur),  // shape x * y * z

      .IN_activated_lane(IN_activated_lane[C]),
      .IN_fmap_LUT(fmap_LUT_wire),

      .OUT_GEMV_result_vector(OUT_final_fmap_C),
      .OUT_valid(OUT_result_valid_C)
  );

  GEMV_reduction_branch #(
      .param(VecCoreDefaultCfg)
  ) u_GEMV_reduction_branch_D (
      .clk  (clk),
      .rst_n(rst_n),

      .IN_weight_valid(IN_weight_valid_D),
      .IN_weight(IN_weight_D),

      .fmap_ready(fmap_ready_wire),
      .IN_num_recur(IN_num_recur),  // shape x * y * z

      .IN_activated_lane(IN_activated_lane[D]),
      .IN_fmap_LUT(fmap_LUT_wire),

      .OUT_GEMV_result_vector(OUT_final_fmap_D),
      .OUT_valid(OUT_result_valid_D)
  );

endmodule

더 보기

GEMV 코어

3. CVO / SFU 코어

CVO_top.sv는 비선형 연산(Softmax, GELU, RMSNorm, RoPE)을 위한 CORDIC + LUT 하이브리드 함수 유닛을 조율합니다. 모든 연산에서 정밀도는 BF16/FP32로 승격됩니다.

리스팅 5 hw/rtl/CVO_CORE/CVO_top.sv
`timescale 1ns / 1ps
`include "GLOBAL_CONST.svh"

import isa_pkg::*;
import bf16_math_pkg::*;

// ===| CVO Top |=================================================================
// Wraps CVO_sfu_unit (EXP/SQRT/GELU/RECIP/SCALE/REDUCE_SUM) and
// CVO_cordic_unit (SIN/COS) behind a unified streaming interface.
//
// Data flow:
//   Host issues OP_CVO via AXI-Lite → Global_Scheduler produces cvo_control_uop_t
//   → CVO_top latches uop, processes IN_length BF16 elements from L2 stream,
//     writes results back via output stream.
//
// FLAG_SUB_EMAX: subtract IN_e_max from each input before the function.
//   Implements exp(x - e_max) for numerically stable softmax.
// FLAG_ACCM: accumulate output into dst (add OUT_result to prior value).
//   Handled externally by the mem subsystem; CVO_top only signals it via OUT_accm.
//
// FSM states:
//   IDLE    : waiting for valid uop
//   RUNNING : streaming IN_length elements through the chosen unit
//   DONE    : pulse OUT_done for one cycle, return to IDLE
// ===============================================================================

module CVO_top (
    input  logic        clk,
    input  logic        rst_n,
    input  logic        i_clear,

    // ===| Dispatch from Global_Scheduler |=====================================
    input  cvo_control_uop_t IN_uop,
    input  logic             IN_uop_valid,
    output logic             OUT_uop_ready,

    // ===| BF16 Input Stream (from L2 via mem_dispatcher) |=====================
    input  logic [15:0]  IN_data,
    input  logic         IN_data_valid,
    output logic         OUT_data_ready,

    // ===| BF16 Output Stream (to L2 via mem_dispatcher) |=====================
    output logic [15:0]  OUT_result,
    output logic         OUT_result_valid,
    input  logic         IN_result_ready,

    // ===| e_max for FLAG_SUB_EMAX |============================================
    // Passed in as BF16; CVO subtracts this from each element before the function.
    input  logic [15:0]  IN_e_max,

    // ===| Status |=============================================================
    output logic         OUT_busy,
    output logic         OUT_done,
    output logic         OUT_accm   // mirrors IN_uop.flags.accm to mem subsystem
);

  // ===| FSM |===================================================================
  typedef enum logic [1:0] {
    ST_IDLE    = 2'b00,
    ST_RUNNING = 2'b01,
    ST_DONE    = 2'b10
  } cvo_state_e;

  cvo_state_e state;

  // ===| Latched UOP |===========================================================
  cvo_func_e   uop_func;
  cvo_flags_t  uop_flags;
  logic [15:0] uop_length;
  logic [15:0] elem_count;   // elements processed in current operation

  // ===| BF16 subtract e_max (combinational) |===================================
  // Implements x - e_max in BF16 via bf16_add(x, -e_max).

  logic [15:0] sub_emax_result_wire;

  always_comb begin : comb_sub_emax
    // Negate e_max by flipping sign bit, then add to x
    sub_emax_result_wire = bf16_add(IN_data, {~IN_e_max[15], IN_e_max[14:0]});
  end

  // ===| Input to sub-units (after optional e_max subtraction) |=================
  logic [15:0] data_to_unit_wire;
  logic        data_valid_to_unit_wire;

  always_comb begin
    data_to_unit_wire    = uop_flags.sub_emax ? sub_emax_result_wire : IN_data;
    data_valid_to_unit_wire = (state == ST_RUNNING) && IN_data_valid;
  end

  // ===| SFU Instantiation |=====================================================
  logic [15:0] sfu_result;
  logic        sfu_result_valid;
  logic        sfu_ready;

  CVO_sfu_unit u_CVO_sfu_unit (
      .clk             (clk),
      .rst_n           (rst_n),
      .i_clear         (i_clear),

      .IN_func         (uop_func),
      .IN_length       (uop_length),
      .IN_flags        (uop_flags),

      .IN_data         (data_to_unit_wire),
      .IN_valid        (data_valid_to_unit_wire && !is_cordic_op_wire),
      .OUT_data_ready  (sfu_ready),

      .OUT_result      (sfu_result),
      .OUT_result_valid(sfu_result_valid)
  );

  // ===| CORDIC Instantiation |==================================================
  logic [15:0] cordic_sin;
  logic [15:0] cordic_cos;
  logic        cordic_valid;

  CVO_cordic_unit u_CVO_cordic_unit (
      .clk          (clk),
      .rst_n        (rst_n),

      .IN_angle_bf16(data_to_unit_wire),
      .IN_valid     (data_valid_to_unit_wire && is_cordic_op_wire),

      .OUT_sin_bf16 (cordic_sin),
      .OUT_cos_bf16 (cordic_cos),
      .OUT_valid    (cordic_valid)
  );

  // ===| Opcode Routing |========================================================
  logic is_cordic_op_wire;
  always_comb begin
    is_cordic_op_wire = (uop_func == CVO_SIN) || (uop_func == CVO_COS);
  end

  // ===| FSM Logic |=============================================================
  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      state      <= ST_IDLE;
      uop_func   <= CVO_EXP;
      uop_flags  <= '0;
      uop_length <= 16'd0;
      elem_count <= 16'd0;
      OUT_done   <= 1'b0;
    end else begin
      OUT_done <= 1'b0;

      case (state)
        // ===| IDLE: wait for dispatch |===
        ST_IDLE: begin
          if (IN_uop_valid) begin
            uop_func   <= IN_uop.cvo_func;
            uop_flags  <= IN_uop.flags;
            uop_length <= IN_uop.length;
            elem_count <= 16'd0;
            state      <= ST_RUNNING;
          end
        end

        // ===| RUNNING: count consumed elements |===
        ST_RUNNING: begin
          if (IN_data_valid && OUT_data_ready) begin
            elem_count <= elem_count + 16'd1;
            if (elem_count == uop_length - 16'd1) begin
              state    <= ST_DONE;
            end
          end
        end

        // ===| DONE: pulse and return |===
        ST_DONE: begin
          OUT_done <= 1'b1;
          state    <= ST_IDLE;
        end

        default: state <= ST_IDLE;
      endcase
    end
  end

  // ===| Output Mux |============================================================
  // CORDIC outputs two results per input; select sin or cos based on function.
  logic [15:0] result_mux_wire;
  logic        result_valid_mux_wire;

  always_comb begin
    if (is_cordic_op_wire) begin
      result_mux_wire       = (uop_func == CVO_SIN) ? cordic_sin : cordic_cos;
      result_valid_mux_wire = cordic_valid;
    end else begin
      result_mux_wire       = sfu_result;
      result_valid_mux_wire = sfu_result_valid;
    end
  end

  // ===| Output Registers |======================================================
  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      OUT_result       <= 16'd0;
      OUT_result_valid <= 1'b0;
    end else begin
      OUT_result       <= result_mux_wire;
      OUT_result_valid <= result_valid_mux_wire && IN_result_ready;
    end
  end

  // ===| Status & Control |======================================================
  assign OUT_busy      = (state != ST_IDLE);
  assign OUT_uop_ready = (state == ST_IDLE);
  assign OUT_data_ready = sfu_ready && (state == ST_RUNNING);
  assign OUT_accm      = uop_flags.accm;

endmodule

4. DSP48E2 MAC 유닛

GEMM_dsp_unit.sv는 단일 DSP48E2 슬라이스를 사용한 듀얼 채널 W4A8 MAC을 구현합니다. 비트 패킹 유도 과정은 DSP48E2 W4A8 비트 패킹과 부호 복원 를 참조하세요.

리스팅 6 hw/rtl/MAT_CORE/GEMM_dsp_unit.sv
`timescale 1ns / 1ps


`include "GLOBAL_CONST.svh"
`include "GEMM_Array.svh"

module GEMM_dsp_unit #(
    parameter IS_TOP_ROW = 0,
    parameter BREAK_CASCADE = 0,  // If 1, break the vertical cascade chain here
    parameter ACIN_size = `DSP48E2_A_WIDTH,
    parameter ACOUT_size = `DSP48E2_A_WIDTH,
    parameter POUT_size = `DSP48E2_POUT_SIZE
) (
    input logic clk,
    input logic rst_n,

    input  logic i_clear,
    input  logic i_valid,         // Feature Map Data Valid
    input  logic i_weight_valid,  // Background Weight Shift Enable
    output logic o_valid,

    // [Horizontal] int4 (4-bit) -> external FF -> DSP B port
    input  logic [`GEMM_MAC_UNIT_IN_H - 1:0] in_H,
    output logic [`GEMM_MAC_UNIT_IN_H - 1:0] out_H,

    // [Vertical] 30-bit -> DSP A/ACIN port
    input logic [ACIN_size:0] in_V,  // Used if IS_TOP_ROW == 1 or BREAK_CASCADE == 1
    input logic [ACIN_size:0] ACIN_in,  // Used if IS_TOP_ROW == 0 and BREAK_CASCADE == 0
    output logic [ACOUT_size:0] ACOUT_out,

    // [3-Bit VLIW Instruction]
    input  logic [2:0] instruction_in_V,
    output logic [2:0] instruction_out_V,
    input  logic       inst_valid_in_V,    // Cascaded Instruction Valid
    output logic       inst_valid_out_V,

    // vertical shift port
    input  logic [POUT_size:0] V_result_in,   // PCIN (or Fabric C) from upper DSP
    output logic [POUT_size:0] V_result_out,  // PCOUT to lower DSP's PCIN
    output logic [POUT_size:0] P_fabric_out   // P to lower DSP's Fabric C if broken
);

  // ===| [Instruction Latch (Event-Driven)] |============================
  logic [2:0] current_inst;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      current_inst <= 3'b000;
    end else if (inst_valid_in_V) begin
      current_inst <= instruction_in_V;
    end
  end

  // Pass instruction and its valid signal down to the next PE
  always_ff @(posedge clk) begin
    if (!rst_n) begin
      instruction_out_V <= 3'b000;
      inst_valid_out_V  <= 1'b0;
    end else begin
      instruction_out_V <= instruction_in_V;
      inst_valid_out_V  <= inst_valid_in_V;
    end
  end

  // ===| [The "Flush & Load" Sequencer] |================================
  logic [3:0] flush_sequence;

  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      flush_sequence <= 4'd0;
    end else begin
      flush_sequence <= {flush_sequence[2:0], 1'b0};
      if (inst_valid_in_V && instruction_in_V[2] == 1'b1) begin
        flush_sequence[0] <= 1'b1;
      end
    end
  end

  // ===| [Hardware Mapping (VLIW Decoding)] |============================
  logic [8:0] dynamic_opmode;
  logic [3:0] dynamic_alumode;

  logic is_flushing;
  assign is_flushing = flush_sequence[1] | flush_sequence[2];

  // OPMODE Selection
  // W(2), Z(3), Y(2), X(2)
  // If BREAK_CASCADE == 1, we must take the previous result from the C port instead of PCIN.
  // Z-mux: 001 is PCIN, 011 is C.
  localparam logic [2:0] Z_MUX = BREAK_CASCADE ? 3'b011 : 3'b001;

  always_comb begin
    if (is_flushing) begin
      // Flush: P = 0 + 0 + 0 (Clear accumulator)
      dynamic_opmode  = 9'b00_000_00_00;
      dynamic_alumode = 4'b0000;
    end else if (current_inst[0] == 1'b1) begin
      // Calc: P = P_prev + A*B
      dynamic_opmode  = {2'b00, Z_MUX, 2'b01, 2'b01};
      dynamic_alumode = 4'b0000;
    end else begin
      // Idle: P = P_prev (Pass through)
      dynamic_opmode  = {2'b00, Z_MUX, 2'b00, 2'b00};
      dynamic_alumode = 4'b0000;
    end
  end

  logic dsp_ce_p;
  assign dsp_ce_p = current_inst[0] | is_flushing;

  // ===| [Fabric FF & Weight Pipeline] |=================================
  always_ff @(posedge clk) begin
    if (!rst_n || i_clear) begin
      out_H <= 0;
    end else begin
      if (i_weight_valid) begin
        out_H <= in_H;
      end
    end
  end

  // ===| [Dual B-Register Control] |================
  logic dsp_ce_b1;
  logic dsp_ce_b2;
  logic load_trigger;

  assign load_trigger = flush_sequence[3];

  always_comb begin
    if (current_inst[1] == 1'b1) begin
      dsp_ce_b1 = i_valid;
      dsp_ce_b2 = i_valid;
    end else begin
      dsp_ce_b1 = load_trigger | i_weight_valid;
      dsp_ce_b2 = load_trigger;
    end
  end

  logic valid_delay;
  always_ff @(posedge clk) begin
    if (!rst_n) valid_delay <= 1'b0;
    else valid_delay <= i_valid;
  end
  assign o_valid = valid_delay;

  // <><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
  // [DSP48E2 primitive instantiation] <><><><><><><><><><><><><><><><><>
  logic [17:0] in_H_padded;
  assign in_H_padded = {{14{in_H[`GEMM_MAC_UNIT_IN_H-1]}}, in_H};

  // If TOP_ROW or BREAK_CASCADE, we get A from Fabric (in_V). Otherwise from ACIN.
  logic [29:0] dsp_a_input;
  assign dsp_a_input = (IS_TOP_ROW || BREAK_CASCADE) ? in_V : 30'd0;

  logic [29:0] dsp_acin_input;
  assign dsp_acin_input = (IS_TOP_ROW || BREAK_CASCADE) ? 30'd0 : ACIN_in;

  // If BREAK_CASCADE, we receive the accumulated result from Fabric C (V_result_in)
  logic [47:0] dsp_c_input;
  assign dsp_c_input = BREAK_CASCADE ? V_result_in : 48'd0;

  logic [47:0] dsp_pcin_input;
  assign dsp_pcin_input = BREAK_CASCADE ? 48'd0 : V_result_in;

  logic [47:0] p_internal;

  DSP48E2 #(
      .A_INPUT((IS_TOP_ROW || BREAK_CASCADE) ? "DIRECT" : "CASCADE"),
      .B_INPUT("DIRECT"),
      .AREG(1),
      .BREG(2),
      .CREG(0),
      .MREG(1),
      .PREG(1),
      .OPMODEREG(1),
      .ALUMODEREG(1),
      .USE_MULT("MULTIPLY")
  ) DSP_HARD_BLOCK (
      .CLK(clk),
      .RSTA(i_clear),
      .RSTB(i_clear),
      .RSTM(i_clear),
      .RSTP(i_clear),
      .RSTCTRL(i_clear),
      .RSTALLCARRYIN(i_clear),
      .RSTALUMODE(i_clear),
      .RSTC(i_clear),

      .CEA1(i_valid),
      .CEA2(i_valid),
      .CEB1(dsp_ce_b1),
      .CEB2(dsp_ce_b2),
      .CEM(i_valid),
      .CEP(dsp_ce_p),
      .CECTRL(1'b1),
      .CEALUMODE(1'b1),
      .CEC(1'b1),  // Enable C register if breaking cascade

      .A(dsp_a_input),
      .ACIN(dsp_acin_input),
      .ACOUT(ACOUT_out),

      .B(in_H_padded),
      .C(dsp_c_input),

      .PCIN (dsp_pcin_input),
      .PCOUT(V_result_out),

      .OPMODE(dynamic_opmode),
      .ALUMODE(dynamic_alumode),
      .P(p_internal)
  );

  // We must expose the internal P so it can be routed via fabric to the next DSP
  assign P_fabric_out = p_internal;

endmodule