diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -745,6 +745,24 @@ The above pattern removes the `Foo` and replaces all uses of `Foo` with `$input`. +### `either` + +The `either` directive is used to specify the operands may be matched in either +order. When two adjacents are marked with `either`, it'll try to match the +operands with both ordering of constraints. For example, + +```tablegen +def : Pat<(TwoArgOp (either $firstArg, (AnOp $secondArg))), + (...)>; +``` + +The above pattern will accpet both `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and +`"test.TwoArgOp"(%AnOpArg, %I32Arg)`. + +Only type operand is supported with `either` and note that an operation +with `Commutative` trait doesn't imply that it'll have the same behavior like +`either` while pattern matching. + ## Debugging Tips ### Run `mlir-tblgen` to see the generated content diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2683,6 +2683,21 @@ def returnType; +// Directive used to specify the operands may be matched in either order. When +// two adjacents are marked with `either`, it'll try to match the operands in +// both ordering of constraints. Example: +// +// ``` +// (TwoArgOp (either $firstArg, (AnOp $secondArg))) +// ``` +// The above pattern will accept both `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and +// `"test.TwoArgOp"(%AnOpArg, %I32Arg)`. +// +// Only type operand is supported with `either` and note that an operation +// with `Commutative` trait doesn't imply that it'll have the same behavior like +// `either` while pattern matching. +def either; + //===----------------------------------------------------------------------===// // Attribute and Type generation //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -186,6 +186,9 @@ // Returns true if this DAG node is wrapping native code call. bool isNativeCodeCall() const; + // Returns whether this DAG is a return `either` specifier. + bool isEither() const; + // Returns true if this DAG node is an operation. bool isOperation() const; diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -113,7 +113,7 @@ bool DagNode::isOperation() const { return !isNativeCodeCall() && !isReplaceWithValue() && - !isLocationDirective() && !isReturnTypeDirective(); + !isLocationDirective() && !isReturnTypeDirective() && !isEither(); } llvm::StringRef DagNode::getNativeCodeTemplate() const { @@ -142,7 +142,9 @@ } int DagNode::getNumOps() const { - int count = isReplaceWithValue() ? 0 : 1; + // We want to get number of operations recursively involved in the DAG tree. + // All other directives should be excluded. + int count = isOperation() ? 1 : 0; for (int i = 0, e = getNumArgs(); i != e; ++i) { if (auto child = getArgAsNestedDag(i)) count += child.getNumOps(); @@ -184,6 +186,11 @@ return dagOpDef->getName() == "returnType"; } +bool DagNode::isEither() const { + auto *dagOpDef = cast(node->getOperator())->getDef(); + return dagOpDef->getName() == "either"; +} + void DagNode::print(raw_ostream &os) const { if (node) node->print(os); @@ -764,22 +771,25 @@ if (tree.isOperation()) { auto &op = getDialectOp(tree); auto numOpArgs = op.getNumArgs(); + int numEither = 0; - // The pattern might have trailing directives. + // We need to exclude the trailing directives and `either` directive groups + // two operands of the operation. int numDirectives = 0; for (int i = numTreeArgs - 1; i >= 0; --i) { if (auto dagArg = tree.getArgAsNestedDag(i)) { if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective()) ++numDirectives; - else - break; + else if (dagArg.isEither()) + ++numEither; } } - if (numOpArgs != numTreeArgs - numDirectives) { - auto err = formatv("op '{0}' argument number mismatch: " - "{1} in pattern vs. {2} in definition", - op.getOperationName(), numTreeArgs, numOpArgs); + if (numOpArgs != numTreeArgs - numDirectives + numEither) { + auto err = + formatv("op '{0}' argument number mismatch: " + "{1} in pattern vs. {2} in definition", + op.getOperationName(), numTreeArgs + numEither, numOpArgs); PrintFatalError(&def, err); } @@ -791,10 +801,28 @@ verifyBind(infoMap.bindOpResult(treeName, op), treeName); } - for (int i = 0; i != numTreeArgs; ++i) { + // The operand in `either` DAG should be bound to the operation in the + // parent DagNode. + auto collectSymbolInEither = [&](DagNode parent, DagNode tree, int &opArg) { + for (int i = 0; i < 2; ++i, ++opArg) { + if (DagNode subTree = tree.getArgAsNestedDag(i)) { + collectBoundSymbols(subTree, infoMap, isSrcPattern); + } else { + auto argName = tree.getArgName(i); + if (!argName.empty() && argName != "_") + verifyBind(infoMap.bindOpArgument(parent, argName, op, opArg), + argName); + } + } + }; + + for (int i = 0, opArg = 0; i != numTreeArgs; ++i, ++opArg) { if (auto treeArg = tree.getArgAsNestedDag(i)) { - // This DAG node argument is a DAG node itself. Go inside recursively. - collectBoundSymbols(treeArg, infoMap, isSrcPattern); + if (treeArg.isEither()) + collectSymbolInEither(tree, treeArg, opArg); + else + // This DAG node argument is a DAG node itself. Go inside recursively. + collectBoundSymbols(treeArg, infoMap, isSrcPattern); continue; } @@ -806,7 +834,7 @@ if (!treeArgName.empty() && treeArgName != "_") { LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " << treeArgName << '\n'); - verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i), + verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArg), treeArgName); } } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1329,6 +1329,30 @@ (replaceWithValue $results__2), ConstantAttr)>; +//===----------------------------------------------------------------------===// +// Test Patterns (either) + +def TestEitherOpA : TEST_Op<"either_op_a"> { + let arguments = (ins AnyInteger:$arg0, AnyInteger:$arg1, AnyInteger:$arg2); + let results = (outs I32:$output); +} + +def TestEitherOpB : TEST_Op<"either_op_b"> { + let arguments = (ins AnyInteger:$arg0); + let results = (outs I32:$output); +} + +def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $_), + (TestEitherOpB $arg2)>; + +def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1), I16:$arg2), $_), + (TestEitherOpB $arg2)>; + +def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1), + (TestEitherOpB I16:$arg2)), + $_), + (TestEitherOpB $arg2)>; + //===----------------------------------------------------------------------===// // Test Patterns (Location) diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -531,6 +531,40 @@ return %0 : i32 } +//===----------------------------------------------------------------------===// +// Test either directive +//===----------------------------------------------------------------------===// + +// CHECK: @either_dag_leaf_only +func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { + // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32 + %0 = "test.either_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32 + // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32 + %1 = "test.either_op_a"(%arg1, %arg0, %arg2) : (i16, i32, i8) -> i32 + return +} + +// CHECK: @either_dag_leaf_dag_node +func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { + %0 = "test.either_op_b"(%arg0) : (i32) -> i32 + // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32 + %1 = "test.either_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32 + // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32 + %2 = "test.either_op_a"(%arg1, %0, %arg2) : (i16, i32, i8) -> i32 + return +} + +// CHECK: @either_dag_node_dag_node +func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { + %0 = "test.either_op_b"(%arg0) : (i32) -> i32 + %1 = "test.either_op_b"(%arg1) : (i16) -> i32 + // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32 + %2 = "test.either_op_a"(%0, %1, %arg2) : (i32, i32, i8) -> i32 + // CHECK: "test.either_op_b"(%arg1) : (i16) -> i32 + %3 = "test.either_op_a"(%1, %0, %arg2) : (i32, i32, i8) -> i32 + return +} + //===----------------------------------------------------------------------===// // Test that ops without type deduction can be created with type builders. //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -117,10 +117,17 @@ void emitOpMatch(DagNode tree, StringRef opName, int depth); // Emits C++ statements for matching the `argIndex`-th argument of the given - // DAG `tree` as an operand. operandIndex is the index in the DAG excluding - // the preceding attributes. - void emitOperandMatch(DagNode tree, StringRef opName, int argIndex, - int operandIndex, int depth); + // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the + // bound name and the constraint of the operand respectively. + void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName, + DagLeaf operandMatcher, StringRef argName, + int argIndex); + + // Emits C++ statements for matching the operands which can be matched in + // either order. + void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, + StringRef opName, int argIndex, int &operandIndex, + int depth); // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an attribute. @@ -470,6 +477,9 @@ for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { std::string argName = formatv("arg{0}_{1}", depth, i); if (DagNode argTree = tree.getArgAsNestedDag(i)) { + if (argTree.isEither()) + PrintFatalError(loc, "NativeCodeCall cannot have `either` operands"); + os << "Value " << argName << ";\n"; } else { auto leaf = tree.getArgAsLeaf(i); @@ -583,12 +593,6 @@ formatv("\"{0} is not {1} type\"", castedName, op.getQualCppClassName())); - if (tree.getNumArgs() != op.getNumArgs()) - PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " - "pattern vs. {2} in definition", - op.getOperationName(), tree.getNumArgs(), - op.getNumArgs())); - // If the operand's name is set, set to that variable. auto name = tree.getSymbol(); if (!name.empty()) @@ -600,6 +604,11 @@ // Handle nested DAG construct first if (DagNode argTree = tree.getArgAsNestedDag(i)) { + if (argTree.isEither()) { + emitEitherOperandMatch(tree, argTree, castedName, i, nextOperand, + depth); + continue; + } if (auto *operand = opArg.dyn_cast()) { if (operand->isVariableLength()) { auto error = formatv("use nested DAG construct to match op {0}'s " @@ -608,6 +617,7 @@ PrintFatalError(loc, error); } } + os << "{\n"; // Attributes don't count for getODSOperands. @@ -617,9 +627,10 @@ "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", argName, castedName, nextOperand); // Null check of operand's definingOp - emitMatchCheck(castedName, /*matchStr=*/argName, - formatv("\"Operand {0} of {1} has null definingOp\"", - nextOperand++, castedName)); + emitMatchCheck( + castedName, /*matchStr=*/argName, + formatv("\"There's no operation that defines operand {0} of {1}\"", + nextOperand++, castedName)); emitMatch(argTree, argName, depth + 1); os << formatv("tblgen_ops.push_back({0});\n", argName); os.unindent() << "}\n"; @@ -628,8 +639,12 @@ // Next handle DAG leaf: operand or attribute if (opArg.is()) { - // emitOperandMatch's argument indexing counts attributes. - emitOperandMatch(tree, castedName, i, nextOperand, depth); + auto operandName = + formatv("{0}.getODSOperands({1})", castedName, nextOperand); + emitOperandMatch(tree, castedName, operandName.str(), + /*operandMatcher=*/tree.getArgAsLeaf(i), + /*argName=*/tree.getArgName(i), + /*argIndex=*/i); ++nextOperand; } else if (opArg.is()) { emitAttributeMatch(tree, opName, i, depth); @@ -643,24 +658,23 @@ } void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, - int argIndex, int operandIndex, - int depth) { + StringRef operandName, + DagLeaf operandMatcher, StringRef argName, + int argIndex) { Operator &op = tree.getDialectOp(opMap); auto *operand = op.getArg(argIndex).get(); - auto matcher = tree.getArgAsLeaf(argIndex); // If a constraint is specified, we need to generate C++ statements to // check the constraint. - if (!matcher.isUnspecified()) { - if (!matcher.isOperandMatcher()) { + if (!operandMatcher.isUnspecified()) { + if (!operandMatcher.isOperandMatcher()) PrintFatalError( loc, formatv("the {1}-th argument of op '{0}' should be an operand", op.getOperationName(), argIndex + 1)); - } // Only need to verify if the matcher's type is different from the one // of op definition. - Constraint constraint = matcher.getAsConstraint(); + Constraint constraint = operandMatcher.getAsConstraint(); if (operand->constraint != constraint) { if (operand->isVariableLength()) { auto error = formatv( @@ -668,8 +682,7 @@ op.getOperationName(), argIndex); PrintFatalError(loc, error); } - auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()", - opName, operandIndex); + auto self = formatv("(*{0}.begin()).getType()", operandName); StringRef verifier = staticMatcherHelper.getVerifierName(constraint); emitStaticVerifierCall( verifier, opName, self.str(), @@ -682,22 +695,78 @@ } // Capture the value - auto name = tree.getArgName(argIndex); // `$_` is a special symbol to ignore op argument matching. - if (!name.empty() && name != "_") { - // We need to subtract the number of attributes before this operand to get - // the index in the operand list. - auto numPrevAttrs = std::count_if( - op.arg_begin(), op.arg_begin() + argIndex, - [](const Argument &arg) { return arg.is(); }); - - auto res = symbolInfoMap.findBoundSymbol(name, tree, op, argIndex); - os << formatv("{0} = {1}.getODSOperands({2});\n", - res->second.getVarName(name), opName, - argIndex - numPrevAttrs); + if (!argName.empty() && argName != "_") { + auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex); + os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName); } } +void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, + StringRef opName, int argIndex, + int &operandIndex, int depth) { + if (eitherArgTree.getNumArgs() != 2) + PrintFatalError(loc, "`either` only supports grouping two operands"); + + Operator &op = tree.getDialectOp(opMap); + + std::string codeBuffer; + llvm::raw_string_ostream tblgenOps(codeBuffer); + + std::string lambda = formatv("eitherLambda{0}", depth); + os << formatv("auto {0} = [&](OperandRange v0, OperandRange v1) {{\n", + lambda); + + os.indent(); + + for (int i = 0; i < 2; ++i, ++argIndex) { + if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) { + if (argTree.isEither()) + PrintFatalError(loc, "either cannot be nested"); + + std::string argName = formatv("local_op_{0}", i).str(); + + os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName, + i); + emitMatchCheck( + opName, /*matchStr=*/argName, + formatv("\"There's no operation that defines operand {0} of {1}\"", + operandIndex++, opName)); + emitMatch(argTree, argName, depth + 1); + // `tblgen_ops` is used to collect the matched operations. In either, we + // need to queue the operation only if the matching success. Thus we emit + // the code at the end. + tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName); + } else if (op.getArg(argIndex).is()) { + emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(), + /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i), + /*argName=*/eitherArgTree.getArgName(i), argIndex); + ++operandIndex; + } else + PrintFatalError(loc, "either can only be applied on type operand"); + } + + os << tblgenOps.str(); + os << "return success();\n"; + os.unindent() << "};\n"; + + os << "{\n"; + os.indent(); + + os << formatv("auto either_operand0 = {0}.getODSOperands({1});\n", opName, + operandIndex - 2); + os << formatv("auto either_operand1 = {0}.getODSOperands({1});\n", opName, + operandIndex - 1); + + os << formatv("if(failed({0}(either_operand0, either_operand1)) && " + "failed({0}(either_operand1, " + "either_operand0)))\n", + lambda); + os.indent() << "return failure();\n"; + + os.unindent().unindent() << "}\n"; +} + void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap);