diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -807,15 +807,16 @@ // The operand in `either` DAG should be bound to the operation in the // parent DagNode. auto collectSymbolInEither = [&](DagNode parent, DagNode tree, - int &opArgIdx) { + int opArgIdx) { for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) { if (DagNode subTree = tree.getArgAsNestedDag(i)) { collectBoundSymbols(subTree, infoMap, isSrcPattern); } else { auto argName = tree.getArgName(i); - if (!argName.empty() && argName != "_") + if (!argName.empty() && argName != "_") { verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx), argName); + } } } }; @@ -824,6 +825,14 @@ if (auto treeArg = tree.getArgAsNestedDag(i)) { if (treeArg.isEither()) { collectSymbolInEither(tree, treeArg, opArgIdx); + // `either` DAG is *flattened*. For example, + // + // (FooOp (either arg0, arg1), arg2) + // + // can be viewed as: + // + // (FooOp arg0, arg1, arg2) + ++opArgIdx; } else { // This DAG node argument is a DAG node itself. Go inside recursively. collectBoundSymbols(treeArg, infoMap, isSrcPattern); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1691,20 +1691,20 @@ } def TestEitherOpB : TEST_Op<"either_op_b"> { - let arguments = (ins AnyInteger:$arg0); + let arguments = (ins AnyInteger:$arg0, AnyInteger:$arg1); let results = (outs I32:$output); } -def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $_), - (TestEitherOpB $arg2)>; +def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x), + (TestEitherOpB $arg2, $x)>; -def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1), I16:$arg2), $_), - (TestEitherOpB $arg2)>; +def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_), I16:$arg2), $x), + (TestEitherOpB $arg2, $x)>; -def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1), - (TestEitherOpB I16:$arg2)), - $_), - (TestEitherOpB $arg2)>; +def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_), + (TestEitherOpB I16:$arg2, $_)), + $x), + (TestEitherOpB $arg2, $x)>; //===----------------------------------------------------------------------===// // Test Patterns (Location) diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -532,30 +532,30 @@ // CHECK: @either_dag_leaf_only func.func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { - // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32 + // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32 %0 = "test.either_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32 - // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32 + // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32 %1 = "test.either_op_a"(%arg1, %arg0, %arg2) : (i16, i32, i8) -> i32 return } // CHECK: @either_dag_leaf_dag_node func.func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { - %0 = "test.either_op_b"(%arg0) : (i32) -> i32 - // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32 + %0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32 + // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32 %1 = "test.either_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32 - // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32 + // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32 %2 = "test.either_op_a"(%arg1, %0, %arg2) : (i16, i32, i8) -> i32 return } // CHECK: @either_dag_node_dag_node func.func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { - %0 = "test.either_op_b"(%arg0) : (i32) -> i32 - %1 = "test.either_op_b"(%arg1) : (i16) -> i32 - // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32 + %0 = "test.either_op_b"(%arg0, %arg0) : (i32, i32) -> i32 + %1 = "test.either_op_b"(%arg1, %arg1) : (i16, i16) -> i32 + // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32 %2 = "test.either_op_a"(%0, %1, %arg2) : (i32, i32, i8) -> i32 - // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32 + // CHECK: "test.either_op_b"(%arg1, %arg2) : (i16, i8) -> i32 %3 = "test.either_op_a"(%1, %0, %arg2) : (i32, i32, i8) -> i32 return } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -582,22 +582,24 @@ if (!name.empty()) os << formatv("{0} = {1};\n", name, castedName); - for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) { - auto opArg = op.getArg(i); + for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; + ++i, ++opArgIdx) { + auto opArg = op.getArg(opArgIdx); std::string argName = formatv("op{0}", depth + 1); // Handle nested DAG construct first if (DagNode argTree = tree.getArgAsNestedDag(i)) { if (argTree.isEither()) { - emitEitherOperandMatch(tree, argTree, castedName, i, nextOperand, + emitEitherOperandMatch(tree, argTree, castedName, opArgIdx, nextOperand, depth); + ++opArgIdx; continue; } if (auto *operand = opArg.dyn_cast()) { if (operand->isVariableLength()) { auto error = formatv("use nested DAG construct to match op {0}'s " "variadic operand #{1} unsupported now", - op.getOperationName(), i); + op.getOperationName(), opArgIdx); PrintFatalError(loc, error); } } @@ -627,11 +629,10 @@ formatv("{0}.getODSOperands({1})", castedName, nextOperand); emitOperandMatch(tree, castedName, operandName.str(), /*operandMatcher=*/tree.getArgAsLeaf(i), - /*argName=*/tree.getArgName(i), - /*argIndex=*/i); + /*argName=*/tree.getArgName(i), opArgIdx); ++nextOperand; } else if (opArg.is()) { - emitAttributeMatch(tree, opName, i, depth); + emitAttributeMatch(tree, opName, opArgIdx, depth); } else { PrintFatalError(loc, "unhandled case when matching op"); }