diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -54,9 +54,7 @@ return isSubClassOf("EnumAttrCaseInfo"); } -bool DagLeaf::isStringAttr() const { - return isa(def); -} +bool DagLeaf::isStringAttr() const { return isa(def); } Constraint DagLeaf::getAsConstraint() const { assert((isOperandMatcher() || isAttrMatcher()) && @@ -257,10 +255,15 @@ auto *operand = op->getArg(*argIndex).get(); // If this operand is variadic, then return a range. Otherwise, return the // value itself. - if (operand->isVariableLength()) { + if (operand->isVariadic()) { auto repl = formatv(fmt, name); LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); return std::string(repl); + } else if (operand->isOptional()) { + auto repl = formatv("({0}.size() > 0? (*{0}.begin()): Value())", + formatv(fmt, name)); + LLVM_DEBUG(llvm::dbgs() << repl << " (OptionalOperand)\n"); + return std::string(repl); } auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td --- a/mlir/test/mlir-tblgen/rewriter-indexing.td +++ b/mlir/test/mlir-tblgen/rewriter-indexing.td @@ -85,3 +85,24 @@ // CHECK: nativeCall(rewriter, odsLoc, (*v1.begin()), (*v2.begin()), (*v3.begin()), (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin())) def test4 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10), (NativeBuilder $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>; + +// Tests dag operand indexing for ops with optional operand. +// --- + +def EOp : NS_Op<"e_op", []> { + let arguments = (ins + AnyInteger: $any_integer, + Optional: $optioanl_integer + ); +} + +def FOp : NS_Op<"f_op", []> { + let arguments = (ins + AnyInteger: $any_integer, + Optional: $optioanl_integer + ); +} + +// CHECK: if (auto tmpOperand = (*arg0.begin())) { +// CHECK: if (auto tmpOperand = (arg1.size() > 0? (*arg1.begin()): Value())) { +def test5 : Pat<(EOp $arg0, $arg1), (FOp $arg0, $arg1)>; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1221,6 +1221,8 @@ "if (auto tmpAttr = {1}) {\n" " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), " "tmpAttr);\n}\n"; + const char *addOperandCmd = "if (auto tmpOperand = {0}) {{\n" + " tblgen_values.push_back(tmpOperand);\n}\n"; for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { if (resultOp.getArg(argIndex).is()) { // The argument in the op definition. @@ -1257,22 +1259,22 @@ os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n", range); } else { - os << formatv("tblgen_values.push_back("); + std::string tmpSymbol; if (node.isNestedDagArg(argIndex)) { - os << symbolInfoMap.getValueAndRangeUse( - childNodeNames.lookup(argIndex)); + tmpSymbol = + symbolInfoMap.getValueAndRangeUse(childNodeNames.lookup(argIndex)); } else { DagLeaf leaf = node.getArgAsLeaf(argIndex); auto symbol = symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); if (leaf.isNativeCodeCall()) { - os << std::string( + tmpSymbol = std::string( tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); } else { - os << symbol; + tmpSymbol = symbol; } } - os << ");\n"; + os << formatv(addOperandCmd, tmpSymbol); } } }