1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
| import ipdb
def cache_read_and_coop_fetch(sch, block, nthread, read_idx, read_loc): read_cache = sch.cache_read(block=block, read_buffer_index=read_idx, storage_scope="shared") sch.compute_at(block=read_cache, loop=read_loc) inner0, inner1 = sch.get_loops(block=read_cache)[-2:] inner = sch.fuse(inner0, inner1) _, tx, vec = sch.split(loop=inner, factors=[None, nthread, 4]) sch.vectorize(vec) sch.bind(tx, "threadIdx.x")
def blocking_with_shared( sch, tile_local_y, tile_local_x, tile_block_y, tile_block_x, tile_k): block_C = sch.get_block("C") C_local = sch.cache_write(block_C, 0, "local")
i, j, k = sch.get_loops(block=block_C)
i0, i1, i2 = sch.split(loop=i, factors=[None, tile_block_y, tile_local_y]) j0, j1, j2 = sch.split(loop=j, factors=[None, tile_block_x, tile_local_x]) k0, k1 = sch.split(loop=k, factors=[None, tile_k])
sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2) ipdb.set_trace() sch.reverse_compute_at(C_local, j1)
sch.bind(i0, "blockIdx.y") sch.bind(j0, "blockIdx.x")
tx = sch.fuse(i1, j1) sch.bind(tx, "threadIdx.x") nthread = tile_block_y * tile_block_x cache_read_and_coop_fetch(sch, block_C, nthread, 0, k0) cache_read_and_coop_fetch(sch, block_C, nthread, 1, k0) sch.decompose_reduction(block_C, k0)
return sch
sch = tvm.tir.Schedule(MyModuleMatmul) sch = blocking_with_shared(sch, 8, 8, 8, 8, 8) sch.mod.show()
|