• Paddle build_cinn_pass_test源码阅读(fluid目录下)


    代码位置在 paddle\fluid\framework\paddle2cinn\build_cinn_pass_test.cc ,因为paddle CINN和PIR部分依旧在高频更新,所以各位看到的可能和我的不一样

    inline bool CheckNodeExisted(const std::unordered_set& nodes,
                                 const std::string& op_name) {
      return std::find_if(nodes.begin(), nodes.end(), [&op_name](const Node* node) {
               return node->Name() == op_name;
             }) != nodes.end();
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    用一个内联函数, 去看一个 unordered_set (一系列节点) 中是否有某个 node 的名字是 op_name,用 std::find_if 去实现, 第三个参数传入的是匿名函数。[&op_name] 闭包被定义在Lambda表达式声明中的方括号[]内. 这个机制允许这些变量被按值或按引用捕获.

    函数匿名函数的闭包可以参考这篇文章: https://www.cnblogs.com/pzhfei/archive/2013/01/14/lambda_expression.html

    接下来就是返回名字为 op_namenode 数量

    inline int CountNode(const std::unordered_set& nodes,
                         const std::string& op_name) {
      return std::count_if(
          nodes.begin(), nodes.end(), [&op_name](const Node* node) {
            return node->Name() == op_name;
          });
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    接下来是返回节点名字是 op_name 的 节点,注意 std::find_if 前面为啥有 * 呢,因为 find_if 返回一个迭代器, *迭代器 可以返回一个 Node*

    inline Node* GetNode(const std::unordered_set& nodes,
                         const std::string& op_name) {
      return *std::find_if(
          nodes.begin(), nodes.end(), [&op_name](const Node* node) {
            return node->Name().find(op_name) != std::string::npos;
          });
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    CheckGraphIndependence 内部定义了一个 check_node_ok 匿名函数,匿名函数中 n1n2 都是节点 Node 的指针,
    ( 说明一下,Paddle PIR之前的节点,节点既有 Op, 也有 Var )
    只有 n1n2 一个为 OP, 一个为 Var 才有可能返回 true;

    inline bool CheckGraphIndependence(const std::unordered_set& nodes) {
      auto check_node_ok = [&nodes](Node* n1, Node* n2) -> bool {
        if (n1->IsOp() && !n2->IsVar()) {
          return false;
        }
        if (n1->IsVar() && !n2->IsOp()) {
          return false;
        }
        if (nodes.count(n2) == 0) {
          return false;
        }
        return true;
      };
    
      for (auto node : nodes) {
        for (auto in : node->inputs) {
          if (!check_node_ok(node, in)) {
            return false;
          }
        }
        for (auto out : node->outputs) {
          if (!check_node_ok(node, out)) {
            return false;
          }
        }
      }
      return true;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

    这里需要说明一下,由于 Paddle pir之前 Op 和 Var 都是node, 所以这样定义

    var1 -> op1 -> var2
    op3-> var3 -> op4
    
    • 1
    • 2

    op1的输入是 var1,输出是 var2,而下边那一行是
    va3 的输入是 op3,var3 的输出是 op4 , 这样写有点儿诡异,不过确实是这样定义的

    所以 CheckGraphIndependence 的用法就是,首先检查是不是 op->varvar->op 的关系,其次就是看当前 op/var 在不在当前 Graph 的 unordered_set

    可以看到之后的调用就是将计算图的节点 g->Nodes() 传入 CheckGraphIndependence,如果返回值不为 True 则报错

      ASSERT_TRUE(CheckGraphIndependence(g->Nodes()));
    
    • 1

    这个函数主要是将 kCinnLaunchOpoperators::kCompilationKey 属性取出来扔到 compilation_keys这个 vector 中, 目前暂时未知有什么用

    // Get compilation_key values
    std::vector GetCompilationKeys(const Graph& graph) {
      std::vector compilation_keys;
      for (auto& node : graph.Nodes()) {
        if (node->IsOp() && node->Name() == kCinnLaunchOp) {
          compilation_keys.emplace_back(PADDLE_GET_CONST(
              int64_t, node->Op()->GetAttr(operators::kCompilationKey)));
        }
      }
      return compilation_keys;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    接下来创建一个CINN子图,创建一个空图 Graph, 之后依次添加 op 和 var

    std::unique_ptr BuildNoCinnSubgraph() {
      ProgramDesc prog;
      auto g = std::make_unique(prog);
      // var1 --
      //        | --> fake1 --> var3 --> fake2 --> var4
      // var2 --
    
      // *Desc 是之后用来创建 OpNode 和 VarNode 的类
      OpDesc fake1_op;
      fake1_op.SetType("fake1");
      OpDesc fake2_op;
      fake2_op.SetType("fake2");
    
      VarDesc var1("var1");
      VarDesc var2("var2");
      var2.SetPersistable(true);
      var2.SetIsParameter(true);
      VarDesc var3("var3");
      VarDesc var4("var4");
      
      // 之后用 graph 的 Create*Node 来创建对应的 ir::Node
      ir::Node* fake1 = g->CreateOpNode(&fake1_op);
      ir::Node* fake2 = g->CreateOpNode(&fake2_op);
    
      ir::Node* v1 = g->CreateVarNode(&var1);
      ir::Node* v2 = g->CreateVarNode(&var2);
      ir::Node* v3 = g->CreateVarNode(&var3);
      ir::Node* v4 = g->CreateVarNode(&var4);
      
      // ----------- 创建完 node 之后, 把 op/var 串起来
      // fill op node
      fake1->inputs = {v1, v2};
      fake1->outputs = {v3};
      fake2->inputs = {v3};
      fake2->outputs = {v4};
    
      // fill variable node
      v1->outputs = {fake1};
      v2->outputs = {fake1};
    
      v3->inputs = {fake1};
      v3->outputs = {fake2};
    
      v4->inputs = {fake2};
    
      return g;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47

    接下来出现第一个单测

    TEST(BuildCinnPassTest, NoCinnSubgraph) {
      auto g = BuildNoCinnSubgraph();    // 调用上边的函数建计算图
      auto previous_nodes = g->Nodes();  // 取出计算图的节点
      
      // 创建 pass 这个应该是旧IR的pass
      auto pass =
          paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
      // g.get() 返回的是图的指针, g是个 unique_ptr 的智能指针
      pass->Apply(g.get());
    
      // After search, origin graph should no change
      // 注释的意思是, pass search 之后, 原来的计算图不应当修改
      ASSERT_EQ(previous_nodes, g->Nodes());
      ASSERT_TRUE(CheckGraphIndependence(g->Nodes())); // 接下来看计算图是否合法且不依赖其他计算图
    
      // After search, there should be no cinn subgraph
      ASSERT_TRUE(GetCompilationKeys(*g).empty());  // pass search之后没有 cinn subgraph 子图怎么理解
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    接下来依旧是 BuildAllOpSupportCinnGraph 与上一个建图的函数没啥太大区别

    • 图更加复杂
    • op 的 type 从 fake2 变成了 elementwise_add | mul | relu
    std::unique_ptr BuildAllOpSupportCinnGraph() {
      ProgramDesc prog;
      auto g = std::make_unique(prog);
    
      // v1 --
      //      | --> mul --> v3 --
      // v2 --                   | --> add --> v5 --> relu --> v6
      //                    v4 --
    
      OpDesc add_op;
      add_op.SetType("elementwise_add");
      OpDesc mul_op;
      mul_op.SetType("mul");
      OpDesc relu_op;
      relu_op.SetType("relu");
    
      VarDesc var1("var1");
      VarDesc var2("var2");
      var2.SetPersistable(true);
      var2.SetIsParameter(true);
      VarDesc var3("var3");
      VarDesc var4("var4");
      VarDesc var5("var5");
      VarDesc var6("var6");
    
      ir::Node* add = g->CreateOpNode(&add_op);
      ir::Node* mul = g->CreateOpNode(&mul_op);
      ir::Node* relu = g->CreateOpNode(&relu_op);
    
      ir::Node* v0 = g->CreateEmptyNode("var0", Node::Type::kVariable);     // 创建空节点用意是?
      ir::Node* v1 = g->CreateVarNode(&var1);
      ir::Node* v2 = g->CreateVarNode(&var2);
      ir::Node* v3 = g->CreateVarNode(&var3);
      ir::Node* v4 = g->CreateVarNode(&var4);
      ir::Node* v5 = g->CreateVarNode(&var5);
      ir::Node* v6 = g->CreateVarNode(&var6);
      ir::Node* v7 = g->CreateControlDepVar();
    
      // fill op node
      mul->inputs = {v0, v1, v2};
      mul->outputs = {v3};
      add->inputs = {v3, v4};
      add->outputs = {v5};
      relu->inputs = {v5};
      relu->outputs = {v6, v7};
    
      // fill variable node
      v0->outputs = {mul};
      v1->outputs = {mul};
      v2->outputs = {mul};
    
      v3->inputs = {mul};
      v3->outputs = {add};
    
      v4->outputs = {add};
    
      v5->inputs = {add};
      v5->outputs = {relu};
    
      v6->inputs = {relu};
      v7->inputs = {relu};
    
      return g;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64

    上边这个注释有点儿问题:

      // v1 --
      //      | --> mul --> v3 --
      // v2 --                   | --> add --> v5 --> relu --> v6
      //                    v4 --
    
    • 1
    • 2
    • 3
    • 4

    应该改成:

      // v0 --|
      // v1 --|                  
      // v2 --| --> mul  --> v3 --|
      //                 --> v4 --| --> add  --> v5 --> relu  --> v6
      //                                                      --> v7
    
    • 1
    • 2
    • 3
    • 4
    • 5

    接下来的 TEST 和之前的一样,只不过由于图结构变化,pass 之后图结构都变化为 kCinnLaunchOp

    TEST(BuildCinnPassTest, AllOpSupportCinn) {
      auto g = BuildAllOpSupportCinnGraph();
    
      auto pass =
          paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
      pass->Apply(g.get());
    
      // After search, the graph should as following
      // v0 --|
      // v1 --|                   |--> v6
      // v2 --| --> kCinnLaunchOp |--> v7
      // v4 --|
      const auto& nodes = g->Nodes();
      ASSERT_EQ(nodes.size(), static_cast(7));      // 节点数为 7, 4个输入, 2个输出 和 1 个 Op 节点
      ASSERT_TRUE(CheckGraphIndependence(nodes));           // 检测该图是否独立,是否会依赖其他图
    
      // A new op named kCinnLaunchOp should be added
      ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));  // kCinnLaunchOp 是个常量字符串, 检测节点 vector 中有无 kCinnLaunchOp 
      auto* cinn_op = GetNode(nodes, kCinnLaunchOp);
      auto* v0 = GetNode(nodes, "var0");
      auto* v1 = GetNode(nodes, "var1");                    // 依次获取对应的 var Node 指针
      auto* v2 = GetNode(nodes, "var2");
      auto* v4 = GetNode(nodes, "var4");
      auto* v6 = GetNode(nodes, "var6");
      auto* v7 = GetNode(nodes, Node::kControlDepVarName);
      
      // 查看 cinn_op 的输入输出是否与 `v0, v1, v2, v4` 和 `v6, v7` 对应
      ASSERT_EQ(
          std::unordered_set(cinn_op->inputs.begin(), cinn_op->inputs.end()),
          std::unordered_set({v0, v1, v2, v4}));
      ASSERT_EQ(std::unordered_set(cinn_op->outputs.begin(),
                                          cinn_op->outputs.end()),
                std::unordered_set({v6, v7}));
      
      // 查看 var 节点的输入输出是否是 cinn_op 
      ASSERT_EQ(v1->outputs, std::vector({cinn_op}));
      ASSERT_EQ(v6->inputs, std::vector({cinn_op}));
    
      // previous op (mul, add, relu) should all removed
      // 由于 mul/elementwise_add/relu 被整体合并为 cinn_op 所以图中不应该被搜索到
      ASSERT_FALSE(CheckNodeExisted(nodes, "mul"));
      ASSERT_FALSE(CheckNodeExisted(nodes, "elementwise_add"));
      ASSERT_FALSE(CheckNodeExisted(nodes, "relu"));
    
      // After search, there should has just one cinn subgraph
      // feed --> v1 --
      //               | --> mul --> v3 --
      // feed --> v2 --                   | --> add --> v5 --> relu --> v6 --> fetch
      //                    feed --> v4 --
      
      // 获取编译完毕之后的 key, 之后会根据 key 去取对应的 subgraph 
      auto compilation_keys = GetCompilationKeys(*g);
      ASSERT_EQ(compilation_keys.size(), static_cast(1));  // 因为只有一个 kCinnLaunchOp 所以 key 的数量也为 1 
      auto* cinn_compiler = CinnCompiler::GetInstance();
      const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);  // 根据 key 拿对应的子图
    
      const auto& subnodes = subgraph.Nodes();             // 拿子图的节点set
      ASSERT_EQ(subnodes.size(), static_cast(13));
      ASSERT_TRUE(CheckGraphIndependence(subnodes));
    
      // 该 cinn op 就是这三 mul | elementwise_add | relu 的合体
      ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
      ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add"));
      ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
      ASSERT_EQ(CountNode(subnodes, "feed"), 3);   // 上边注释有 3个feed Op
      ASSERT_EQ(CountNode(subnodes, "fetch"), 1);  // 1 个 fetch Op
      
      // 在 kCinnLaunchOp 中有参和无参的 node 都应当有 feed Op 
      // No-parameter input should has feed op
      auto new_v1 = GetNode(subnodes, "var1");
      ASSERT_EQ(new_v1->inputs.size(), static_cast(1));
      ASSERT_EQ(new_v1->outputs.size(), static_cast(1));
      ASSERT_EQ(new_v1->inputs[0]->Name(), "feed");
      ASSERT_EQ(new_v1->outputs[0]->Name(), "mul");
    
      // Parameter input should also have the feed op
      auto new_v2 = GetNode(subnodes, "var2");
      ASSERT_EQ(new_v2->inputs.size(), static_cast(1));
      ASSERT_EQ(new_v2->inputs[0]->Name(), "feed");
      ASSERT_EQ(new_v2->outputs.size(), static_cast(1));
      ASSERT_EQ(new_v2->outputs[0]->Name(), "mul");
    
      // kCinnLaunchOp 输出中应当有 fetch Op
      // output should has fetch op
      auto new_v6 = GetNode(subnodes, "var6");
      ASSERT_EQ(new_v6->inputs.size(), static_cast(1));
      ASSERT_EQ(new_v6->outputs.size(), static_cast(1));
      ASSERT_EQ(new_v6->inputs[0]->Name(), "relu");
      ASSERT_EQ(new_v6->outputs[0]->Name(), "fetch");
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90

    第一个单测是只有 fake Op 没办法 pass 优化,第二个单测是所有Op 都支持 CINN Pass, 那下一个就是一半是 fake Op,另一半是 只是 CINN Pass 的 OP

    std::unique_ptr BuildGraphWithOneCinnSubgraph() {
      ProgramDesc prog;
      auto g = std::make_unique(prog);
    
      // fake1 --> v1 --
      //                | --> mul --> v3 --> relu --> v4 --> fake2
      //           v2 --
    
      OpDesc fake1_op;
      fake1_op.SetType("fake1");
      OpDesc mul_op;
      mul_op.SetType("mul");
      OpDesc relu_op;
      relu_op.SetType("relu");
      OpDesc fake2_op;
      fake2_op.SetType("fake2");
    
      VarDesc var1("var1");
      VarDesc var2("var2");
      var2.SetPersistable(true);
      var2.SetIsParameter(true);
      VarDesc var3("var3");
      VarDesc var4("var4");
    
      ir::Node* fake1 = g->CreateOpNode(&fake1_op);
      ir::Node* mul = g->CreateOpNode(&mul_op);
      ir::Node* relu = g->CreateOpNode(&relu_op);
      ir::Node* fake2 = g->CreateOpNode(&fake2_op);
    
      ir::Node* v1 = g->CreateVarNode(&var1);
      ir::Node* v2 = g->CreateVarNode(&var2);
      ir::Node* v3 = g->CreateVarNode(&var3);
      ir::Node* v4 = g->CreateVarNode(&var4);
    
      // fill op node
      fake1->outputs = {v1};
      mul->inputs = {v2, v1};
      mul->outputs = {v3};
      relu->inputs = {v3};
      relu->outputs = {v4};
      fake2->inputs = {v4};
    
      // fill variable node
      v2->outputs = {mul};
    
      v1->inputs = {fake1};
      v1->outputs = {mul};
    
      v3->inputs = {mul};
      v3->outputs = {relu};
    
      v4->inputs = {relu};
      v4->outputs = {fake2};
    
      return g;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56

    上边的函数就是建立了一个这样的一个图

      // fake1 --> v1 --
      //                | --> mul --> v3 --> relu --> v4 --> fake2
      //           v2 --
    
    • 1
    • 2
    • 3
    
    
    • 1

    通过 cinn pass 之后这个图的节点变成下边儿这样:

      // fake1 --> v1 --
      //                | --> kCinnLaunchOp --> v4 --> fake2
      //           v2 --
    
    • 1
    • 2
    • 3

    只有一个 kCinnLaunchOp 其子图为,有9个节点

      // feed --> v1 --
      //               | --> mul --> v3 --> relu --> v4 --> fetch
      // feed --> v2 --
    
    • 1
    • 2
    • 3

    之前的图是单个 cinn op,下一个单测是多个 cinn op 的情况:

    std::unique_ptr BuildGraphWithMultiCinnSubgraph() {
      ProgramDesc prog;
      auto g = std::make_unique(prog);
    
      // fake1 --> v1 --
      //                | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3
      //           v2 --
    
      OpDesc fake1_op;
      fake1_op.SetType("fake1");
      OpDesc mul_op;
      mul_op.SetType("mul");
      OpDesc relu_op;
      relu_op.SetType("relu");
      OpDesc fake2_op;
      fake2_op.SetType("fake2");
      OpDesc fake3_op;
      fake3_op.SetType("fake3");
    
      VarDesc var1("var1");
      VarDesc var2("var2");
      var2.SetPersistable(true);
      var2.SetIsParameter(true);
      VarDesc var3("var3");
      VarDesc var4("var4");
      VarDesc var5("var5");
    
      ir::Node* fake1 = g->CreateOpNode(&fake1_op);
      ir::Node* mul = g->CreateOpNode(&mul_op);
      ir::Node* relu = g->CreateOpNode(&relu_op);
      ir::Node* fake2 = g->CreateOpNode(&fake2_op);
      ir::Node* fake3 = g->CreateOpNode(&fake3_op);
    
      ir::Node* v1 = g->CreateVarNode(&var1);
      ir::Node* v2 = g->CreateVarNode(&var2);
      ir::Node* v3 = g->CreateVarNode(&var3);
      ir::Node* v4 = g->CreateVarNode(&var4);
      ir::Node* v5 = g->CreateVarNode(&var5);
    
      // fill op node
      fake1->outputs = {v1};
      mul->inputs = {v2, v1};
      mul->outputs = {v3};
      fake2->inputs = {v3};
      fake2->outputs = {v4};
      relu->inputs = {v4};
      relu->outputs = {v5};
      fake3->inputs = {v5};
    
      // fill variable node
      v2->outputs = {mul};
    
      v1->inputs = {fake1};
      v1->outputs = {mul};
    
      v3->inputs = {mul};
      v3->outputs = {fake2};
    
      v4->inputs = {fake2};
      v4->outputs = {relu};
    
      v5->inputs = {relu};
      v5->outputs = {fake3};
    
      return g;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66

    以上代码建立一个这样的图:

      // fake1 --> v1 --
      //                | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3
      //           v2 --
    
    • 1
    • 2
    • 3

    fake2 op 为界,可以建立两个 cinn op pass

      // fake1 -> v1 -
      //              | -> CinnOp -> v3 -> fake2 -> v4 -> CinnOp ->v5 -> fake3
      //          v2 -
    
    • 1
    • 2
    • 3

    cinn pass 就两句代码:

      auto pass =
          paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
      pass->Apply(g.get());
    
    • 1
    • 2
    • 3

    此处是检验有两个 cinn pass Op 的代码:

      // A new op named kCinnLaunchOp should be added
      ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
      ASSERT_EQ(CountNode(nodes, kCinnLaunchOp), 2);
    
    • 1
    • 2
    • 3

    最后的编译结果是 cinn pass 之后有两个 子图:

      // subgraph1:
      // feed --> v4 --> relu --> v5 --> fetch
      // subgraph2:
      // feed --> v1 --
      //               | --> mul --> v3 --> fetch
      //          v2 --
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    BuildGraphWithNoNeedBufferInput 就是建立一个这样的子图:

      // fake1 --> v1 --                 --> v4 --> relu_grad --> v6
      //           v2 -- | --> add_grad |
      //           v3 --                 --> v5 --> fake2
    
    • 1
    • 2
    • 3

    BuildGraphWithNoNeedBufferInput 与之前不同的是,add_grad_op 使用了设置输入的 API SetInput

      OpDesc add_grad_op;
      add_grad_op.SetType("elementwise_add_grad");
      add_grad_op.SetInput(::paddle::framework::GradVarName("Out"), {"var1"});
      add_grad_op.SetInput("X", {"var2"});
      add_grad_op.SetInput("Y", {"var3"});
    
    • 1
    • 2
    • 3
    • 4
    • 5

    之后的单测写了,no_need_buffer_x 不知道什么意思.

      // A new op named kCinnLaunchOp should be added and
      // its input arguments are set correctly
      ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
      ASSERT_EQ(CountNode(nodes, kCinnLaunchOp), 1);
      auto* cinn_op_node = GetNode(nodes, kCinnLaunchOp);
      ASSERT_EQ(cinn_op_node->Op()->Input(operators::kX),
                std::vector({"var1"}));
      auto& no_need_buffer_x = cinn_op_node->Op()->Input(operators::kNoNeedBufferX);
      ASSERT_EQ(std::unordered_set(no_need_buffer_x.begin(),
                                                no_need_buffer_x.end()),
                std::unordered_set({"var2", "var3"}));
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    这里的 no_need_buffer_feeds 什么意思??

      ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add_grad"));
      ASSERT_TRUE(CheckNodeExisted(subnodes, "relu_grad"));
      ASSERT_EQ(CountNode(subnodes, "feed"), 3);
      ASSERT_EQ(CountNode(subnodes, "fetch"), 2);
      const auto& no_need_buffer_feeds =
          subgraph.Get>(kNoNeedBufferFeeds);
      ASSERT_EQ(no_need_buffer_feeds.size(), 2);
      ASSERT_EQ(no_need_buffer_feeds,
                std::unordered_set({"var2", "var3"}));
    
      // check the attributes of variable lists are saved correctly
      ASSERT_TRUE(subgraph.Has(kInputVars));
      EXPECT_EQ(subgraph.Get>(kInputVars),
                std::vector({"var1"}));
      ASSERT_TRUE(subgraph.Has(kInternalVars));
      EXPECT_EQ(subgraph.Get>(kInternalVars),
                std::vector({"var4"}));
      ASSERT_TRUE(subgraph.Has(kOutputVars));
      const auto& output_vars = subgraph.Get>(kOutputVars);
      EXPECT_EQ(
          std::unordered_set(output_vars.begin(), output_vars.end()),
          std::unordered_set({"var5", "var6"}));
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    TEST(BuildCinnPassTest, TestSkipGcVars){
      auto g = BuildGraphWithOneCinnSubgraph();
      
      // 这里什么意思????
      std::unordered_set all_skip_gc_vars = {"var1", "var3"};
      g->SetNotOwned(kSkipGcVarNames, &all_skip_gc_vars);
    
      auto pass =
          paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
      pass->Apply(g.get());
    
      // After search, the graph should as following
      // fake1 --> v1 --
      //                | --> kCinnLaunchOp --> v4 --> fake2
      //           v2 --
      const auto& nodes = g->Nodes();
      ASSERT_EQ(nodes.size(), static_cast(7));  // 这里为啥变成了 7
      ASSERT_TRUE(CheckGraphIndependence(nodes));
    
      // A new op named kCinnLaunchOp should be added
      ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
    
      // After search, there should has just one cinn subgraph
      // Note v3 has fetched because of v3 in kSkipGcVarNames
      // And v1 is a feed var so v1 no need fetched though it in kSkipGcVarNames
      // feed --> v1 --
      //               | --> mul --> v3 --> relu --> v4 --> fetch
      // feed --> v2 --                 --> fetch
      auto compilation_keys = GetCompilationKeys(*g);
      ASSERT_EQ(compilation_keys.size(), static_cast(1));
      auto* cinn_compiler = CinnCompiler::GetInstance();
      const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
    
      const auto& subnodes = subgraph.Nodes();
      ASSERT_EQ(subnodes.size(), static_cast(10));
      ASSERT_TRUE(CheckGraphIndependence(subnodes));
    
      ASSERT_EQ(CountNode(subnodes, "feed"), 2);
      // var3 and var4 should has fetch op
      ASSERT_EQ(CountNode(subnodes, "fetch"), 2);
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41

    最后两个 TEST 没看懂,留下问题

  • 相关阅读:
    6年前的麒麟980依旧可以再战
    关于linux与android传输代码tcp -传文件
    【云原生】-Zabbix6监控MySQL最佳实践
    C++项目:【负载均衡式在线OJ】
    前端vue论坛项目(七)------构建 UserProfileView 用户页面
    ZooKeeper数据模型/znode节点深入
    linux使用ros打开奥比中光astra相机,查看红外图像
    09 项目资源管理
    一文看懂MySQL的行锁
    入门力扣自学笔记155 C++ (题目编号698)
  • 原文地址:https://blog.csdn.net/HaoZiHuang/article/details/133826560