Index: mlir/docs/DefiningDialects/AttributesAndTypes.md =================================================================== --- mlir/docs/DefiningDialects/AttributesAndTypes.md +++ mlir/docs/DefiningDialects/AttributesAndTypes.md @@ -866,17 +866,33 @@ respectively generate calls to: ```c++ -LogicalResult parseFoo(AsmParser &parser, FailureOr &foo); +LogicalResult parseFoo(AsmParser &parser, int &foo); void printFoo(AsmPrinter &printer, int foo); ``` +As you can see, by default parameters are passed into the parse function by +reference. This is only possible if the C++ type is default constructible. +If the C++ type is not default constructible, the parameter is wrapped in a +`FailureOr`. Therefore, given the following definition: + +```tablegen +let parameters = (ins "NotDefaultConstructible":$foobar); +let assemblyFormat = "custom($foobar)"; +``` + +It will generate calls expecting the following signature for `parseFizz`: + +```c++ +LogicalResult parseFizz(AsmParser &parser, FailureOr &foo); +``` + A previously bound variable can be passed as a parameter to a `custom` directive by wrapping it in a `ref` directive. In the previous example, `$foo` is bound by the first directive. The second directive references it and expects the following printer and parser signatures: ```c++ -LogicalResult parseBar(AsmParser &parser, FailureOr &bar, int foo); +LogicalResult parseBar(AsmParser &parser, int &bar, int foo); void printBar(AsmPrinter &printer, int bar, int foo); ``` @@ -885,8 +901,7 @@ For example, `StringRefParameter` expects the parser and printer signatures as: ```c++ -LogicalResult parseStringParam(AsmParser &parser, - FailureOr &value); +LogicalResult parseStringParam(AsmParser &parser, std::string &value); void printStringParam(AsmPrinter &printer, StringRef value); ``` Index: mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -219,7 +219,7 @@ } // namespace detail /// Parse any MLIR type or a concise syntax for LLVM types. -ParseResult parsePrettyLLVMType(AsmParser &p, FailureOr &type); +ParseResult parsePrettyLLVMType(AsmParser &p, Type &type); /// Print any MLIR type or a concise syntax for LLVM types. void printPrettyLLVMType(AsmPrinter &p, Type type); Index: mlir/include/mlir/IR/AttributeSupport.h =================================================================== --- mlir/include/mlir/IR/AttributeSupport.h +++ mlir/include/mlir/IR/AttributeSupport.h @@ -264,6 +264,19 @@ static void initializeAttributeStorage(AttributeStorage *storage, MLIRContext *ctx, TypeID attrID); }; + +// Internal function called by ODS generated code. +// Default initializes the type within a FailureOr if T is default +// constructible and returns a reference to the instance. +// Otherwise, returns a reference to the FailureOr. +template +decltype(auto) unwrapForCustomParse(FailureOr &failureOr) { + if constexpr (std::is_default_constructible_v) + return failureOr.emplace(); + else + return failureOr; +} + } // namespace detail } // namespace mlir Index: mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -356,9 +356,8 @@ return type; } -ParseResult LLVM::parsePrettyLLVMType(AsmParser &p, FailureOr &type) { - type.emplace(); - return dispatchParse(p, *type); +ParseResult LLVM::parsePrettyLLVMType(AsmParser &p, Type &type) { + return dispatchParse(p, type); } void LLVM::printPrettyLLVMType(AsmPrinter &p, Type type) { Index: mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -34,9 +34,8 @@ //===----------------------------------------------------------------------===// static ParseResult parseFunctionTypes(AsmParser &p, - FailureOr> ¶ms, - FailureOr &isVarArg) { - params.emplace(); + SmallVector ¶ms, + bool &isVarArg) { isVarArg = false; // `(` `)` if (succeeded(p.parseOptionalRParen())) @@ -49,10 +48,10 @@ } // type (`,` type)* (`,` `...`)? - FailureOr type; + Type type; if (parsePrettyLLVMType(p, type)) return failure(); - params->push_back(*type); + params.push_back(type); while (succeeded(p.parseOptionalComma())) { if (succeeded(p.parseOptionalEllipsis())) { isVarArg = true; @@ -60,7 +59,7 @@ } if (parsePrettyLLVMType(p, type)) return failure(); - params->push_back(*type); + params.push_back(type); } return p.parseRParen(); } @@ -81,11 +80,10 @@ // custom //===----------------------------------------------------------------------===// -static ParseResult parsePointer(AsmParser &p, FailureOr &elementType, - FailureOr &addressSpace) { - addressSpace = 0; +static ParseResult parsePointer(AsmParser &p, Type &elementType, + unsigned &addressSpace) { // `<` addressSpace `>` - OptionalParseResult result = p.parseOptionalInteger(*addressSpace); + OptionalParseResult result = p.parseOptionalInteger(addressSpace); if (result.has_value()) { if (failed(result.value())) return failure(); @@ -96,7 +94,7 @@ if (parsePrettyLLVMType(p, elementType)) return failure(); if (succeeded(p.parseOptionalComma())) - return p.parseInteger(*addressSpace); + return p.parseInteger(addressSpace); return success(); } Index: mlir/test/lib/Dialect/Test/TestAttributes.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -165,11 +165,11 @@ //===----------------------------------------------------------------------===// static ParseResult parseTrueFalse(AsmParser &p, - FailureOr> &result) { + std::optional &result) { bool b; if (p.parseInteger(b)) return failure(); - result = std::optional(b); + result = b; return success(); } Index: mlir/test/lib/Dialect/Test/TestTypes.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestTypes.cpp +++ mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -89,23 +89,21 @@ // TestCustomType //===----------------------------------------------------------------------===// -static LogicalResult parseCustomTypeA(AsmParser &parser, - FailureOr &aResult) { - aResult.emplace(); - return parser.parseInteger(*aResult); +static LogicalResult parseCustomTypeA(AsmParser &parser, int &aResult) { + return parser.parseInteger(aResult); } static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; } static LogicalResult parseCustomTypeB(AsmParser &parser, int a, - FailureOr> &bResult) { + std::optional &bResult) { if (a < 0) return success(); for (int i : llvm::seq(0, a)) if (failed(parser.parseInteger(i))) return failure(); bResult.emplace(0); - return parser.parseInteger(**bResult); + return parser.parseInteger(*bResult); } static void printCustomTypeB(AsmPrinter &printer, int a, std::optional b) { @@ -117,8 +115,7 @@ printer << *b; } -static LogicalResult parseFooString(AsmParser &parser, - FailureOr &foo) { +static LogicalResult parseFooString(AsmParser &parser, std::string &foo) { std::string result; if (parser.parseString(&result)) return failure(); Index: mlir/test/mlir-tblgen/attr-or-type-format.td =================================================================== --- mlir/test/mlir-tblgen/attr-or-type-format.td +++ mlir/test/mlir-tblgen/attr-or-type-format.td @@ -593,7 +593,7 @@ // TYPE-LABEL: ::mlir::Type TestNType::parse // TYPE: parseFoo( -// TYPE-NEXT: _result_a, +// TYPE-NEXT: ::mlir::detail::unwrapForCustomParse(_result_a), // TYPE-NEXT: 1); // TYPE-LABEL: void TestNType::print Index: mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -612,7 +612,8 @@ for (FormatElement *arg : el->getArguments()) { os << ",\n"; if (auto *param = dyn_cast(arg)) - os << "_result_" << param->getName(); + os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName() + << ")"; else if (auto *ref = dyn_cast(arg)) os << "*_result_" << cast(ref->getArg())->getName(); else