diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -645,8 +645,26 @@ let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`"; } +// TYPE-LABEL: TestQType::parse +// TYPE: if (auto result = [&]() -> ::mlir::OptionalParseResult { +// TYPE: auto odsCustomResult = parseAB(odsParser +// TYPE: if (!odsCustomResult) return {}; +// TYPE: if (::mlir::failed(*odsCustomResult)) return ::mlir::failure(); +// TYPE: return ::mlir::success(); +// TYPE: }(); result.has_value() && ::mlir::failed(*result)) { +// TYPE: return {}; +// TYPE: } else if (result.has_value()) { +// TYPE: // Parse literal 'y' +// TYPE: } else { +// TYPE: // Parse literal 'x' +def TypeO : TestType<"TestQ"> { + let parameters = (ins OptionalParameter<"int">:$a); + let mnemonic = "type_o"; + let assemblyFormat = "(custom($a)^ `x`) : (`y`)?"; +} + // DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser) // DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType); // DEFAULT_TYPE_PARSER: if (parseResult.has_value()) { // DEFAULT_TYPE_PARSER: if (::mlir::succeeded(parseResult.value())) -// DEFAULT_TYPE_PARSER: return genType; \ No newline at end of file +// DEFAULT_TYPE_PARSER: return genType; diff --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td --- a/mlir/test/mlir-tblgen/op-format-invalid.td +++ b/mlir/test/mlir-tblgen/op-format-invalid.td @@ -357,7 +357,7 @@ def OptionalInvalidC : TestFormat_Op<[{ ($attr)? attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; -// CHECK: error: first parsable element of an optional group must be a literal or variable +// CHECK: error: first parsable element of an optional group must be a literal, variable, or custom directive def OptionalInvalidD : TestFormat_Op<[{ (type($operand) $operand^)? attr-dict }]>, Arguments<(ins Optional:$operand)>; diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td --- a/mlir/test/mlir-tblgen/op-format.td +++ b/mlir/test/mlir-tblgen/op-format.td @@ -84,9 +84,20 @@ ($a^)? attr-dict }]>, Arguments<(ins DefaultValuedStrAttr:$a)>; +// CHECK-LABEL: OptionalGroupD::parse +// CHECK: if (auto result = [&]() -> ::mlir::OptionalParseResult { +// CHECK: auto odsResult = parseCustom(parser, aOperand, bOperand); +// CHECK: if (!odsResult) return {}; +// CHECK: if (::mlir::failed(*odsResult)) return ::mlir::failure(); +// CHECK: return ::mlir::success(); +// CHECK: }(); result.has_value() && ::mlir::failed(*result)) { +// CHECK: return ::mlir::failure(); +// CHECK: } else if (result.has_value()) { + // CHECK-LABEL: OptionalGroupD::print // CHECK-NEXT: if (((getA()) || (getB()))) { -// CHECK-NEXT: odsPrinter << "(" +// CHECK-NEXT: odsPrinter << ' '; +// CHECK-NEXT: printCustom def OptionalGroupD : TestFormat_Op<[{ - (`(` custom($a, $b)^ `)`)? attr-dict + (custom($a, $b)^)? attr-dict }], [AttrSizedOperandSegments]>, Arguments<(ins Optional:$a, Optional:$b)>; diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -201,7 +201,8 @@ /// Generate the parser code for a `struct` directive. void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the parser code for a `custom` directive. - void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os); + void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os, + bool isOptional = false); /// Generate the parser code for an optional group. void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, MethodBody &os); @@ -598,7 +599,7 @@ } void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx, - MethodBody &os) { + MethodBody &os, bool isOptional) { os << "{\n"; os.indent(); @@ -620,7 +621,12 @@ os << tgfmt(cast(arg)->getValue(), &ctx); } os.unindent() << ");\n"; - os << "if (::mlir::failed(odsCustomResult)) return {};\n"; + if (isOptional) { + os << "if (!odsCustomResult) return {};\n"; + os << "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n"; + } else { + os << "if (::mlir::failed(odsCustomResult)) return {};\n"; + } for (FormatElement *arg : el->getArguments()) { if (auto *param = dyn_cast(arg)) { if (param->isOptional()) @@ -629,7 +635,7 @@ os.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx) << "\"custom parser failed to parse parameter '" << param->getName() << "'\");\n"; - os << "return {};\n"; + os << "return " << (isOptional ? "::mlir::failure()" : "{}") << ";\n"; os.unindent() << "}\n"; } } @@ -663,6 +669,17 @@ } else if (auto *params = dyn_cast(first)) { genParamsParser(params, ctx, os); guardOn(params->getParams()); + } else if (auto *custom = dyn_cast(first)) { + os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n"; + os.indent(); + genCustomParser(custom, ctx, os, /*isOptional=*/true); + os << "return ::mlir::success();\n"; + os.unindent(); + os << "}(); result.has_value() && ::mlir::failed(*result)) {\n"; + os.indent(); + os << "return {};\n"; + os.unindent(); + os << "} else if (result.has_value()) {\n"; } else { auto *strct = cast(first); genStructParser(strct, ctx, os); diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -383,9 +383,9 @@ unsigned thenParseStart = std::distance(thenElements.begin(), thenParseBegin); unsigned elseParseStart = std::distance(elseElements.begin(), elseParseBegin); - if (!isa(*thenParseBegin)) { + if (!isa(*thenParseBegin)) { return emitError(loc, "first parsable element of an optional group must be " - "a literal or variable"); + "a literal, variable, or custom directive"); } return create(std::move(thenElements), std::move(elseElements), thenParseStart, diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -953,7 +953,8 @@ /// Generate the parser for a custom directive. static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, bool useProperties, - StringRef opCppClassName) { + StringRef opCppClassName, + bool isOptional = false) { body << " {\n"; // Preprocess the directive variables. @@ -1011,14 +1012,19 @@ } } - body << " if (parse" << dir->getName() << "(parser"; + body << " auto odsResult = parse" << dir->getName() << "(parser"; for (FormatElement *param : dir->getArguments()) { body << ", "; genCustomParameterParser(param, body); } + body << ");\n"; - body << "))\n" - << " return ::mlir::failure();\n"; + if (isOptional) { + body << " if (!odsResult) return {};\n" + << " if (::mlir::failed(*odsResult)) return ::mlir::failure();\n"; + } else { + body << " if (odsResult) return ::mlir::failure();\n"; + } // After parsing, add handling for any of the optional constructs. for (FormatElement *param : dir->getArguments()) { @@ -1273,6 +1279,14 @@ body << llvm::formatv(regionEnsureSingleBlockParserCode, region->name); } + } else if (auto *custom = dyn_cast(firstElement)) { + body << " if (auto result = [&]() -> ::mlir::OptionalParseResult {\n"; + genCustomDirectiveParser(custom, body, useProperties, opCppClassName, + /*isOptional=*/true); + body << " return ::mlir::success();\n" + << " }(); result.has_value() && ::mlir::failed(*result)) {\n" + << " return ::mlir::failure();\n" + << " } else if (result.has_value()) {\n"; } genElementParsers(firstElement, thenElements.drop_front(),