diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -664,6 +664,12 @@ - Represents the attribute dictionary of the operation, but prefixes the dictionary with an `attributes` keyword. +* `custom` < UserDirective > ( Params ) + + - Represents a custom directive implemented by the user in C++. + - See the [Custom Directives](#custom-directives) section below for more + details. + * `functional-type` ( inputs , results ) - Formats the `inputs` and `results` arguments as a @@ -705,6 +711,75 @@ Attribute variables are printed with their respective value type, unless that value type is buildable. In those cases, the type of the attribute is elided. +#### Custom Directives + +The declarative assembly format specification allows for handling a large +majority of the common cases when formatting an operation. For the operations +that require or desire specifying parts of the operation in a form not supported +by the declarative syntax, custom directives may be specified. A custom +directive essentially allows for users to use C++ for printing and parsing +subsections of an otherwise declaratively specified format. Looking at the +specification of a custom directive above: + +``` +custom-directive ::= `custom` `<` UserDirective `>` `(` Params `)` +``` + +A custom directive has two main parts: The `UserDirective` and the `Params`. A +custom directive is transformed into a call to a `print*` and a `parse*` method +when generating the C++ code for the format. The `UserDirective` is an +identifier used as a suffix to these two calls, i.e., `custom(...)` +would result in calls to `parseMyDirective` and `printMyDirective` wihtin the +parser and printer respectively. `Params` may be any combination of variables +(i.e. Attribute, Operand, Successor, etc.) and type directives. The type +directives must refer to a variable, but that variable need not also be a +parameter to the custom directive. + +The arguments to the `parse` method is firstly a reference to the +`OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters +corresponding to the parameters specified in the format. The mapping of +declarative parameter to `parse` method argument is detailed below: + +* Attribute Variables + - Single: `(e.g. Attribute) &` + - Optional: `(e.g. Attribute) &` +* Operand Variables + - Single: `OpAsmParser::OperandType &` + - Optional: `Optional &` + - Variadic: `SmallVectorImpl &` +* Successor Variables + - Single: `Block *&` + - Variadic: `SmallVectorImpl &` +* Type Directives + - Single: `Type &` + - Optional: `Type &` + - Variadic: `SmallVectorImpl &` + +When a variable is optional, the value should only be specified if the variable +is present. Otherwise, the value should remain `None` or null. + +The arguments to the `print` method is firstly a reference to the +`OpAsmPrinter`(`OpAsmPrinter &`), and secondly a set of output parameters +corresponding to the parameters specified in the format. The mapping of +declarative parameter to `print` method argument is detailed below: + +* Attribute Variables + - Single: `(e.g. Attribute)` + - Optional: `(e.g. Attribute)` +* Operand Variables + - Single: `Value` + - Optional: `Value` + - Variadic: `OperandRange` +* Successor Variables + - Single: `Block *` + - Variadic: `SuccessorRange` +* Type Directives + - Single: `Type` + - Optional: `Type` + - Variadic: `TypeRange` + +When a variable is optional, the provided value may be null. + #### Optional Groups In certain situations operations may have "optional" information, e.g. @@ -722,8 +797,8 @@ should be printed/parsed. - An element is marked as the anchor by adding a trailing `^`. - The first element is *not* required to be the anchor of the group. -* Literals, variables, and type directives are the only valid elements within - the group. +* Literals, variables, custom directives, and type directives are the only + valid elements within the group. - Any attribute variable may be used, but only optional attributes can be marked as the anchor. - Only variadic or optional operand arguments can be used. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -202,6 +202,10 @@ llvm::interleaveComma(types, p); return p; } +inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const TypeRange &types) { + llvm::interleaveComma(types, p); + return p; +} inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef types) { llvm::interleaveComma(types, p); return p; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -267,6 +267,108 @@ results.insert(context); } +//===----------------------------------------------------------------------===// +// Test Format* operations +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Parsing + +static ParseResult parseCustomDirectiveOperands( + OpAsmParser &parser, OpAsmParser::OperandType &operand, + Optional &optOperand, + SmallVectorImpl &varOperands) { + if (parser.parseOperand(operand)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + optOperand.emplace(); + if (parser.parseOperand(*optOperand)) + return failure(); + } + if (parser.parseArrow() || parser.parseLParen() || + parser.parseOperandList(varOperands) || parser.parseRParen()) + return failure(); + return success(); +} +static ParseResult +parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, + Type &optOperandType, + SmallVectorImpl &varOperandTypes) { + if (parser.parseColon()) + return failure(); + + if (parser.parseType(operandType)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseType(optOperandType)) + return failure(); + } + if (parser.parseArrow() || parser.parseLParen() || + parser.parseTypeList(varOperandTypes) || parser.parseRParen()) + return failure(); + return success(); +} +static ParseResult parseCustomDirectiveOperandsAndTypes( + OpAsmParser &parser, OpAsmParser::OperandType &operand, + Optional &optOperand, + SmallVectorImpl &varOperands, Type &operandType, + Type &optOperandType, SmallVectorImpl &varOperandTypes) { + if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || + parseCustomDirectiveResults(parser, operandType, optOperandType, + varOperandTypes)) + return failure(); + return success(); +} +static ParseResult +parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, + SmallVectorImpl &varSuccessors) { + if (parser.parseSuccessor(successor)) + return failure(); + if (failed(parser.parseOptionalComma())) + return success(); + Block *varSuccessor; + if (parser.parseSuccessor(varSuccessor)) + return failure(); + varSuccessors.append(2, varSuccessor); + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing + +static void printCustomDirectiveOperands(OpAsmPrinter &printer, Value operand, + Value optOperand, + OperandRange varOperands) { + printer << operand; + if (optOperand) + printer << ", " << optOperand; + printer << " -> (" << varOperands << ")"; +} +static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType, + Type optOperandType, + TypeRange varOperandTypes) { + printer << " : " << operandType; + if (optOperandType) + printer << ", " << optOperandType; + printer << " -> (" << varOperandTypes << ")"; +} +static void +printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand, + Value optOperand, OperandRange varOperands, + Type operandType, Type optOperandType, + TypeRange varOperandTypes) { + printCustomDirectiveOperands(printer, operand, optOperand, varOperands); + printCustomDirectiveResults(printer, operandType, optOperandType, + varOperandTypes); +} +static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, + Block *successor, + SuccessorRange varSuccessors) { + printer << successor; + if (!varSuccessors.empty()) + printer << ", " << varSuccessors.front(); +} + //===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1404,8 +1404,60 @@ } //===----------------------------------------------------------------------===// -// AllTypesMatch type inference +// Custom Directives + +def FormatCustomDirectiveOperands + : TEST_Op<"format_custom_directive_operands", [AttrSizedOperandSegments]> { + let arguments = (ins I64:$operand, Optional:$optOperand, + Variadic:$varOperands); + let assemblyFormat = [{ + custom( + $operand, $optOperand, $varOperands + ) + attr-dict + }]; +} + +def FormatCustomDirectiveOperandsAndTypes + : TEST_Op<"format_custom_directive_operands_and_types", + [AttrSizedOperandSegments]> { + let arguments = (ins AnyType:$operand, Optional:$optOperand, + Variadic:$varOperands); + let assemblyFormat = [{ + custom( + $operand, $optOperand, $varOperands, + type($operand), type($optOperand), type($varOperands) + ) + attr-dict + }]; +} + +def FormatCustomDirectiveResults + : TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> { + let results = (outs AnyType:$result, Optional:$optResult, + Variadic:$varResults); + let assemblyFormat = [{ + custom( + type($result), type($optResult), type($varResults) + ) + attr-dict + }]; +} + +def FormatCustomDirectiveSuccessors + : TEST_Op<"format_custom_directive_successors", [Terminator]> { + let successors = (successor AnySuccessor:$successor, + VariadicSuccessor:$successors); + let assemblyFormat = [{ + custom( + $successor, $successors + ) + attr-dict + }]; +} + //===----------------------------------------------------------------------===// +// AllTypesMatch type inference def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [ AllTypesMatch<["value1", "value2", "result"]> @@ -1425,7 +1477,6 @@ //===----------------------------------------------------------------------===// // TypesMatchWith type inference -//===----------------------------------------------------------------------===// def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [ TypesMatchWith<"result type matches operand", "value", "result", "$_self"> diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -42,6 +42,49 @@ attr-dict-with-keyword }]>; +//===----------------------------------------------------------------------===// +// custom + +// CHECK: error: expected '<' before custom directive name +def DirectiveCustomInvalidA : TestFormat_Op<"custom_invalid_a", [{ + custom( +}]>; +// CHECK: error: expected custom directive name identifier +def DirectiveCustomInvalidB : TestFormat_Op<"custom_invalid_b", [{ + custom<> +}]>; +// CHECK: error: expected '>' after custom directive name +def DirectiveCustomInvalidC : TestFormat_Op<"custom_invalid_c", [{ + custom; +// CHECK: error: expected '(' before custom directive parameters +def DirectiveCustomInvalidD : TestFormat_Op<"custom_invalid_d", [{ + custom) +}]>; +// CHECK: error: only variables and types may be used as parameters to a custom directive +def DirectiveCustomInvalidE : TestFormat_Op<"custom_invalid_e", [{ + custom(operands) +}]>; +// CHECK: error: expected ')' after custom directive parameters +def DirectiveCustomInvalidF : TestFormat_Op<"custom_invalid_f", [{ + custom($operand< +}]>, Arguments<(ins I64:$operand)>; +// CHECK: error: type directives within a custom directive may only refer to variables +def DirectiveCustomInvalidH : TestFormat_Op<"custom_invalid_h", [{ + custom(type(operands)) +}]>; + +// CHECK-NOT: error +def DirectiveCustomValidA : TestFormat_Op<"custom_valid_a", [{ + custom($operand) attr-dict +}]>, Arguments<(ins Optional:$operand)>; +def DirectiveCustomValidB : TestFormat_Op<"custom_valid_b", [{ + custom($operand, type($operand), type($result)) attr-dict +}]>, Arguments<(ins I64:$operand)>, Results<(outs I64:$result)>; +def DirectiveCustomValidC : TestFormat_Op<"custom_valid_c", [{ + custom($attr) attr-dict +}]>, Arguments<(ins I64Attr:$attr)>; + //===----------------------------------------------------------------------===// // functional-type @@ -238,6 +281,10 @@ def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{ ($arg^) }]>, Arguments<(ins Variadic:$arg)>; +// CHECK: error: only variables can be used to anchor an optional group +def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{ + (custom($arg)^)? +}]>, Arguments<(ins I64:$arg)>; //===----------------------------------------------------------------------===// // Variables diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -122,6 +122,40 @@ // CHECK: test.format_optional_operand_result_b_op : i64 test.format_optional_operand_result_b_op : i64 +//===----------------------------------------------------------------------===// +// Format custom directives +//===----------------------------------------------------------------------===// + +// CHECK: test.format_custom_directive_operands %[[I64]], %[[I64]] -> (%[[I64]]) +test.format_custom_directive_operands %i64, %i64 -> (%i64) + +// CHECK: test.format_custom_directive_operands %[[I64]] -> (%[[I64]]) +test.format_custom_directive_operands %i64 -> (%i64) + +// CHECK: test.format_custom_directive_operands_and_types %[[I64]], %[[I64]] -> (%[[I64]]) : i64, i64 -> (i64) +test.format_custom_directive_operands_and_types %i64, %i64 -> (%i64) : i64, i64 -> (i64) + +// CHECK: test.format_custom_directive_operands_and_types %[[I64]] -> (%[[I64]]) : i64 -> (i64) +test.format_custom_directive_operands_and_types %i64 -> (%i64) : i64 -> (i64) + +// CHECK: test.format_custom_directive_results : i64, i64 -> (i64) +test.format_custom_directive_results : i64, i64 -> (i64) + +// CHECK: test.format_custom_directive_results : i64 -> (i64) +test.format_custom_directive_results : i64 -> (i64) + +func @foo() { + // CHECK: test.format_custom_directive_successors ^bb1, ^bb2 + test.format_custom_directive_successors ^bb1, ^bb2 + +^bb1: + // CHECK: test.format_custom_directive_successors ^bb2 + test.format_custom_directive_successors ^bb2 + +^bb2: + return +} + //===----------------------------------------------------------------------===// // Format trait type inference //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -45,6 +45,7 @@ enum class Kind { /// This element is a directive. AttrDictDirective, + CustomDirective, FunctionalTypeDirective, OperandsDirective, ResultsDirective, @@ -164,6 +165,33 @@ bool withKeyword; }; +/// This class represents a custom format directive that is implemented by the +/// user in C++. +class CustomDirective : public Element { +public: + CustomDirective(StringRef name, + std::vector> &&arguments) + : Element{Kind::CustomDirective}, name(name), + arguments(std::move(arguments)) {} + + static bool classof(const Element *element) { + return element->getKind() == Kind::CustomDirective; + } + + /// Return the name of this optional element. + StringRef getName() const { return name; } + + /// Return the arguments to the custom directive. + auto getArguments() const { return llvm::make_pointee_range(arguments); } + +private: + /// The user provided name of the directive. + StringRef name; + + /// The arguments to the custom directive. + std::vector> arguments; +}; + /// This class represents the `functional-type` directive. This directive takes /// two arguments and formats them, respectively, as the inputs and results of a /// FunctionType. @@ -370,19 +398,16 @@ /// The code snippet used to generate a parser call for an attribute. /// -/// {0}: The storage type of the attribute. -/// {1}: The name of the attribute. -/// {2}: The type for the attribute. +/// {0}: The name of the attribute. +/// {1}: The type for the attribute. const char *const attrParserCode = R"( - {0} {1}Attr; - if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes)) + if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes)) return failure(); )"; const char *const optionalAttrParserCode = R"( - {0} {1}Attr; { ::mlir::OptionalParseResult parseResult = - parser.parseOptionalAttribute({1}Attr{2}, "{1}", result.attributes); + parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes); if (parseResult.hasValue() && failed(*parseResult)) return failure(); } @@ -408,11 +433,11 @@ return parser.emitError(loc, "invalid ") << "{0} attribute specification: " << attrVal; - result.addAttribute("{0}", {3}); + {0}Attr = {3}; + result.addAttribute("{0}", {0}Attr); } )"; const char *const optionalEnumAttrParserCode = R"( - Attribute {0}Attr; { ::mlir::StringAttr attrVal; ::mlir::NamedAttrList attrStorage; @@ -440,11 +465,13 @@ /// /// {0}: The name of the operand. const char *const variadicOperandParserCode = R"( + {0}OperandsLoc = parser.getCurrentLocation(); if (parser.parseOperandList({0}Operands)) return failure(); )"; const char *const optionalOperandParserCode = R"( { + {0}OperandsLoc = parser.getCurrentLocation(); ::mlir::OpAsmParser::OperandType operand; ::mlir::OptionalParseResult parseResult = parser.parseOptionalOperand(operand); @@ -456,6 +483,7 @@ } )"; const char *const operandParserCode = R"( + {0}OperandsLoc = parser.getCurrentLocation(); if (parser.parseOperand({0}RawOperands[0])) return failure(); )"; @@ -500,7 +528,6 @@ /// /// {0}: The name for the successor list. const char *successorListParserCode = R"( - ::llvm::SmallVector<::mlir::Block *, 2> {0}Successors; { ::mlir::Block *succ; auto firstSucc = parser.parseOptionalSuccessor(succ); @@ -523,7 +550,6 @@ /// /// {0}: The name of the successor. const char *successorParserCode = R"( - ::mlir::Block *{0}Successor = nullptr; if (parser.parseSuccessor({0}Successor)) return failure(); )"; @@ -595,8 +621,34 @@ /// Generate the storage code required for parsing the given element. static void genElementParserStorage(Element *element, OpMethodBody &body) { if (auto *optional = dyn_cast(element)) { - for (auto &childElement : optional->getElements()) - genElementParserStorage(&childElement, body); + auto elements = optional->getElements(); + + // If the anchor is a unit attribute, it won't be parsed directly so elide + // it. + auto *anchor = dyn_cast(optional->getAnchor()); + Element *elidedAnchorElement = nullptr; + if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr()) + elidedAnchorElement = anchor; + for (auto &childElement : elements) + if (&childElement != elidedAnchorElement) + genElementParserStorage(&childElement, body); + + } else if (auto *custom = dyn_cast(element)) { + for (auto ¶mElement : custom->getArguments()) + genElementParserStorage(¶mElement, body); + + } else if (isa(element)) { + body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> " + "allOperands;\n"; + + } else if (isa(element)) { + body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n"; + + } else if (auto *attr = dyn_cast(element)) { + const NamedAttribute *var = attr->getVar(); + body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(), + var->name); + } else if (auto *operand = dyn_cast(element)) { StringRef name = operand->getVar()->name; if (operand->getVar()->isVariableLength()) { @@ -608,10 +660,19 @@ << " ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name << "Operands(" << name << "RawOperands);"; } - body << llvm::formatv( - " ::llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation();\n" - " (void){0}OperandsLoc;\n", - name); + body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n" + " (void){0}OperandsLoc;\n", + name); + } else if (auto *successor = dyn_cast(element)) { + StringRef name = successor->getVar()->name; + if (successor->getVar()->isVariadic()) { + body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> " + "{0}Successors;\n", + name); + } else { + body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name); + } + } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind lengthKind; StringRef name = getTypeListName(dir->getOperand(), lengthKind); @@ -631,6 +692,106 @@ } } +/// Generate the parser for a parameter to a custom directive. +static void genCustomParameterParser(Element ¶m, OpMethodBody &body) { + body << ", "; + if (auto *attr = dyn_cast(¶m)) { + body << attr->getVar()->name << "Attr"; + + } else if (auto *operand = dyn_cast(¶m)) { + StringRef name = operand->getVar()->name; + ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); + if (lengthKind == ArgumentLengthKind::Variadic) + body << llvm::formatv("{0}Operands", name); + else if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv("{0}Operand", name); + else + body << formatv("{0}RawOperands[0]", name); + + } else if (auto *successor = dyn_cast(¶m)) { + StringRef name = successor->getVar()->name; + if (successor->getVar()->isVariadic()) + body << llvm::formatv("{0}Successors", name); + else + body << llvm::formatv("{0}Successor", name); + + } else if (auto *dir = dyn_cast(¶m)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Variadic) + body << llvm::formatv("{0}Types", listName); + else if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv("{0}Type", listName); + else + body << formatv("{0}RawTypes[0]", listName); + } else { + llvm_unreachable("unknown custom directive parameter"); + } +} + +/// Generate the parser for a custom directive. +static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) { + body << " {\n"; + + // Preprocess the directive variables. + // * Add a local variable for optional operands and types. This provides a + // better API to the user defined parser methods. + // * Set the location of operand variables. + for (Element ¶m : dir->getArguments()) { + if (auto *operand = dyn_cast(¶m)) { + body << " " << operand->getVar()->name + << "OperandsLoc = parser.getCurrentLocation();\n"; + if (operand->getVar()->isOptional()) { + body << llvm::formatv( + " llvm::Optional<::mlir::OpAsmParser::OperandType> " + "{0}Operand;\n", + operand->getVar()->name); + } + } else if (auto *dir = dyn_cast(¶m)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName); + } + } + + body << " if (parse" << dir->getName() << "(parser"; + for (Element ¶m : dir->getArguments()) + genCustomParameterParser(param, body); + + body << "))\n" + << " return failure();\n"; + + // After parsing, add handling for any of the optional constructs. + for (Element ¶m : dir->getArguments()) { + if (auto *attr = dyn_cast(¶m)) { + const NamedAttribute *var = attr->getVar(); + if (var->attr.isOptional()) + body << llvm::formatv(" if ({0}Attr)\n ", var->name); + + body << llvm::formatv( + " result.attributes.addAttribute(\"{0}\", {0}Attr);", var->name); + } else if (auto *operand = dyn_cast(¶m)) { + const NamedTypeConstraint *var = operand->getVar(); + if (!var->isOptional()) + continue; + body << llvm::formatv(" if ({0}Operand.hasValue())\n" + " {0}Operands.push_back(*{0}Operand);\n", + var->name); + } else if (auto *dir = dyn_cast(¶m)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Optional) { + body << llvm::formatv(" if ({0}Type)\n" + " {0}Types.push_back({0}Type);\n", + listName); + } + } + } + + body << " }\n"; +} + /// Generate the parser for a single format element. static void genElementParser(Element *element, OpMethodBody &body, FmtContext &attrTypeCtx) { @@ -711,7 +872,7 @@ body << formatv(var->attr.isOptional() ? optionalAttrParserCode : attrParserCode, - var->attr.getStorageType(), var->name, attrTypeStr); + var->name, attrTypeStr); } else if (auto *operand = dyn_cast(element)) { ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); StringRef name = operand->getVar()->name; @@ -732,10 +893,11 @@ << (attrDict->isWithKeyword() ? "WithKeyword" : "") << "(result.attributes))\n" << " return failure();\n"; + } else if (auto *customDir = dyn_cast(element)) { + genCustomDirectiveParser(customDir, body); + } else if (isa(element)) { body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n" - << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> " - "allOperands;\n" << " if (parser.parseOperandList(allOperands))\n" << " return failure();\n"; } else if (isa(element)) { @@ -980,6 +1142,20 @@ llvm::interleaveComma(op.getOperands(), body, interleaveFn); body << "}));\n"; } + + if (!allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments")) { + body << " result.addAttribute(\"result_segment_sizes\", " + << "parser.getBuilder().getI32VectorAttr({"; + auto interleaveFn = [&](const NamedTypeConstraint &result) { + // If the result is variadic emit the parsed size. + if (result.isVariableLength()) + body << "static_cast(" << result.name << "Types.size())"; + else + body << "1"; + }; + llvm::interleaveComma(op.getResults(), body, interleaveFn); + body << "}));\n"; + } } //===----------------------------------------------------------------------===// @@ -1007,6 +1183,8 @@ // Elide the variadic segment size attributes if necessary. if (!fmt.allOperands && op.getTrait("OpTrait::AttrSizedOperandSegments")) body << "\"operand_segment_sizes\", "; + if (!fmt.allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments")) + body << "\"result_segment_sizes\", "; llvm::interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; }); @@ -1038,6 +1216,42 @@ lastWasPunctuation = !(value.front() == '_' || isalpha(value.front())); } +/// Generate the printer for a literal value. `shouldEmitSpace` is true if a +/// space should be emitted before this element. `lastWasPunctuation` is true if +/// the previous element was a punctuation literal. +static void genCustomDirectivePrinter(CustomDirective *customDir, + OpMethodBody &body) { + body << " print" << customDir->getName() << "(p"; + for (Element ¶m : customDir->getArguments()) { + body << ", "; + if (auto *attr = dyn_cast(¶m)) { + body << attr->getVar()->name << "Attr()"; + + } else if (auto *operand = dyn_cast(¶m)) { + body << operand->getVar()->name << "()"; + + } else if (auto *successor = dyn_cast(¶m)) { + body << successor->getVar()->name << "()"; + + } else if (auto *dir = dyn_cast(¶m)) { + auto *typeOperand = dir->getOperand(); + auto *operand = dyn_cast(typeOperand); + auto *var = operand ? operand->getVar() + : cast(typeOperand)->getVar(); + if (var->isVariadic()) + body << var->name << "().getTypes()"; + else if (var->isOptional()) + body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name); + else + body << var->name << "().getType()"; + } else { + llvm_unreachable("unknown custom directive parameter"); + } + } + + body << ");\n"; +} + /// Generate the C++ for an operand to a (*-)type directive. static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) { if (isa(arg)) @@ -1145,6 +1359,8 @@ body << " ::llvm::interleaveComma(" << var->name << "(), p);\n"; else body << " p << " << var->name << "();\n"; + } else if (auto *dir = dyn_cast(element)) { + genCustomDirectivePrinter(dir, body); } else if (isa(element)) { body << " p << getOperation()->getOperands();\n"; } else if (isa(element)) { @@ -1202,12 +1418,15 @@ caret, comma, equal, + less, + greater, question, // Keywords. keyword_start, kw_attr_dict, kw_attr_dict_w_keyword, + kw_custom, kw_functional_type, kw_operands, kw_results, @@ -1353,6 +1572,10 @@ return formToken(Token::comma, tokStart); case '=': return formToken(Token::equal, tokStart); + case '<': + return formToken(Token::less, tokStart); + case '>': + return formToken(Token::greater, tokStart); case '?': return formToken(Token::question, tokStart); case '(': @@ -1406,6 +1629,7 @@ llvm::StringSwitch(str) .Case("attr-dict", Token::kw_attr_dict) .Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword) + .Case("custom", Token::kw_custom) .Case("functional-type", Token::kw_functional_type) .Case("operands", Token::kw_operands) .Case("results", Token::kw_results) @@ -1513,6 +1737,10 @@ LogicalResult parseAttrDictDirective(std::unique_ptr &element, llvm::SMLoc loc, bool isTopLevel, bool withKeyword); + LogicalResult parseCustomDirective(std::unique_ptr &element, + llvm::SMLoc loc, bool isTopLevel); + LogicalResult parseCustomDirectiveParameter( + std::vector> ¶meters); LogicalResult parseFunctionalTypeDirective(std::unique_ptr &element, Token tok, bool isTopLevel); LogicalResult parseOperandsDirective(std::unique_ptr &element, @@ -1930,6 +2158,8 @@ case Token::kw_attr_dict_w_keyword: return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel, /*withKeyword=*/true); + case Token::kw_custom: + return parseCustomDirective(element, dirTok.getLoc(), isTopLevel); case Token::kw_functional_type: return parseFunctionalTypeDirective(element, dirTok, isTopLevel); case Token::kw_operands: @@ -2054,15 +2284,15 @@ seenVariables.insert(ele->getVar()); return success(); }) - // Literals and type directives may be used, but they can't anchor the - // group. - .Case( - [&](Element *) { - if (isAnchor) - return emitError(childLoc, "only variables can be used to anchor " - "an optional group"); - return success(); - }) + // Literals, custom directives, and type directives may be used, + // but they can't anchor the group. + .Case([&](Element *) { + if (isAnchor) + return emitError(childLoc, "only variables can be used to anchor " + "an optional group"); + return success(); + }) .Default([&](Element *) { return emitError(childLoc, "only literals, types, and variables can be " "used within an optional group"); @@ -2084,6 +2314,71 @@ return success(); } +LogicalResult +FormatParser::parseCustomDirective(std::unique_ptr &element, + llvm::SMLoc loc, bool isTopLevel) { + llvm::SMLoc curLoc = curToken.getLoc(); + + // Parse the custom directive name. + if (failed( + parseToken(Token::less, "expected '<' before custom directive name"))) + return failure(); + + Token nameTok = curToken; + if (failed(parseToken(Token::identifier, + "expected custom directive name identifier")) || + failed(parseToken(Token::greater, + "expected '>' after custom directive name")) || + failed(parseToken(Token::l_paren, + "expected '(' before custom directive parameters"))) + return failure(); + + // Parse the child elements for this optional group.= + std::vector> elements; + do { + if (failed(parseCustomDirectiveParameter(elements))) + return failure(); + if (curToken.getKind() != Token::comma) + break; + consumeToken(); + } while (true); + + if (failed(parseToken(Token::r_paren, + "expected ')' after custom directive parameters"))) + return failure(); + + // After parsing all of the elements, ensure that all type directives refer + // only to variables. + for (auto &ele : elements) { + if (auto *typeEle = dyn_cast(ele.get())) { + if (!isa(typeEle->getOperand())) { + return emitError(curLoc, "type directives within a custom directive " + "may only refer to variables"); + } + } + } + + element = std::make_unique(nameTok.getSpelling(), + std::move(elements)); + return success(); +} + +LogicalResult FormatParser::parseCustomDirectiveParameter( + std::vector> ¶meters) { + llvm::SMLoc childLoc = curToken.getLoc(); + parameters.push_back({}); + if (failed(parseElement(parameters.back(), /*isTopLevel=*/true))) + return failure(); + + // Verify that the element can be placed within a custom directive. + if (!isa(parameters.back().get())) { + return emitError(childLoc, "only variables and types may be used as " + "parameters to a custom directive"); + } + return success(); +} + LogicalResult FormatParser::parseFunctionalTypeDirective(std::unique_ptr &element, Token tok, bool isTopLevel) {