• Partial Lowering to Lower-Level Dialects for Optimization



    toy部分lowering到affine,从toy变成混合的mlir
    使用DialectConversion框架需要提供两个东西(和一个可选的第三个)

    • 转换目标
      不合法的操作需要重写变成合法的
    • 一组重写模式
      按照这组规则将非法的操作变成合法的

    Conversion Target

    void ToyToAffineLoweringPass::runOnOperation() {
      // The first thing to define is the conversion target. This will define the
      // final target for this lowering.
      mlir::ConversionTarget target(getContext());
    
      // We define the specific operations, or dialects, that are legal targets for
      // this lowering. In our case, we are lowering to a combination of the
      // `Affine`, `Arith`, `Func`, and `MemRef` dialects.
      target.addLegalDialect();
    
      // We also define the Toy dialect as Illegal so that the conversion will fail
      // if any of these operations are *not* converted. Given that we actually want
      // a partial lowering, we explicitly mark the Toy operations that don't want
      // to lower, `toy.print`, as *legal*. `toy.print` will still need its operands
      // to be updated though (as we convert from TensorType to MemRefType), so we
      // only treat it as `legal` if its operands are legal.
      target.addIllegalDialect();
      target.addDynamicallyLegalOp([](toy::PrintOp op) {
        return llvm::none_of(op->getOperandTypes(),
                             [](Type type) { return type.isa(); });
      });
      ...
    }
    

    在MLIR(Multi-Level Intermediate Representation)框架中,转换目标(Conversion Target)定义了在转换过程中哪些操作是合法的,哪些操作是非法的,以及在什么条件下某些操作可以被视为合法。具体来说,mlir::ConversionTarget target(getContext());这行代码的意思是在当前上下文(getContext())中创建一个转换目标对象target,用于指导后续的转换过程。

    mlir::ConversionTarget target(getContext());
    
    • 这行代码实例化了一个ConversionTarget对象。这个对象用于定义转换过程中操作的合法性规则。这是转换过程的起点。
    • getContext()函数返回当前的MLIR上下文。上下文包含了所有与MLIR相关的全局信息,如已注册的方言(Dialects)、类型(Types)等。在转换过程中,上下文提供了所需的所有全局状态和元数据。
    target.addLegalDialect();
    

    Affine、Arith、Func和MemRef方言中的操作标记为合法

    target.addIllegalDialect();
    

    将Toy方言中的操作标记为非法。这意味着在转换过程中,如果任何Toy方言中的操作没有被转换,这将导致转换失败。

    target.addDynamicallyLegalOp([](toy::PrintOp op) {
      return llvm::none_of(op->getOperandTypes(),
                           [](Type type) { return type.isa(); });
    });
    

    为toy::PrintOp操作设置了动态合法性条件。如果toy::PrintOp的所有操作数类型都不是TensorType,那么它就是合法的。这是为了支持部分转换,使得某些特定操作在特定条件下可以被保留.具体来说,如果toy::PrintOp操作的操作数(operands)类型中没有TensorType类型,那么它就是合法的。

    Conversion Patterns

    /// Lower the `toy.transpose` operation to an affine loop nest.
    struct TransposeOpLowering : public mlir::ConversionPattern {
      TransposeOpLowering(mlir::MLIRContext *ctx)
          : mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {}
    
      /// Match and rewrite the given `toy.transpose` operation, with the given
      /// operands that have been remapped from `tensor<...>` to `memref<...>`.
      llvm::LogicalResult
      matchAndRewrite(mlir::Operation *op, ArrayRef operands,
                      mlir::ConversionPatternRewriter &rewriter) const final {
        auto loc = op->getLoc();
    
        // Call to a helper function that will lower the current operation to a set
        // of affine loops. We provide a functor that operates on the remapped
        // operands, as well as the loop induction variables for the inner most
        // loop body.
        lowerOpToLoops(
            op, operands, rewriter,
            [loc](mlir::PatternRewriter &rewriter,
                  ArrayRef memRefOperands,
                  ArrayRef loopIvs) {
              // Generate an adaptor for the remapped operands of the TransposeOp.
              // This allows for using the nice named accessors that are generated
              // by the ODS. This adaptor is automatically provided by the ODS
              // framework.
              TransposeOpAdaptor transposeAdaptor(memRefOperands);
              mlir::Value input = transposeAdaptor.input();
    
              // Transpose the elements by generating a load from the reverse
              // indices.
              SmallVector reverseIvs(llvm::reverse(loopIvs));
              return rewriter.create(loc, input, reverseIvs);
            });
        return success();
      }
    

    如何匹配到操作?

      1. 构造函数中的操作名匹配
    TransposeOpLowering(mlir::MLIRContext *ctx)
        : mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {}
    

    构造函数中,TransposeOpLowering 调用基类 mlir::ConversionPattern 的构造函数,并传递了 TransposeOp::getOperationName() 作为操作名。这实际上是告诉 ConversionPattern,这个模式是专门用来匹配 TransposeOp 操作的。

      1. TransposeOp::getOperationName()
        TransposeOp::getOperationName() 返回的是 toy.transpose 操作的名称。这个名称是在定义 TransposeOp 操作时指定的,例如在 .td(TableGen)文件中。
      1. 模式匹配过程
        ConversionPattern 使用这个操作名来匹配 IR 中的操作。具体的匹配过程如下:
    • 注册模式:在某个地方(通常是在转换 pass 中),会将这个模式(TransposeOpLowering)注册到 MLIR 的模式列表中。

    • 模式应用:当转换 pass 运行时,它会遍历 IR 中的每个操作,并尝试应用已注册的模式。ConversionPattern 会检查每个操作的名称,如果操作名与 TransposeOp::getOperationName() 返回的名称匹配,就会调用 matchAndRewrite 方法。

    matchAndRewrite

    • 矩阵转置
      比如2*3的矩阵转3*2
    A = [ [a, b, c],
         [d, e, f] ]
    B = [ [a, d],
         [b, e],
         [c, f] ]
    

    A转B写成循环应该是

    for (int i = 0; i < 2; ++i) {
     for (int j = 0; j < 3; ++j) {
       B[j][i] = A[i][j];
     }
    }
    

    在 MLIR 中,我们需要使用操作如 AffineLoadOp 和 AffineStoreOp 来实现数据的加载和存储。考虑以下简化的伪代码,展示如何通过 MLIR 实现这一点:

    // 假设 A 是原始 2x3 矩阵,B 是目标 3x2 矩阵
    memref : memref<2x3xf32>
    memref : memref<3x2xf32>
    
    for i = 0 to 2 {
     for j = 0 to 3 {
       // 从 A 中加载值
       value = affine.load A[i, j]
       // 将值存储到 B 中
       affine.store value, B[j, i]
     }
    }
    

    lowerOpToLoops

     lowerOpToLoops(
            op, operands, rewriter,
            [loc](mlir::PatternRewriter &rewriter,
                  ArrayRef memRefOperands,
                  ArrayRef loopIvs) {
              // Generate an adaptor for the remapped operands of the TransposeOp.
              // This allows for using the nice named accessors that are generated
              // by the ODS. This adaptor is automatically provided by the ODS
              // framework.
              TransposeOpAdaptor transposeAdaptor(memRefOperands);
              mlir::Value input = transposeAdaptor.input();
    
              // Transpose the elements by generating a load from the reverse
              // indices.
              SmallVector reverseIvs(llvm::reverse(loopIvs));
              return rewriter.create(loc, input, reverseIvs);
            });
    

    函数参数

    • op:这是需要被降低的操作符。

    • operands:这是操作符的操作数。

    • rewriter:这是用于模式匹配和替换的工具。

    • callback:这是一个lambda函数,用于定义如何将操作符的具体行为映射到基础的循环和加载/存储操作中。
      Lambda函数

    • loc:位置信息,用于在创建新的MLIR操作时保留源代码的位置信息。

    • transposeAdaptor:这是一个自动生成的适配器,用于方便地访问TransposeOp的操作数。

    • input:这是TransposeOp的输入数据。

    • reverseIvs:这是一个小向量,包含反转后的循环索引,用于生成反向加载操作。
      简单理解这里生成了循环,lambda里面包含需要循环的操作。

    Partial Lowering

    def PrintOp : Toy_Op<"print"> {
      ...
    
      // The print operation takes an input tensor to print.
      // We also allow a F64MemRef to enable interop during partial lowering.
      let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
    }
    

    这里意思说,让print能过够接受F64MemRef的参数,加上前面有表明什么情况下保留print,最后我们就得到混合的mlir

    func.func @main() {
      %cst = arith.constant 1.000000e+00 : f64
      %cst_0 = arith.constant 2.000000e+00 : f64
      %cst_1 = arith.constant 3.000000e+00 : f64
      %cst_2 = arith.constant 4.000000e+00 : f64
      %cst_3 = arith.constant 5.000000e+00 : f64
      %cst_4 = arith.constant 6.000000e+00 : f64
    
      // Allocating buffers for the inputs and outputs.
      %0 = memref.alloc() : memref<3x2xf64>
      %1 = memref.alloc() : memref<3x2xf64>
      %2 = memref.alloc() : memref<2x3xf64>
    
      // Initialize the input buffer with the constant values.
      affine.store %cst, %2[0, 0] : memref<2x3xf64>
      affine.store %cst_0, %2[0, 1] : memref<2x3xf64>
      affine.store %cst_1, %2[0, 2] : memref<2x3xf64>
      affine.store %cst_2, %2[1, 0] : memref<2x3xf64>
      affine.store %cst_3, %2[1, 1] : memref<2x3xf64>
      affine.store %cst_4, %2[1, 2] : memref<2x3xf64>
    
      // Load the transpose value from the input buffer and store it into the
      // next input buffer.
      affine.for %arg0 = 0 to 3 {
        affine.for %arg1 = 0 to 2 {
          %3 = affine.load %2[%arg1, %arg0] : memref<2x3xf64>
          affine.store %3, %1[%arg0, %arg1] : memref<3x2xf64>
        }
      }
    
      // Multiply and store into the output buffer.
      affine.for %arg0 = 0 to 3 {
        affine.for %arg1 = 0 to 2 {
          %3 = affine.load %1[%arg0, %arg1] : memref<3x2xf64>
          %4 = affine.load %1[%arg0, %arg1] : memref<3x2xf64>
          %5 = arith.mulf %3, %4 : f64
          affine.store %5, %0[%arg0, %arg1] : memref<3x2xf64>
        }
      }
    
      // Print the value held by the buffer.
      toy.print %0 : memref<3x2xf64>
      memref.dealloc %2 : memref<2x3xf64>
      memref.dealloc %1 : memref<3x2xf64>
      memref.dealloc %0 : memref<3x2xf64>
      return
    }
    
  • 相关阅读:
    MySQL数据表内容查询(一)
    Java编程学习知识点总结
    zabbix告警 邮件告警 钉钉告警
    Mdserver-web:一个开源、免费的 Linux 主机面板
    【LeetCode】No.79. Word Search -- Java Version
    Spring-Web(一) RestTemplate使用与源码浅析
    操作配置文件保存方式(上位机)
    揭秘梦幻般的Glam风格是什么?
    uni-app 中实现 onLaunch 异步回调后执行 onLoad 最佳实践
    spring:简介
  • 原文地址:https://blog.csdn.net/Kongxiangyunltj/article/details/140353823