MLC-Lesson4-自动程序优化
Lesson4 自动程序优化
前文知识回顾:
- 驱动高层执行的计算图抽象
- 元张量函数的抽象
- 通过注册环境函数从而能被调用的库函数
所有的元素被分装在一个IRModule中,大多数MLC过程可以看作是元张量函数之间的变换。
1 随机调度变换(Stochastic Schedule Transformation)
概率式编程:给我们的变换增加一些随机元素,方法如下:
1 | def stochastic_schedule_mm(sch: tvm.tir.Schedule): |
代码解读:这里的sch.sample_perfect_tile(loop=j, n=2)方法表示对循环j进行随机采样分解成2个因子
通过print(sch.trace)我们可以发现多次运行j_factors会得到不同的值。
2 深入研究随机变换
在我们尝试逐步运行随机变换并且检获得的j_factors时,我们会发现j_factors不是一个实整数,而是被采样随机变量的符号变量。
1 | sch = tvm.tir.Schedule(MyModule) |
如果我们查看当前时间点的代码,我们可以发现 IRModule 保持不变,因为我们只对随机变量进行了采样,但还没有基于它们进行任何变换操作。
1 | j_0, j_1 = sch.split(loop=j, factors=j_factors) |
在这之后IRModule才会发生变化
3 随机变换搜索
事实上,stochastic_schedule_mm创建了一个可能程序的搜索空间。
所以它指定的是一组程序,那么问题是什么是最佳选择呢?
为此我们需要一个搜索算法,例如:连续运行很多次,取运行时间最短的那一次作为最佳选择。当然这是最直接简单的想法。在实践中,TVM 的 Meta-Schedule API 提供了一些附加功能:
- 跨越多个进程的并行基准测试。
- 使用代价模型 (cost model) 来避免每次都进行基准测试。
- 基于历史轨迹进行遗传搜索 (evolutionary search),而不是每次都随机采样。
纵使有这么多工具,我们的核心思想是保持不变的:使用随机变换来指定好的程序的搜索空间,使用 tune_tir API 帮助在搜索空间内搜索并找到最优的调度变换。
以下是在指定的搜索空间进搜索,tune_tir 函数返回在调优过程中找到的优化后的调度。
1 | from tvm import meta_schedule as ms |
Meta-Schedule 带有内置通用随机变换集合,能够适用于广泛的 TensorIR 计算。这种方法也称为自动调度 (auto-scheduling),因为搜索空间是由系统生成的。我们可以通过删除行 space=ms.space_generator.ScheduleFn(stochastic_schedule_mm) 来运行它。
1 | database = ms.tune_tir( |
4 集成到端到端模型部署中
主要步骤一览:
- 构建原模型
MyModuleMixture - 注册原环境中的算子
env.linear和`env.relu - 对IRModule的自定义函数
linear0进行自动程序优化 - 用自动程序优化好的
linear0代替原来的`linear0 - 构建运行
以下为具体的实现:
(1)构建原模型MyModuleMixture:
1 |
|
(2)注册原环境中的算子env.linear和env.relu
1 |
|
(3)对IRModule的自定义函数linear0进行自动程序优化
调优 API 只接受一个带有一个 main 函数的 IRModule,所以我们首先将 linear0 取出到另一个模块的 main 函数中并将其传递给 tune_tir。(据说以后会优化这个操作?)
1 | mod_linear = tvm.IRModule.from_expr(MyModuleMixture["linear0"].with_attr("global_symbol", "main")) |
(4)用自动程序优化好的linear0代替原来的linear0
现在我们需要在调优后用新函数替换原来的 linear0。我们可以通过首先获得一个 global_var(一个指向 IRModule 中函数的 pointer 引用),然后调用 update_func 来用新的函数替换原本的函数。
1 | MyModuleWithParams2 = relax.transform.BindParams("main", nd_params)(MyModuleMixture) |
(5)构建运行
1 | ex = relax.build(MyModuleWithParams2, target="llvm") |
