diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -681,6 +681,10 @@ - Represents all of the operands of an operation. +* `regions` + + - Represents all of the regions of an operation. + * `results` - Represents all of the results of an operation. @@ -700,13 +704,14 @@ A literal is either a keyword or punctuation surrounded by \`\`. The following are the set of valid punctuation: - `:`, `,`, `=`, `<`, `>`, `(`, `)`, `[`, `]`, `->` + +`:`, `,`, `=`, `<`, `>`, `(`, `)`, `{`, `}`, `[`, `]`, `->` #### Variables A variable is an entity that has been registered on the operation itself, i.e. -an argument(attribute or operand), result, successor, etc. In the `CallOp` -example above, the variables would be `$callee` and `$args`. +an argument(attribute or operand), region, result, successor, etc. In the +`CallOp` example above, the variables would be `$callee` and `$args`. Attribute variables are printed with their respective value type, unless that value type is buildable. In those cases, the type of the attribute is elided. @@ -747,6 +752,9 @@ - Single: `OpAsmParser::OperandType &` - Optional: `Optional &` - Variadic: `SmallVectorImpl &` +* Region Variables + - Single: `Region &` + - Variadic: `SmallVectorImpl> &` * Successor Variables - Single: `Block *&` - Variadic: `SmallVectorImpl &` @@ -770,6 +778,9 @@ - Single: `Value` - Optional: `Value` - Variadic: `OperandRange` +* Region Variables + - Single: `Region &` + - Variadic: `MutableArrayRef` * Successor Variables - Single: `Block *` - Variadic: `SuccessorRange` @@ -788,20 +799,24 @@ information. An optional group is defined by wrapping a set of elements within `()` followed by a `?` and has the following requirements: -* The first element of the group must either be a literal, attribute, or an - operand. -* This is because the first element must be optionally parsable. +* The first element of the group must either be a attribute, literal, operand, + or variadic region. + - This is because the first element must be optionally parsable. * Exactly one argument variable within the group must be marked as the anchor of the group. - The anchor is the element whose presence controls whether the group should be printed/parsed. - An element is marked as the anchor by adding a trailing `^`. - The first element is *not* required to be the anchor of the group. + - When a non-variadic region anchors a group, the detector for printing + the group is if the region is empty. * Literals, variables, custom directives, and type directives are the only valid elements within the group. - Any attribute variable may be used, but only optional attributes can be marked as the anchor. - Only variadic or optional operand arguments can be used. + - All region variables can be used. When a non-variable length region is + used, if the group is not present the region is empty. - The operands to a type directive must be defined within the optional group. @@ -853,18 +868,22 @@ The format specification has a certain set of requirements that must be adhered to: -1. The output and operation name are never shown as they are fixed and cannot be - altered. -1. All operands within the operation must appear within the format, either - individually or with the `operands` directive. -1. All operand and result types must appear within the format using the various - `type` directives, either individually or with the `operands` or `results` - directives. -1. The `attr-dict` directive must always be present. -1. Must not contain overlapping information; e.g. multiple instances of - 'attr-dict', types, operands, etc. - - Note that `attr-dict` does not overlap with individual attributes. These - attributes will simply be elided when printing the attribute dictionary. +1. The output and operation name are never shown as they are fixed and cannot + be altered. +1. All operands within the operation must appear within the format, either + individually or with the `operands` directive. +1. All regions within the operation must appear within the format, either + individually or with the `regions` directive. +1. All successors within the operation must appear within the format, either + individually or with the `successors` directive. +1. All operand and result types must appear within the format using the various + `type` directives, either individually or with the `operands` or `results` + directives. +1. The `attr-dict` directive must always be present. +1. Must not contain overlapping information; e.g. multiple instances of + 'attr-dict', types, operands, etc. + - Note that `attr-dict` does not overlap with individual attributes. These + attributes will simply be elided when printing the attribute dictionary. ##### Type Inference diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -621,16 +621,23 @@ /// can only be set to true for regions attached to operations that are /// "IsolatedFromAbove". virtual ParseResult parseRegion(Region ®ion, - ArrayRef arguments, - ArrayRef argTypes, + ArrayRef arguments = {}, + ArrayRef argTypes = {}, bool enableNameShadowing = false) = 0; /// Parses a region if present. virtual ParseResult parseOptionalRegion(Region ®ion, - ArrayRef arguments, - ArrayRef argTypes, + ArrayRef arguments = {}, + ArrayRef argTypes = {}, bool enableNameShadowing = false) = 0; + /// Parses a region if present. If the region is present, a new region is + /// allocated and placed in `region`. If no region is present or on failure, + /// `region` remains untouched. + virtual OptionalParseResult parseOptionalRegion( + std::unique_ptr ®ion, ArrayRef arguments = {}, + ArrayRef argTypes = {}, bool enableNameShadowing = false) = 0; + /// Parse a region argument, this argument is resolved when calling /// 'parseRegion'. virtual ParseResult parseRegionArgument(OperandType &argument) = 0; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -414,6 +414,11 @@ /// region is null, a new empty region will be attached to the Operation. void addRegion(std::unique_ptr &®ion); + /// Take ownership of a set of regions that should be attached to the + /// Operation. The body of the region will be transferred when the Operation + /// is constructed. + void addRegions(MutableArrayRef> regions); + /// Get the context held by this operation state. MLIRContext *getContext() const { return location->getContext(); } }; diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -199,6 +199,12 @@ regions.push_back(std::move(region)); } +void OperationState::addRegions( + MutableArrayRef> regions) { + for (std::unique_ptr ®ion : regions) + addRegion(std::move(region)); +} + //===----------------------------------------------------------------------===// // OperandStorage //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1355,6 +1355,23 @@ return parseRegion(region, arguments, argTypes, enableNameShadowing); } + /// Parses a region if present. If the region is present, a new region is + /// allocated and placed in `region`. If no region is present, `region` + /// remains untouched. + OptionalParseResult + parseOptionalRegion(std::unique_ptr ®ion, + ArrayRef arguments, ArrayRef argTypes, + bool enableNameShadowing = false) override { + if (parser.getToken().isNot(Token::l_brace)) + return llvm::None; + std::unique_ptr newRegion = std::make_unique(); + if (parseRegion(*newRegion, arguments, argTypes, enableNameShadowing)) + return failure(); + + region = std::move(newRegion); + return success(); + } + /// Parse a region argument. The type of the argument will be resolved later /// by a call to `parseRegion`. ParseResult parseRegionArgument(OperandType &argument) override { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -319,6 +319,19 @@ return failure(); return success(); } +static ParseResult parseCustomDirectiveRegions( + OpAsmParser &parser, Region ®ion, + SmallVectorImpl> &varRegions) { + if (parser.parseRegion(region)) + return failure(); + if (failed(parser.parseOptionalComma())) + return success(); + std::unique_ptr varRegion = std::make_unique(); + if (parser.parseRegion(*varRegion)) + return failure(); + varRegions.emplace_back(std::move(varRegion)); + return success(); +} static ParseResult parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, SmallVectorImpl &varSuccessors) { @@ -361,6 +374,15 @@ printCustomDirectiveResults(printer, operandType, optOperandType, varOperandTypes); } +static void printCustomDirectiveRegions(OpAsmPrinter &printer, Region ®ion, + MutableArrayRef varRegions) { + printer.printRegion(region); + if (!varRegions.empty()) { + printer << ", "; + for (Region ®ion : varRegions) + printer.printRegion(region); + } +} static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Block *successor, SuccessorRange varSuccessors) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1323,6 +1323,33 @@ let assemblyFormat = "$buildable attr-dict"; } +// Test various mixings of region formatting. +class FormatRegionBase + : TEST_Op<"format_region_" # suffix # "_op"> { + let regions = (region AnyRegion:$region); + let assemblyFormat = fmt; +} +def FormatRegionAOp : FormatRegionBase<"a", [{ + regions attr-dict +}]>; +def FormatRegionBOp : FormatRegionBase<"b", [{ + $region attr-dict +}]>; +def FormatRegionCOp : FormatRegionBase<"c", [{ + (`region` $region^)? attr-dict +}]>; +class FormatVariadicRegionBase + : TEST_Op<"format_variadic_region_" # suffix # "_op"> { + let regions = (region VariadicRegion:$regions); + let assemblyFormat = fmt; +} +def FormatVariadicRegionAOp : FormatVariadicRegionBase<"a", [{ + $regions attr-dict +}]>; +def FormatVariadicRegionBOp : FormatVariadicRegionBase<"b", [{ + ($regions^ `found_regions`)? attr-dict +}]>; + // Test various mixings of result type formatting. class FormatResultBase : TEST_Op<"format_result_" # suffix # "_op"> { @@ -1444,6 +1471,16 @@ }]; } +def FormatCustomDirectiveRegions : TEST_Op<"format_custom_directive_regions"> { + let regions = (region AnyRegion:$region, VariadicRegion:$regions); + let assemblyFormat = [{ + custom( + $region, $regions + ) + attr-dict + }]; +} + def FormatCustomDirectiveResults : TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> { let results = (outs AnyType:$result, Optional:$optResult, diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -133,6 +133,28 @@ operands attr-dict }]>; +//===----------------------------------------------------------------------===// +// regions + +// CHECK: error: 'regions' directive creates overlap in format +def DirectiveRegionsInvalidA : TestFormat_Op<"regions_invalid_a", [{ + regions regions attr-dict +}]>; +// CHECK: error: 'regions' directive creates overlap in format +def DirectiveRegionsInvalidB : TestFormat_Op<"regions_invalid_b", [{ + $region regions attr-dict +}]> { + let regions = (region AnyRegion:$region); +} +// CHECK: error: 'regions' is only valid as a top-level directive +def DirectiveRegionsInvalidC : TestFormat_Op<"regions_invalid_c", [{ + type(regions) +}]>; +// CHECK-NOT: error: +def DirectiveRegionsValid : TestFormat_Op<"regions_valid", [{ + regions attr-dict +}]>; + //===----------------------------------------------------------------------===// // results @@ -249,7 +271,7 @@ def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{ ($attr)? attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; -// CHECK: error: first element of an operand group must be an attribute, literal, or operand +// CHECK: error: first element of an operand group must be an attribute, literal, operand, or variadic region def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{ (type($operand) $operand^)? attr-dict }]>, Arguments<(ins Optional:$operand)>; @@ -290,7 +312,7 @@ // Variables //===----------------------------------------------------------------------===// -// CHECK: error: expected variable to refer to an argument, result, or successor +// CHECK: error: expected variable to refer to an argument, region, result, or successor def VariableInvalidA : TestFormat_Op<"variable_invalid_a", [{ $unknown_arg attr-dict }]>; @@ -330,11 +352,35 @@ def VariableInvalidI : TestFormat_Op<"variable_invalid_i", [{ (`foo` $attr^)? `:` attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; -// CHECK-NOT: error: +// CHECK: error: region 'region' is already bound def VariableInvalidJ : TestFormat_Op<"variable_invalid_j", [{ + $region $region attr-dict +}]> { + let regions = (region AnyRegion:$region); +} +// CHECK: error: region 'region' is already bound +def VariableInvalidK : TestFormat_Op<"variable_invalid_K", [{ + regions $region attr-dict +}]> { + let regions = (region AnyRegion:$region); +} +// CHECK: error: regions can only be used at the top level +def VariableInvalidL : TestFormat_Op<"variable_invalid_l", [{ + type($region) +}]> { + let regions = (region AnyRegion:$region); +} +// CHECK: error: region #0, named 'region', not found +def VariableInvalidM : TestFormat_Op<"variable_invalid_m", [{ + attr-dict +}]> { + let regions = (region AnyRegion:$region); +} +// CHECK-NOT: error: +def VariableValidA : TestFormat_Op<"variable_valid_a", [{ $attr `:` attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; -def VariableInvalidK : TestFormat_Op<"variable_invalid_k", [{ +def VariableValidB : TestFormat_Op<"variable_valid_b", [{ (`foo` $attr^)? `:` attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -40,6 +40,56 @@ // CHECK: test.format_buildable_type_op %[[I64]] %ignored = test.format_buildable_type_op %i64 +//===----------------------------------------------------------------------===// +// Format regions +//===----------------------------------------------------------------------===// + +// CHECK: test.format_region_a_op { +// CHECK-NEXT: test.return +test.format_region_a_op { + "test.return"() : () -> () +} + +// CHECK: test.format_region_b_op { +// CHECK-NEXT: test.return +test.format_region_b_op { + "test.return"() : () -> () +} + +// CHECK: test.format_region_c_op region { +// CHECK-NEXT: test.return +test.format_region_c_op region { + "test.return"() : () -> () +} +// CHECK: test.format_region_c_op +// CHECK-NOT: region { +test.format_region_c_op + +// CHECK: test.format_variadic_region_a_op { +// CHECK-NEXT: test.return +// CHECK-NEXT: }, { +// CHECK-NEXT: test.return +// CHECK-NEXT: } +test.format_variadic_region_a_op { + "test.return"() : () -> () +}, { + "test.return"() : () -> () +} +// CHECK: test.format_variadic_region_b_op { +// CHECK-NEXT: test.return +// CHECK-NEXT: }, { +// CHECK-NEXT: test.return +// CHECK-NEXT: } found_regions +test.format_variadic_region_b_op { + "test.return"() : () -> () +}, { + "test.return"() : () -> () +} found_regions +// CHECK: test.format_variadic_region_b_op +// CHECK-NOT: { +// CHECK-NOT: found_regions +test.format_variadic_region_b_op + //===----------------------------------------------------------------------===// // Format results //===----------------------------------------------------------------------===// @@ -147,6 +197,24 @@ // CHECK: test.format_custom_directive_operands_and_types %[[I64]] -> (%[[I64]]) : i64 -> (i64) test.format_custom_directive_operands_and_types %i64 -> (%i64) : i64 -> (i64) +// CHECK: test.format_custom_directive_regions { +// CHECK-NEXT: test.return +// CHECK-NEXT: } +test.format_custom_directive_regions { + "test.return"() : () -> () +} + +// CHECK: test.format_custom_directive_regions { +// CHECK-NEXT: test.return +// CHECK-NEXT: }, { +// CHECK-NEXT: test.return +// CHECK-NEXT: } +test.format_custom_directive_regions { + "test.return"() : () -> () +}, { + "test.return"() : () -> () +} + // CHECK: test.format_custom_directive_results : i64, i64 -> (i64) test.format_custom_directive_results : i64, i64 -> (i64) diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -48,6 +48,7 @@ CustomDirective, FunctionalTypeDirective, OperandsDirective, + RegionsDirective, ResultsDirective, SuccessorsDirective, TypeDirective, @@ -58,6 +59,7 @@ /// This element is an variable value. AttributeVariable, OperandVariable, + RegionVariable, ResultVariable, SuccessorVariable, @@ -119,6 +121,10 @@ using OperandVariable = VariableElement; +/// This class represents a variable that refers to a region. +using RegionVariable = + VariableElement; + /// This class represents a variable that refers to a result. using ResultVariable = VariableElement; @@ -133,8 +139,7 @@ namespace { /// This class implements single kind directives. -template -class DirectiveElement : public Element { +template class DirectiveElement : public Element { public: DirectiveElement() : Element(type){}; static bool classof(const Element *ele) { return ele->getKind() == type; } @@ -143,6 +148,10 @@ /// all of the operands of an operation. using OperandsDirective = DirectiveElement; +/// This class represents the `regions` directive. This directive represents +/// all of the regions of an operation. +using RegionsDirective = DirectiveElement; + /// This class represents the `results` directive. This directive represents /// all of the results of an operation. using ResultsDirective = DirectiveElement; @@ -358,6 +367,8 @@ /// Generate the c++ to resolve the types of operands and results during /// parsing. void genParserTypeResolution(Operator &op, OpMethodBody &body); + /// Generate the c++ to resolve regions during parsing. + void genParserRegionResolution(Operator &op, OpMethodBody &body); /// Generate the c++ to resolve successors during parsing. void genParserSuccessorResolution(Operator &op, OpMethodBody &body); /// Generate the c++ to handling variadic segment size traits. @@ -542,6 +553,37 @@ {1}Types = {0}__{1}_functionType.getResults(); )"; +/// The code snippet used to generate a parser call for a region list. +/// +/// {0}: The name for the region list. +const char *regionListParserCode = R"( + { + std::unique_ptr<::mlir::Region> region; + auto firstRegionResult = parser.parseOptionalRegion(region); + if (firstRegionResult.hasValue()) { + if (failed(*firstRegionResult)) + return failure(); + {0}Regions.emplace_back(std::move(region)); + + // Parse any trailing regions. + while (succeeded(parser.parseOptionalComma())) { + region = std::make_unique<::mlir::Region>(); + if (parser.parseRegion(*region)) + return failure(); + {0}Regions.emplace_back(std::move(region)); + } + } + } +)"; + +/// The code snippet used to generate a parser call for a region. +/// +/// {0}: The name of the region. +const char *regionParserCode = R"( + if (parser.parseRegion(*{0}Region)) + return failure(); +)"; + /// The code snippet used to generate a parser call for a successor list. /// /// {0}: The name for the successor list. @@ -659,6 +701,10 @@ body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> " "allOperands;\n"; + } else if (isa(element)) { + body << " ::llvm::SmallVector, 2> " + "fullRegions;\n"; + } else if (isa(element)) { body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n"; @@ -681,6 +727,20 @@ body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n" " (void){0}OperandsLoc;\n", name); + + } else if (auto *region = dyn_cast(element)) { + StringRef name = region->getVar()->name; + if (region->getVar()->isVariadic()) { + body << llvm::formatv( + " ::llvm::SmallVector, 2> " + "{0}Regions;\n", + name); + } else { + body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = " + "std::make_unique<::mlir::Region>();\n", + name); + } + } else if (auto *successor = dyn_cast(element)) { StringRef name = successor->getVar()->name; if (successor->getVar()->isVariadic()) { @@ -726,6 +786,13 @@ else body << formatv("{0}RawOperands[0]", name); + } else if (auto *region = dyn_cast(¶m)) { + StringRef name = region->getVar()->name; + if (region->getVar()->isVariadic()) + body << llvm::formatv("{0}Regions", name); + else + body << llvm::formatv("*{0}Region", name); + } else if (auto *successor = dyn_cast(¶m)) { StringRef name = successor->getVar()->name; if (successor->getVar()->isVariadic()) @@ -830,6 +897,9 @@ } else if (auto *opVar = dyn_cast(firstElement)) { genElementParser(opVar, body, attrTypeCtx); body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n"; + } else if (auto *regionVar = dyn_cast(firstElement)) { + genElementParser(regionVar, body, attrTypeCtx); + body << " if (!" << regionVar->getVar()->name << "Regions.empty()) {\n"; } // If the anchor is a unit attribute, we don't need to print it. When @@ -908,6 +978,12 @@ body << llvm::formatv(optionalOperandParserCode, name); else body << formatv(operandParserCode, name); + + } else if (auto *region = dyn_cast(element)) { + bool isVariadic = region->getVar()->isVariadic(); + body << formatv(isVariadic ? regionListParserCode : regionParserCode, + region->getVar()->name); + } else if (auto *successor = dyn_cast(element)) { bool isVariadic = successor->getVar()->isVariadic(); body << formatv(isVariadic ? successorListParserCode : successorParserCode, @@ -926,8 +1002,13 @@ body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n" << " if (parser.parseOperandList(allOperands))\n" << " return failure();\n"; + + } else if (isa(element)) { + body << llvm::formatv(regionListParserCode, "full"); + } else if (isa(element)) { body << llvm::formatv(successorListParserCode, "full"); + } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -971,6 +1052,7 @@ // Generate the code to resolve the operand/result types and successors now // that they have been parsed. genParserTypeResolution(op, body); + genParserRegionResolution(op, body); genParserSuccessorResolution(op, body); genParserVariadicSegmentResolution(op, body); @@ -1134,6 +1216,33 @@ } } +void OperationFormat::genParserRegionResolution(Operator &op, + OpMethodBody &body) { + // Check for the case where all regions were parsed. + bool hasAllRegions = llvm::any_of( + elements, [](auto &elt) { return isa(elt.get()); }); + if (hasAllRegions) { + body << " result.addRegions(fullRegions);\n"; + + // Otherwise, handle each region individually. + } else { + for (const NamedRegion ®ion : op.getRegions()) { + if (region.isVariadic()) + body << " result.addRegions(" << region.name << "Regions);\n"; + else + body << " result.addRegion(std::move(" << region.name << "Region));\n"; + } + } + + // If this operation has the SingleBlockImplicitTerminator trait, ensure that + // each of the regions has terminators. + if (op.getTrait("SingleBlockImplicitTerminator")) { + body << " for (auto ®ion : result.regions)\n" + << " ensureTerminator(*region, parser.getBuilder(), " + "result.location);"; + } +} + void OperationFormat::genParserSuccessorResolution(Operator &op, OpMethodBody &body) { // Check for the case where all successors were parsed. @@ -1187,6 +1296,23 @@ //===----------------------------------------------------------------------===// // PrinterGen +/// The code snippet used to generate a printer call for a region of an +// operation that has the SingleBlockImplicitTerminator trait. +/// +/// {0}: The name of the region. +const char *regionSingleBlockImplicitTerminatorPrinterCode = R"( + { + bool printTerminator = true; + if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) { + printTerminator = !term->getMutableAttrDict().empty() || + term->getNumOperands() != 0 || + term->getNumResults() != 0; + } + p.printRegion({0}, /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/printTerminator); + } +)"; + /// Generate the printer for the 'attr-dict' directive. static void genAttrDictPrinter(OperationFormat &fmt, Operator &op, OpMethodBody &body, bool withKeyword) { @@ -1256,6 +1382,9 @@ } else if (auto *operand = dyn_cast(¶m)) { body << operand->getVar()->name << "()"; + } else if (auto *region = dyn_cast(¶m)) { + body << region->getVar()->name << "()"; + } else if (auto *successor = dyn_cast(¶m)) { body << successor->getVar()->name << "()"; @@ -1278,6 +1407,23 @@ body << ");\n"; } +/// Generate the printer for a region with the given variable name. +static void genRegionPrinter(Twine regionName, OpMethodBody &body, + Operator &op) { + if (op.getTrait("SingleBlockImplicitTerminator")) + body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode, + regionName); + else + body << "p.printRegion(" << regionName << ");\n"; +} +static void genVariadicRegionPrinter(Twine regionListName, OpMethodBody &body, + Operator &op) { + body << " llvm::interleaveComma(" << regionListName + << ", p, [&](::mlir::Region ®ion) {\n "; + genRegionPrinter("region", body, op); + body << " });\n"; +} + /// Generate the C++ for an operand to a (*-)type directive. static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) { if (isa(arg)) @@ -1315,6 +1461,11 @@ body << " if (" << var->name << "()) {\n"; else if (var->isVariadic()) body << " if (!" << var->name << "().empty()) {\n"; + } else if (auto *region = dyn_cast(anchor)) { + const NamedRegion *var = region->getVar(); + // TODO: Add a check for optional here when ODS supports it. + body << " if (!" << var->name << "().empty()) {\n"; + } else { body << " if (getAttr(\"" << cast(anchor)->getVar()->name << "\")) {\n"; @@ -1385,6 +1536,13 @@ } else { body << " p << " << operand->getVar()->name << "();\n"; } + } else if (auto *region = dyn_cast(element)) { + const NamedRegion *var = region->getVar(); + if (var->isVariadic()) { + genVariadicRegionPrinter(var->name + "()", body, op); + } else { + genRegionPrinter(var->name + "()", body, op); + } } else if (auto *successor = dyn_cast(element)) { const NamedSuccessor *var = successor->getVar(); if (var->isVariadic()) @@ -1395,6 +1553,8 @@ genCustomDirectivePrinter(dir, body); } else if (isa(element)) { body << " p << getOperation()->getOperands();\n"; + } else if (isa(element)) { + genVariadicRegionPrinter("getOperation()->getRegions()", body, op); } else if (isa(element)) { body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n"; } else if (auto *dir = dyn_cast(element)) { @@ -1461,6 +1621,7 @@ kw_custom, kw_functional_type, kw_operands, + kw_regions, kw_results, kw_successors, kw_type, @@ -1664,6 +1825,7 @@ .Case("custom", Token::kw_custom) .Case("functional-type", Token::kw_functional_type) .Case("operands", Token::kw_operands) + .Case("regions", Token::kw_regions) .Case("results", Token::kw_results) .Case("successors", Token::kw_successors) .Case("type", Token::kw_type) @@ -1677,8 +1839,7 @@ /// Function to find an element within the given range that has the same name as /// 'name'. -template -static auto findArg(RangeT &&range, StringRef name) { +template static auto findArg(RangeT &&range, StringRef name) { auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; }); return it != range.end() ? &*it : nullptr; } @@ -1721,6 +1882,9 @@ verifyOperands(llvm::SMLoc loc, llvm::StringMap &variableTyResolver); + /// Verify the state of operation regions within the format. + LogicalResult verifyRegions(llvm::SMLoc loc); + /// Verify the state of operation results within the format. LogicalResult verifyResults(llvm::SMLoc loc, @@ -1777,6 +1941,8 @@ Token tok, bool isTopLevel); LogicalResult parseOperandsDirective(std::unique_ptr &element, llvm::SMLoc loc, bool isTopLevel); + LogicalResult parseRegionsDirective(std::unique_ptr &element, + llvm::SMLoc loc, bool isTopLevel); LogicalResult parseResultsDirective(std::unique_ptr &element, llvm::SMLoc loc, bool isTopLevel); LogicalResult parseSuccessorsDirective(std::unique_ptr &element, @@ -1824,10 +1990,11 @@ // The following are various bits of format state used for verification // during parsing. bool hasAllOperands = false, hasAttrDict = false; - bool hasAllSuccessors = false; + bool hasAllRegions = false, hasAllSuccessors = false; llvm::SmallBitVector seenOperandTypes, seenResultTypes; - llvm::DenseSet seenOperands; llvm::DenseSet seenAttrs; + llvm::DenseSet seenOperands; + llvm::DenseSet seenRegions; llvm::DenseSet seenSuccessors; llvm::DenseSet optionalVariables; }; @@ -1869,7 +2036,7 @@ if (failed(verifyAttributes(loc)) || failed(verifyResults(loc, variableTyResolver)) || failed(verifyOperands(loc, variableTyResolver)) || - failed(verifySuccessors(loc))) + failed(verifyRegions(loc)) || failed(verifySuccessors(loc))) return failure(); // Check to see if we are formatting all of the operands. @@ -1993,6 +2160,24 @@ return success(); } +LogicalResult FormatParser::verifyRegions(llvm::SMLoc loc) { + // Check that all of the regions are within the format. + if (hasAllRegions) + return success(); + + for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) { + const NamedRegion ®ion = op.getRegion(i); + if (!seenRegions.count(®ion)) { + return emitErrorAndNote(loc, + "region #" + Twine(i) + ", named '" + + region.name + "', not found", + "suggest adding a '$" + region.name + + "' directive to the custom assembly format"); + } + } + return success(); +} + LogicalResult FormatParser::verifyResults( llvm::SMLoc loc, llvm::StringMap &variableTyResolver) { @@ -2158,6 +2343,15 @@ element = std::make_unique(operand); return success(); } + /// Regions + if (const NamedRegion *region = findArg(op.getRegions(), name)) { + if (!isTopLevel) + return emitError(loc, "regions can only be used at the top level"); + if (hasAllRegions || !seenRegions.insert(region).second) + return emitError(loc, "region '" + name + "' is already bound"); + element = std::make_unique(region); + return success(); + } /// Results. if (const auto *result = findArg(op.getResults(), name)) { if (isTopLevel) @@ -2174,8 +2368,8 @@ element = std::make_unique(successor); return success(); } - return emitError( - loc, "expected variable to refer to an argument, result, or successor"); + return emitError(loc, "expected variable to refer to an argument, region, " + "result, or successor"); } LogicalResult FormatParser::parseDirective(std::unique_ptr &element, @@ -2196,6 +2390,8 @@ return parseFunctionalTypeDirective(element, dirTok, isTopLevel); case Token::kw_operands: return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel); + case Token::kw_regions: + return parseRegionsDirective(element, dirTok.getLoc(), isTopLevel); case Token::kw_results: return parseResultsDirective(element, dirTok.getLoc(), isTopLevel); case Token::kw_successors: @@ -2249,9 +2445,12 @@ // optional fashion. Element *firstElement = &*elements.front(); if (!isa(firstElement) && - !isa(firstElement) && !isa(firstElement)) + !isa(firstElement) && + !isa(firstElement) && + !(isa(firstElement) && + cast(firstElement)->getVar()->isVariadic())) return emitError(curLoc, "first element of an operand group must be an " - "attribute, literal, or operand"); + "attribute, literal, operand, or variadic region"); // After parsing all of the elements, ensure that all type directives refer // only to elements within the group. @@ -2316,6 +2515,11 @@ seenVariables.insert(ele->getVar()); return success(); }) + .Case([&](RegionVariable *) { + // TODO: When ODS has proper support for marking "optional" regions, add + // a check here. + return success(); + }) // Literals, custom directives, and type directives may be used, // but they can't anchor the group. .Case(parameters.back().get())) { return emitError(childLoc, "only variables and types may be used as " "parameters to a custom directive"); @@ -2442,6 +2646,18 @@ return success(); } +LogicalResult +FormatParser::parseRegionsDirective(std::unique_ptr &element, + llvm::SMLoc loc, bool isTopLevel) { + if (!isTopLevel) + return emitError(loc, "'regions' is only valid as a top-level directive"); + if (hasAllRegions || !seenRegions.empty()) + return emitError(loc, "'regions' directive creates overlap in format"); + hasAllRegions = true; + element = std::make_unique(); + return success(); +} + LogicalResult FormatParser::parseResultsDirective(std::unique_ptr &element, llvm::SMLoc loc, bool isTopLevel) {