void ToyToAffineLoweringPass::runOnOperation() {
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
mlir::ConversionTarget target(getContext());
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine`, `Arith`, `Func`, and `MemRef` dialects.
target.addLegalDialect();
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
// a partial lowering, we explicitly mark the Toy operations that don't want
// to lower, `toy.print`, as *legal*. `toy.print` will still need its operands
// to be updated though (as we convert from TensorType to MemRefType), so we
// only treat it as `legal` if its operands are legal.
target.addIllegalDialect();
target.addDynamicallyLegalOp([](toy::PrintOp op) {
return llvm::none_of(op->getOperandTypes(),
[](Type type) { return type.isa(); });
});
...
}
在MLIR(Multi-Level Intermediate Representation)框架中,转换目标(Conversion Target)定义了在转换过程中哪些操作是合法的,哪些操作是非法的,以及在什么条件下某些操作可以被视为合法。具体来说,mlir::ConversionTarget target(getContext());这行代码的意思是在当前上下文(getContext())中创建一个转换目标对象target,用于指导后续的转换过程。
mlir::ConversionTarget target(getContext());
target.addLegalDialect();
将Affine、Arith、Func和MemRef方言中的操作标记为合法
target.addIllegalDialect();
将Toy方言中的操作标记为非法。这意味着在转换过程中,如果任何Toy方言中的操作没有被转换,这将导致转换失败。
target.addDynamicallyLegalOp([](toy::PrintOp op) {
return llvm::none_of(op->getOperandTypes(),
[](Type type) { return type.isa(); });
});
为toy::PrintOp操作设置了动态合法性条件。如果toy::PrintOp的所有操作数类型都不是TensorType,那么它就是合法的。这是为了支持部分转换,使得某些特定操作在特定条件下可以被保留.具体来说,如果toy::PrintOp操作的操作数(operands)类型中没有TensorType类型,那么它就是合法的。
/// Lower the `toy.transpose` operation to an affine loop nest.
struct TransposeOpLowering : public mlir::ConversionPattern {
TransposeOpLowering(mlir::MLIRContext *ctx)
: mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {}
/// Match and rewrite the given `toy.transpose` operation, with the given
/// operands that have been remapped from `tensor<...>` to `memref<...>`.
llvm::LogicalResult
matchAndRewrite(mlir::Operation *op, ArrayRef operands,
mlir::ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
// Call to a helper function that will lower the current operation to a set
// of affine loops. We provide a functor that operates on the remapped
// operands, as well as the loop induction variables for the inner most
// loop body.
lowerOpToLoops(
op, operands, rewriter,
[loc](mlir::PatternRewriter &rewriter,
ArrayRef memRefOperands,
ArrayRef loopIvs) {
// Generate an adaptor for the remapped operands of the TransposeOp.
// This allows for using the nice named accessors that are generated
// by the ODS. This adaptor is automatically provided by the ODS
// framework.
TransposeOpAdaptor transposeAdaptor(memRefOperands);
mlir::Value input = transposeAdaptor.input();
// Transpose the elements by generating a load from the reverse
// indices.
SmallVector reverseIvs(llvm::reverse(loopIvs));
return rewriter.create(loc, input, reverseIvs);
});
return success();
}
TransposeOpLowering(mlir::MLIRContext *ctx)
: mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {}
在构造函数中,TransposeOpLowering 调用基类 mlir::ConversionPattern 的构造函数,并传递了 TransposeOp::getOperationName() 作为操作名。这实际上是告诉 ConversionPattern,这个模式是专门用来匹配 TransposeOp 操作的。
注册模式:在某个地方(通常是在转换 pass 中),会将这个模式(TransposeOpLowering)注册到 MLIR 的模式列表中。
模式应用:当转换 pass 运行时,它会遍历 IR 中的每个操作,并尝试应用已注册的模式。ConversionPattern 会检查每个操作的名称,如果操作名与 TransposeOp::getOperationName() 返回的名称匹配,就会调用 matchAndRewrite 方法。
A = [ [a, b, c],
[d, e, f] ]
B = [ [a, d],
[b, e],
[c, f] ]
A转B写成循环应该是
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
B[j][i] = A[i][j];
}
}
在 MLIR 中,我们需要使用操作如 AffineLoadOp 和 AffineStoreOp 来实现数据的加载和存储。考虑以下简化的伪代码,展示如何通过 MLIR 实现这一点:
// 假设 A 是原始 2x3 矩阵,B 是目标 3x2 矩阵
memref : memref<2x3xf32>
memref : memref<3x2xf32>
for i = 0 to 2 {
for j = 0 to 3 {
// 从 A 中加载值
value = affine.load A[i, j]
// 将值存储到 B 中
affine.store value, B[j, i]
}
}
lowerOpToLoops(
op, operands, rewriter,
[loc](mlir::PatternRewriter &rewriter,
ArrayRef memRefOperands,
ArrayRef loopIvs) {
// Generate an adaptor for the remapped operands of the TransposeOp.
// This allows for using the nice named accessors that are generated
// by the ODS. This adaptor is automatically provided by the ODS
// framework.
TransposeOpAdaptor transposeAdaptor(memRefOperands);
mlir::Value input = transposeAdaptor.input();
// Transpose the elements by generating a load from the reverse
// indices.
SmallVector reverseIvs(llvm::reverse(loopIvs));
return rewriter.create(loc, input, reverseIvs);
});
函数参数:
op:这是需要被降低的操作符。
operands:这是操作符的操作数。
rewriter:这是用于模式匹配和替换的工具。
callback:这是一个lambda函数,用于定义如何将操作符的具体行为映射到基础的循环和加载/存储操作中。
Lambda函数:
loc:位置信息,用于在创建新的MLIR操作时保留源代码的位置信息。
transposeAdaptor:这是一个自动生成的适配器,用于方便地访问TransposeOp的操作数。
input:这是TransposeOp的输入数据。
reverseIvs:这是一个小向量,包含反转后的循环索引,用于生成反向加载操作。
简单理解这里生成了循环,lambda里面包含需要循环的操作。
def PrintOp : Toy_Op<"print"> {
...
// The print operation takes an input tensor to print.
// We also allow a F64MemRef to enable interop during partial lowering.
let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
}
这里意思说,让print能过够接受F64MemRef的参数,加上前面有表明什么情况下保留print,最后我们就得到混合的mlir
func.func @main() {
%cst = arith.constant 1.000000e+00 : f64
%cst_0 = arith.constant 2.000000e+00 : f64
%cst_1 = arith.constant 3.000000e+00 : f64
%cst_2 = arith.constant 4.000000e+00 : f64
%cst_3 = arith.constant 5.000000e+00 : f64
%cst_4 = arith.constant 6.000000e+00 : f64
// Allocating buffers for the inputs and outputs.
%0 = memref.alloc() : memref<3x2xf64>
%1 = memref.alloc() : memref<3x2xf64>
%2 = memref.alloc() : memref<2x3xf64>
// Initialize the input buffer with the constant values.
affine.store %cst, %2[0, 0] : memref<2x3xf64>
affine.store %cst_0, %2[0, 1] : memref<2x3xf64>
affine.store %cst_1, %2[0, 2] : memref<2x3xf64>
affine.store %cst_2, %2[1, 0] : memref<2x3xf64>
affine.store %cst_3, %2[1, 1] : memref<2x3xf64>
affine.store %cst_4, %2[1, 2] : memref<2x3xf64>
// Load the transpose value from the input buffer and store it into the
// next input buffer.
affine.for %arg0 = 0 to 3 {
affine.for %arg1 = 0 to 2 {
%3 = affine.load %2[%arg1, %arg0] : memref<2x3xf64>
affine.store %3, %1[%arg0, %arg1] : memref<3x2xf64>
}
}
// Multiply and store into the output buffer.
affine.for %arg0 = 0 to 3 {
affine.for %arg1 = 0 to 2 {
%3 = affine.load %1[%arg0, %arg1] : memref<3x2xf64>
%4 = affine.load %1[%arg0, %arg1] : memref<3x2xf64>
%5 = arith.mulf %3, %4 : f64
affine.store %5, %0[%arg0, %arg1] : memref<3x2xf64>
}
}
// Print the value held by the buffer.
toy.print %0 : memref<3x2xf64>
memref.dealloc %2 : memref<2x3xf64>
memref.dealloc %1 : memref<3x2xf64>
memref.dealloc %0 : memref<3x2xf64>
return
}