diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -280,6 +280,75 @@ saying the call to `AOp::build()` cannot be resolved because of the number of parameters mismatch. +#### Optional Op Arguments + +In the case of ODS-generated rewrite rules, we use the special `build()` method +that was described in the [previous section](#building-operations). When +specifying optional attributes as arguments, they may be omitted from the +patterns in certain cases, namely when omitting the result from +the parameterized lists passed to the `build()` function won't cause the op +creation to fail. For example, suppose we have the following `OptionalOp`: + +```tablegen +def OptionalOp : Op<"optional_op"> { + let arguments = (ins OptionalAttr:$opt); +} +``` + +In this case, the following would be the valid rewrite rules for `OptionalOp`: + +```tablegen +def : Pat<(AOp $input, $attr), (OptionalOp)>; +def : Pat<(AOp $input, $attr), (OptionalOp ConstBoolAttrTrue)>; +def : Pat<(AOp $input, $attr), (OptionalOp ConstBoolAttrFalse)>; +``` + +The number of arguments required to construct an op is therefore determined by +the last non-optional (required) argument to the Op. Take for example the +following definition of `OptioanlOp`: + +```tablegen +def OptionalOp : Op<"optional_op"> { + let arguments = (ins + OptionalAttr:$opt1, + BoolAttr:$req, + OptionalAttr:$opt2 + ); +} +``` + +In this case, `$opt1` is always required, since there is no way to construct the +op without specifying `$opt1` in order to also specify `$req`. This means that +`OptionalOp` requires at least two arguments. However, `$opt2` is "truly" +optional, and may be omitted from result rewrite patterns: + +```tablegen +// NOT OK -- $req is not defined +def : Pat<(AOp $i, $a), (OptionalOp)>; + +// NOT OK -- $req is not defined +def : Pat<(AOp $i, $a), (OptionalOp ConstBoolAttrTrue)>; + +// OK +def : Pat<(AOp $i, $a), (OptionalOp ConstBoolAttrTrue ConstBoolAttrTrue)>; + +// OK +def : Pat<(AOp $i, $a), + (OptionalOp ConstBoolAttrTrue ConstBoolAttrTrue ConstBoolAttrTrue)>; +``` + +Optional attributes can also be omitted in the source pattern, but this is +identical to using `$_` to denote an unused argument: + +```tablegen +// The following are identical: +def : Pat<(OptionalOp $opt, $req), + (OptionalOp $opt, $req, ConstBoolAttrTrue)>; + +def : Pat<(OptionalOp $opt, $req, $_), + (OptionalOp $opt, $req, ConstBoolAttrTrue)>; +``` + #### Generating DAG of operations `dag` objects can be nested to generate a DAG of operations: diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -151,6 +151,9 @@ // Returns the total number of arguments. int getNumArgs() const { return arguments.size(); } + // Returns the total number of required arguments. + int getNumRequiredArgs() const; + // Returns true of the operation has a single variadic arg. bool hasSingleVariadicArg() const; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -132,6 +132,22 @@ }); } +int Operator::getNumRequiredArgs() const { + int numRequiredArgs = getNumArgs(); + while (numRequiredArgs > 0) { + auto arg = getArg(numRequiredArgs - 1); + if (arg.is()) { + if (!arg.get()->isOptional()) { + break; + } + } else if (!arg.get()->attr.isOptional()) { + break; + } + --numRequiredArgs; + } + return numRequiredArgs; +} + bool Operator::hasSingleVariadicArg() const { return getNumArgs() == 1 && getArg(0).is() && getOperand(0).isVariadic(); diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -675,19 +675,28 @@ if (tree.isOperation()) { auto &op = getDialectOp(tree); - auto numOpArgs = op.getNumArgs(); // The pattern might have the last argument specifying the location. bool hasLocDirective = false; if (numTreeArgs != 0) { - if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) + if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) { hasLocDirective = lastArg.isLocationDirective(); + } } - if (numOpArgs != numTreeArgs - hasLocDirective) { - auto err = formatv("op '{0}' argument number mismatch: " - "{1} in pattern vs. {2} in definition", - op.getOperationName(), numTreeArgs, numOpArgs); + const int numRequiredArgs = op.getNumRequiredArgs(); + const int opArgs = numTreeArgs - hasLocDirective; + if (opArgs < numRequiredArgs) { + auto err = formatv("too few arguments for op '{0}': " + "{1} in pattern vs. {2} required in definition", + op.getOperationName(), opArgs, numRequiredArgs); + PrintFatalError(&def, err); + } + + if (opArgs > op.getNumArgs()) { + auto err = formatv("too many arguments for op '{0}': " + "{1} in pattern vs. {2} allowed in definition", + op.getOperationName(), opArgs, op.getNumArgs()); PrintFatalError(&def, err); } @@ -699,7 +708,7 @@ verifyBind(infoMap.bindOpResult(treeName, op), treeName); } - for (int i = 0; i != numTreeArgs; ++i) { + for (int i = 0; i != opArgs; ++i) { if (auto treeArg = tree.getArgAsNestedDag(i)) { // This DAG node argument is a DAG node itself. Go inside recursively. collectBoundSymbols(treeArg, infoMap, isSrcPattern); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -966,6 +966,33 @@ def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>; def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>; +// Test partial matches for optional attributes +def OpAttrMatch5 : TEST_Op<"match_optional_op_attribute1"> { + let arguments = (ins + BoolAttr:$required_attr, + OptionalAttr:$optional_attr1, + OptionalAttr:$optional_attr2 + ); +} +def OpAttrMatch6 : TEST_Op<"match_optional_op_attribute2"> { + let arguments = (ins + BoolAttr:$required_attr1, + OptionalAttr:$optional_attr1, + OptionalAttr:$optional_attr2, + BoolAttr:$required_attr2 + ); +} +def : Pat<(OpAttrMatch5 ConstBoolAttrTrue, $optional), + (OpAttrMatch5 ConstBoolAttrFalse, $optional)>; +def : Pat<(OpAttrMatch5 ConstBoolAttrFalse, $optional, $_), + (OpAttrMatch5 ConstBoolAttrTrue, $optional, $optional)>; +def : Pat<(OpAttrMatch6 ConstBoolAttrTrue, $optional, $optional2, ConstBoolAttrFalse), + (OpAttrMatch5 ConstBoolAttrTrue, $optional)>; +def : Pat<(OpAttrMatch6 ConstBoolAttrFalse, $optional, $optional2, $required2), + (OpAttrMatch5 $required2, $optional, $optional2)>; +def : Pat<(OpAttrMatch6 $required, $optional, $optional2, $required2), + (OpAttrMatch5 $required)>; + //===----------------------------------------------------------------------===// // Test Patterns (Multi-result Ops) diff --git a/mlir/test/mlir-tblgen/op-error.td b/mlir/test/mlir-tblgen/op-error.td --- a/mlir/test/mlir-tblgen/op-error.td +++ b/mlir/test/mlir-tblgen/op-error.td @@ -1,6 +1,9 @@ // RUN: not mlir-tblgen -gen-op-decls -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s // RUN: not mlir-tblgen -gen-op-decls -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s // RUN: not mlir-tblgen -gen-op-decls -I %S/../../include -DERROR3 %s 2>&1 | FileCheck --check-prefix=ERROR3 %s +// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR4 %s 2>&1 | FileCheck --check-prefix=ERROR4 %s +// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR5 %s 2>&1 | FileCheck --check-prefix=ERROR5 %s +// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR6 %s 2>&1 | FileCheck --check-prefix=ERROR6 %s include "mlir/IR/OpBase.td" @@ -34,3 +37,31 @@ ]; } #endif + +def OpAttrMatch : Op { + let arguments = (ins + BoolAttr:$required_attr1, + OptionalAttr:$optional_attr1, + OptionalAttr:$optional_attr2, + BoolAttr:$required_attr2 + ); + let results = (outs I32); +} + +#ifdef ERROR4 +// ERROR4: error: too few arguments for op 'test_dialect.optional_attr_match' +def : Pat<(OpAttrMatch ConstBoolAttrTrue, $optional, $optional2, ConstBoolAttrFalse), + (OpAttrMatch ConstBoolAttrFalse)>; +#endif + +#ifdef ERROR5 +// ERROR5: error: too few arguments for op 'test_dialect.optional_attr_match' +def : Pat<(OpAttrMatch:$op1 ConstBoolAttrTrue, $optional, $optional2, ConstBoolAttrFalse), + (OpAttrMatch ConstBoolAttrFalse, $optional, $optional, (location $op1))>; +#endif + +#ifdef ERROR6 +// ERROR6: error: too many arguments for op 'test_dialect.optional_attr_match' +def : Pat<(OpAttrMatch ConstBoolAttrTrue, $o, $o2, ConstBoolAttrFalse), + (OpAttrMatch ConstBoolAttrFalse, $o, $o, $o, $o)>; +#endif diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -337,11 +337,12 @@ // Skip if there is no defining operation (e.g., arguments to function). os << formatv("if (!{0}) return failure();\n", castedName); } - if (tree.getNumArgs() != op.getNumArgs()) { + if (tree.getNumArgs() < op.getNumRequiredArgs() || + tree.getNumArgs() > op.getNumArgs()) { PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " - "pattern vs. {2} in definition", + "pattern vs. [{2},{3}] required in definition", op.getOperationName(), tree.getNumArgs(), - op.getNumArgs())); + op.getNumRequiredArgs(), op.getNumArgs())); } // If the operand's name is set, set to that variable. @@ -973,7 +974,6 @@ LLVM_DEBUG(llvm::dbgs() << '\n'); Operator &resultOp = tree.getDialectOp(opMap); - auto numOpArgs = resultOp.getNumArgs(); auto numPatArgs = tree.getNumArgs(); bool hasLocationDirective; @@ -981,11 +981,19 @@ std::tie(hasLocationDirective, locToUse) = getLocation(tree); auto inPattern = numPatArgs - hasLocationDirective; - if (numOpArgs != inPattern) { + if (resultOp.getNumRequiredArgs() > inPattern) { PrintFatalError(loc, - formatv("resultant op '{0}' argument number mismatch: " - "{1} in pattern vs. {2} in definition", - resultOp.getOperationName(), inPattern, numOpArgs)); + formatv("resultant op '{0}' fas too few arguments: " + "{1} in pattern vs. {2} required in definition", + resultOp.getOperationName(), inPattern, + resultOp.getNumRequiredArgs())); + } + + if (inPattern > resultOp.getNumArgs()) { + PrintFatalError(loc, formatv("resultant op '{0}' has too many arguments: " + "{1} in pattern vs. {2} allowed in definition", + resultOp.getOperationName(), inPattern, + resultOp.getNumArgs())); } // A map to collect all nested DAG child nodes' names, with operand index as @@ -1100,8 +1108,13 @@ // * If the operand is variadic, we create a `SmallVector` local // variable. + int numOpArgs = node.getNumArgs(); + if (auto lastArg = node.getArgAsNestedDag(numOpArgs - 1)) { + numOpArgs -= lastArg.isLocationDirective(); + } + int valueIndex = 0; // An index for uniquing local variable names. - for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { + for (int argIndex = 0; argIndex < numOpArgs; ++argIndex) { const auto *operand = resultOp.getArg(argIndex).dyn_cast(); // We do not need special handling for attributes. @@ -1151,8 +1164,13 @@ void PatternEmitter::supplyValuesForOpArgs( DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { Operator &resultOp = node.getDialectOp(opMap); - for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); - argIndex != numOpArgs; ++argIndex) { + + int numOpArgs = node.getNumArgs(); + if (auto lastArg = node.getArgAsNestedDag(numOpArgs - 1)) { + numOpArgs -= lastArg.isLocationDirective(); + } + + for (int argIndex = 0; argIndex != numOpArgs; ++argIndex) { // Start each argument on its own line. os << ",\n "; @@ -1168,6 +1186,8 @@ // The argument in the op definition. auto opArgName = resultOp.getArgName(argIndex); if (auto subTree = node.getArgAsNestedDag(argIndex)) { + if (subTree.isLocationDirective()) + continue; if (!subTree.isNativeCodeCall()) PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); @@ -1205,11 +1225,15 @@ "if (auto tmpAttr = {1}) {\n" " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), " "tmpAttr);\n}\n"; - for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { + for (int argIndex = 0, e = node.getNumArgs(); argIndex < e; ++argIndex) { + auto subTree = node.getArgAsNestedDag(argIndex); + if (subTree && subTree.isLocationDirective()) + continue; + if (resultOp.getArg(argIndex).is()) { // The argument in the op definition. auto opArgName = resultOp.getArgName(argIndex); - if (auto subTree = node.getArgAsNestedDag(argIndex)) { + if (subTree) { if (!subTree.isNativeCodeCall()) PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute");