1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
|
@I.ir_module class Module: @T.prim_func def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(128)), "float32")): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(1), T.int64(128)): with T.block("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1]
@T.prim_func def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceholder_1: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(1), T.int64(10)): with T.block("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1]
@T.prim_func def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxplaceholder_1: T.Buffer((T.int64(784), T.int64(128)), "float32"), T_matmul_NN: T.Buffer((T.int64(1), T.int64(128)), "float32")): T.func_attr({"layout_free_buffers": [1], "tir.noalias": True}) for i, j, k in T.grid(T.int64(1), T.int64(128), T.int64(784)): with T.block("T_matmul_NN"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_k, v_j]) T.writes(T_matmul_NN[v_i, v_j]) with T.init(): T_matmul_NN[v_i, v_j] = T.float32(0) T_matmul_NN[v_i, v_j] = T_matmul_NN[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_k, v_j]
@T.prim_func def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(10)), "float32"), T_matmul_NN: T.Buffer((T.int64(1), T.int64(10)), "float32")): T.func_attr({"layout_free_buffers": [1], "tir.noalias": True}) for i, j, k in T.grid(T.int64(1), T.int64(10), T.int64(128)): with T.block("T_matmul_NN"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_k, v_j]) T.writes(T_matmul_NN[v_i, v_j]) with T.init(): T_matmul_NN[v_i, v_j] = T.float32(0) T_matmul_NN[v_i, v_j] = T_matmul_NN[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_k, v_j]
@T.prim_func def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), compute: T.Buffer((T.int64(1), T.int64(128)), "float32")): T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(1), T.int64(128)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0))
@T.prim_func def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(784), T.int64(128)): with T.block("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax1, v_ax0]) T.writes(T_transpose[v_ax0, v_ax1]) T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
@T.prim_func def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")): T.func_attr({"tir.noalias": True}) for ax0, ax1 in T.grid(T.int64(128), T.int64(10)): with T.block("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax1, v_ax0]) T.writes(T_transpose[v_ax0, v_ax1]) T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
@R.function def fused_matmul_add0(x: R.Tensor((1, 784), dtype="float32"), w: R.Tensor((784, 128), dtype="float32"), b: R.Tensor((128,), dtype="float32")) -> R.Tensor((1, 128), dtype="float32"): R.func_attr({"Primitive": 1}) cls = Module with R.dataflow(): lv = R.call_tir(cls.matmul, (x, w), out_sinfo=R.Tensor((1, 128), dtype="float32")) gv = R.call_tir(cls.add, (lv, b), out_sinfo=R.Tensor((1, 128), dtype="float32")) R.output(gv) return gv
@R.function def fused_matmul_add1(x: R.Tensor((1, 128), dtype="float32"), w: R.Tensor((128, 10), dtype="float32"), b: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"): R.func_attr({"Primitive": 1}) cls = Module with R.dataflow(): lv = R.call_tir(cls.matmul1, (x, w), out_sinfo=R.Tensor((1, 10), dtype="float32")) gv = R.call_tir(cls.add1, (lv, b), out_sinfo=R.Tensor((1, 10), dtype="float32")) R.output(gv) return gv
@R.function def main(x: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"): cls = Module with R.dataflow(): lv = R.call_tir(cls.transpose, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((784, 128), dtype="float32")) lv2: R.Tensor((1, 128), dtype="float32") = cls.fused_matmul_add0(x, lv, metadata["relax.expr.Constant"][1]) lv3 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((1, 128), dtype="float32")) lv4 = R.call_tir(cls.transpose1, (metadata["relax.expr.Constant"][2],), out_sinfo=R.Tensor((128, 10), dtype="float32")) lv6: R.Tensor((1, 10), dtype="float32") = cls.fused_matmul_add1(lv3, lv4, metadata["relax.expr.Constant"][3]) gv: R.Tensor((1, 10), dtype="float32") = lv6 R.output(gv) return gv
|