diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -151,6 +151,9 @@ // Returns the total number of arguments. int getNumArgs() const { return arguments.size(); } + // Returns the total number of required arguments. + int getNumRequiredArgs() const; + // Returns true of the operation has a single variadic arg. bool hasSingleVariadicArg() const; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -132,6 +132,22 @@ }); } +int Operator::getNumRequiredArgs() const { + int numRequiredArgs = getNumArgs(); + while (numRequiredArgs > 0) { + auto arg = getArg(numRequiredArgs - 1); + if (arg.is()) { + if (!arg.get()->isOptional()) { + break; + } + } else if (!arg.get()->attr.isOptional()) { + break; + } + --numRequiredArgs; + } + return numRequiredArgs; +} + bool Operator::hasSingleVariadicArg() const { return getNumArgs() == 1 && getArg(0).is() && getOperand(0).isVariadic(); diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -675,19 +675,28 @@ if (tree.isOperation()) { auto &op = getDialectOp(tree); - auto numOpArgs = op.getNumArgs(); // The pattern might have the last argument specifying the location. bool hasLocDirective = false; if (numTreeArgs != 0) { - if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) + if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) { hasLocDirective = lastArg.isLocationDirective(); + } + } + + const int numRequiredArgs = op.getNumRequiredArgs(); + const int opArgs = numTreeArgs - hasLocDirective; + if (opArgs < numRequiredArgs) { + auto err = formatv("too few arguments for op '{0}': " + "{1} in pattern vs. {2} required in definition", + op.getOperationName(), numTreeArgs, numRequiredArgs); + PrintFatalError(&def, err); } - if (numOpArgs != numTreeArgs - hasLocDirective) { - auto err = formatv("op '{0}' argument number mismatch: " - "{1} in pattern vs. {2} in definition", - op.getOperationName(), numTreeArgs, numOpArgs); + if (opArgs > op.getNumArgs()) { + auto err = formatv("too many arguments for op '{0}': " + "{1} in pattern vs. {2} allowed in definition", + op.getOperationName(), numTreeArgs, op.getNumArgs()); PrintFatalError(&def, err); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -966,6 +966,31 @@ def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>; def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>; +// Test partial matches for optional attributes +def OpAttrMatch5 : TEST_Op<"match_optional_op_attribute1"> { + let arguments = (ins + UnitAttr:$required_attr, + OptionalAttr:$optional_attr1, + OptionalAttr:$optional_attr2 + ); + let results = (outs I32); +} +def OpAttrMatch6 : TEST_Op<"match_optional_op_attribute2"> { + let arguments = (ins + UnitAttr:$required_attr1, + OptionalAttr:$optional_attr1, + OptionalAttr:$optional_attr2, + UnitAttr:$required_attr2 + ); + let results = (outs I32); +} +def : Pat<(OpAttrMatch6 ConstUnitAttr, $optional, $optional2, ConstUnitAttr), + (OpAttrMatch5 ConstUnitAttr, $optional)>; +def : Pat<(OpAttrMatch6 ConstUnitAttr, $optional, $optional2, $required2), + (OpAttrMatch5 $required2, $optional, $optional2)>; +def : Pat<(OpAttrMatch6 $required, $optional, $optional2, $required2), + (OpAttrMatch5 $required)>; + //===----------------------------------------------------------------------===// // Test Patterns (Multi-result Ops) diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -337,11 +337,12 @@ // Skip if there is no defining operation (e.g., arguments to function). os << formatv("if (!{0}) return failure();\n", castedName); } - if (tree.getNumArgs() != op.getNumArgs()) { + if (tree.getNumArgs() < op.getNumRequiredArgs() || + op.getNumArgs() > tree.getNumArgs()) { PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " - "pattern vs. {2} in definition", + "pattern vs. [{2},{3}] required in definition", op.getOperationName(), tree.getNumArgs(), - op.getNumArgs())); + op.getNumRequiredArgs(), op.getNumArgs())); } // If the operand's name is set, set to that variable. @@ -973,7 +974,6 @@ LLVM_DEBUG(llvm::dbgs() << '\n'); Operator &resultOp = tree.getDialectOp(opMap); - auto numOpArgs = resultOp.getNumArgs(); auto numPatArgs = tree.getNumArgs(); bool hasLocationDirective; @@ -981,11 +981,19 @@ std::tie(hasLocationDirective, locToUse) = getLocation(tree); auto inPattern = numPatArgs - hasLocationDirective; - if (numOpArgs != inPattern) { + if (resultOp.getNumRequiredArgs() > inPattern) { PrintFatalError(loc, - formatv("resultant op '{0}' argument number mismatch: " - "{1} in pattern vs. {2} in definition", - resultOp.getOperationName(), inPattern, numOpArgs)); + formatv("resultant op '{0}' fas too few arguments: " + "{1} in pattern vs. {2} required in definition", + resultOp.getOperationName(), inPattern, + resultOp.getNumRequiredArgs())); + } + + if (inPattern > resultOp.getNumArgs()) { + PrintFatalError(loc, formatv("resultant op '{0}' has too many arguments: " + "{1} in pattern vs. {2} allowed in definition", + resultOp.getOperationName(), inPattern, + resultOp.getNumArgs())); } // A map to collect all nested DAG child nodes' names, with operand index as @@ -1100,8 +1108,13 @@ // * If the operand is variadic, we create a `SmallVector` local // variable. + int numOpArgs = node.getNumArgs(); + if (auto lastArg = node.getArgAsNestedDag(numOpArgs - 1)) { + numOpArgs -= lastArg.isLocationDirective(); + } + int valueIndex = 0; // An index for uniquing local variable names. - for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { + for (int argIndex = 0; argIndex < numOpArgs; ++argIndex) { const auto *operand = resultOp.getArg(argIndex).dyn_cast(); // We do not need special handling for attributes. @@ -1151,8 +1164,13 @@ void PatternEmitter::supplyValuesForOpArgs( DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { Operator &resultOp = node.getDialectOp(opMap); - for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); - argIndex != numOpArgs; ++argIndex) { + + int numOpArgs = node.getNumArgs(); + if (auto lastArg = node.getArgAsNestedDag(numOpArgs - 1)) { + numOpArgs -= lastArg.isLocationDirective(); + } + + for (int argIndex = 0; argIndex != numOpArgs; ++argIndex) { // Start each argument on its own line. os << ",\n "; @@ -1168,6 +1186,8 @@ // The argument in the op definition. auto opArgName = resultOp.getArgName(argIndex); if (auto subTree = node.getArgAsNestedDag(argIndex)) { + if (subTree.isLocationDirective()) + continue; if (!subTree.isNativeCodeCall()) PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); @@ -1205,11 +1225,15 @@ "if (auto tmpAttr = {1}) {\n" " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), " "tmpAttr);\n}\n"; - for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { + for (int argIndex = 0, e = node.getNumArgs(); argIndex < e; ++argIndex) { + auto subTree = node.getArgAsNestedDag(argIndex); + if (subTree && subTree.isLocationDirective()) + continue; + if (resultOp.getArg(argIndex).is()) { // The argument in the op definition. auto opArgName = resultOp.getArgName(argIndex); - if (auto subTree = node.getArgAsNestedDag(argIndex)) { + if (subTree) { if (!subTree.isNativeCodeCall()) PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute");