행렬 코어 (GEMM)¶
DSP48E2 기반 MAC 으로 구성된 32×32 시스톨릭 어레이입니다. HP0/HP1
(128 비트/clk) 으로 가중치 타일을, L2 캐시에서 활성화 행을 받아
양 방향으로 multiply-accumulate 를 순환시키고, 패킹된 INT48 결과
벡터를 mat_result_normalizer 에 넘겨 writeback 전에 BF16 barrel
shift 로 정규화합니다.
더 보기
- pccx: Parallel Compute Core eXecutor
GEMM 이 3 코어 아키텍처 내 어디 위치하는지.
어레이 쉘¶
GEMM_systolic_top.sv— 컨트롤러가 보는 래퍼. 디스패치된gemm_control_uop_t수신 후 done 어서트.GEMM_systolic_array.sv— 32×32 PE 격자 본체.GEMM_Array.svh— 어레이와 주변 feeder 가 공유하는 매개변수 헤더.
GEMM_systolic_top.sv
`timescale 1ns / 1ps
`include "GLOBAL_CONST.svh"
`include "GEMM_Array.svh"
/**
* Module: GEMM_systolic_top
* Target: Kria KV260 @ 400MHz
*
* Architecture V2:
* - Weight Dispatcher (Unpacker)
* - Staggered Delay Lines for FMap & Instructions
* - 32x32 Systolic Array Core
* - e_max Pipe for Synchronization with Result Output
*/
module GEMM_systolic_top #(
parameter weight_lane_cnt = `HP_PORT_CNT,
parameter weight_width_per_lane = `HP_PORT_SINGLE_WIDTH,
parameter weight_size = `INT4,
// 32 = 128bit / int4(4bit)
parameter weight_cnt = `HP_WEIGHT_CNT(`HP_PORT_SINGLE_WIDTH, `INT4),
parameter array_horizontal = `ARRAY_SIZE_H,
parameter array_vertical = `ARRAY_SIZE_V,
parameter dsp_A_port = `ACIN,
parameter IN_fmap_brodcast = `FIXED_MANT_WIDTH
)(
input logic clk,
input logic rst_n,
input logic i_clear,
// Control & Inst
input logic global_weight_valid,
input logic [2:0] global_inst,
input logic global_inst_valid,
// Feature Map Broadcast (from SRAM Cache)
input logic [IN_fmap_brodcast-1:0] IN_fmap_broadcast [0:`ARRAY_SIZE_H-1],
input logic IN_fmap_broadcast_valid,
// e_max (from Cache for Normalization alignment)
input logic [`BF16_EXP_WIDTH-1:0] IN_cached_emax_out[0:`ARRAY_SIZE_H-1],
// Weight Input from FIFO (Direct)
input logic [`HP_PORT_MAX_WIDTH-1:0] IN_weight_fifo_data,
input logic IN_weight_fifo_valid,
output logic weight_fifo_ready,
// Output Results (Raw)
output logic [`DSP48E2_POUT_SIZE-1:0] raw_res_sum [0:`ARRAY_SIZE_H-1],
output logic raw_res_sum_valid[0:`ARRAY_SIZE_H-1],
// Delayed e_max for Normalizers
output logic [`BF16_EXP_WIDTH-1:0] delayed_emax_32[0:`ARRAY_SIZE_H-1]
);
// ===| Weight Dispatcher (The Unpacker) |=======
logic [weight_size-1:0] unpacked_weights [0:weight_cnt-1];
logic weights_ready_for_array;
GEMM_weight_dispatcher #(
.weight_lane_cnt(weight_lane_cnt),
.weight_width_per_lane(weight_width_per_lane),
.weight_size(weight_size),
.weight_cnt(weight_cnt)
.array_horizontal(array_horizontal),
.array_vertical(array_vertical)
) u_weight_unpacker (
.clk(clk),
.rst_n(rst_n),
.fifo_data(IN_weight_fifo_data),
.fifo_valid(IN_weight_fifo_valid),
.fifo_ready(weight_fifo_ready),
.weight_out(unpacked_weights),
.weight_valid(weights_ready_for_array)
);
// ===| Staggered Delay Line for FMap & Instructions |=======
logic [dsp_A_port-1:0] staggered_fmap [0:`ARRAY_SIZE_H-1];
logic staggered_fmap_valid[0:`ARRAY_SIZE_H-1];
logic [ 2:0] staggered_inst [0:`ARRAY_SIZE_H-1];
logic staggered_inst_valid[0:`ARRAY_SIZE_H-1];
GEMM_fmap_staggered_dispatch #(
.fmap_width(IN_fmap_brodcast),
.array_size(array_vertical),
.fmap_out_width(dsp_A_port)
) u_delay_line (
.clk(clk),
.rst_n(rst_n),
.fmap_in(IN_fmap_broadcast),
.fmap_valid(IN_fmap_broadcast_valid),
.global_inst(global_inst),
.global_inst_valid(global_inst_valid),
.row_data(staggered_fmap),
.row_valid(staggered_fmap_valid),
.row_inst(staggered_inst),
.row_inst_valid(staggered_inst_valid)
);
// ===| Systolic Array Core (The Engine) |=======
logic [`DSP48E2_POUT_SIZE-1:0] raw_res_seq[0:`ARRAY_SIZE_H-1];
GEMM_systolic_array #(
.ARRAY_HORIZONTAL(`ARRAY_SIZE_H),
.array_vertical (`ARRAY_SIZE_V),
.h_in_size(`GEMM_MAC_UNIT_IN_H),
.v_in_size(`GEMM_MAC_UNIT_IN_V)
) u_compute_core (
.clk(clk),
.rst_n(rst_n),
.i_clear(i_clear),
.i_weight_valid(global_weight_valid),
// Horizontal: Weights
.H_in(unpacked_weights),
// Vertical: Feature Map Broadcast & Instructions (Staggered)
.V_in(staggered_fmap),
.in_valid(staggered_fmap_valid),
.inst_in(staggered_inst),
.inst_valid_in(staggered_inst_valid),
.V_out(raw_res_seq),
.V_ACC_out(raw_res_sum),
.V_ACC_valid(raw_res_sum_valid)
);
// ===| e_max Delay Pipe for Normalization alignment |=======
localparam TOTAL_LATENCY = `SYSTOLIC_TOTAL_LATENCY;
logic [`BF16_EXP_WIDTH-1:0] emax_pipe[0:`ARRAY_SIZE_H-1][0:TOTAL_LATENCY-1];
always_ff @(posedge clk) begin
if (!rst_n) begin
for (int c = 0; c < `ARRAY_SIZE_H; c++) begin
for (int d = 0; d < TOTAL_LATENCY; d++) begin
emax_pipe[c][d] <= 0;
end
end
end else begin
for (int c = 0; c < `ARRAY_SIZE_H; c++) begin
emax_pipe[c][0] <= IN_cached_emax_out[c];
for (int d = 1; d < TOTAL_LATENCY; d++) begin
emax_pipe[c][d] <= emax_pipe[c][d-1];
end
end
end
end
always_comb begin
for (int c = 0; c < `ARRAY_SIZE_H; c++) begin
delayed_emax_32[c] = emax_pipe[c][TOTAL_LATENCY-1];
end
end
endmodule
GEMM_systolic_array.sv
`timescale 1ns / 1ps
`include "GLOBAL_CONST.svh"
`include "GEMM_Array.svh"
module GEMM_systolic_array #(
parameter array_horizontal = `ARRAY_SIZE_H,
parameter array_vertical = `ARRAY_SIZE_V,
parameter h_in_size = `GEMM_MAC_UNIT_IN_H,
parameter v_in_size = `GEMM_MAC_UNIT_IN_V
) (
input logic clk,
input logic rst_n,
input logic i_clear,
// =| Global Controls |=
input logic i_weight_valid, // Enables horizontal weight shifting
// =| Delay line input (from FMap Cache and Weight Dispatcher) |=
input logic [h_in_size-1:0] H_in[0:array_horizontal-1],
input logic [v_in_size-1:0] V_in[0:array_vertical-1],
input logic in_valid[0:array_vertical-1], // Staggered valid from FMap delay line
// =| VLIW Instruction Input (Staggered along with V_in) |=
input logic [2:0] inst_in [0:array_horizontal-1],
input logic inst_valid_in[ 0:array_vertical-1],
// =| Outputs |=
output logic [`DSP48E2_POUT_SIZE-1:0] V_out [ 0:array_horizontal-1],
output logic [`DSP48E2_POUT_SIZE-1:0] V_ACC_out [ 0:array_horizontal-1],
output logic V_ACC_valid[0:array_horizontalE_V-1]
);
// ===| Systolic Array Internal Wires |==================================
// Horizontal logic wires (Weights)
// Size is [Row][Col], data flows Left to Right.
// H_in feeds into Col 0.
logic [`GEMM_MAC_UNIT_IN_H - 1 : 0] gemm_H_wire[0 : array_horizontal-1][0 : array_vertical];
logic [`GEMM_MAC_UNIT_IN_H - 1 : 0] gemm_H_REG[0 : array_horizontal-1][0 : array_vertical];
// Vertical logic wires (Feature Map / ACIN)
// Size is [Row][Col], data flows Top to Bottom.
logic [29:0] gemm_ACIN_wire[0 : array_horizontal][0 : array_vertical-1];
// Instruction logic wires (Top to Bottom)
logic [2:0] gemm_inst_wire[0 : array_horizontal][0 : array_vertical-1];
logic gemm_inst_valid_wire[0 : array_horizontal][0 : array_vertical-1];
// Valid signal logic wires (Top to Bottom)
logic gemm_V_valid_wire[0 : array_horizontal][0 : array_vertical-1];
// Result shift wires (Top to Bottom)
logic [`DSP48E2_POUT_SIZE - 1 : 0] gemm_V_result_wire[0 : array_horizontal][0 : array_vertical-1];
// Fabric break wires for row 15 -> 16
logic [47:0] gemm_P_fabric_wire[0 : array_horizontal-1][0 : array_vertical-1];
// V_in fabric delay line to replace A_fabric_wire
logic [29:0] gemm_in_V_fabric[0 : array_horizontal][0 : array_vertical-1];
// ======================================================================
// ===| Input Assignments |==============================================
// >>>| TOP INPUT LANE |<<<
genvar i;
generate
for (i = 0; i < array_vertical; i++) begin : assign_v_inputs
// Top row ACIN is not used (A is used directly) 30'd0;
assign gemm_ACIN_wire[0][i] = '0;
assign gemm_inst_wire[0][i] = inst_in[i];
assign gemm_inst_valid_wire[0][i] = inst_valid_in[i];
assign gemm_V_valid_wire[0][i] = in_valid[i];
assign gemm_V_result_wire[0][i] = '0;
//48'd0; // Top row PCIN is 0
// Initialize the fabric delay line with V_in padded to 30 bits
assign gemm_in_V_fabric[0][i] = {3'd0, V_in[i]};
end
for (i = 0; i < array_horizontal; i++) begin : assign_h_inputs
assign gemm_H_wire[i][0] = H_in[i];
end
endgenerate
// >>>| Normal Lane |<<<
// Fabric delay line for V_in to reach row 16 correctly
genvar d_row, d_col;
generate
for (d_row = 0; d_row < array_horizontal; d_row++) begin : v_delay_row
for (d_col = 0; d_col < array_vertical; d_col++) begin : v_delay_col
always_ff @(posedge clk) begin
if (gemm_V_valid_wire[d_row][d_col]) begin
gemm_in_V_fabric[d_row+1][d_col] <= gemm_in_V_fabric[d_row][d_col];
end
end
end
end
endgenerate
// ===| 2D Array Instantiation |=========================================
genvar row, col;
generate
for (row = 0; row < array_horizontal; row++) begin : gemm_row_loop
for (col = 0; col < array_vertical; col++) begin : gemm_col_loop
if (row == array_horizontal - 1) begin : last_row
GEMM_dsp_unit_last_ROW #(
.IS_TOP_ROW(0)
) dsp_unit_last_ROW (
.clk(clk),
.rst_n(rst_n),
.i_clear(i_clear),
.i_valid(gemm_V_valid_wire[row][col]),
.inst_valid_in_V(gemm_inst_valid_wire[row][col]),
.i_weight_valid(i_weight_valid),
.o_valid(gemm_V_valid_wire[row+1][col]),
.in_H (gemm_H_wire[row][col]),
.out_H(gemm_H_wire[row][col+1]),
.ACIN_in (gemm_ACIN_wire[row][col]),
.ACOUT_out(gemm_ACIN_wire[row+1][col]),
.instruction_in_V (gemm_inst_wire[row][col]),
.instruction_out_V(gemm_inst_wire[row+1][col]),
.inst_valid_out_V (gemm_inst_valid_wire[row+1][col]),
.V_result_in (gemm_V_result_wire[row][col]),
.V_result_out(gemm_V_result_wire[row+1][col]),
.gemm_unit_results(V_out[col])
);
end else if (row == 16) begin : break_row
GEMM_dsp_unit #(
.IS_TOP_ROW(0),
.BREAK_CASCADE(1)
) dsp_unit_break (
.clk(clk),
.rst_n(rst_n),
.i_clear(i_clear),
.i_valid(gemm_V_valid_wire[row][col]),
.inst_valid_in_V(gemm_inst_valid_wire[row][col]),
.i_weight_valid(i_weight_valid),
.o_valid(gemm_V_valid_wire[row+1][col]),
.in_H (gemm_H_wire[row][col]),
.out_H(gemm_H_wire[row][col+1]),
// Take delayed input from fabric shift register instead of CASCADE
.in_V(gemm_in_V_fabric[row][col]),
.ACIN_in(30'd0),
.ACOUT_out(gemm_ACIN_wire[row+1][col]),
.instruction_in_V (gemm_inst_wire[row][col]),
.instruction_out_V(gemm_inst_wire[row+1][col]),
.inst_valid_out_V (gemm_inst_valid_wire[row+1][col]),
// Take result from previous row's P fabric out
.V_result_in (gemm_P_fabric_wire[row-1][col]),
.V_result_out(gemm_V_result_wire[row+1][col]),
.P_fabric_out(gemm_P_fabric_wire[row][col])
);
end else begin : normal_row
GEMM_dsp_unit #(
.IS_TOP_ROW(row == 0 ? 1 : 0),
.BREAK_CASCADE(0)
) dsp_unit (
.clk(clk),
.rst_n(rst_n),
.i_clear(i_clear),
.i_valid(gemm_V_valid_wire[row][col]),
.inst_valid_in_V(gemm_inst_valid_wire[row][col]),
.i_weight_valid(i_weight_valid),
.o_valid(gemm_V_valid_wire[row+1][col]),
.in_H (gemm_H_wire[row][col]),
.out_H(gemm_H_wire[row][col+1]),
.in_V(row == 0 ? {3'd0, V_in[col]} : 30'd0),
.ACIN_in(gemm_ACIN_wire[row][col]),
.ACOUT_out(gemm_ACIN_wire[row+1][col]),
.instruction_in_V (gemm_inst_wire[row][col]),
.instruction_out_V(gemm_inst_wire[row+1][col]),
.inst_valid_out_V (gemm_inst_valid_wire[row+1][col]),
.V_result_in (gemm_V_result_wire[row][col]),
.V_result_out(gemm_V_result_wire[row+1][col]),
.P_fabric_out(gemm_P_fabric_wire[row][col])
);
end
end
end
// Accumulators for the final row
for (col = 0; col < array_vertical; col++) begin : gemm_ACC_col_loop
assign V_ACC_valid[col] = gemm_inst_valid_wire[array_horizontal][col];
GEMM_accumulator #() gemm_ACC (
.clk(clk),
.rst_n(rst_n),
.i_clear(i_clear),
// Accumulator should trigger when the last row outputs a valid result
.i_valid(V_ACC_valid[col]),
// PCIN connects to the V_result_out of the LAST_ROW (which is stored in [array_horizontal])
.PCIN(gemm_V_result_wire[array_horizontal][col]),
.gemm_ACC_result(V_ACC_out[col])
);
end
endgenerate
endmodule
GEMM_Array.svh
`define ARRAY_SIZE_H 32
`define ARRAY_SIZE_V 32
`define gemm_instruction_dispatcher_CLOCK_CONSUMPTION 1
// systolic delay line
`define MINIMUM_DELAY_LINE_LENGTH 1
// systolic delay line V | TYPE:INT4
`define INT4_WIDTH 4
// systolic delay line H | TYPE: BFLOAT 16
`define BFLOAT_WIDTH 16
DSP MAC 유닛¶
GEMM_dsp_unit.sv— 단일 row PE. 피드백 P-레지스터를 가진 DSP48E2 MAC.GEMM_dsp_unit_last_ROW.sv— 누산기 체인을 종결하고 partial sum 을 내보내는 최하단 row 변형.
GEMM_dsp_unit.sv
`timescale 1ns / 1ps
`include "GLOBAL_CONST.svh"
`include "GEMM_Array.svh"
module GEMM_dsp_unit #(
parameter IS_TOP_ROW = 0,
parameter BREAK_CASCADE = 0, // If 1, break the vertical cascade chain here
parameter ACIN_size = `DSP48E2_A_WIDTH,
parameter ACOUT_size = `DSP48E2_A_WIDTH,
parameter POUT_size = `DSP48E2_POUT_SIZE
) (
input logic clk,
input logic rst_n,
input logic i_clear,
input logic i_valid, // Feature Map Data Valid
input logic i_weight_valid, // Background Weight Shift Enable
output logic o_valid,
// [Horizontal] int4 (4-bit) -> external FF -> DSP B port
input logic [`GEMM_MAC_UNIT_IN_H - 1:0] in_H,
output logic [`GEMM_MAC_UNIT_IN_H - 1:0] out_H,
// [Vertical] 30-bit -> DSP A/ACIN port
input logic [ACIN_size:0] in_V, // Used if IS_TOP_ROW == 1 or BREAK_CASCADE == 1
input logic [ACIN_size:0] ACIN_in, // Used if IS_TOP_ROW == 0 and BREAK_CASCADE == 0
output logic [ACOUT_size:0] ACOUT_out,
// [3-Bit VLIW Instruction]
input logic [2:0] instruction_in_V,
output logic [2:0] instruction_out_V,
input logic inst_valid_in_V, // Cascaded Instruction Valid
output logic inst_valid_out_V,
// vertical shift port
input logic [POUT_size:0] V_result_in, // PCIN (or Fabric C) from upper DSP
output logic [POUT_size:0] V_result_out, // PCOUT to lower DSP's PCIN
output logic [POUT_size:0] P_fabric_out // P to lower DSP's Fabric C if broken
);
// ===| [Instruction Latch (Event-Driven)] |============================
logic [2:0] current_inst;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
current_inst <= 3'b000;
end else if (inst_valid_in_V) begin
current_inst <= instruction_in_V;
end
end
// Pass instruction and its valid signal down to the next PE
always_ff @(posedge clk) begin
if (!rst_n) begin
instruction_out_V <= 3'b000;
inst_valid_out_V <= 1'b0;
end else begin
instruction_out_V <= instruction_in_V;
inst_valid_out_V <= inst_valid_in_V;
end
end
// ===| [The "Flush & Load" Sequencer] |================================
logic [3:0] flush_sequence;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
flush_sequence <= 4'd0;
end else begin
flush_sequence <= {flush_sequence[2:0], 1'b0};
if (inst_valid_in_V && instruction_in_V[2] == 1'b1) begin
flush_sequence[0] <= 1'b1;
end
end
end
// ===| [Hardware Mapping (VLIW Decoding)] |============================
logic [8:0] dynamic_opmode;
logic [3:0] dynamic_alumode;
logic is_flushing;
assign is_flushing = flush_sequence[1] | flush_sequence[2];
// OPMODE Selection
// W(2), Z(3), Y(2), X(2)
// If BREAK_CASCADE == 1, we must take the previous result from the C port instead of PCIN.
// Z-mux: 001 is PCIN, 011 is C.
localparam logic [2:0] Z_MUX = BREAK_CASCADE ? 3'b011 : 3'b001;
always_comb begin
if (is_flushing) begin
// Flush: P = 0 + 0 + 0 (Clear accumulator)
dynamic_opmode = 9'b00_000_00_00;
dynamic_alumode = 4'b0000;
end else if (current_inst[0] == 1'b1) begin
// Calc: P = P_prev + A*B
dynamic_opmode = {2'b00, Z_MUX, 2'b01, 2'b01};
dynamic_alumode = 4'b0000;
end else begin
// Idle: P = P_prev (Pass through)
dynamic_opmode = {2'b00, Z_MUX, 2'b00, 2'b00};
dynamic_alumode = 4'b0000;
end
end
logic dsp_ce_p;
assign dsp_ce_p = current_inst[0] | is_flushing;
// ===| [Fabric FF & Weight Pipeline] |=================================
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
out_H <= 0;
end else begin
if (i_weight_valid) begin
out_H <= in_H;
end
end
end
// ===| [Dual B-Register Control] |================
logic dsp_ce_b1;
logic dsp_ce_b2;
logic load_trigger;
assign load_trigger = flush_sequence[3];
always_comb begin
if (current_inst[1] == 1'b1) begin
dsp_ce_b1 = i_valid;
dsp_ce_b2 = i_valid;
end else begin
dsp_ce_b1 = load_trigger | i_weight_valid;
dsp_ce_b2 = load_trigger;
end
end
logic valid_delay;
always_ff @(posedge clk) begin
if (!rst_n) valid_delay <= 1'b0;
else valid_delay <= i_valid;
end
assign o_valid = valid_delay;
// <><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
// [DSP48E2 primitive instantiation] <><><><><><><><><><><><><><><><><>
logic [17:0] in_H_padded;
assign in_H_padded = {{14{in_H[`GEMM_MAC_UNIT_IN_H-1]}}, in_H};
// If TOP_ROW or BREAK_CASCADE, we get A from Fabric (in_V). Otherwise from ACIN.
logic [29:0] dsp_a_input;
assign dsp_a_input = (IS_TOP_ROW || BREAK_CASCADE) ? in_V : 30'd0;
logic [29:0] dsp_acin_input;
assign dsp_acin_input = (IS_TOP_ROW || BREAK_CASCADE) ? 30'd0 : ACIN_in;
// If BREAK_CASCADE, we receive the accumulated result from Fabric C (V_result_in)
logic [47:0] dsp_c_input;
assign dsp_c_input = BREAK_CASCADE ? V_result_in : 48'd0;
logic [47:0] dsp_pcin_input;
assign dsp_pcin_input = BREAK_CASCADE ? 48'd0 : V_result_in;
logic [47:0] p_internal;
DSP48E2 #(
.A_INPUT((IS_TOP_ROW || BREAK_CASCADE) ? "DIRECT" : "CASCADE"),
.B_INPUT("DIRECT"),
.AREG(1),
.BREG(2),
.CREG(0),
.MREG(1),
.PREG(1),
.OPMODEREG(1),
.ALUMODEREG(1),
.USE_MULT("MULTIPLY")
) DSP_HARD_BLOCK (
.CLK(clk),
.RSTA(i_clear),
.RSTB(i_clear),
.RSTM(i_clear),
.RSTP(i_clear),
.RSTCTRL(i_clear),
.RSTALLCARRYIN(i_clear),
.RSTALUMODE(i_clear),
.RSTC(i_clear),
.CEA1(i_valid),
.CEA2(i_valid),
.CEB1(dsp_ce_b1),
.CEB2(dsp_ce_b2),
.CEM(i_valid),
.CEP(dsp_ce_p),
.CECTRL(1'b1),
.CEALUMODE(1'b1),
.CEC(1'b1), // Enable C register if breaking cascade
.A(dsp_a_input),
.ACIN(dsp_acin_input),
.ACOUT(ACOUT_out),
.B(in_H_padded),
.C(dsp_c_input),
.PCIN (dsp_pcin_input),
.PCOUT(V_result_out),
.OPMODE(dynamic_opmode),
.ALUMODE(dynamic_alumode),
.P(p_internal)
);
// We must expose the internal P so it can be routed via fabric to the next DSP
assign P_fabric_out = p_internal;
endmodule
GEMM_dsp_unit_last_ROW.sv
`include "GLOBAL_CONST.svh"
`timescale 1ns / 1ps
`include "GEMM_Array.svh"
module GEMM_dsp_unit_last_ROW #(
parameter IS_TOP_ROW = 0 // By definition, last row is 0, but added for consistency if needed
) (
input logic clk,
input logic rst_n,
input logic i_clear,
input logic i_valid,
input logic inst_valid_in_V,
input logic i_weight_valid,
output logic o_valid,
// [Horizontal] int4 (4-bit) -> DSP B port
input logic [`GEMM_MAC_UNIT_IN_H - 1:0] in_H,
output logic [`GEMM_MAC_UNIT_IN_H - 1:0] out_H,
// [Vertical] 30-bit -> DSP ACIN port
input logic [29:0] ACIN_in,
output logic [29:0] ACOUT_out,
// [3-Bit VLIW Instruction]
input logic [2:0] instruction_in_V,
output logic [2:0] instruction_out_V,
output logic inst_valid_out_V,
// vertical shift port
// pass value to Accumulator
input logic [47:0] V_result_in, // PCIN from upper DSP's PCOUT
output logic [47:0] V_result_out, // PCOUT to lower DSP's PCIN (if any)
// Pass value to barrelshifter
output logic [47:0] gemm_unit_results
);
// ===| [Instruction Latch (Event-Driven)] |============================
logic [2:0] current_inst;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
current_inst <= 3'b000;
end else if (inst_valid_in_V) begin
current_inst <= instruction_in_V;
end
end
always_ff @(posedge clk) begin
if (!rst_n) begin
instruction_out_V <= 3'b000;
inst_valid_out_V <= 1'b0;
end else begin
instruction_out_V <= instruction_in_V;
inst_valid_out_V <= inst_valid_in_V;
end
end
// ===| [The "Flush & Load" Sequencer] |================================
logic [3:0] flush_sequence;
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
flush_sequence <= 4'd0;
end else begin
flush_sequence <= {flush_sequence[2:0], 1'b0};
if (inst_valid_in_V && instruction_in_V[2] == 1'b1) begin
flush_sequence[0] <= 1'b1;
end
end
end
// ===| [Hardware Mapping (VLIW Decoding)] |============================
logic [8:0] dynamic_opmode;
logic [3:0] dynamic_alumode;
logic is_flushing;
assign is_flushing = flush_sequence[1] | flush_sequence[2];
always_comb begin
if (is_flushing) begin
dynamic_opmode = 9'b00_001_00_00;
dynamic_alumode = 4'b0000;
end else if (current_inst[0] == 1'b1) begin
dynamic_opmode = 9'b00_010_01_01;
dynamic_alumode = 4'b0000;
end else begin
dynamic_opmode = 9'b00_010_00_00;
dynamic_alumode = 4'b0000;
end
end
logic dsp_ce_p;
assign dsp_ce_p = current_inst[0] | is_flushing;
// ===| [Fabric FF & Weight Pipeline] |=================================
always_ff @(posedge clk) begin
if (!rst_n || i_clear) begin
out_H <= 0;
end else begin
if (i_weight_valid) begin
out_H <= in_H;
end
end
end
// ===| [Dual B-Register Control] |================
logic dsp_ce_b1;
logic dsp_ce_b2;
logic load_trigger;
assign load_trigger = flush_sequence[3];
always_comb begin
if (current_inst[1] == 1'b1) begin
dsp_ce_b1 = i_valid;
dsp_ce_b2 = i_valid;
end else begin
dsp_ce_b1 = load_trigger | i_weight_valid;
dsp_ce_b2 = load_trigger;
end
end
logic valid_delay;
always_ff @(posedge clk) begin
if (!rst_n) valid_delay <= 1'b0;
else valid_delay <= i_valid;
end
assign o_valid = valid_delay;
// <><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
// [DSP48E2 primitive instantiation] <><><><><><><><><><><><><><><><><>
logic [17:0] in_H_padded;
assign in_H_padded = {{14{in_H[`GEMM_MAC_UNIT_IN_H-1]}}, in_H};
DSP48E2 #(
.A_INPUT(IS_TOP_ROW ? "DIRECT" : "CASCADE"),
.B_INPUT("DIRECT"),
.AREG(1),
.BREG(2),
.CREG(0),
.MREG(1),
.PREG(1),
.OPMODEREG(1),
.ALUMODEREG(1),
.USE_MULT("MULTIPLY")
) DSP_HARD_BLOCK (
.CLK(clk),
.RSTA(i_clear),
.RSTB(i_clear),
.RSTM(i_clear),
.RSTP(i_clear),
.RSTCTRL(i_clear),
.RSTALLCARRYIN(i_clear),
.RSTALUMODE(i_clear),
.RSTC(i_clear),
.CEA1(i_valid),
.CEA2(i_valid),
.CEB1(dsp_ce_b1),
.CEB2(dsp_ce_b2),
.CEM(i_valid),
.CEP(dsp_ce_p),
.CECTRL(1'b1),
.CEALUMODE(1'b1),
.A(30'd0),
.ACIN(ACIN_in),
.ACOUT(ACOUT_out),
.B(in_H_padded),
.C(48'd0),
.PCIN (V_result_in),
.PCOUT(V_result_out),
.OPMODE (dynamic_opmode),
.ALUMODE(dynamic_alumode),
// All result will send to Barrelshifter(FF)
.P(gemm_unit_results)
);
// [DSP48E2 primitive instantiation] <><><><><><><><><><><><><><><><><>
// <><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
endmodule
공급단¶
GEMM_weight_dispatcher.sv— HP0/HP1 스트림을 column 별 가중치 타일로 분해.GEMM_fmap_staggered_delay.sv— 활성화 행을 어레이에 삼각형 주입하도록 단계 지연시키는 shift register bank.GEMM_accumulator.sv— 다중 타일 동안 DSP P-레지스터 체인 뒤로 따라가는 partial sum 누산기.
GEMM_weight_dispatcher.sv
`timescale 1ns / 1ps
`include "GLOBAL_CONST.svh"
`include "GEMM_Array.svh"
/**
* Module: GEMM_weight_dispatcher
* Description:
* Unpacks 128-bit wide data into 32 individual 4-bit INT4 weights.
* Provides registered outputs to maintain 400MHz timing.
*/
module GEMM_weight_dispatcher #(
parameter weight_lane_cnt = `HP_PORT_CNT,
parameter weight_width_per_lane = `HP_PORT_SINGLE_WIDTH,
parameter weight_size = `INT4,
parameter weight_cnt = `HP_WEIGHT_CNT(`HP_PORT_SINGLE_WIDTH, `INT4),
parameter array_horizontal = `ARRAY_SIZE_H,
parameter array_vertical = `ARRAY_SIZE_V
) (
input logic clk,
input logic rst_n,
// ===| 128-bit * 4 Input from FIFO |===
input logic [weight_size-1:0] fifo_data [0:weight_cnt-1],
input logic fifo_valid[0:weight_cnt-1],
output logic fifo_ready[0:weight_cnt-1],
// ===| 32 x 4-bit Outputs to Systolic Array (V_in) |========
output logic [weight_size-1:0] weight_out [0:weight_cnt-1],
output logic weight_valid
);
// ===| Flow Control: Always ready if not stalled by downstream |=====================
assign fifo_ready = 1'b1;
// ===| Unpacking Logic with Pipeline Registers |=====================================
// ===| This ensures that the massive fan-out (1 to 32) doesn't break timing. |=======
always_ff @(posedge clk) begin
if (!rst_n) begin
weight_valid <= 1'b0;
for (int i = 0; i < weight_cnt; i++) begin
weight_out[w_lane_cnt][i] <= '0;
end
end else begin
weight_valid <= fifo_valid;
// ===| Unpack 128-bit into 32 x 4-bit |==============================================
for (int i = 0; i < weight_cnt; i++) begin
weight_out[i] <= fifo_data[i];
end
end
end
endmodule
GEMM_fmap_staggered_delay.sv
`include "GLOBAL_CONST.svh"
`timescale 1ns / 1ps
`include "GEMM_Array.svh"
module GEMM_fmap_staggered_dispatch #(
// Fixed-point width after shifter
parameter fmap_width = 27,
parameter array_size = 32,
parameter fmap_out_width = `ACIN
) (
input logic clk,
input logic rst_n,
// =| Input from FMap Cache Broadcast |=
input logic [fmap_width-1:0] fmap_in [0:array_size-1],
input logic fmap_valid,
// =| Global Instruction from FSM |=
input logic [2:0] global_inst,
input logic global_inst_valid,
// =| 32 Staggered Outputs to Systolic Array Vertical Lanes |=
output logic [fmap_out_width-1:0] row_data [0:array_size-1],
output logic row_valid [0:array_size-1],
output logic [ 2:0] row_inst [0:array_size-1],
output logic row_inst_valid[0:array_size-1]
);
// ===| Delay Line Implementation |=======
// We use a shift register chain to delay data, valid, and instructions.
// Col[0] = 0 delay, Col[1] = 1 delay, ..., Col[31] = 31 delay.
genvar c;
generate
for (c = 0; c < array_size; c++) begin : col_gen
// =| Delay Logic for each Column |=
if (c == 0) begin : no_delay
// Col 0 has 0 additional delay (only 1-stage for timing)
always_ff @(posedge clk) begin
if (!rst_n) begin
row_valid[c] <= 1'b0;
row_inst_valid[c] <= 1'b0;
row_data[c] <= '0;
row_inst[c] <= 3'b000;
end else begin
row_data[c] <= {fmap_in[c]};
row_valid[c] <= fmap_valid;
row_inst[c] <= global_inst;
row_inst_valid[c] <= global_inst_valid;
end
end
end else begin : shift_delay
// Col[c] uses a shift register of length 'c'
logic [fmap_width-1:0] shift_data [0:c];
logic shift_valid [0:c];
logic [ 2:0] shift_inst [0:c];
logic shift_inst_valid[0:c];
always_ff @(posedge clk) begin
if (!rst_n) begin
for (int i = 0; i <= c; i++) begin
shift_valid[i] <= 1'b0;
shift_inst_valid[i] <= 1'b0;
shift_data[i] <= '0;
shift_inst[i] <= 3'b000;
end
end else begin
// Input to shift register
shift_data[0] <= fmap_in[c];
shift_valid[0] <= fmap_valid;
shift_inst[0] <= global_inst;
shift_inst_valid[0] <= global_inst_valid;
// =| Chain the registers |=
for (int i = 1; i <= c; i++) begin
shift_data[i] <= shift_data[i-1];
shift_valid[i] <= shift_valid[i-1];
shift_inst[i] <= shift_inst[i-1];
shift_inst_valid[i] <= shift_inst_valid[i-1];
end
end
end
assign row_data[c] = shift_data[c];
assign row_valid[c] = shift_valid[c];
assign row_inst[c] = shift_inst[c];
assign row_inst_valid[c] = shift_inst_valid[c];
end
end
endgenerate
endmodule
GEMM_accumulator.sv
`include "GLOBAL_CONST.svh"
`timescale 1ns / 1ps
`include "GEMM_Array.svh"
module GEMM_accumulator (
input logic clk,
input logic rst_n,
input logic i_clear,
input logic i_valid,
input logic [47:0] PCIN,
// final output -> to P port
output logic [47:0] gemm_ACC_result
);
// OPMODE: W=00, Z=PCIN(001), Y=0(00), X=P(10) -> P = P + PCIN
wire [8:0] static_opmode = 9'b00_001_00_10;
wire [3:0] static_alumode = 4'b0000;
DSP48E2 #(
.AREG(0),
.BREG(0),
.CREG(0),
.MREG(0),
.PREG(1),
.ACASCREG(0),
.BCASCREG(0),
.OPMODEREG(0),
.ALUMODEREG(0),
// Disable multiplier
.USE_MULT("NONE")
) DSP_ACC (
.CLK(clk),
.RSTP(i_clear || ~rst_n),
.RSTA(1'b0),
.RSTB(1'b0),
.RSTM(1'b0),
.RSTCTRL(1'b0),
.RSTALLCARRYIN(1'b0),
.RSTALUMODE(1'b0),
.RSTC(1'b0),
.CEP(i_valid),
.CEA1(1'b0),
.CEA2(1'b0),
.CEB1(1'b0),
.CEB2(1'b0),
.CEM(1'b0),
.CECTRL(1'b0),
.CEALUMODE(1'b0),
.CEC(1'b0),
.A(30'd0),
.B(18'd0),
.C(48'd0),
.PCIN (PCIN),
.PCOUT(),
.ACOUT(),
.OPMODE (static_opmode),
.ALUMODE(static_alumode),
.P(gemm_ACC_result)
);
endmodule
결과 경로¶
FROM_mat_result_packer.sv— INT48 누산기 row 를 AXI beat 로 패킹.mat_result_normalizer.sv— 패킹된 INT48 벡터를 Emax 정렬과 함께 BF16 으로 barrel shift 다운.
FROM_mat_result_packer.sv
`include "GLOBAL_CONST.svh"
`timescale 1ns / 1ps
`include "GEMM_Array.svh"
/**
* Module: FROM_gemm_result_packer
* Description:
* Collects 32 staggered 16-bit results and packs them into 128-bit DMA words.
* Uses an internal FSM to sequentially scan and pack results from all columns.
*/
module FROM_gemm_result_packer #(
parameter ARRAY_SIZE = 32
) (
input logic clk,
input logic rst_n,
// =| Input from Normalizers (16-bit BF16) |=
input logic [`BF16_WIDTH-1:0] row_res [0:ARRAY_SIZE-1],
input logic row_res_valid[0:ARRAY_SIZE-1],
// =| Output to FIFO (128-bit) |=
output logic [`AXI_DATA_WIDTH-1:0] packed_data,
output logic packed_valid,
input logic packed_ready,
// =| Status |=
output logic o_busy
);
// ===| Internal Buffer to Hold Results |=======
// Since systolic array results are staggered, we need to capture them.
logic [`BF16_WIDTH-1:0] capture_reg[0:ARRAY_SIZE-1];
logic [ARRAY_SIZE-1:0] capture_valid;
// ===| State Machine for Packing (Round-Robin) |=======
typedef enum logic [1:0] {
IDLE,
CHECK_VALID,
SEND_DATA
} state_t;
state_t state;
// Busy if any capture_valid bit is set or we are in a non-IDLE state
assign o_busy = (|capture_valid) || (state != IDLE);
always_ff @(posedge clk) begin
if (!rst_n) begin
capture_valid <= '0;
for (int i = 0; i < ARRAY_SIZE; i++) capture_reg[i] <= '0;
end else begin
for (int i = 0; i < ARRAY_SIZE; i++) begin
if (row_res_valid[i]) begin
capture_reg[i] <= row_res[i];
capture_valid[i] <= 1'b1;
end
end
// Clear valid bits once they are consumed (handled by FSM below)
if (state == SEND_DATA && packed_ready) begin
for (int i = 0; i < 8; i++) begin
capture_valid[send_idx+i] <= 1'b0;
end
end
end
end
// ===| State Machine for Packing (Round-Robin) |=======
// typedef enum logic [1:0] {IDLE, CHECK_VALID, SEND_DATA} state_t;
// state_t state;
logic [5:0] send_idx; // 0 to 31
always_ff @(posedge clk) begin
if (!rst_n) begin
state <= IDLE;
send_idx <= 0;
packed_valid <= 1'b0;
packed_data <= '0;
end else begin
case (state)
IDLE: begin
packed_valid <= 1'b0;
if (|capture_valid) begin
state <= CHECK_VALID;
send_idx <= 0;
end
end
CHECK_VALID: begin
// We need 8 results to form a 128-bit word (16*8=128)
// In a real systolic array, they might not arrive all at once,
// but for simplicity, let's wait until we have a chunk of 8.
if (&capture_valid[send_idx+:8]) begin
state <= SEND_DATA;
end
end
SEND_DATA: begin
if (packed_ready) begin
packed_data <= {
capture_reg[send_idx+7],
capture_reg[send_idx+6],
capture_reg[send_idx+5],
capture_reg[send_idx+4],
capture_reg[send_idx+3],
capture_reg[send_idx+2],
capture_reg[send_idx+1],
capture_reg[send_idx+0]
};
packed_valid <= 1'b1;
if (send_idx >= 24) begin
state <= IDLE;
send_idx <= 0;
end else begin
send_idx <= send_idx + 8;
state <= CHECK_VALID;
end
end else begin
packed_valid <= 1'b1; // Keep high until ready
end
end
default: state <= IDLE;
endcase
end
end
endmodule
mat_result_normalizer.sv
`include "GLOBAL_CONST.svh"
`timescale 1ns / 1ps
`include "GEMM_Array.svh"
/**
* Module: gemm_result_normalizer
* Description:
* Converts 48-bit 2's complement to Normalized Format (BF16-like).
* Pipeline: [1] Sign-Mag -> [2] LOD -> [3] Barrel Shift -> [4] Exp Adj
*/
module gemm_result_normalizer (
input logic clk,
input logic rst_n,
input logic [47:0] data_in, // 48-bit Accumulator Result
input logic [ 7:0] e_max, // Original delayed exponent for this column
input logic valid_in,
output logic [15:0] data_out, // 1:Sign, 8:Exp, 7:Mantissa (BF16 format)
output logic valid_out
);
// ===| Stage 1: Sign-Magnitude Conversion |=======
// Converting from 2's complement to absolute value (Sign + Magnitude)
logic [47:0] s1_abs_data;
logic s1_sign;
logic [ 7:0] s1_emax;
logic s1_valid;
always_ff @(posedge clk) begin
if (!rst_n) begin
s1_valid <= 1'b0;
s1_sign <= 1'b0;
s1_emax <= 8'd0;
s1_abs_data <= 48'd0;
end else begin
s1_valid <= valid_in;
s1_sign <= data_in[47];
s1_emax <= e_max;
// If negative, invert and add 1 (2's complement to absolute)
s1_abs_data <= (data_in[47]) ? (~data_in + 1'b1) : data_in;
end
end
// ===| Stage 2: Leading One Detection (LOD) |=======
// Finding the position of the most significant '1' bit
logic [ 5:0] s2_first_one_pos;
logic [47:0] s2_abs_data;
logic s2_sign;
logic [ 7:0] s2_emax;
logic s2_valid;
always_ff @(posedge clk) begin
if (!rst_n) begin
s2_valid <= 1'b0;
s2_sign <= 1'b0;
s2_abs_data <= 48'd0;
s2_emax <= 8'd0;
s2_first_one_pos <= 6'd0;
end else begin
s2_valid <= s1_valid;
s2_sign <= s1_sign;
s2_abs_data <= s1_abs_data;
s2_emax <= s1_emax;
// Simple Priority Encoder for LOD
// In 400MHz, this might need further pipelining if timing fails,
// but starting with a basic loop since Vivado is good at tree extraction.
s2_first_one_pos <= 6'd0; // Default to 0
for (int i = 46; i >= 0; i--) begin
if (s1_abs_data[i]) begin
s2_first_one_pos <= i[5:0];
break;
end
end
end
end
// ===| Stage 3: Normalization Barrel Shift & Exponent Update |=======
// Shifting the mantissa so that the leading '1' sits right before the 7-bit fractional part.
logic [6:0] s3_mantissa;
logic [7:0] s3_new_exp;
logic s3_sign;
logic s3_valid;
always_ff @(posedge clk) begin
if (!rst_n) begin
s3_valid <= 1'b0;
s3_sign <= 1'b0;
s3_new_exp <= 8'd0;
s3_mantissa <= 7'd0;
end else begin
s3_valid <= s2_valid;
s3_sign <= s2_sign;
if (s2_abs_data == 0) begin
s3_new_exp <= 8'd0;
s3_mantissa <= 7'd0;
end else begin
// Update exponent: original e_max + current bit position offset
// Example bias: Assume our fixed-point format implies the 1.0 bit is at position 26.
// Depending on your actual Shifter logic, this offset (26) should be matched.
s3_new_exp <= s2_emax + s2_first_one_pos - 8'd26;
// Align mantissa to BF16 (7 bits of fraction)
if (s2_first_one_pos >= 7)
// Take the 7 bits immediately below the first '1'
s3_mantissa <= s2_abs_data[s2_first_one_pos-1-:7];
else
// Shift left to pad with zeros
s3_mantissa <= s2_abs_data[6:0] << (7 - s2_first_one_pos);
end
end
end
// ===| Stage 4: Final Packing |=======
// Constructing the final 16-bit word
always_ff @(posedge clk) begin
if (!rst_n) begin
valid_out <= 1'b0;
data_out <= 16'd0;
end else begin
valid_out <= s3_valid;
data_out <= {s3_sign, s3_new_exp, s3_mantissa};
end
end
endmodule