趋近智
实现自定义方言,使角色从简单使用MLIR的现有功能转变为积极构建新的编译器组件。尽管像linalg和affine这样的预定义方言涵盖了标准张量操作,但定制硬件加速器或特定领域框架通常需要无法清晰映射到现有IR结构的操作。通过定义自定义方言,你可以在机器学习 (machine learning)模型的高层逻辑和后续优化过程之间建立一种约定。
本次实践主要介绍使用TableGen的Operation Definition Specification (ODS) 框架。我们将构建一个名为tmath(张量数学)的最小方言,旨在处理简化矩阵操作。随后,我们将实现一个降级过程,将这些高层操作转换为affine方言,从而实现前几章中提及的多面体优化。
为每个编译器操作编写C++样板代码容易出错且难以维护。MLIR通过使用TableGen(一种基于记录的特定领域语言)来解决此问题,以定义操作的结构、约束和验证逻辑。构建系统随后自动生成相应的C++类(头文件和实现)。
该工作流程包括三个主要阶段:定义方言、定义操作,以及将生成的产物链接到主编译器二进制文件。
TableGen 将声明性规范转换为 C++ 源代码,随后编译成最终二进制文件的构建过程。
方言定义作为操作的命名空间和注册点。在名为TMathDialect.td的文件中,我们继承了ODS提供的Dialect类。
// TMathDialect.td
include "mlir/IR/OpBase.td"
def TMath_Dialect : Dialect {
let name = "tmath";
let summary = "一个用于演示的最小张量数学方言";
let description = [{
tmath方言提供高层矩阵操作,旨在降级为仿射循环。
}];
let cppNamespace = "::mlir::tmath";
}
此定义会在指定命名空间中生成一个C++类TMathDialect。name字段决定IR的前缀,因此操作在文本格式中将显示为tmath.op_name。
方言确定后,我们定义操作。MLIR中的操作通过其名称、参数 (parameter)(操作数和属性)、结果和特性来表征。特性使编译器能够判断操作的行为,例如它是否有副作用或结果类型是否依赖于输入类型。
我们将定义一个MatMulOp。不同于通用的linalg.matmul,我们的操作将强制使用严格的2D张量以简化降级逻辑。
// TMathOps.td
include "TMathDialect.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def MatMulOp : Op<TMath_Dialect, "matmul", [Pure]> {
let summary = "执行矩阵乘法";
let description = [{
计算两个2D张量的乘积。
给定输入[M, K]和[K, N],返回一个维度为[M, N]的新张量。
}];
let arguments = (ins
F32Tensor:$lhs,
F32Tensor:$rhs
);
let results = (outs F32Tensor:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `*` type($rhs) `->` type($result)";
let hasVerifier = 1;
}
在此规范中:
[Pure] 特性:表明该操作本质上不访问内存或全局状态,允许死代码消除(DCE)在未使用时将其移除。F32Tensor。ODS自动生成类型检查代码,确保只接受32位浮点张量。hasVerifier = 1设置为告诉TableGen,我们将提供一个C++实现来验证运行时约束,例如保证矩阵的内维度匹配(中的)。TableGen生成声明,但我们必须在相应的.cpp文件中实现其逻辑。此验证步骤有助于在编译流程早期发现形状不匹配问题。
// TMathOps.cpp
llvm::LogicalResult MatMulOp::verify() {
auto lhsType = getLhs().getType().cast<RankedTensorType>();
auto rhsType = getRhs().getType().cast<RankedTensorType>();
if (lhsType.getRank() != 2 || rhsType.getRank() != 2)
return emitOpError("操作数必须是2D张量");
if (lhsType.getDimSize(1) != rhsType.getDimSize(0)) {
return emitOpError("维度不匹配:左操作数列大小 ")
<< lhsType.getDimSize(1) << " 必须匹配右操作数行大小 "
<< rhsType.getDimSize(0);
}
return success();
}
方言只有在能转换为可执行代码时才有价值。我们将实现一个重写模式,将tmath.matmul降级到affine方言。此转换将高层矩阵乘法替换为三个嵌套循环、显式加载和存储。
MLIR中的模式重写机制围绕OpRewritePattern类展开。我们会重写matchAndRewrite方法。
struct MatMulLowering : public OpRewritePattern<tmath::MatMulOp> {
using OpRewritePattern<tmath::MatMulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tmath::MatMulOp op, PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value lhs = op.getLhs();
Value rhs = op.getRhs();
// 获取形状
auto lhsType = lhs.getType().cast<RankedTensorType>();
auto rhsType = rhs.getType().cast<RankedTensorType>();
int64_t M = lhsType.getDimSize(0);
int64_t K = lhsType.getDimSize(1);
int64_t N = rhsType.getDimSize(1);
// 使用memref为结果分配缓冲区
auto resultMemRefType = MemRefType::get({M, N}, rewriter.getF32Type());
Value resultAlloc = rewriter.create<memref::AllocOp>(loc, resultMemRefType);
// 为循环创建仿射映射
// 循环i从0到M
// 循环j从0到N
// 循环k从0到K
// 构建循环嵌套
buildAffineLoopNest(rewriter, loc, {M, N, K},
[&](OpBuilder &b, Location loc, ValueRange ivs) {
Value i = ivs[0];
Value j = ivs[1];
Value k = ivs[2];
// 加载 A[i, k]
Value aVal = b.create<affine::AffineLoadOp>(loc, lhs, ValueRange{i, k});
// 加载 B[k, j]
Value bVal = b.create<affine::AffineLoadOp>(loc, rhs, ValueRange{k, j});
// 计算乘积
Value product = b.create<arith::MulFOp>(loc, aVal, bVal);
// 加载当前 C[i, j](累加器)
Value cVal = b.create<affine::AffineLoadOp>(loc, resultAlloc, ValueRange{i, j});
// 累加
Value sum = b.create<arith::AddFOp>(loc, cVal, product);
// 存储结果
b.create<affine::AffineStoreOp>(loc, sum, resultAlloc, ValueRange{i, j});
}
);
// 由于我们从Tensor转换到MemRef,通常需要在此处生成tensor_to_memref
// 或缓冲区化过程。
// 对于此代码片段,我们假设输入已进行缓冲区化。
rewriter.replaceOp(op, resultAlloc);
return success();
}
};
要在mlir-opt等工具中使用此方言,必须在MLIRContext中注册它。随后将降级过程添加到PassManager中。
int main(int argc, char **argv) {
mlir::DialectRegistry registry;
// 注册我们的自定义方言
registry.insert<mlir::tmath::TMathDialect>();
// 注册我们降级到的标准方言
registry.insert<mlir::affine::AffineDialect, mlir::memref::MemRefDialect>();
mlir::MLIRContext context(registry);
// 加载一个模块,运行过程管理器...
// ...
}
完成此工作流程后,你已成功扩展了编译器的中间表示。tmath方言现在可以作为前端(如Python解析器)的目标,并且可以优化或降级为标准方言,最终映射到用于CPU执行的LLVM IR或用于GPU执行的SPIR-V。这种可扩展性是MLIR支持各种硬件架构而无需重写整个编译栈的机制。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•