Index: mlir/docs/DeclarativeRewrites.md =================================================================== --- mlir/docs/DeclarativeRewrites.md +++ mlir/docs/DeclarativeRewrites.md @@ -646,6 +646,50 @@ [TODO] +#### Match variadic operand + +Use the `variadic` DAG node to match a variadic operand with a fixed number of +actual sub-operands. + +For example, assume that `ConcatenateOp` is an operation with a variadic +operand: + +```tablegen +def ConcatenateOp : TEST_Op<"concatenate"> { + let arguments = (ins + Variadic:$inputs, + I32Attr:$axis + ); + + let results = (outs + AnyTensor$output + ); +} +``` + +We can match `ConcatenateOp` with exactly 2 actual operands with: + +```tablegen +def : Pat<(ConcatenateOp (variadic $input0, $input1), $axis), + ...>; +``` + +The variadic sub-operands can be sub-DAGs to be matched: + +```tablegen +def : Pat<(ConcatenateOp (variadic (SomeOp $a), (AnotherOp $b, $c)), $axis), + (OtherOp $a, $b, $c)>; +``` + +The variadic DAG can be bound to a symbol, which refers to the full +`operand_range`: + +```tablegen +def : Pat<(ConcatenateOp (variadic:$inputs $input0, $input1), + ConstantAttr), + (VStackOp $inputs)>; +``` + ### Supplying additional constraints Constraints can be placed on op arguments when matching. But sometimes we need Index: mlir/include/mlir/IR/PatternBase.td =================================================================== --- mlir/include/mlir/IR/PatternBase.td +++ mlir/include/mlir/IR/PatternBase.td @@ -211,6 +211,22 @@ // `either` while pattern matching. def either; +// Directive used to match variadic operands. This directive only matches if +// the variadic operand has the same length as the specified formal +// sub-dags. +// +// ``` +// (VariadicOp (variadic:$input1 $input1a, $input1b), +// (variadic:$input2 $input2a, $input2b, $input2c), +// $attr1, $attr2) +// ``` +// +// The pattern above only matches if the `$input1` operand is of length 2, +// `$input2` is of length 3, and all sub-dags match respectively. The `$input1` +// symbol denotes the full variadic operand range. The `$input1a` symbol +// denotes the first operand in the variadic sub-operands. +def variadic; + //===----------------------------------------------------------------------===// // Common value constraints //===----------------------------------------------------------------------===// Index: mlir/include/mlir/TableGen/Pattern.h =================================================================== --- mlir/include/mlir/TableGen/Pattern.h +++ mlir/include/mlir/TableGen/Pattern.h @@ -22,6 +22,7 @@ #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" +#include #include namespace llvm { @@ -189,6 +190,9 @@ // Returns whether this DAG is an `either` specifier. bool isEither() const; + // Returns whether this DAG is an `variadic` specifier. + bool isVariadic() const; + // Returns true if this DAG node is an operation. bool isOperation() const; @@ -268,9 +272,94 @@ // Allow SymbolInfoMap to access private methods. friend class SymbolInfoMap; - // DagNode and DagLeaf are accessed by value which means it can't be used as - // identifier here. Use an opaque pointer type instead. - using DagAndConstant = std::pair; + // Structure to uniquely distinguish different locations of the symbols. + // + // * If a symbol is defined as an operand of an operation, `dag` specifies + // the DAG of the operation, `operandIndexOrNumValues` specifies the + // operand index, and `variadicSubIndex` must be set to `std::nullopt`. + // + // * If a symbol is defined in a `variadic` DAG, `dag` specifies the DAG + // of the parent operation, `operandIndexOrNumValues` specifies the + // declared operand index of the variadic operand in the parent + // operation. + // + // - If the symbol is defined as a result of `variadic` DAG, the + // `variadicSubIndex` must be set to `std::nullopt`, which means that + // the symbol binds to the full operand range. + // + // - If the symbol is defined as a operand, the `variadicSubIndex` must + // be set to the index within the variadic sub-operand list. + // + // * If a symbol is defined in a `either` DAG, `dag` specifies the DAG + // of the parent operation, `operandIndexOrNumValues` specifies the + // operand index in the parent operation (not necessary the index in the + // DAG). + // + // * If a symbol is defined as a result, specifies the number of returning + // value. + // + // Example 1: + // + // def : Pat<(OpA $input0, $input1), ...>; + // + // $input0: (OpA, 0, nullopt) + // $input1: (OpA, 1, nullopt) + // + // Example 2: + // + // def : Pat<(OpB (variadic:$input0 $input0a, $input0b), + // (variadic:$input1 $input1a, $input1b, $input1c)), + // ...>; + // + // $input0: (OpB, 0, nullopt) + // $input0a: (OpB, 0, 0) + // $input0b: (OpB, 0, 1) + // $input1: (OpB, 1, nullopt) + // $input1a: (OpB, 1, 0) + // $input1b: (OpB, 1, 1) + // $input1c: (OpB, 1, 2) + // + // Example 3: + // + // def : Pat<(OpC $input0, (either $input1, $input2)), ...>; + // + // $input0: (OpC, 0, nullopt) + // $input1: (OpC, 1, nullopt) + // $input2: (OpC, 2, nullopt) + // + // Example 4: + // + // def ThreeResultOp : TEST_Op<...> { + // let results = (outs + // AnyType:$result1, + // AnyType:$result2, + // AnyType:$result3 + // ); + // } + // + // def : Pat<..., + // (ThreeResultOp:$result ...)>; + // + // $result: (nullptr, 3, nullopt) + // + struct DagAndConstant { + // DagNode and DagLeaf are accessed by value which means it can't be used + // as identifier here. Use an opaque pointer type instead. + const void *dag; + int operandIndexOrNumValues; + std::optional variadicSubIndex; + + DagAndConstant(const void *dag, int operandIndexOrNumValues, + std::optional variadicSubIndex) + : dag(dag), operandIndexOrNumValues(operandIndexOrNumValues), + variadicSubIndex(variadicSubIndex) {} + + bool operator==(const DagAndConstant &rhs) const { + return dag == rhs.dag && + operandIndexOrNumValues == rhs.operandIndexOrNumValues && + variadicSubIndex == rhs.variadicSubIndex; + } + }; // What kind of entity this symbol represents: // * Attr: op attribute @@ -288,14 +377,18 @@ // Static methods for creating SymbolInfo. static SymbolInfo getAttr(const Operator *op, int index) { - return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index)); + return SymbolInfo(op, Kind::Attr, + DagAndConstant(nullptr, index, std::nullopt)); } static SymbolInfo getAttr() { return SymbolInfo(nullptr, Kind::Attr, std::nullopt); } - static SymbolInfo getOperand(DagNode node, const Operator *op, int index) { + static SymbolInfo + getOperand(DagNode node, const Operator *op, int operandIndex, + std::optional variadicSubIndex = std::nullopt) { return SymbolInfo(op, Kind::Operand, - DagAndConstant(node.getAsOpaquePointer(), index)); + DagAndConstant(node.getAsOpaquePointer(), operandIndex, + variadicSubIndex)); } static SymbolInfo getResult(const Operator *op) { return SymbolInfo(op, Kind::Result, std::nullopt); @@ -305,7 +398,7 @@ } static SymbolInfo getMultipleValues(int numValues) { return SymbolInfo(nullptr, Kind::MultipleValues, - DagAndConstant(nullptr, numValues)); + DagAndConstant(nullptr, numValues, std::nullopt)); } // Returns the number of static values this symbol corresponds to. @@ -333,18 +426,23 @@ const char *separator) const; // The argument index (for `Attr` and `Operand` only) - int getArgIndex() const { return (*dagAndConstant).second; } + int getArgIndex() const { return dagAndConstant->operandIndexOrNumValues; } // The number of values in the MultipleValue - int getSize() const { return (*dagAndConstant).second; } + int getSize() const { return dagAndConstant->operandIndexOrNumValues; } + + // The variadic sub-operands index (for variadic `Operand` only) + std::optional getVariadicSubIndex() const { + return dagAndConstant->variadicSubIndex; + } const Operator *op; // The op where the bound entity belongs Kind kind; // The kind of the bound entity - // The pair of DagNode pointer and constant value (for `Attr`, `Operand` and - // the size of MultipleValue symbol). Note that operands may be bound to the - // same symbol, use the DagNode and index to distinguish them. For `Attr` - // and MultipleValue, the Dag part will be nullptr. + // The tuple of DagNode pointer and two constant values (for `Attr`, + // `Operand` and the size of MultipleValue symbol). Note that operands may + // be bound to the same symbol, use the DagNode and index to distinguish + // them. For `Attr` and MultipleValue, the Dag part will be nullptr. std::optional dagAndConstant; // Alternative name for the symbol. It is used in case the name @@ -367,7 +465,8 @@ // Binds the given `symbol` to the `argIndex`-th argument to the given `op`. // Returns false if `symbol` is already bound and symbols are not operands. bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, - int argIndex); + int argIndex, + std::optional variadicSubIndex = std::nullopt); // Binds the given `symbol` to the results the given `op`. Returns false if // `symbol` is already bound. @@ -397,7 +496,8 @@ // Returns an iterator to the information of the given symbol named as `key`, // with index `argIndex` for operator `op`. const_iterator findBoundSymbol(StringRef key, DagNode node, - const Operator &op, int argIndex) const; + const Operator &op, int argIndex, + std::optional variadicSubIndex) const; const_iterator findBoundSymbol(StringRef key, const SymbolInfo &symbolInfo) const; Index: mlir/lib/TableGen/Pattern.cpp =================================================================== --- mlir/lib/TableGen/Pattern.cpp +++ mlir/lib/TableGen/Pattern.cpp @@ -115,7 +115,8 @@ bool DagNode::isOperation() const { return !isNativeCodeCall() && !isReplaceWithValue() && - !isLocationDirective() && !isReturnTypeDirective() && !isEither(); + !isLocationDirective() && !isReturnTypeDirective() && !isEither() && + !isVariadic(); } llvm::StringRef DagNode::getNativeCodeTemplate() const { @@ -193,6 +194,11 @@ return dagOpDef->getName() == "either"; } +bool DagNode::isVariadic() const { + auto *dagOpDef = cast(node->getOperator())->getDef(); + return dagOpDef->getName() == "variadic"; +} + void DagNode::print(raw_ostream &os) const { if (node) node->print(os); @@ -296,9 +302,10 @@ case Kind::Operand: { assert(index < 0); auto *operand = op->getArg(getArgIndex()).get(); - // If this operand is variadic, then return a range. Otherwise, return the - // value itself. - if (operand->isVariableLength()) { + // If this operand is variadic and this SymbolInfo doesn't have a range + // index, then return the full variadic operand_range. Otherwise, return + // the value itself. + if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) { auto repl = formatv(fmt, name); LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); return std::string(repl); @@ -426,7 +433,8 @@ } bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol, - const Operator &op, int argIndex) { + const Operator &op, int argIndex, + std::optional variadicSubIndex) { StringRef name = getValuePackName(symbol); if (name != symbol) { auto error = formatv( @@ -434,9 +442,10 @@ PrintFatalError(loc, error); } - auto symInfo = op.getArg(argIndex).is() - ? SymbolInfo::getAttr(&op, argIndex) - : SymbolInfo::getOperand(node, &op, argIndex); + auto symInfo = + op.getArg(argIndex).is() + ? SymbolInfo::getAttr(&op, argIndex) + : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex); std::string key = symbol.str(); if (symbolInfoMap.count(key)) { @@ -499,8 +508,10 @@ SymbolInfoMap::const_iterator SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op, - int argIndex) const { - return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex)); + int argIndex, + std::optional variadicSubIndex) const { + return findBoundSymbol( + key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex)); } SymbolInfoMap::const_iterator @@ -821,6 +832,33 @@ } }; + // The operand in `variadic` DAG should be bound to the operation in the + // parent DagNode. The range index must be included as well to distinguish + // (potentially) repeating argName within the `variadic` DAG. + auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree, + int opArgIdx) { + auto treeName = tree.getSymbol(); + if (!treeName.empty()) { + // If treeName is specified, bind to the full variadic operand_range. + verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx, + std::nullopt), + treeName); + } + + for (int i = 0; i < tree.getNumArgs(); ++i) { + if (DagNode subTree = tree.getArgAsNestedDag(i)) { + collectBoundSymbols(subTree, infoMap, isSrcPattern); + } else { + auto argName = tree.getArgName(i); + if (!argName.empty() && argName != "_") { + verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx, + /*variadicSubIndex=*/i), + argName); + } + } + } + }; + for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) { if (auto treeArg = tree.getArgAsNestedDag(i)) { if (treeArg.isEither()) { @@ -833,6 +871,8 @@ // // (FooOp arg0, arg1, arg2) ++opArgIdx; + } else if (treeArg.isVariadic()) { + collectSymbolInVariadic(tree, treeArg, opArgIdx); } else { // This DAG node argument is a DAG node itself. Go inside recursively. collectBoundSymbols(treeArg, infoMap, isSrcPattern); Index: mlir/test/lib/Dialect/Test/TestOps.td =================================================================== --- mlir/test/lib/Dialect/Test/TestOps.td +++ mlir/test/lib/Dialect/Test/TestOps.td @@ -1636,6 +1636,76 @@ (replaceWithValue $results__2), ConstantAttr)>; +// Variadic structured matching +def MixedVOperandOp4 : TEST_Op<"mixed_variadic_in4"> { + let arguments = (ins + Variadic:$input1, + I32:$input2, + I32Attr:$attr1 + ); +} + +def MixedVOperandOp5 : TEST_Op<"mixed_variadic_in5"> { + let arguments = (ins + I32:$input1, + I32:$input2, + I32:$input3, + I32Attr:$attr1, + StrAttr:$pattern_name + ); +} + +// Helper op to test variadic recursive pattern matching +def MixedVOperandInOutI32Op : TEST_Op<"mixed_variadic_in_out_i32"> { + let arguments = (ins + I32:$input + ); + let results = (outs + I32:$output + ); +} + +def : Pat< + (MixedVOperandOp4 (variadic $input1a, $input1b), $input2, + ConstantAttr:$attr1), + (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1, + ConstantStrAttr)>; + +def : Pat< + (MixedVOperandOp4 (variadic (MixedVOperandInOutI32Op $input1a), + (MixedVOperandInOutI32Op $input1b)), + $input2, ConstantAttr:$attr1), + (MixedVOperandOp5 $input1a, $input1b, $input2, $attr1, + ConstantStrAttr)>; + +def : Pat< + (MixedVOperandOp4 (variadic $input1, $input1), $input2, + ConstantAttr:$attr1), + (MixedVOperandOp5 $input1, $input1, $input2, $attr1, + ConstantStrAttr)>; + +def MixedVOperandOp6 : TEST_Op<"mixed_variadic_in6", + [SameVariadicOperandSize]> { + let arguments = (ins + Variadic:$input1, + Variadic:$input2, + I32Attr:$attr1 + ); +} + +def : Pat< + (MixedVOperandOp6 (variadic:$input1 $input1a, $input1b), + (variadic:$input2 $input2a, $input2b), + ConstantAttr:$attr1), + (MixedVOperandOp6 $input2, $input1, ConstantAttr)>; + +def : Pat< + (MixedVOperandOp6 (variadic $input1a, $input1b), + (variadic $input2a, $input2b), + ConstantAttr:$attr1), + (MixedVOperandOp5 $input2a, $input2b, $input1b, $attr1, + ConstantStrAttr)>; + //===----------------------------------------------------------------------===// // Test Patterns (either) Index: mlir/test/mlir-tblgen/pattern.mlir =================================================================== --- mlir/test/mlir-tblgen/pattern.mlir +++ mlir/test/mlir-tblgen/pattern.mlir @@ -515,6 +515,67 @@ return %0 : i32 } +// CHECK-LABEL: @testMatchVariadic +func.func @testMatchVariadic(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () { + // CHECK: "test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 0 : i32, pattern_name = "MatchVariadic"}> : (i32, i32, i32) -> () + "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) {attr1 = 0 : i32} : (i32, i32, i32) -> () + + // Note: Not rewritten because variadic operand size mismatches. + // CHECK: "test.mixed_variadic_in4"(%arg0, %arg1, %arg2, %arg3) <{attr1 = 0 : i32}> : (i32, i32, i32, i32) -> () + "test.mixed_variadic_in4"(%arg0, %arg1, %arg2, %arg3) {attr1 = 0 : i32} : (i32, i32, i32, i32) -> () + + return +} + +// CHECK-LABEL: @testMatchVariadicSubDag +func.func @testMatchVariadicSubDag(%arg0: i32, %arg1: i32, %arg2: i32) -> () { + // CHECK: %[[IN0:.*]] = "test.mixed_variadic_in_out_i32"(%arg0) : (i32) -> i32 + %0 = "test.mixed_variadic_in_out_i32"(%arg0) : (i32) -> i32 + // CHECK: %[[IN1:.*]] = "test.mixed_variadic_in_out_i32"(%arg1) : (i32) -> i32 + %1 = "test.mixed_variadic_in_out_i32"(%arg1) : (i32) -> i32 + + // CHECK: "test.mixed_variadic_in5"(%arg0, %arg1, %arg2) <{attr1 = 1 : i32, pattern_name = "MatchVariadicSubDag"}> : (i32, i32, i32) -> () + "test.mixed_variadic_in4"(%0, %1, %arg2) {attr1 = 1 : i32} : (i32, i32, i32) -> () + + // Note: MatchVariadicSubDag doesn't apply + // CHECK: "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) <{attr1 = 1 : i32}> : (i32, i32, i32) -> () + "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) {attr1 = 1 : i32} : (i32, i32, i32) -> () + + return +} + +// CHECK-LABEL: @testMatchVariadicSameSymbol +func.func @testMatchVariadicSameSymbol(%arg0: i32, %arg1: i32, %arg2: i32) -> () { + // CHECK: "test.mixed_variadic_in5"(%arg0, %arg0, %arg2) <{attr1 = 2 : i32, pattern_name = "MatchVariadicSameSymbol"}> : (i32, i32, i32) -> () + "test.mixed_variadic_in4"(%arg0, %arg0, %arg2) {attr1 = 2 : i32} : (i32, i32, i32) -> () + + // Note: MatchVariadicSameSymbol doesn't apply. + // CHECK: "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) <{attr1 = 2 : i32}> : (i32, i32, i32) -> () + "test.mixed_variadic_in4"(%arg0, %arg1, %arg2) {attr1 = 2 : i32} : (i32, i32, i32) -> () + + return +} + +// CHECK-LABEL: @testMatchAndRewriteVariadicFullRange +func.func @testMatchAndRewriteVariadicFullRange(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () { + // CHECK: "test.mixed_variadic_in6"(%arg2, %arg3, %arg0, %arg1) <{attr1 = -1 : i32}> : (i32, i32, i32, i32) -> () + "test.mixed_variadic_in6"(%arg0, %arg1, %arg2, %arg3) {attr1 = 1 : i32} : (i32, i32, i32, i32) -> () + + // Note: MatchAndRewriteVariadicFullRange doesn't apply because the length of each variadic operand is not equal to 2. + // CHECK: "test.mixed_variadic_in6"(%arg0, %arg1) <{attr1 = 1 : i32}> : (i32, i32) -> () + "test.mixed_variadic_in6"(%arg0, %arg1) {attr1 = 1 : i32} : (i32, i32) -> () + + return +} + +// CHECK-LABEL: @testMatchMultiVariadicSubSymbol +func.func @testMatchMultiVariadicSubSymbol(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () { + // CHECK: "test.mixed_variadic_in5"(%arg2, %arg3, %arg1) <{attr1 = 2 : i32, pattern_name = "MatchMultiVariadicSubSymbol"}> : (i32, i32, i32) -> () + "test.mixed_variadic_in6"(%arg0, %arg1, %arg2, %arg3) {attr1 = 2 : i32} : (i32, i32, i32, i32) -> () + + return +} + //===----------------------------------------------------------------------===// // Test that natives calls are only called once during rewrites. //===----------------------------------------------------------------------===// Index: mlir/tools/mlir-tblgen/RewriterGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/RewriterGen.cpp +++ mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -103,8 +103,9 @@ // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the // bound name and the constraint of the operand respectively. void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName, - DagLeaf operandMatcher, StringRef argName, - int argIndex); + int operandIndex, DagLeaf operandMatcher, + StringRef argName, int argIndex, + std::optional variadicSubIndex); // Emits C++ statements for matching the operands which can be matched in // either order. @@ -112,6 +113,11 @@ StringRef opName, int argIndex, int &operandIndex, int depth); + // Emits C++ statements for matching a variadic operand. + void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree, + StringRef opName, int argIndex, + int &operandIndex, int depth); + // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an attribute. void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, @@ -262,6 +268,11 @@ // Determine if we should inline the match logic or delegate to a static // function. bool useStaticMatcher(DagNode node) { + // either/variadic node must be associated to the parentOp, thus we can't + // emit a static matcher rooted at them. + if (node.isEither() || node.isVariadic()) + return false; + return refStats[node] > kStaticMatcherThreshold; } @@ -462,6 +473,8 @@ if (DagNode argTree = tree.getArgAsNestedDag(i)) { if (argTree.isEither()) PrintFatalError(loc, "NativeCodeCall cannot have `either` operands"); + if (argTree.isVariadic()) + PrintFatalError(loc, "NativeCodeCall cannot have `variadic` operands"); os << "::mlir::Value " << argName << ";\n"; } else { @@ -596,6 +609,18 @@ continue; } if (auto *operand = llvm::dyn_cast_if_present(opArg)) { + if (argTree.isVariadic()) { + if (!operand->isVariadic()) { + auto error = formatv("variadic DAG construct can't match op {0}'s " + "non-variadic operand #{1}", + op.getOperationName(), opArgIdx); + PrintFatalError(loc, error); + } + emitVariadicOperandMatch(tree, argTree, castedName, opArgIdx, + nextOperand, depth); + ++nextOperand; + continue; + } if (operand->isVariableLength()) { auto error = formatv("use nested DAG construct to match op {0}'s " "variadic operand #{1} unsupported now", @@ -627,9 +652,10 @@ if (opArg.is()) { auto operandName = formatv("{0}.getODSOperands({1})", castedName, nextOperand); - emitOperandMatch(tree, castedName, operandName.str(), + emitOperandMatch(tree, castedName, operandName.str(), opArgIdx, /*operandMatcher=*/tree.getArgAsLeaf(i), - /*argName=*/tree.getArgName(i), opArgIdx); + /*argName=*/tree.getArgName(i), opArgIdx, + /*variadicSubIndex=*/std::nullopt); ++nextOperand; } else if (opArg.is()) { emitAttributeMatch(tree, opName, opArgIdx, depth); @@ -643,11 +669,12 @@ } void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, - StringRef operandName, + StringRef operandName, int operandIndex, DagLeaf operandMatcher, StringRef argName, - int argIndex) { + int argIndex, + std::optional variadicSubIndex) { Operator &op = tree.getDialectOp(opMap); - auto *operand = op.getArg(argIndex).get(); + auto *operand = op.getArg(operandIndex).get(); // If a constraint is specified, we need to generate C++ statements to // check the constraint. @@ -682,7 +709,11 @@ // Capture the value // `$_` is a special symbol to ignore op argument matching. if (!argName.empty() && argName != "_") { - auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex); + auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex, + variadicSubIndex); + if (res == symbolInfoMap.end()) + PrintFatalError(loc, formatv("symbol not found: {0}", argName)); + os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName); } } @@ -735,8 +766,10 @@ tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName); } else if (op.getArg(argIndex).is()) { emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(), + operandIndex, /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i), - /*argName=*/eitherArgTree.getArgName(i), argIndex); + /*argName=*/eitherArgTree.getArgName(i), argIndex, + /*variadicSubIndex=*/std::nullopt); ++operandIndex; } else { PrintFatalError(loc, "either can only be applied on operand"); @@ -764,6 +797,67 @@ os.unindent().unindent() << "}\n"; } +void PatternEmitter::emitVariadicOperandMatch(DagNode tree, + DagNode variadicArgTree, + StringRef opName, int argIndex, + int &operandIndex, int depth) { + Operator &op = tree.getDialectOp(opMap); + + os << "{\n"; + os.indent(); + + os << formatv("auto variadic_operand_range = {0}.getODSOperands({1});\n", + opName, operandIndex); + os << formatv("if (variadic_operand_range.size() != {0}) " + "return ::mlir::failure();\n", + variadicArgTree.getNumArgs()); + + StringRef variadicTreeName = variadicArgTree.getSymbol(); + if (!variadicTreeName.empty()) { + auto res = + symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex, + /*variadicSubIndex=*/std::nullopt); + if (res == symbolInfoMap.end()) + PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName)); + + os << formatv("{0} = variadic_operand_range;\n", + res->second.getVarName(variadicTreeName)); + } + + for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) { + if (DagNode argTree = variadicArgTree.getArgAsNestedDag(i)) { + if (!argTree.isOperation()) + PrintFatalError(loc, "variadic only accepts operation sub-dags"); + + os << "{\n"; + os.indent(); + + std::string argName = formatv("local_op_{0}", i).str(); + os << formatv("auto *{0} = " + "variadic_operand_range[{1}].getDefiningOp();\n", + argName, i); + emitMatchCheck( + opName, /*matchStr=*/argName, + formatv("\"There's no operation that defines variadic operand " + "{0} (variadic sub-opearnd #{1}) of {2}\"", + operandIndex, i, opName)); + emitMatch(argTree, argName, depth + 1); + os << formatv("tblgen_ops.push_back({0});\n", argName); + + os.unindent() << "}\n"; + } else if (op.getArg(argIndex).is()) { + auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i); + emitOperandMatch(tree, opName, operandName.str(), operandIndex, + /*operandMatcher=*/variadicArgTree.getArgAsLeaf(i), + /*argName=*/variadicArgTree.getArgName(i), argIndex, i); + } else { + PrintFatalError(loc, "variadic can only be applied on operand"); + } + } + + os.unindent() << "}\n"; +} + void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap);