diff --git a/mlir/docs/DefiningDialects/AttributesAndTypes.md b/mlir/docs/DefiningDialects/AttributesAndTypes.md --- a/mlir/docs/DefiningDialects/AttributesAndTypes.md +++ b/mlir/docs/DefiningDialects/AttributesAndTypes.md @@ -866,17 +866,33 @@ respectively generate calls to: ```c++ -LogicalResult parseFoo(AsmParser &parser, FailureOr &foo); +LogicalResult parseFoo(AsmParser &parser, int &foo); void printFoo(AsmPrinter &printer, int foo); ``` +As you can see, by default parameters are passed into the parse function by +reference. This is only possible if the C++ type is default constructible. +If the C++ type is not default constructible, the parameter is wrapped in a +`FailureOr`. Therefore, given the following definition: + +```tablegen +let parameters = (ins "NotDefaultConstructible":$foobar); +let assemblyFormat = "custom($foobar)"; +``` + +It will generate calls expecting the following signature for `parseFizz`: + +```c++ +LogicalResult parseFizz(AsmParser &parser, FailureOr &foobar); +``` + A previously bound variable can be passed as a parameter to a `custom` directive by wrapping it in a `ref` directive. In the previous example, `$foo` is bound by the first directive. The second directive references it and expects the following printer and parser signatures: ```c++ -LogicalResult parseBar(AsmParser &parser, FailureOr &bar, int foo); +LogicalResult parseBar(AsmParser &parser, int &bar, int foo); void printBar(AsmPrinter &printer, int bar, int foo); ``` @@ -885,8 +901,7 @@ For example, `StringRefParameter` expects the parser and printer signatures as: ```c++ -LogicalResult parseStringParam(AsmParser &parser, - FailureOr &value); +LogicalResult parseStringParam(AsmParser &parser, std::string &value); void printStringParam(AsmPrinter &printer, StringRef value); ``` diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -219,7 +219,7 @@ } // namespace detail /// Parse any MLIR type or a concise syntax for LLVM types. -ParseResult parsePrettyLLVMType(AsmParser &p, FailureOr &type); +ParseResult parsePrettyLLVMType(AsmParser &p, Type &type); /// Print any MLIR type or a concise syntax for LLVM types. void printPrettyLLVMType(AsmPrinter &p, Type type); diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -264,6 +264,19 @@ static void initializeAttributeStorage(AttributeStorage *storage, MLIRContext *ctx, TypeID attrID); }; + +// Internal function called by ODS generated code. +// Default initializes the type within a FailureOr if T is default +// constructible and returns a reference to the instance. +// Otherwise, returns a reference to the FailureOr. +template +decltype(auto) unwrapForCustomParse(FailureOr &failureOr) { + if constexpr (std::is_default_constructible_v) + return failureOr.emplace(); + else + return failureOr; +} + } // namespace detail } // namespace mlir diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -356,9 +356,8 @@ return type; } -ParseResult LLVM::parsePrettyLLVMType(AsmParser &p, FailureOr &type) { - type.emplace(); - return dispatchParse(p, *type); +ParseResult LLVM::parsePrettyLLVMType(AsmParser &p, Type &type) { + return dispatchParse(p, type); } void LLVM::printPrettyLLVMType(AsmPrinter &p, Type type) { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -33,10 +33,8 @@ // custom //===----------------------------------------------------------------------===// -static ParseResult parseFunctionTypes(AsmParser &p, - FailureOr> ¶ms, - FailureOr &isVarArg) { - params.emplace(); +static ParseResult parseFunctionTypes(AsmParser &p, SmallVector ¶ms, + bool &isVarArg) { isVarArg = false; // `(` `)` if (succeeded(p.parseOptionalRParen())) @@ -49,10 +47,10 @@ } // type (`,` type)* (`,` `...`)? - FailureOr type; + Type type; if (parsePrettyLLVMType(p, type)) return failure(); - params->push_back(*type); + params.push_back(type); while (succeeded(p.parseOptionalComma())) { if (succeeded(p.parseOptionalEllipsis())) { isVarArg = true; @@ -60,7 +58,7 @@ } if (parsePrettyLLVMType(p, type)) return failure(); - params->push_back(*type); + params.push_back(type); } return p.parseRParen(); } @@ -81,11 +79,10 @@ // custom //===----------------------------------------------------------------------===// -static ParseResult parsePointer(AsmParser &p, FailureOr &elementType, - FailureOr &addressSpace) { - addressSpace = 0; +static ParseResult parsePointer(AsmParser &p, Type &elementType, + unsigned &addressSpace) { // `<` addressSpace `>` - OptionalParseResult result = p.parseOptionalInteger(*addressSpace); + OptionalParseResult result = p.parseOptionalInteger(addressSpace); if (result.has_value()) { if (failed(result.value())) return failure(); @@ -96,7 +93,7 @@ if (parsePrettyLLVMType(p, elementType)) return failure(); if (succeeded(p.parseOptionalComma())) - return p.parseInteger(*addressSpace); + return p.parseInteger(addressSpace); return success(); } diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -164,12 +164,11 @@ // TestCustomAnchorAttr //===----------------------------------------------------------------------===// -static ParseResult parseTrueFalse(AsmParser &p, - FailureOr> &result) { +static ParseResult parseTrueFalse(AsmParser &p, std::optional &result) { bool b; if (p.parseInteger(b)) return failure(); - result = std::optional(b); + result = b; return success(); } diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -89,23 +89,21 @@ // TestCustomType //===----------------------------------------------------------------------===// -static LogicalResult parseCustomTypeA(AsmParser &parser, - FailureOr &aResult) { - aResult.emplace(); - return parser.parseInteger(*aResult); +static LogicalResult parseCustomTypeA(AsmParser &parser, int &aResult) { + return parser.parseInteger(aResult); } static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; } static LogicalResult parseCustomTypeB(AsmParser &parser, int a, - FailureOr> &bResult) { + std::optional &bResult) { if (a < 0) return success(); for (int i : llvm::seq(0, a)) if (failed(parser.parseInteger(i))) return failure(); bResult.emplace(0); - return parser.parseInteger(**bResult); + return parser.parseInteger(*bResult); } static void printCustomTypeB(AsmPrinter &printer, int a, std::optional b) { @@ -117,8 +115,7 @@ printer << *b; } -static LogicalResult parseFooString(AsmParser &parser, - FailureOr &foo) { +static LogicalResult parseFooString(AsmParser &parser, std::string &foo) { std::string result; if (parser.parseString(&result)) return failure(); diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -593,7 +593,7 @@ // TYPE-LABEL: ::mlir::Type TestNType::parse // TYPE: parseFoo( -// TYPE-NEXT: _result_a, +// TYPE-NEXT: ::mlir::detail::unwrapForCustomParse(_result_a), // TYPE-NEXT: 1); // TYPE-LABEL: void TestNType::print diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -612,7 +612,8 @@ for (FormatElement *arg : el->getArguments()) { os << ",\n"; if (auto *param = dyn_cast(arg)) - os << "_result_" << param->getName(); + os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName() + << ")"; else if (auto *ref = dyn_cast(arg)) os << "*_result_" << cast(ref->getArg())->getName(); else