diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -745,6 +745,27 @@ The above pattern removes the `Foo` and replaces all uses of `Foo` with `$input`. +### `commutative` + +The `commutative` directive is used to specify the operands that have the +commutative property while pattern matching. When two operands are marked with +`commutative`, it'll try to match the operands with both ordering of +constraints. + +For example, + +```tablegen +def : Pat<(TwoArgOp (commutative $firstArg, (AnOp $secondArg))), + (...)>; +``` + +The above pattern will match both `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and +`"test.TwoArgOp"(%AnOpArg, %I32Arg)`. + +Only type operand is supported with `commutative` and note that an operation +with `Commutative` trait doesn't mean its operand will be matched commtatively. +You need to explicitly specify it. + ## Debugging Tips ### Run `mlir-tblgen` to see the generated content diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2683,6 +2683,21 @@ def returnType; +// Directive used to specify the operands that have the commutative property +// while pattern matching. When two operands are marked with `commutative`, +// it'll try to match the operands with both ordering of constraints. Example: +// +// ``` +// (TwoArgOp (commutative $firstArg, (AnOp $secondArg))) +// ``` +// The above pattern will match both `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and +// `"test.TwoArgOp"(%AnOpArg, %I32Arg)`. +// +// Only type operand is supported with `commutative` and note that an operation +// with `Commutative` trait doesn't mean its operand will be matched +// commtatively. You need to explicitly specify it. +def commutative; + //===----------------------------------------------------------------------===// // Attribute and Type generation //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -186,6 +186,9 @@ // Returns true if this DAG node is wrapping native code call. bool isNativeCodeCall() const; + // Returns whether this DAG is a return commutative specifier. + bool isCommutative() const; + // Returns true if this DAG node is an operation. bool isOperation() const; diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -113,7 +113,7 @@ bool DagNode::isOperation() const { return !isNativeCodeCall() && !isReplaceWithValue() && - !isLocationDirective() && !isReturnTypeDirective(); + !isLocationDirective() && !isReturnTypeDirective() && !isCommutative(); } llvm::StringRef DagNode::getNativeCodeTemplate() const { @@ -142,7 +142,7 @@ } int DagNode::getNumOps() const { - int count = isReplaceWithValue() ? 0 : 1; + int count = isOperation() ? 1 : 0; for (int i = 0, e = getNumArgs(); i != e; ++i) { if (auto child = getArgAsNestedDag(i)) count += child.getNumOps(); @@ -184,6 +184,11 @@ return dagOpDef->getName() == "returnType"; } +bool DagNode::isCommutative() const { + auto *dagOpDef = cast(node->getOperator())->getDef(); + return dagOpDef->getName() == "commutative"; +} + void DagNode::print(raw_ostream &os) const { if (node) node->print(os); @@ -764,22 +769,25 @@ if (tree.isOperation()) { auto &op = getDialectOp(tree); auto numOpArgs = op.getNumArgs(); + int numCommutative = 0; - // The pattern might have trailing directives. + // We need to exclude the trailing directives and unpack commutative DagNode + // while checking argument number. int numDirectives = 0; for (int i = numTreeArgs - 1; i >= 0; --i) { if (auto dagArg = tree.getArgAsNestedDag(i)) { if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective()) ++numDirectives; - else - break; + else if (dagArg.isCommutative()) + ++numCommutative; } } - if (numOpArgs != numTreeArgs - numDirectives) { + if (numOpArgs != numTreeArgs - numDirectives + numCommutative) { auto err = formatv("op '{0}' argument number mismatch: " "{1} in pattern vs. {2} in definition", - op.getOperationName(), numTreeArgs, numOpArgs); + op.getOperationName(), numTreeArgs + numCommutative, + numOpArgs); PrintFatalError(&def, err); } @@ -791,10 +799,29 @@ verifyBind(infoMap.bindOpResult(treeName, op), treeName); } - for (int i = 0; i != numTreeArgs; ++i) { + // The operand in Commutative Dag should be bound to the operation in the + // parent DagNode. + auto collectSymbolInCommutative = [&](DagNode parent, DagNode tree, + int &opArg) { + for (int i = 0; i < 2; ++i, ++opArg) { + if (DagNode subTree = tree.getArgAsNestedDag(i)) { + collectBoundSymbols(subTree, infoMap, isSrcPattern); + } else { + auto argName = tree.getArgName(i); + if (!argName.empty() && argName != "_") + verifyBind(infoMap.bindOpArgument(parent, argName, op, opArg), + argName); + } + } + }; + + for (int i = 0, opArg = 0; i != numTreeArgs; ++i, ++opArg) { if (auto treeArg = tree.getArgAsNestedDag(i)) { - // This DAG node argument is a DAG node itself. Go inside recursively. - collectBoundSymbols(treeArg, infoMap, isSrcPattern); + if (treeArg.isCommutative()) + collectSymbolInCommutative(tree, treeArg, opArg); + else + // This DAG node argument is a DAG node itself. Go inside recursively. + collectBoundSymbols(treeArg, infoMap, isSrcPattern); continue; } @@ -806,7 +833,7 @@ if (!treeArgName.empty() && treeArgName != "_") { LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " << treeArgName << '\n'); - verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i), + verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArg), treeArgName); } } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1329,6 +1329,30 @@ (replaceWithValue $results__2), ConstantAttr)>; +//===----------------------------------------------------------------------===// +// Test Patterns (commutative) + +def TestCommutativeOpA : TEST_Op<"commutative_op_a"> { + let arguments = (ins AnyInteger:$arg0, AnyInteger:$arg1, AnyInteger:$arg2); + let results = (outs I32:$output); +} + +def TestCommutativeOpB : TEST_Op<"commutative_op_b"> { + let arguments = (ins AnyInteger:$arg0); + let results = (outs I32:$output); +} + +def : Pat<(TestCommutativeOpA (commutative I32:$arg1, I16:$arg2), $_), + (TestCommutativeOpB $arg2)>; + +def : Pat<(TestCommutativeOpA (commutative (TestCommutativeOpB I32:$arg1), I16:$arg2), $_), + (TestCommutativeOpB $arg2)>; + +def : Pat<(TestCommutativeOpA (commutative (TestCommutativeOpB I32:$arg1), + (TestCommutativeOpB I16:$arg2)), + $_), + (TestCommutativeOpB $arg2)>; + //===----------------------------------------------------------------------===// // Test Patterns (Location) diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -531,6 +531,40 @@ return %0 : i32 } +//===----------------------------------------------------------------------===// +// Test commutative directive +//===----------------------------------------------------------------------===// + +// CHECK: @commutative_dag_leaf_only +func @commutative_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { + // CHECK: "test.commutative_op_b"(%arg1) : (i16) -> i32 + %0 = "test.commutative_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32 + // CHECK: "test.commutative_op_b"(%arg1) : (i16) -> i32 + %1 = "test.commutative_op_a"(%arg1, %arg0, %arg2) : (i16, i32, i8) -> i32 + return +} + +// CHECK: @commutative_dag_leaf_dag_node +func @commutative_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { + %0 = "test.commutative_op_b"(%arg0) : (i32) -> i32 + // CHECK: "test.commutative_op_b"(%arg1) : (i16) -> i32 + %1 = "test.commutative_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32 + // CHECK: "test.commutative_op_b"(%arg1) : (i16) -> i32 + %2 = "test.commutative_op_a"(%arg1, %0, %arg2) : (i16, i32, i8) -> i32 + return +} + +// CHECK: @commutative_dag_node_dag_node +func @commutative_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () { + %0 = "test.commutative_op_b"(%arg0) : (i32) -> i32 + %1 = "test.commutative_op_b"(%arg1) : (i16) -> i32 + // CHECK: "test.commutative_op_b"(%arg1) : (i16) -> i32 + %2 = "test.commutative_op_a"(%0, %1, %arg2) : (i32, i32, i8) -> i32 + // CHECK: "test.commutative_op_b"(%arg1) : (i16) -> i32 + %3 = "test.commutative_op_a"(%1, %0, %arg2) : (i32, i32, i8) -> i32 + return +} + //===----------------------------------------------------------------------===// // Test that ops without type deduction can be created with type builders. //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -117,10 +117,17 @@ void emitOpMatch(DagNode tree, StringRef opName, int depth); // Emits C++ statements for matching the `argIndex`-th argument of the given - // DAG `tree` as an operand. operandIndex is the index in the DAG excluding - // the preceding attributes. - void emitOperandMatch(DagNode tree, StringRef opName, int argIndex, - int operandIndex, int depth); + // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the + // bound name and the constraint of the operand respectively. + void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName, + DagLeaf operandMatcher, StringRef argName, + int argIndex); + + // Emits C++ statements for matching the operands which have commutative + // property. + void emitCommutativeOperandMatch(DagNode tree, DagNode commutativeArgTree, + StringRef opName, int argIndex, + int &operandIndex, int depth); // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an attribute. @@ -470,6 +477,9 @@ for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { std::string argName = formatv("arg{0}_{1}", depth, i); if (DagNode argTree = tree.getArgAsNestedDag(i)) { + if (argTree.isCommutative()) + PrintFatalError(loc, "NativeCodeCall cannot have commutative operands"); + os << "Value " << argName << ";\n"; } else { auto leaf = tree.getArgAsLeaf(i); @@ -582,12 +592,6 @@ formatv("\"{0} is not {1} type\"", castedName, op.getQualCppClassName())); - if (tree.getNumArgs() != op.getNumArgs()) - PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " - "pattern vs. {2} in definition", - op.getOperationName(), tree.getNumArgs(), - op.getNumArgs())); - // If the operand's name is set, set to that variable. auto name = tree.getSymbol(); if (!name.empty()) @@ -599,6 +603,11 @@ // Handle nested DAG construct first if (DagNode argTree = tree.getArgAsNestedDag(i)) { + if (argTree.isCommutative()) { + emitCommutativeOperandMatch(tree, argTree, castedName, i, nextOperand, + depth); + continue; + } if (auto *operand = opArg.dyn_cast()) { if (operand->isVariableLength()) { auto error = formatv("use nested DAG construct to match op {0}'s " @@ -607,6 +616,7 @@ PrintFatalError(loc, error); } } + os << "{\n"; // Attributes don't count for getODSOperands. @@ -627,8 +637,12 @@ // Next handle DAG leaf: operand or attribute if (opArg.is()) { - // emitOperandMatch's argument indexing counts attributes. - emitOperandMatch(tree, castedName, i, nextOperand, depth); + auto operandName = + formatv("{0}.getODSOperands({1})", castedName, nextOperand); + emitOperandMatch(tree, castedName, operandName.str(), + /*operandMatcher=*/tree.getArgAsLeaf(i), + /*argName=*/tree.getArgName(i), + /*argIndex=*/i); ++nextOperand; } else if (opArg.is()) { emitAttributeMatch(tree, opName, i, depth); @@ -642,24 +656,23 @@ } void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, - int argIndex, int operandIndex, - int depth) { + StringRef operandName, + DagLeaf operandMatcher, StringRef argName, + int argIndex) { Operator &op = tree.getDialectOp(opMap); auto *operand = op.getArg(argIndex).get(); - auto matcher = tree.getArgAsLeaf(argIndex); // If a constraint is specified, we need to generate C++ statements to // check the constraint. - if (!matcher.isUnspecified()) { - if (!matcher.isOperandMatcher()) { + if (!operandMatcher.isUnspecified()) { + if (!operandMatcher.isOperandMatcher()) PrintFatalError( loc, formatv("the {1}-th argument of op '{0}' should be an operand", op.getOperationName(), argIndex + 1)); - } // Only need to verify if the matcher's type is different from the one // of op definition. - Constraint constraint = matcher.getAsConstraint(); + Constraint constraint = operandMatcher.getAsConstraint(); if (operand->constraint != constraint) { if (operand->isVariableLength()) { auto error = formatv( @@ -667,8 +680,7 @@ op.getOperationName(), argIndex); PrintFatalError(loc, error); } - auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()", - opName, operandIndex); + auto self = formatv("(*{0}.begin()).getType()", operandName); StringRef verifier = staticMatcherHelper.getVerifierName(constraint); emitStaticVerifierCall( verifier, opName, self.str(), @@ -681,20 +693,75 @@ } // Capture the value - auto name = tree.getArgName(argIndex); // `$_` is a special symbol to ignore op argument matching. - if (!name.empty() && name != "_") { - // We need to subtract the number of attributes before this operand to get - // the index in the operand list. - auto numPrevAttrs = std::count_if( - op.arg_begin(), op.arg_begin() + argIndex, - [](const Argument &arg) { return arg.is(); }); - - auto res = symbolInfoMap.findBoundSymbol(name, tree, op, argIndex); - os << formatv("{0} = {1}.getODSOperands({2});\n", - res->second.getVarName(name), opName, - argIndex - numPrevAttrs); + if (!argName.empty() && argName != "_") { + auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex); + os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName); + } +} + +void PatternEmitter::emitCommutativeOperandMatch(DagNode tree, + DagNode commutativeArgTree, + StringRef opName, int argIndex, + int &operandIndex, int depth) { + if (commutativeArgTree.getNumArgs() != 2) + PrintFatalError(loc, "commutative only supports for grouping two operands"); + + Operator &op = tree.getDialectOp(opMap); + + std::string tblgenOps; + + std::string lambda = formatv("commutativeLambda{0}", depth); + os << formatv("auto {0} = [&](OperandRange v0, OperandRange v1) {{\n", + lambda); + + os.indent(); + + // TODO: Do more checks + for (int i = 0; i < 2; ++i, ++argIndex) { + if (DagNode argTree = commutativeArgTree.getArgAsNestedDag(i)) { + if (argTree.isCommutative()) + PrintFatalError(loc, "commutative cannot be nested"); + + std::string argName = formatv("local_op_{0}", i).str(); + + os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName, + i); + emitMatchCheck(opName, /*matchStr=*/argName, + formatv("\"Operand {0} of {1} has null definingOp\"", + operandIndex++, opName)); + emitMatch(argTree, argName, depth + 1); + // `tblgen_ops` is used to collect the matched operations. In commutative, + // we need to queue the operation only if the matching success. Thus we + // emit the code at the end. + tblgenOps.append(formatv("tblgen_ops.push_back({0});\n", argName).str()); + } else if (op.getArg(argIndex).is()) { + emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(), + /*operandMatcher=*/commutativeArgTree.getArgAsLeaf(i), + /*argName=*/commutativeArgTree.getArgName(i), argIndex); + ++operandIndex; + } else + PrintFatalError(loc, "commutative can only be applied on type operand"); } + + os << tblgenOps; + os << "return success();\n"; + os.unindent() << "};\n"; + + os << "{\n"; + os.indent(); + + os << formatv("auto operand0 = {0}.getODSOperands({1});\n", opName, + operandIndex - 2); + os << formatv("auto operand1 = {0}.getODSOperands({1});\n", opName, + operandIndex - 1); + + os << formatv("if(failed({0}(operand0, operand1)) && failed({0}(operand1, " + "operand0)))\n", + lambda); + os.indent() << "return failure();\n"; + + os.unindent().unindent() << "}\n"; } void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,