| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 
 | 
 
 
 @I.ir_module
 class Module:
 @T.prim_func
 def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(128)), "float32")):
 T.func_attr({"tir.noalias": True})
 
 for ax0, ax1 in T.grid(T.int64(1), T.int64(128)):
 with T.block("T_add"):
 v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
 T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1])
 T.writes(T_add[v_ax0, v_ax1])
 T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1]
 
 @T.prim_func
 def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceholder_1: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
 T.func_attr({"tir.noalias": True})
 
 for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
 with T.block("T_add"):
 v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
 T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1])
 T.writes(T_add[v_ax0, v_ax1])
 T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1]
 
 @T.prim_func
 def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxplaceholder_1: T.Buffer((T.int64(784), T.int64(128)), "float32"), T_matmul_NN: T.Buffer((T.int64(1), T.int64(128)), "float32")):
 T.func_attr({"layout_free_buffers": [1], "tir.noalias": True})
 
 for i, j, k in T.grid(T.int64(1), T.int64(128), T.int64(784)):
 with T.block("T_matmul_NN"):
 v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
 T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_k, v_j])
 T.writes(T_matmul_NN[v_i, v_j])
 with T.init():
 T_matmul_NN[v_i, v_j] = T.float32(0)
 T_matmul_NN[v_i, v_j] = T_matmul_NN[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_k, v_j]
 
 @T.prim_func
 def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(10)), "float32"), T_matmul_NN: T.Buffer((T.int64(1), T.int64(10)), "float32")):
 T.func_attr({"layout_free_buffers": [1], "tir.noalias": True})
 
 for i, j, k in T.grid(T.int64(1), T.int64(10), T.int64(128)):
 with T.block("T_matmul_NN"):
 v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
 T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_k, v_j])
 T.writes(T_matmul_NN[v_i, v_j])
 with T.init():
 T_matmul_NN[v_i, v_j] = T.float32(0)
 T_matmul_NN[v_i, v_j] = T_matmul_NN[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_k, v_j]
 
 @T.prim_func
 def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), compute: T.Buffer((T.int64(1), T.int64(128)), "float32")):
 T.func_attr({"tir.noalias": True})
 
 for i0, i1 in T.grid(T.int64(1), T.int64(128)):
 with T.block("compute"):
 v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
 T.reads(rxplaceholder[v_i0, v_i1])
 T.writes(compute[v_i0, v_i1])
 compute[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0))
 
 @T.prim_func
 def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
 T.func_attr({"tir.noalias": True})
 
 for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
 with T.block("T_transpose"):
 v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
 T.reads(rxplaceholder[v_ax1, v_ax0])
 T.writes(T_transpose[v_ax0, v_ax1])
 T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
 
 @T.prim_func
 def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
 T.func_attr({"tir.noalias": True})
 
 for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
 with T.block("T_transpose"):
 v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
 T.reads(rxplaceholder[v_ax1, v_ax0])
 T.writes(T_transpose[v_ax0, v_ax1])
 T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
 
 @R.function
 def fused_matmul_add0(x: R.Tensor((1, 784), dtype="float32"), w: R.Tensor((784, 128), dtype="float32"), b: R.Tensor((128,), dtype="float32")) -> R.Tensor((1, 128), dtype="float32"):
 R.func_attr({"Primitive": 1})
 cls = Module
 with R.dataflow():
 lv = R.call_tir(cls.matmul, (x, w), out_sinfo=R.Tensor((1, 128), dtype="float32"))
 gv = R.call_tir(cls.add, (lv, b), out_sinfo=R.Tensor((1, 128), dtype="float32"))
 R.output(gv)
 return gv
 
 @R.function
 def fused_matmul_add1(x: R.Tensor((1, 128), dtype="float32"), w: R.Tensor((128, 10), dtype="float32"), b: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
 R.func_attr({"Primitive": 1})
 cls = Module
 with R.dataflow():
 lv = R.call_tir(cls.matmul1, (x, w), out_sinfo=R.Tensor((1, 10), dtype="float32"))
 gv = R.call_tir(cls.add1, (lv, b), out_sinfo=R.Tensor((1, 10), dtype="float32"))
 R.output(gv)
 return gv
 
 @R.function
 def main(x: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
 cls = Module
 with R.dataflow():
 lv = R.call_tir(cls.transpose, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((784, 128), dtype="float32"))
 lv2: R.Tensor((1, 128), dtype="float32") = cls.fused_matmul_add0(x, lv, metadata["relax.expr.Constant"][1])
 lv3 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((1, 128), dtype="float32"))
 lv4 = R.call_tir(cls.transpose1, (metadata["relax.expr.Constant"][2],), out_sinfo=R.Tensor((128, 10), dtype="float32"))
 lv6: R.Tensor((1, 10), dtype="float32") = cls.fused_matmul_add1(lv3, lv4, metadata["relax.expr.Constant"][3])
 gv: R.Tensor((1, 10), dtype="float32") = lv6
 R.output(gv)
 return gv
 
 
 
 |