diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -646,6 +646,28 @@ [TODO] +#### Matching fixed number of sub-operands as a variadic operand + +The `variadic` directive matches a fixed number of sub-operands as a variadic +operand. For example, given a `concatenate` op with a variadic `inputs` +operand, we can match a `concatenate` op with 2 actual inputs with the +following pattern: + +```tablegen +def ConcatenateOp : Op<"concatenate"> { + let arguments = (ins + Variadic:$inputs, + ); + + let results = (ins + AnyTensor:$output + ); +} + +def : Pat<(ConcatenateOp (variadic $input0, $input1)), + (SomeOtherOp $input0, $input1)>; +``` + ### Supplying additional constraints Constraints can be placed on op arguments when matching. But sometimes we need diff --git a/mlir/include/mlir/IR/PatternBase.td b/mlir/include/mlir/IR/PatternBase.td --- a/mlir/include/mlir/IR/PatternBase.td +++ b/mlir/include/mlir/IR/PatternBase.td @@ -211,6 +211,20 @@ // `either` while pattern matching. def either; +// Directive used to match variadic operands. This directive only matches if +// the variadic operand has the same length as the specified formal +// sub-dags. +// +// ``` +// (VariadicOp (variadic:$input1 $input1a, $input1b), +// (variadic:$input2 $input2a, $input2b, $input2c), +// $attr1, $attr2) +// ``` +// +// The pattern above only matches if the `$input1` operand is of length 2, +// `$input2` is of length 3, and all sub-dags match respectively. +def variadic; + //===----------------------------------------------------------------------===// // Common value constraints //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -189,6 +189,9 @@ // Returns whether this DAG is an `either` specifier. bool isEither() const; + // Returns whether this DAG is an `variadic` specifier. + bool isVariadic() const; + // Returns true if this DAG node is an operation. bool isOperation() const; @@ -270,7 +273,22 @@ // DagNode and DagLeaf are accessed by value which means it can't be used as // identifier here. Use an opaque pointer type instead. - using DagAndConstant = std::pair; + struct DagAndConstant { + const void *dag; + int operandIndexOrNumValues; + int rangeIndex; + + DagAndConstant(const void *dag, int operandIndexOrNumValues, + int rangeIndex) + : dag(dag), operandIndexOrNumValues(operandIndexOrNumValues), + rangeIndex(rangeIndex) {} + + bool operator==(const DagAndConstant &rhs) const { + return dag == rhs.dag && + operandIndexOrNumValues == rhs.operandIndexOrNumValues && + rangeIndex == rhs.rangeIndex; + } + }; // What kind of entity this symbol represents: // * Attr: op attribute @@ -288,14 +306,16 @@ // Static methods for creating SymbolInfo. static SymbolInfo getAttr(const Operator *op, int index) { - return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index)); + return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index, -1)); } static SymbolInfo getAttr() { return SymbolInfo(nullptr, Kind::Attr, std::nullopt); } - static SymbolInfo getOperand(DagNode node, const Operator *op, int index) { - return SymbolInfo(op, Kind::Operand, - DagAndConstant(node.getAsOpaquePointer(), index)); + static SymbolInfo getOperand(DagNode node, const Operator *op, + int operandIndex, int rangeIndex = -1) { + return SymbolInfo( + op, Kind::Operand, + DagAndConstant(node.getAsOpaquePointer(), operandIndex, rangeIndex)); } static SymbolInfo getResult(const Operator *op) { return SymbolInfo(op, Kind::Result, std::nullopt); @@ -305,7 +325,7 @@ } static SymbolInfo getMultipleValues(int numValues) { return SymbolInfo(nullptr, Kind::MultipleValues, - DagAndConstant(nullptr, numValues)); + DagAndConstant(nullptr, numValues, -1)); } // Returns the number of static values this symbol corresponds to. @@ -333,10 +353,10 @@ const char *separator) const; // The argument index (for `Attr` and `Operand` only) - int getArgIndex() const { return (*dagAndConstant).second; } + int getArgIndex() const { return dagAndConstant->operandIndexOrNumValues; } // The number of values in the MultipleValue - int getSize() const { return (*dagAndConstant).second; } + int getSize() const { return dagAndConstant->operandIndexOrNumValues; } const Operator *op; // The op where the bound entity belongs Kind kind; // The kind of the bound entity @@ -367,7 +387,7 @@ // Binds the given `symbol` to the `argIndex`-th argument to the given `op`. // Returns false if `symbol` is already bound and symbols are not operands. bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, - int argIndex); + int argIndex, int rangeIndex = -1); // Binds the given `symbol` to the results the given `op`. Returns false if // `symbol` is already bound. @@ -397,7 +417,8 @@ // Returns an iterator to the information of the given symbol named as `key`, // with index `argIndex` for operator `op`. const_iterator findBoundSymbol(StringRef key, DagNode node, - const Operator &op, int argIndex) const; + const Operator &op, int argIndex, + int rangeIndex = -1) const; const_iterator findBoundSymbol(StringRef key, const SymbolInfo &symbolInfo) const; diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -115,7 +115,8 @@ bool DagNode::isOperation() const { return !isNativeCodeCall() && !isReplaceWithValue() && - !isLocationDirective() && !isReturnTypeDirective() && !isEither(); + !isLocationDirective() && !isReturnTypeDirective() && !isEither() && + !isVariadic(); } llvm::StringRef DagNode::getNativeCodeTemplate() const { @@ -193,6 +194,11 @@ return dagOpDef->getName() == "either"; } +bool DagNode::isVariadic() const { + auto *dagOpDef = cast(node->getOperator())->getDef(); + return dagOpDef->getName() == "variadic"; +} + void DagNode::print(raw_ostream &os) const { if (node) node->print(os); @@ -298,7 +304,7 @@ auto *operand = op->getArg(getArgIndex()).get(); // If this operand is variadic, then return a range. Otherwise, return the // value itself. - if (operand->isVariableLength()) { + if (operand->isVariableLength() && dagAndConstant->rangeIndex == -1) { auto repl = formatv(fmt, name); LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); return std::string(repl); @@ -426,7 +432,8 @@ } bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, - const Operator &op, int argIndex) { + const Operator &op, int argIndex, + int rangeIndex) { StringRef name = getValuePackName(symbol); if (name != symbol) { auto error = formatv( @@ -436,7 +443,7 @@ auto symInfo = op.getArg(argIndex).is() ? SymbolInfo::getAttr(&op, argIndex) - : SymbolInfo::getOperand(node, &op, argIndex); + : SymbolInfo::getOperand(node, &op, argIndex, rangeIndex); std::string key = symbol.str(); if (symbolInfoMap.count(key)) { @@ -499,8 +506,9 @@ SymbolInfoMap::const_iterator SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op, - int argIndex) const { - return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex)); + int argIndex, int rangeIndex) const { + return findBoundSymbol( + key, SymbolInfo::getOperand(node, &op, argIndex, rangeIndex)); } SymbolInfoMap::const_iterator @@ -821,6 +829,25 @@ } }; + // The operand in `variadic` DAG should be bound to the operation in the + // parent DagNode. The range index must be included as well to distinguish + // (potentially) repeating argName within the `variadic` DAG. + auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree, + int opArgIdx) { + for (int i = 0; i < tree.getNumArgs(); ++i) { + if (DagNode subTree = tree.getArgAsNestedDag(i)) { + collectBoundSymbols(subTree, infoMap, isSrcPattern); + } else { + auto argName = tree.getArgName(i); + if (!argName.empty() && argName != "_") { + verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx, + /*rangeIndex=*/i), + argName); + } + } + } + }; + for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) { if (auto treeArg = tree.getArgAsNestedDag(i)) { if (treeArg.isEither()) { @@ -833,6 +860,8 @@ // // (FooOp arg0, arg1, arg2) ++opArgIdx; + } else if (treeArg.isVariadic()) { + collectSymbolInVariadic(tree, treeArg, opArgIdx); } else { // This DAG node argument is a DAG node itself. Go inside recursively. collectBoundSymbols(treeArg, infoMap, isSrcPattern); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1636,6 +1636,54 @@ (replaceWithValue $results__2), ConstantAttr)>; +// Variadic structured matching +def MixedVOperandOp4 : TEST_Op<"mixed_variadic_in4"> { + let arguments = (ins + Variadic:$input1, + I32:$input2, + I32Attr:$attr1 + ); +} + +def MixedVOperandOp5 : TEST_Op<"mixed_variadic_in5"> { + let arguments = (ins + I32:$input1, + I32:$input2, + I32:$input3, + I32Attr:$attr1, + StrAttr:$pattern_name + ); +} + +// Helper op to test variadic recursive pattern matching +def MixedVOperandInOutI32Op : TEST_Op<"mixed_variadic_in_out_i32"> { + let arguments = (ins + I32:$input + ); + let results = (outs + I32:$output + ); +} + +def : Pat< + (MixedVOperandOp4 (variadic $input1a, $input1b), $input2, + ConstantAttr:$attr1), + (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1, + ConstantStrAttr)>; + +def : Pat< + (MixedVOperandOp4 (variadic (MixedVOperandInOutI32Op $input1a), + (MixedVOperandInOutI32Op $input1b)), + $input2, ConstantAttr:$attr1), + (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1, + ConstantStrAttr)>; + +def : Pat< + (MixedVOperandOp4 (variadic $input1, $input1), $input2, + ConstantAttr:$attr1), + (MixedVOperandOp5 $input1, $input1, $input2, $attr1, + ConstantStrAttr)>; + //===----------------------------------------------------------------------===// // Test Patterns (either) diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -515,6 +515,47 @@ return %0 : i32 } +// CHECK-LABEL: @testMatchVariadic +func.func @testMatchVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () { + // CHECK: "test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 0 : i32, pattern_name = "MatchVariadic"}> : (i32, i32, i32) -> () + "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) {attr1 = 0 : i32} : (i32, i32, i32) -> () + + // Note: Not rewritten because variadic operand size mismatches. + // CHECK: "test.mixed_variadic_in4"(%arg0, %arg1, %arg2, %arg3) <{attr1 = 0 : i32}> : (i32, i32, i32, i32) -> () + "test.mixed_variadic_in4"(%arg0, %arg1, %arg2, %arg3) {attr1 = 0 : i32} : (i32, i32, i32, i32) -> () + + return +} + +// CHECK-LABEL: @testMatchVariadicSubDag +func.func @testMatchVariadicSubDag(%arg0: i32, %arg1: i32, %arg2: i32) -> () { + // CHECK: %[[IN0:.*]] = "test.mixed_variadic_in_out_i32"(%arg0) : (i32) -> i32 + %0 = "test.mixed_variadic_in_out_i32"(%arg0) : (i32) -> i32 + // CHECK: %[[IN1:.*]] = "test.mixed_variadic_in_out_i32"(%arg1) : (i32) -> i32 + %1 = "test.mixed_variadic_in_out_i32"(%arg1) : (i32) -> i32 + + // CHECK: "test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 1 : i32, pattern_name = "MatchVariadicSubDag"}> : (i32, i32, i32) -> () + "test.mixed_variadic_in4"(%0, %1, %arg2) {attr1 = 1 : i32} : (i32, i32, i32) -> () + + // Note: MatchVariadicSubDag doesn't apply + // CHECK: "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) <{attr1 = 1 : i32}> : (i32, i32, i32) -> () + "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) {attr1 = 1 : i32} : (i32, i32, i32) -> () + + return +} + +// CHECK-LABEL: @testMatchVariadicSameSymbol +func.func @testMatchVariadicSameSymbol(%arg0: i32, %arg1: i32, %arg2: i32) -> () { + // CHECK: "test.mixed_variadic_in5"(%arg0, %arg0, %arg2) <{attr1 = 2 : i32, pattern_name = "MatchVariadicSameSymbol"}> : (i32, i32, i32) -> () + "test.mixed_variadic_in4"(%arg0, %arg0, %arg2) {attr1 = 2 : i32} : (i32, i32, i32) -> () + + // Note: MatchVariadicSameSymbol doesn't apply. + // CHECK: "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) <{attr1 = 2 : i32}> : (i32, i32, i32) -> () + "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) {attr1 = 2 : i32} : (i32, i32, i32) -> () + + return +} + //===----------------------------------------------------------------------===// // Test that natives calls are only called once during rewrites. //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -103,8 +103,8 @@ // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the // bound name and the constraint of the operand respectively. void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName, - DagLeaf operandMatcher, StringRef argName, - int argIndex); + int operandIndex, DagLeaf operandMatcher, + StringRef argName, int argIndex, int rangeIndex); // Emits C++ statements for matching the operands which can be matched in // either order. @@ -112,6 +112,11 @@ StringRef opName, int argIndex, int &operandIndex, int depth); + // Emits C++ statements for matching a variadic operand. + void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree, + StringRef opName, int argIndex, + int &operandIndex, int depth); + // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an attribute. void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, @@ -462,6 +467,8 @@ if (DagNode argTree = tree.getArgAsNestedDag(i)) { if (argTree.isEither()) PrintFatalError(loc, "NativeCodeCall cannot have `either` operands"); + if (argTree.isVariadic()) + PrintFatalError(loc, "NativeCodeCall cannot have `variadic` operands"); os << "::mlir::Value " << argName << ";\n"; } else { @@ -596,6 +603,18 @@ continue; } if (auto *operand = llvm::dyn_cast_if_present(opArg)) { + if (argTree.isVariadic()) { + if (!operand->isVariadic()) { + auto error = formatv("variadic DAG construct can't match op {0}'s " + "non-variadic operand #{1}", + op.getOperationName(), opArgIdx); + PrintFatalError(loc, error); + } + emitVariadicOperandMatch(tree, argTree, castedName, opArgIdx, + nextOperand, depth); + ++nextOperand; + continue; + } if (operand->isVariableLength()) { auto error = formatv("use nested DAG construct to match op {0}'s " "variadic operand #{1} unsupported now", @@ -627,9 +646,9 @@ if (opArg.is()) { auto operandName = formatv("{0}.getODSOperands({1})", castedName, nextOperand); - emitOperandMatch(tree, castedName, operandName.str(), + emitOperandMatch(tree, castedName, operandName.str(), opArgIdx, /*operandMatcher=*/tree.getArgAsLeaf(i), - /*argName=*/tree.getArgName(i), opArgIdx); + /*argName=*/tree.getArgName(i), opArgIdx, -1); ++nextOperand; } else if (opArg.is()) { emitAttributeMatch(tree, opName, opArgIdx, depth); @@ -643,19 +662,20 @@ } void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, - StringRef operandName, + StringRef operandName, int operandIndex, DagLeaf operandMatcher, StringRef argName, - int argIndex) { + int argIndex, int rangeIndex) { Operator &op = tree.getDialectOp(opMap); - auto *operand = op.getArg(argIndex).get(); + auto *operand = op.getArg(operandIndex).get(); // If a constraint is specified, we need to generate C++ statements to // check the constraint. if (!operandMatcher.isUnspecified()) { - if (!operandMatcher.isOperandMatcher()) + if (!operandMatcher.isOperandMatcher()) { PrintFatalError( loc, formatv("the {1}-th argument of op '{0}' should be an operand", op.getOperationName(), argIndex + 1)); + } // Only need to verify if the matcher's type is different from the one // of op definition. @@ -682,7 +702,11 @@ // Capture the value // `$_` is a special symbol to ignore op argument matching. if (!argName.empty() && argName != "_") { - auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex); + auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex, + rangeIndex); + if (res == symbolInfoMap.end()) { + PrintFatalError(loc, formatv("symbol not found: {0}", argName)); + } os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName); } } @@ -726,8 +750,9 @@ tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName); } else if (op.getArg(argIndex).is()) { emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(), + operandIndex, /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i), - /*argName=*/eitherArgTree.getArgName(i), argIndex); + /*argName=*/eitherArgTree.getArgName(i), argIndex, -1); ++operandIndex; } else { PrintFatalError(loc, "either can only be applied on operand"); @@ -755,6 +780,53 @@ os.unindent().unindent() << "}\n"; } +void PatternEmitter::emitVariadicOperandMatch(DagNode tree, + DagNode variadicArgTree, + StringRef opName, int argIndex, + int &operandIndex, int depth) { + Operator &op = tree.getDialectOp(opMap); + + os << "{\n"; + os.indent(); + + os << formatv("auto variadic_operand_range = {0}.getODSOperands({1});\n", + opName, operandIndex); + os << formatv("if (variadic_operand_range.size() != {0}) " + "return ::mlir::failure();\n", + variadicArgTree.getNumArgs()); + + for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) { + if (DagNode argTree = variadicArgTree.getArgAsNestedDag(i)) { + if (!argTree.isOperation()) + PrintFatalError(loc, "variadic only accepts operation sub-dags"); + + os << "{\n"; + + std::string argName = formatv("local_op_{0}", i).str(); + os << formatv("auto *{0} = " + "variadic_operand_range[{1}].getDefiningOp();\n", + argName, i); + emitMatchCheck( + opName, /*matchStr=*/argName, + formatv("\"There's no operation that defines variadic operand " + "{0} (variadic sub-opearnd #{1}) of {2}\"", + operandIndex, i, opName)); + emitMatch(argTree, argName, depth + 1); + os << formatv("tblgen_ops.push_back({0});\n", argName); + os.unindent() << "}\n"; + } else if (op.getArg(argIndex).is()) { + auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i); + emitOperandMatch(tree, opName, operandName.str(), operandIndex, + /*operandMatcher=*/variadicArgTree.getArgAsLeaf(i), + /*argName=*/variadicArgTree.getArgName(i), argIndex, i); + } else { + PrintFatalError(loc, "variadic can only be applied on operand"); + } + } + + os.unindent() << "}\n"; +} + void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap);