diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -619,6 +619,43 @@ Attribute variables are printed with their respective value type, unless that value type is buildable. In those cases, the type of the attribute is elided. +#### Optional Groups + +In certain situations operations may have "optional" information, e.g. +attributes or an empty set of variadic operands. In these situtations a section +of the assembly format can be marked as `optional` based on the presence of this +information. An optional group is defined by wrapping a set of elements within +`()` followed by a `?` and has the following requirements: + +* The first element of the group must either be a literal or an operand. + - This is because the first element was be optionally parsable. +* Exactly one argument variable within the group must be marked as the anchor + of the group. + - The anchor is the element whose presence controls whether the group + should be printed/parsed. + - An element is marked as the anchor by adding a trailing `^`. + - The first element is *not* required to be the anchor of the group. +* Literals, variables, and type directives are the only valid elements within + the group. + - Any attribute variable may be used, but only optional attributes can be + marked as the anchor. + - Only variadic, i.e. optional, operand arguments can be used. + - The operands to the type directive must be defined within the optional + group. + +An example of an operation with an optional group is `std.return`, which has a +variadic number of operands. + +``` +def ReturnOp : ... { + let arguments = (ins Variadic:$operands); + + // We only print the operands and types if there are a non-zero number + // of operands. + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; +} +``` + #### Requirements The format specification has a certain set of requirements that must be adhered diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1059,6 +1059,8 @@ let builders = [OpBuilder< "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] >]; + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape, diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1735,21 +1735,6 @@ // ReturnOp //===----------------------------------------------------------------------===// -static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { - SmallVector opInfo; - SmallVector types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(opInfo) || - (!opInfo.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(opInfo, types, loc, result.operands)); -} - -static void print(OpAsmPrinter &p, ReturnOp op) { - p << "return"; - if (op.getNumOperands() != 0) - p << ' ' << op.getOperands() << " : " << op.getOperandTypes(); -} - static LogicalResult verify(ReturnOp op) { auto function = cast(op.getParentOp()); diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -46,7 +46,7 @@ def DirectiveFunctionalTypeInvalidB : TestFormat_Op<"functype_invalid_b", [{ functional-type }]>; -// CHECK: error: expected directive, literal, or variable +// CHECK: error: expected directive, literal, variable, or optional group def DirectiveFunctionalTypeInvalidC : TestFormat_Op<"functype_invalid_c", [{ functional-type( }]>; @@ -54,7 +54,7 @@ def DirectiveFunctionalTypeInvalidD : TestFormat_Op<"functype_invalid_d", [{ functional-type(operands }]>; -// CHECK: error: expected directive, literal, or variable +// CHECK: error: expected directive, literal, variable, or optional group def DirectiveFunctionalTypeInvalidE : TestFormat_Op<"functype_invalid_e", [{ functional-type(operands, }]>; @@ -98,7 +98,7 @@ def DirectiveTypeInvalidA : TestFormat_Op<"type_invalid_a", [{ type }]>; -// CHECK: error: expected directive, literal, or variable +// CHECK: error: expected directive, literal, variable, or optional group def DirectiveTypeInvalidB : TestFormat_Op<"type_invalid_b", [{ type( }]>; @@ -165,7 +165,7 @@ `1` }]>; // CHECK: error: unexpected end of file in literal -// CHECK: error: expected directive, literal, or variable +// CHECK: error: expected directive, literal, variable, or optional group def LiteralInvalidB : TestFormat_Op<"literal_invalid_b", [{ ` }]>; @@ -175,6 +175,55 @@ attr-dict }]>; +//===----------------------------------------------------------------------===// +// Optional Groups +//===----------------------------------------------------------------------===// + +// CHECK: error: optional groups can only be used as top-level elements +def OptionalInvalidA : TestFormat_Op<"optional_invalid_a", [{ + type(($attr^)?) attr-dict +}]>, Arguments<(ins OptionalAttr:$attr)>; +// CHECK: error: expected directive, literal, variable, or optional group +def OptionalInvalidB : TestFormat_Op<"optional_invalid_b", [{ + () attr-dict +}]>, Arguments<(ins OptionalAttr:$attr)>; +// CHECK: error: optional group specified no anchor element +def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{ + ($attr)? attr-dict +}]>, Arguments<(ins OptionalAttr:$attr)>; +// CHECK: error: first element of an operand group must be a literal or operand +def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{ + ($attr^)? attr-dict +}]>, Arguments<(ins OptionalAttr:$attr)>; +// CHECK: error: type directive can only refer to variables within the optional group +def OptionalInvalidE : TestFormat_Op<"optional_invalid_e", [{ + (`,` $attr^ type(operands))? attr-dict +}]>, Arguments<(ins OptionalAttr:$attr)>; +// CHECK: error: only one element can be marked as the anchor of an optional group +def OptionalInvalidF : TestFormat_Op<"optional_invalid_f", [{ + ($attr^ $attr2^) attr-dict +}]>, Arguments<(ins OptionalAttr:$attr, OptionalAttr:$attr2)>; +// CHECK: error: only optional attributes can be used to anchor an optional group +def OptionalInvalidG : TestFormat_Op<"optional_invalid_g", [{ + ($attr^) attr-dict +}]>, Arguments<(ins I64Attr:$attr)>; +// CHECK: error: only variadic operands can be used within an optional group +def OptionalInvalidH : TestFormat_Op<"optional_invalid_h", [{ + ($arg^) attr-dict +}]>, Arguments<(ins I64:$arg)>; +// CHECK: error: only variables can be used to anchor an optional group +def OptionalInvalidI : TestFormat_Op<"optional_invalid_i", [{ + ($arg type($arg)^) attr-dict +}]>, Arguments<(ins Variadic:$arg)>; +// CHECK: error: only literals, types, and variables can be used within an optional group +def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{ + (attr-dict) +}]>; +// CHECK: error: expected '?' after optional group +def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{ + ($arg^) +}]>, Arguments<(ins Variadic:$arg)>; + //===----------------------------------------------------------------------===// // Variables //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -58,6 +58,9 @@ AttributeVariable, OperandVariable, ResultVariable, + + /// This element is an optional element. + Optional, }; Element(Kind kind) : kind(kind) {} virtual ~Element() = default; @@ -164,7 +167,7 @@ class LiteralElement : public Element { public: LiteralElement(StringRef literal) - : Element{Kind::Literal}, literal(literal){}; + : Element{Kind::Literal}, literal(literal) {} static bool classof(const Element *element) { return element->getKind() == Kind::Literal; } @@ -203,6 +206,36 @@ }); } +//===----------------------------------------------------------------------===// +// OptionalElement + +namespace { +/// This class represents a group of elements that are optionally emitted based +/// upon an optional variable of the operation. +class OptionalElement : public Element { +public: + OptionalElement(std::vector> &&elements, + unsigned anchor) + : Element{Kind::Optional}, elements(std::move(elements)), anchor(anchor) { + } + static bool classof(const Element *element) { + return element->getKind() == Kind::Optional; + } + + /// Return the nested elements of this grouping. + auto getElements() const { return llvm::make_pointee_range(elements); } + + /// Return the anchor of this optional group. + Element *getAnchor() const { return elements[anchor].get(); } + +private: + /// The child elements of this optional. + std::vector> elements; + /// The index of the element that acts as the anchor for the optional group. + unsigned anchor; +}; +} // end anonymous namespace + //===----------------------------------------------------------------------===// // OperationFormat //===----------------------------------------------------------------------===// @@ -327,32 +360,26 @@ const char *const variadicOperandParserCode = R"( llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation(); (void){0}OperandsLoc; - SmallVector {0}Operands; if (parser.parseOperandList({0}Operands)) return failure(); )"; const char *const operandParserCode = R"( llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation(); (void){0}OperandsLoc; - OpAsmParser::OperandType {0}RawOperands[1]; if (parser.parseOperand({0}RawOperands[0])) return failure(); - ArrayRef {0}Operands({0}RawOperands); )"; /// The code snippet used to generate a parser call for a type list. /// /// {0}: The name for the type list. const char *const variadicTypeParserCode = R"( - SmallVector {0}Types; if (parser.parseTypeList({0}Types)) return failure(); )"; const char *const typeParserCode = R"( - Type {0}RawTypes[1] = {{nullptr}; if (parser.parseType({0}RawTypes[0])) return failure(); - ArrayRef {0}Types({0}RawTypes); )"; /// The code snippet used to generate a parser call for a functional type. @@ -363,8 +390,8 @@ FunctionType {0}__{1}_functionType; if (parser.parseType({0}__{1}_functionType)) return failure(); - ArrayRef {0}Types = {0}__{1}_functionType.getInputs(); - ArrayRef {1}Types = {0}__{1}_functionType.getResults(); + {0}Types = {0}__{1}_functionType.getInputs(); + {1}Types = {0}__{1}_functionType.getResults(); )"; /// Get the name used for the type list for the given type directive operand. @@ -388,25 +415,144 @@ /// Generate the parser for a literal value. static void genLiteralParser(StringRef value, OpMethodBody &body) { - body << " if (parser.parse"; - // Handle the case of a keyword/identifier. if (value.front() == '_' || isalpha(value.front())) { body << "Keyword(\"" << value << "\")"; + return; + } + body << (StringRef)llvm::StringSwitch(value) + .Case("->", "Arrow()") + .Case(":", "Colon()") + .Case(",", "Comma()") + .Case("=", "Equal()") + .Case("<", "Less()") + .Case(">", "Greater()") + .Case("(", "LParen()") + .Case(")", "RParen()") + .Case("[", "LSquare()") + .Case("]", "RSquare()"); +} + +/// Generate the storage code required for parsing the given element. +static void genElementParserStorage(Element *element, OpMethodBody &body) { + if (auto *optional = dyn_cast(element)) { + for (auto &childElement : optional->getElements()) + genElementParserStorage(&childElement, body); + } else if (auto *operand = dyn_cast(element)) { + StringRef name = operand->getVar()->name; + if (operand->getVar()->isVariadic()) + body << " SmallVector " << name + << "Operands;\n"; + else + body << " OpAsmParser::OperandType " << name << "RawOperands[1];\n" + << " ArrayRef " << name << "Operands(" + << name << "RawOperands);"; + } else if (auto *dir = dyn_cast(element)) { + bool variadic = false; + StringRef name = getTypeListName(dir->getOperand(), variadic); + if (variadic) + body << " SmallVector " << name << "Types;\n"; + else + body << llvm::formatv(" Type {0}RawTypes[1];\n", name) + << llvm::formatv(" ArrayRef {0}Types({0}RawTypes);\n", name); + } else if (auto *dir = dyn_cast(element)) { + bool ignored = false; + body << " ArrayRef " << getTypeListName(dir->getInputs(), ignored) + << "Types;\n"; + body << " ArrayRef " << getTypeListName(dir->getResults(), ignored) + << "Types;\n"; + } +} + +/// Generate the parser for a single format element. +static void genElementParser(Element *element, OpMethodBody &body, + FmtContext &attrTypeCtx) { + /// Optional Group. + if (auto *optional = dyn_cast(element)) { + auto elements = optional->getElements(); + + // Generate a special optional parser for the first element to gate the + // parsing of the rest of the elements. + if (auto *literal = dyn_cast(&*elements.begin())) { + body << " if (succeeded(parser.parseOptional"; + genLiteralParser(literal->getLiteral(), body); + body << ")) {\n"; + } else if (auto *opVar = dyn_cast(&*elements.begin())) { + genElementParser(opVar, body, attrTypeCtx); + body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n"; + } + + // Generate the rest of the elements normally. + for (auto &childElement : llvm::drop_begin(elements, 1)) + genElementParser(&childElement, body, attrTypeCtx); + body << " }\n"; + + /// Literals. + } else if (LiteralElement *literal = dyn_cast(element)) { + body << " if (parser.parse"; + genLiteralParser(literal->getLiteral(), body); + body << ")\n return failure();\n"; + + /// Arguments. + } else if (auto *attr = dyn_cast(element)) { + const NamedAttribute *var = attr->getVar(); + + // Check to see if we can parse this as an enum attribute. + if (canFormatEnumAttr(var)) { + const EnumAttr &enumAttr = cast(var->attr); + + // Generate the code for building an attribute for this enum. + std::string attrBuilderStr; + { + llvm::raw_string_ostream os(attrBuilderStr); + os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx, + "attrOptional.getValue()"); + } + + body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(), + enumAttr.getStringToSymbolFnName(), attrBuilderStr); + return; + } + + // If this attribute has a buildable type, use that when parsing the + // attribute. + std::string attrTypeStr; + if (Optional attrType = var->attr.getValueType()) { + if (Optional typeBuilder = attrType->getBuilderCall()) { + llvm::raw_string_ostream os(attrTypeStr); + os << ", " << tgfmt(*typeBuilder, &attrTypeCtx); + } + } + + body << formatv(attrParserCode, var->attr.getStorageType(), var->name, + attrTypeStr); + } else if (auto *operand = dyn_cast(element)) { + bool isVariadic = operand->getVar()->isVariadic(); + body << formatv(isVariadic ? variadicOperandParserCode : operandParserCode, + operand->getVar()->name); + + /// Directives. + } else if (isa(element)) { + body << " if (parser.parseOptionalAttrDict(result.attributes))\n" + << " return failure();\n"; + } else if (isa(element)) { + body << " llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n" + << " SmallVector allOperands;\n" + << " if (parser.parseOperandList(allOperands))\n" + << " return failure();\n"; + } else if (auto *dir = dyn_cast(element)) { + bool isVariadic = false; + StringRef listName = getTypeListName(dir->getOperand(), isVariadic); + body << formatv(isVariadic ? variadicTypeParserCode : typeParserCode, + listName); + } else if (auto *dir = dyn_cast(element)) { + bool ignored = false; + body << formatv(functionalTypeParserCode, + getTypeListName(dir->getInputs(), ignored), + getTypeListName(dir->getResults(), ignored)); } else { - body << (StringRef)llvm::StringSwitch(value) - .Case("->", "Arrow()") - .Case(":", "Colon()") - .Case(",", "Comma()") - .Case("=", "Equal()") - .Case("<", "Less()") - .Case(">", "Greater()") - .Case("(", "LParen()") - .Case(")", "RParen()") - .Case("[", "LSquare()") - .Case("]", "RSquare()"); + llvm_unreachable("unknown format element"); } - body << ")\n return failure();\n"; } void OperationFormat::genParser(Operator &op, OpClass &opClass) { @@ -415,79 +561,19 @@ OpMethod::MP_Static); auto &body = method.body(); + // Generate variables to store the operands and type within the format. This + // allows for referencing these variables in the presence of optional + // groupings. + for (auto &element : elements) + genElementParserStorage(&*element, body); + // A format context used when parsing attributes with buildable types. FmtContext attrTypeCtx; attrTypeCtx.withBuilder("parser.getBuilder()"); // Generate parsers for each of the elements. - for (auto &element : elements) { - /// Literals. - if (LiteralElement *literal = dyn_cast(element.get())) { - genLiteralParser(literal->getLiteral(), body); - - /// Arguments. - } else if (auto *attr = dyn_cast(element.get())) { - const NamedAttribute *var = attr->getVar(); - - // Check to see if we can parse this as an enum attribute. - if (canFormatEnumAttr(var)) { - const EnumAttr &enumAttr = cast(var->attr); - - // Generate the code for building an attribute for this enum. - std::string attrBuilderStr; - { - llvm::raw_string_ostream os(attrBuilderStr); - os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx, - "attrOptional.getValue()"); - } - - body << formatv(enumAttrParserCode, var->name, - enumAttr.getCppNamespace(), - enumAttr.getStringToSymbolFnName(), attrBuilderStr); - continue; - } - - // If this attribute has a buildable type, use that when parsing the - // attribute. - std::string attrTypeStr; - if (Optional attrType = var->attr.getValueType()) { - if (Optional typeBuilder = attrType->getBuilderCall()) { - llvm::raw_string_ostream os(attrTypeStr); - os << ", " << tgfmt(*typeBuilder, &attrTypeCtx); - } - } - - body << formatv(attrParserCode, var->attr.getStorageType(), var->name, - attrTypeStr); - } else if (auto *operand = dyn_cast(element.get())) { - bool isVariadic = operand->getVar()->isVariadic(); - body << formatv(isVariadic ? variadicOperandParserCode - : operandParserCode, - operand->getVar()->name); - - /// Directives. - } else if (isa(element.get())) { - body << " if (parser.parseOptionalAttrDict(result.attributes))\n" - << " return failure();\n"; - } else if (isa(element.get())) { - body << " llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n" - << " SmallVector allOperands;\n" - << " if (parser.parseOperandList(allOperands))\n" - << " return failure();\n"; - } else if (auto *dir = dyn_cast(element.get())) { - bool isVariadic = false; - StringRef listName = getTypeListName(dir->getOperand(), isVariadic); - body << formatv(isVariadic ? variadicTypeParserCode : typeParserCode, - listName); - } else if (auto *dir = dyn_cast(element.get())) { - bool ignored = false; - body << formatv(functionalTypeParserCode, - getTypeListName(dir->getInputs(), ignored), - getTypeListName(dir->getResults(), ignored)); - } else { - llvm_unreachable("unknown format element"); - } - } + for (auto &element : elements) + genElementParser(element.get(), body, attrTypeCtx); // Generate the code to resolve the operand and result types now that they // have been parsed. @@ -676,7 +762,7 @@ lastWasPunctuation = !(value.front() == '_' || isalpha(value.front())); } -/// Generate the c++ for an operand to a (*-)type directive. +/// Generate the C++ for an operand to a (*-)type directive. static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) { if (isa(arg)) return body << "getOperation()->getOperandTypes()"; @@ -689,6 +775,79 @@ return body << "ArrayRef(" << var->name << "().getType())"; } +/// Generate the code for printing the given element. +static void genElementPrinter(Element *element, OpMethodBody &body, + OperationFormat &fmt, bool &shouldEmitSpace, + bool &lastWasPunctuation) { + if (LiteralElement *literal = dyn_cast(element)) + return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace, + lastWasPunctuation); + + // Emit an optional group. + if (OptionalElement *optional = dyn_cast(element)) { + // Emit the check for the presence of the anchor element. + Element *anchor = optional->getAnchor(); + if (AttributeVariable *attrVar = dyn_cast(anchor)) + body << " if (getAttr(\"" << attrVar->getVar()->name << "\")) {\n"; + else + body << " if (!" << cast(anchor)->getVar()->name + << "().empty()) {\n"; + + // Emit each of the elements. + for (Element &childElement : optional->getElements()) + genElementPrinter(&childElement, body, fmt, shouldEmitSpace, + lastWasPunctuation); + body << " }\n"; + return; + } + + // Emit the attribute dictionary. + if (isa(element)) { + genAttrDictPrinter(fmt, body); + lastWasPunctuation = false; + return; + } + + // Optionally insert a space before the next element. The AttrDict printer + // already adds a space as necessary. + if (shouldEmitSpace || !lastWasPunctuation) + body << " p << \" \";\n"; + lastWasPunctuation = false; + shouldEmitSpace = true; + + if (auto *attr = dyn_cast(element)) { + const NamedAttribute *var = attr->getVar(); + + // If we are formatting as a enum, symbolize the attribute as a string. + if (canFormatEnumAttr(var)) { + const EnumAttr &enumAttr = cast(var->attr); + body << " p << \"\\\"\" << " << enumAttr.getSymbolToStringFnName() << "(" + << var->name << "()) << \"\\\"\";\n"; + return; + } + + // Elide the attribute type if it is buildable. + Optional attrType = var->attr.getValueType(); + if (attrType && attrType->getBuilderCall()) + body << " p.printAttributeWithoutType(" << var->name << "Attr());\n"; + else + body << " p.printAttribute(" << var->name << "Attr());\n"; + } else if (auto *operand = dyn_cast(element)) { + body << " p << " << operand->getVar()->name << "();\n"; + } else if (isa(element)) { + body << " p << getOperation()->getOperands();\n"; + } else if (auto *dir = dyn_cast(element)) { + body << " p << "; + genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; + } else if (auto *dir = dyn_cast(element)) { + body << " p.printFunctionalType("; + genTypeOperandPrinter(dir->getInputs(), body) << ", "; + genTypeOperandPrinter(dir->getResults(), body) << ");\n"; + } else { + llvm_unreachable("unknown format element"); + } +} + void OperationFormat::genPrinter(Operator &op, OpClass &opClass) { auto &method = opClass.newMethod("void", "print", "OpAsmPrinter &p"); auto &body = method.body(); @@ -706,60 +865,9 @@ // Flags for if we should emit a space, and if the last element was // punctuation. bool shouldEmitSpace = true, lastWasPunctuation = false; - for (auto &element : elements) { - // Emit a literal element. - if (LiteralElement *literal = dyn_cast(element.get())) { - genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace, - lastWasPunctuation); - continue; - } - - // Emit the attribute dictionary. - if (isa(element.get())) { - genAttrDictPrinter(*this, body); - lastWasPunctuation = false; - continue; - } - - // Optionally insert a space before the next element. The AttrDict printer - // already adds a space as necessary. - if (shouldEmitSpace || !lastWasPunctuation) - body << " p << \" \";\n"; - lastWasPunctuation = false; - shouldEmitSpace = true; - - if (auto *attr = dyn_cast(element.get())) { - const NamedAttribute *var = attr->getVar(); - - // If we are formatting as a enum, symbolize the attribute as a string. - if (canFormatEnumAttr(var)) { - const EnumAttr &enumAttr = cast(var->attr); - body << " p << \"\\\"\" << " << enumAttr.getSymbolToStringFnName() - << "(" << var->name << "()) << \"\\\"\";\n"; - continue; - } - - // Elide the attribute type if it is buildable. - Optional attrType = var->attr.getValueType(); - if (attrType && attrType->getBuilderCall()) - body << " p.printAttributeWithoutType(" << var->name << "Attr());\n"; - else - body << " p.printAttribute(" << var->name << "Attr());\n"; - } else if (auto *operand = dyn_cast(element.get())) { - body << " p << " << operand->getVar()->name << "();\n"; - } else if (isa(element.get())) { - body << " p << getOperation()->getOperands();\n"; - } else if (auto *dir = dyn_cast(element.get())) { - body << " p << "; - genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; - } else if (auto *dir = dyn_cast(element.get())) { - body << " p.printFunctionalType("; - genTypeOperandPrinter(dir->getInputs(), body) << ", "; - genTypeOperandPrinter(dir->getResults(), body) << ");\n"; - } else { - llvm_unreachable("unknown format element"); - } - } + for (auto &element : elements) + genElementPrinter(element.get(), body, *this, shouldEmitSpace, + lastWasPunctuation); } //===----------------------------------------------------------------------===// @@ -778,8 +886,10 @@ // Tokens with no info. l_paren, r_paren, + caret, comma, equal, + question, // Keywords. keyword_start, @@ -908,10 +1018,14 @@ return formToken(Token::eof, tokStart); // Lex punctuation. + case '^': + return formToken(Token::caret, tokStart); case ',': return formToken(Token::comma, tokStart); case '=': return formToken(Token::equal, tokStart); + case '?': + return formToken(Token::question, tokStart); case '(': return formToken(Token::l_paren, tokStart); case ')': @@ -1026,6 +1140,12 @@ LogicalResult parseDirective(std::unique_ptr &element, bool isTopLevel); LogicalResult parseLiteral(std::unique_ptr &element); + LogicalResult parseOptional(std::unique_ptr &element, + bool isTopLevel); + LogicalResult parseOptionalChildElement( + std::vector> &childElements, + SmallPtrSetImpl &seenVariables, + Optional &anchorIdx); /// Parse the various different directives. LogicalResult parseAttrDictDirective(std::unique_ptr &element, @@ -1077,6 +1197,7 @@ llvm::SmallBitVector seenOperandTypes, seenResultTypes; llvm::DenseSet seenOperands; llvm::DenseSet seenAttrs; + llvm::DenseSet optionalVariables; }; } // end anonymous namespace @@ -1236,11 +1357,14 @@ // Literals. if (curToken.getKind() == Token::literal) return parseLiteral(element); + // Optionals. + if (curToken.getKind() == Token::l_paren) + return parseOptional(element, isTopLevel); // Variables. if (curToken.getKind() == Token::variable) return parseVariable(element, isTopLevel); return emitError(curToken.getLoc(), - "expected directive, literal, or variable"); + "expected directive, literal, variable, or optional group"); } LogicalResult FormatParser::parseVariable(std::unique_ptr &element, @@ -1314,6 +1438,115 @@ return success(); } +LogicalResult FormatParser::parseOptional(std::unique_ptr &element, + bool isTopLevel) { + llvm::SMLoc curLoc = curToken.getLoc(); + if (!isTopLevel) + return emitError(curLoc, "optional groups can only be used as top-level " + "elements"); + consumeToken(); + + // Parse the child elements for this optional group. + std::vector> elements; + SmallPtrSet seenVariables; + Optional anchorIdx; + do { + if (failed(parseOptionalChildElement(elements, seenVariables, anchorIdx))) + return failure(); + } while (curToken.getKind() != Token::r_paren); + consumeToken(); + if (failed(parseToken(Token::question, "expected '?' after optional group"))) + return failure(); + + // The optional group is required to have an anchor. + if (!anchorIdx) + return emitError(curLoc, "optional group specified no anchor element"); + + // The first element of the group must be one that can be parsed/printed in an + // optional fashion. + if (!isa(&*elements.front()) && + !isa(&*elements.front())) + return emitError(curLoc, "first element of an operand group must be a " + "literal or operand"); + + // After parsing all of the elements, ensure that all type directives refer + // only to elements within the group. + auto checkTypeOperand = [&](Element *typeEle) { + auto *opVar = dyn_cast(typeEle); + const NamedTypeConstraint *var = opVar ? opVar->getVar() : nullptr; + if (!seenVariables.count(var)) + return emitError(curLoc, "type directive can only refer to variables " + "within the optional group"); + return success(); + }; + for (auto &ele : elements) { + if (auto *typeEle = dyn_cast(ele.get())) { + if (failed(checkTypeOperand(typeEle->getOperand()))) + return failure(); + } else if (auto *typeEle = dyn_cast(ele.get())) { + if (failed(checkTypeOperand(typeEle->getInputs())) || + failed(checkTypeOperand(typeEle->getResults()))) + return failure(); + } + } + + optionalVariables.insert(seenVariables.begin(), seenVariables.end()); + element = std::make_unique(std::move(elements), *anchorIdx); + return success(); +} + +LogicalResult FormatParser::parseOptionalChildElement( + std::vector> &childElements, + SmallPtrSetImpl &seenVariables, + Optional &anchorIdx) { + llvm::SMLoc childLoc = curToken.getLoc(); + childElements.push_back({}); + if (failed(parseElement(childElements.back(), /*isTopLevel=*/true))) + return failure(); + + // Check to see if this element is the anchor of the optional group. + bool isAnchor = curToken.getKind() == Token::caret; + if (isAnchor) { + if (anchorIdx) + return emitError(childLoc, "only one element can be marked as the anchor " + "of an optional group"); + anchorIdx = childElements.size() - 1; + consumeToken(); + } + + return TypeSwitch(childElements.back().get()) + // All attributes can be within the optional group, but only optional + // attributes can be the anchor. + .Case([&](AttributeVariable *attrEle) { + if (isAnchor && !attrEle->getVar()->attr.isOptional()) + return emitError(childLoc, "only optional attributes can be used to " + "anchor an optional group"); + return success(); + }) + // Only optional-like(i.e. variadic) operands can be within an optional + // group. + .Case([&](auto *ele) { + if (!ele->getVar()->isVariadic()) + return emitError(childLoc, "only variadic operands can be used within" + " an optional group"); + seenVariables.insert(ele->getVar()); + return success(); + }) + // Literals and type directives may be used, but they can't anchor the + // group. + .Case( + [&](auto *) { + if (isAnchor) + return emitError(childLoc, "only variables can be used to anchor " + "an optional group"); + return success(); + }) + .Default([&](auto *) { + return emitError(childLoc, "only literals, types, and variables can be " + "used within an optional group"); + }); +} + LogicalResult FormatParser::parseAttrDictDirective(std::unique_ptr &element, llvm::SMLoc loc, bool isTopLevel) { @@ -1344,8 +1577,6 @@ failed(parseTypeDirectiveOperand(results)) || failed(parseToken(Token::r_paren, "expected ')' after argument list"))) return failure(); - - // Get the proper directive kind and create it. element = std::make_unique(std::move(inputs), std::move(results)); return success();