1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
|
@I.ir_module class Module: @T.prim_func def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) A_global_A_reg = T.alloc_buffer((1024, 1024), scope="global.A_reg") B_global_B_reg = T.alloc_buffer((1024, 1024), scope="global.B_reg") C_global_accumulator = T.alloc_buffer((1024, 1024), scope="global.accumulator") for i_0, j_0, k_0 in T.grid(64, 64, 64): for ax0, ax1 in T.grid(16, 16): with T.block("A_global.A_reg"): v0 = T.axis.spatial(1024, i_0 * 16 + ax0) v1 = T.axis.spatial(1024, k_0 * 16 + ax1) T.reads(A[v0, v1]) T.writes(A_global_A_reg[v0, v1]) A_global_A_reg[v0, v1] = A[v0, v1] for ax0, ax1 in T.grid(16, 16): with T.block("B_global.B_reg"): v0 = T.axis.spatial(1024, j_0 * 16 + ax0) v1 = T.axis.spatial(1024, k_0 * 16 + ax1) T.reads(B[v0, v1]) T.writes(B_global_B_reg[v0, v1]) B_global_B_reg[v0, v1] = B[v0, v1] with T.block("matmul_o"): vi_o, vj_o, vk_o = T.axis.remap("SSR", [i_0, j_0, k_0]) T.reads(A_global_A_reg[vi_o * 16:vi_o * 16 + 16, vk_o * 16:vk_o * 16 + 16], B_global_B_reg[vj_o * 16:vj_o * 16 + 16, vk_o * 16:vk_o * 16 + 16]) T.writes(C_global_accumulator[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16]) with T.init(): for i_1, j_1 in T.grid(16, 16): with T.block("matmul_init"): vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1]) T.reads() T.writes(C_global_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init]) C_global_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0) for i_1, j_1, k_1 in T.grid(16, 16, 16): with T.block("matmul"): vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1]) T.reads(C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i], A_global_A_reg[vi_o * 16 + vi_i, vk_o * 16 + vk_i], B_global_B_reg[vj_o * 16 + vj_i, vk_o * 16 + vk_i]) T.writes(C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + A_global_A_reg[vi_o * 16 + vi_i, vk_o * 16 + vk_i] * B_global_B_reg[vj_o * 16 + vj_i, vk_o * 16 + vk_i]
for ax0, ax1 in T.grid(1024, 1024): with T.block("C_global.accumulator"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(C_global_accumulator[v0, v1]) T.writes(C[v0, v1]) C[v0, v1] = C_global_accumulator[v0, v1]
|