diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -620,6 +620,14 @@ - The constraints on `inputs` and `results` are the same as the `input` of the `type` directive. +* `oilist` ( \`keyword\` elements | \`otherKeyword\` elements ...) + + - Represents an optional order-independent list of clauses. Each clause + has a keyword and corresponding assembly format. + - Each clause can appear 0 or 1 times (in any order). + - Only literals, types and variables can be used within an oilist element. + - All the variables must be optional or variadic. + * `operands` - Represents all of the operands of an operation. diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -488,6 +488,75 @@ // ----- +// CHECK-LABEL: @succeededOilistTrivial +func @succeededOilistTrivial() { + // CHECK: test.oilist_with_keywords_only keyword + test.oilist_with_keywords_only keyword + // CHECK: test.oilist_with_keywords_only otherKeyword + test.oilist_with_keywords_only otherKeyword + // CHECK: test.oilist_with_keywords_only keyword otherKeyword + test.oilist_with_keywords_only keyword otherKeyword + // CHECK: test.oilist_with_keywords_only keyword otherKeyword + test.oilist_with_keywords_only otherKeyword keyword + return +} + +// ----- + +// CHECK-LABEL: @succeededOilistSimple +func @succeededOilistSimple(%arg0 : i32, %arg1 : i32, %arg2 : i32) { + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 + test.oilist_with_simple_args keyword %arg0 : i32 + // CHECK: test.oilist_with_simple_args otherKeyword %{{.*}} : i32 + test.oilist_with_simple_args otherKeyword %arg0 : i32 + // CHECK: test.oilist_with_simple_args thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args thirdKeyword %arg0 : i32 + + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 + test.oilist_with_simple_args keyword %arg0 : i32 otherKeyword %arg1 : i32 + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args keyword %arg0 : i32 thirdKeyword %arg1 : i32 + // CHECK: test.oilist_with_simple_args otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args thirdKeyword %arg0 : i32 otherKeyword %arg1 : i32 + + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args keyword %arg0 : i32 otherKeyword %arg1 : i32 thirdKeyword %arg2 : i32 + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args otherKeyword %arg0 : i32 keyword %arg1 : i32 thirdKeyword %arg2 : i32 + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args otherKeyword %arg0 : i32 thirdKeyword %arg1 : i32 keyword %arg2 : i32 + return +} + +// ----- + +// CHECK-LABEL: @succeededOilistVariadic +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) +func @succeededOilistVariadic(%arg0: i32, %arg1: i32, %arg2: i32) { + // CHECK: test.oilist_variadic_with_parens keyword(%[[ARG0]], %[[ARG1]] : i32, i32) + test.oilist_variadic_with_parens keyword (%arg0, %arg1 : i32, i32) + // CHECK: test.oilist_variadic_with_parens keyword(%[[ARG0]], %[[ARG1]] : i32, i32) otherKeyword(%[[ARG2]], %[[ARG1]] : i32, i32) + test.oilist_variadic_with_parens otherKeyword (%arg2, %arg1 : i32, i32) keyword (%arg0, %arg1 : i32, i32) + // CHECK: test.oilist_variadic_with_parens keyword(%[[ARG0]], %[[ARG1]] : i32, i32) otherKeyword(%[[ARG0]], %[[ARG1]] : i32, i32) thirdKeyword(%[[ARG2]], %[[ARG0]], %[[ARG1]] : i32, i32, i32) + test.oilist_variadic_with_parens thirdKeyword (%arg2, %arg0, %arg1 : i32, i32, i32) keyword (%arg0, %arg1 : i32, i32) otherKeyword (%arg0, %arg1 : i32, i32) + return +} + +// ----- +// CHECK-LABEL: succeededOilistCustom +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) +func @succeededOilistCustom(%arg0: i32, %arg1: i32, %arg2: i32) { + // CHECK: test.oilist_custom private(%[[ARG0]], %[[ARG1]] : i32, i32) + test.oilist_custom private (%arg0, %arg1 : i32, i32) + // CHECK: test.oilist_custom private(%[[ARG0]], %[[ARG1]] : i32, i32) nowait + test.oilist_custom private (%arg0, %arg1 : i32, i32) nowait + // CHECK: test.oilist_custom private(%arg0, %arg1 : i32, i32) nowait reduction (%arg1) + test.oilist_custom nowait reduction (%arg1) private (%arg0, %arg1 : i32, i32) + return +} + +// ----- + func @failedHasDominanceScopeOutsideDominanceFreeScope() -> () { "test.ssacfg_region"() ({ test.graph_region { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -387,6 +387,16 @@ //===----------------------------------------------------------------------===// // Parsing +static ParseResult +parseCustomOptionalOperand(OpAsmParser &parser, + Optional<OpAsmParser::OperandType> &optOperand) { + if (succeeded(parser.parseOptionalLParen())) { + optOperand.emplace(); + if (parser.parseOperand(*optOperand) || parser.parseRParen()) + return failure(); + } + return success(); +} static ParseResult parseCustomDirectiveOperands( OpAsmParser &parser, OpAsmParser::OperandType &operand, Optional<OpAsmParser::OperandType> &optOperand, @@ -505,6 +515,11 @@ //===----------------------------------------------------------------------===// // Printing +static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, + Value optOperand) { + if (optOperand) + printer << "(" << optOperand << ") "; +} static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, Value operand, Value optOperand, OperandRange varOperands) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -637,6 +637,54 @@ // is the dialect parser and printer hooks. def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">; +// Ops related to OIList primitive +def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> { + let assemblyFormat = [{ + oilist( `keyword` + | `otherKeyword`) attr-dict + }]; +} + +def OIListSimple : TEST_Op<"oilist_with_simple_args", [AttrSizedOperandSegments]> { + let arguments = (ins Optional<AnyType>:$arg0, + Optional<AnyType>:$arg1, + Optional<AnyType>:$arg2); + let assemblyFormat = [{ + oilist( `keyword` $arg0 `:` type($arg0) + | `otherKeyword` $arg1 `:` type($arg1) + | `thirdKeyword` $arg2 `:` type($arg2) ) attr-dict + }]; +} + +def OIListVariadic : TEST_Op<"oilist_variadic_with_parens", [AttrSizedOperandSegments]> { + let arguments = (ins Variadic<AnyType>:$arg0, + Variadic<AnyType>:$arg1, + Variadic<AnyType>:$arg2); + let assemblyFormat = [{ + oilist( `keyword` `(` $arg0 `:` type($arg0) `)` + | `otherKeyword` `(` $arg1 `:` type($arg1) `)` + | `thirdKeyword` `(` $arg2 `:` type($arg2) `)`) attr-dict + }]; +} + +def OIListCustom : TEST_Op<"oilist_custom", [AttrSizedOperandSegments]> { + let arguments = (ins Variadic<AnyType>:$arg0, + Optional<I32>:$optOperand, + UnitAttr:$nowait); + let assemblyFormat = [{ + oilist( `private` `(` $arg0 `:` type($arg0) `)` + | `nowait` + | `reduction` custom<CustomOptionalOperand>($optOperand) + ) attr-dict + }]; +} + +def OIListAllowedLiteral : TEST_Op<"oilist_allowed_literal"> { + let assemblyFormat = [{ + oilist( `foo` | `bar` ) `buzz` attr-dict + }]; +} + // This is used to test encoding of a string attribute into an SSA name of a // pretty printed value name. def StringAttrPrettyNameOp diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -344,6 +344,56 @@ attr-dict }]>; +//===----------------------------------------------------------------------===// +// oilist + +// CHECK: error: format ambiguity because bar is used in two adjacent oilist elements. +def OIListAdjacentOIList : TestFormat_Op<[{ + oilist ( `foo` | `bar` ) oilist ( `bar` | `buzz` ) attr-dict +}]>; +// CHECK: error: expected literal, but got ')' +def OIListErrorExpectedLiteral : TestFormat_Op<[{ + oilist( `keyword` | ) attr-dict +}]>; +// CHECK: error: expected literal, but got ')' +def OIListErrorExpectedEmpty : TestFormat_Op<[{ + oilist() attr-dict +}]>; +// CHECK: error: expected literal, but got '$arg0' +def OIListErrorNoLiteral : TestFormat_Op<[{ + oilist( $arg0 `:` type($arg0) | $arg1 `:` type($arg1) ) attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Optional<AnyType>:$arg0, Optional<AnyType>:$arg1)>; +// CHECK: error: format ambiguity because foo is used both in oilist element and the adjacent literal. +def OIListLiteralAmbiguity : TestFormat_Op<[{ + oilist( `foo` | `bar` ) `foo` attr-dict +}]>; +// CHECK: error: expected '(' before oilist argument list +def OIListStartingToken : TestFormat_Op<[{ + oilist `wrong` attr-dict +}]>; + +// CHECK-NOT: error +def OIListTrivial : TestFormat_Op<[{ + oilist(`keyword` `(` `)` | `otherkeyword` `(` `)`) attr-dict +}]>; +def OIListSimple : TestFormat_Op<[{ + oilist( `keyword` $arg0 `:` type($arg0) + | `otherkeyword` $arg1 `:` type($arg1) + | `thirdkeyword` $arg2 `:` type($arg2) ) + attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Optional<AnyType>:$arg0, Optional<AnyType>:$arg1, Optional<AnyType>:$arg2)>; +def OIListVariadic : TestFormat_Op<[{ + oilist( `keyword` `(` $args0 `:` type($args0) `)` + | `otherkeyword` `(` $args1 `:` type($args1) `)` + | `thirdkeyword` `(` $args2 `:` type($args2) `)`) + attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Variadic<AnyType>:$args0, Variadic<AnyType>:$args1, Variadic<AnyType>:$args2)>; +def OIListCustom : TestFormat_Op<[{ + oilist( `private` `(` $arg0 `:` type($arg0) `)` + | `nowait` + | `reduction` custom<ReductionClause>($arg1, type($arg1))) attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Optional<AnyType>:$arg0, Optional<AnyType>:$arg1)>; + //===----------------------------------------------------------------------===// // Optional Groups //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -50,6 +50,7 @@ greater, question, star, + pipe, // Keywords. keyword_start, @@ -57,6 +58,7 @@ kw_attr_dict_w_keyword, kw_custom, kw_functional_type, + kw_oilist, kw_operands, kw_params, kw_qualified, diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -115,6 +115,8 @@ return formToken(FormatToken::r_paren, tokStart); case '*': return formToken(FormatToken::star, tokStart); + case '|': + return formToken(FormatToken::pipe, tokStart); // Ignore whitespace characters. case 0: @@ -164,6 +166,7 @@ .Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword) .Case("custom", FormatToken::kw_custom) .Case("functional-type", FormatToken::kw_functional_type) + .Case("oilist", FormatToken::kw_oilist) .Case("operands", FormatToken::kw_operands) .Case("params", FormatToken::kw_params) .Case("ref", FormatToken::kw_ref) diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -45,6 +45,7 @@ AttrDictDirective, CustomDirective, FunctionalTypeDirective, + OIListDirective, OperandsDirective, RefDirective, RegionsDirective, @@ -376,6 +377,76 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// OIListElement + +namespace { +class OIListElement : public Element { +public: + OIListElement( + std::vector<std::unique_ptr<Element>> &&literalElements, + std::vector<std::vector<std::unique_ptr<Element>>> &&parsingElements) + : Element(Kind::OIListDirective), + literalElements(std::move(literalElements)), + parsingElements(std::move(parsingElements)) {} + + static bool classof(const Element *element) { + return element->getKind() == Kind::OIListDirective; + } + + /// Returns a range to iterate over the LiteralElements. + auto getLiteralElements() const { + // The use of std::function is unfortunate but necessary here. Lambda + // functions cannot be copied but std::function can be copied. This copy + // constructor is used in llvm::zip. + std::function<LiteralElement *(const std::unique_ptr<Element> &)> + literalElementCastConverter = + [](auto &el) { return cast<LiteralElement>(el.get()); }; + return llvm::make_pointee_range( + llvm::map_range(literalElements, literalElementCastConverter)); + } + + /// Returns a range to iterate over the parsing elements + /// corresponding to the clauses. + auto getParsingElements() const { + // A helper function to take a "vector of unique pointers to Element" and + // return an iterator for a "vector to Elements". + // The use of std::function is unfortunate but necessary here. Lambda + // functions cannot be copied but std::function can be copied. This copy + // constructor is used in llvm::zip. + std::function<llvm::iterator_range<llvm::pointee_iterator<decltype(( + std::begin(std::declval<std::vector<std::unique_ptr<Element>>>())))>>( + const std::vector<std::unique_ptr<Element>> &vec)> + vectorConverter = [](const std::vector<std::unique_ptr<Element>> &vec) { + return llvm::make_pointee_range(vec); + }; + return llvm::map_range(parsingElements, vectorConverter); + } + +private: + /// A vector of `LiteralElement` objects. Each element stores *the keyword* + /// for one case of oilist element. For example, an oilist element along with + /// the literalElements vector: + /// ``` + /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] + /// literalElements = { `keyword`, `otherKeyword` } + /// ``` + std::vector<std::unique_ptr<Element>> literalElements; + + /// A vector of valid declarative assembly format vectors. Each object in + /// parsing elements is a vector of elements in assembly format syntax. + /// For example, an oilist element along with the parsingElements vector: + /// ``` + /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] + /// parsingElements = { + /// { `=`, `(`, $arg0, `)` }, + /// { `<`, $arg1, `>` } + /// } + /// ``` + std::vector<std::vector<std::unique_ptr<Element>>> parsingElements; +}; +} // end anonymous namespace + //===----------------------------------------------------------------------===// // OperationFormat //===----------------------------------------------------------------------===// @@ -819,6 +890,19 @@ return ::mlir::failure(); )"; +/// The code snippet used to generate a parser for OIList +/// +/// {0}: literal keyword corresponding to a case for oilist +const char *oilistParserCode = R"( + if ({0}Clause) { + return parser.emitError(parser.getNameLoc()) + << "`{0}` clause can appear at most once in the expansion of the " + "oilist directive"; + } + {0}Clause = true; + result.addAttribute("{0}", UnitAttr::get(parser.getContext())); +)"; + namespace { /// The type of length for a given parse argument. enum class ArgumentLengthKind { @@ -908,6 +992,11 @@ for (auto &childElement : optional->getElseElements()) genElementParserStorage(&childElement, op, body); + } else if (auto *oilist = dyn_cast<OIListElement>(element)) { + for (auto pelement : oilist->getParsingElements()) + for (auto &element : pelement) + genElementParserStorage(&element, op, body); + } else if (auto *custom = dyn_cast<CustomDirective>(element)) { for (auto ¶mElement : custom->getArguments()) genElementParserStorage(¶mElement, op, body); @@ -1295,6 +1384,35 @@ } body << "\n"; + // OIList Directive + } else if (OIListElement *oilist = dyn_cast<OIListElement>(element)) { + for (LiteralElement &le : oilist->getLiteralElements()) { + body << " bool " << le.getLiteral() << "Clause = false;\n"; + } + + // Generate the parsing loop + body << " while(true) {\n"; + for (auto it : llvm::zip(oilist->getLiteralElements(), + oilist->getParsingElements())) { + LiteralElement &lelement = std::get<0>(it); + llvm::iterator_range<llvm::pointee_iterator< + std::vector<std::unique_ptr<Element>>::const_iterator>> &pelement = + std::get<1>(it); + body << "if (succeeded(parser.parseOptional"; + genLiteralParser(lelement.getLiteral(), body); + body << ")) {\n"; + StringRef attrName = lelement.getLiteral(); + body << formatv(oilistParserCode, attrName); + inferredAttributes.insert(attrName); + for (Element &el : pelement) + genElementParser(&el, body, attrTypeCtx); + body << " } else "; + } + body << " {\n"; + body << " break;\n"; + body << " }\n"; + body << "}\n"; + /// Literals. } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) { body << " if (parser.parse"; @@ -2030,6 +2148,26 @@ return; } + if (auto *oilist = dyn_cast<OIListElement>(element)) { + genLiteralPrinter(" ", body, shouldEmitSpace, lastWasPunctuation); + for (auto it : llvm::zip(oilist->getLiteralElements(), + oilist->getParsingElements())) { + LiteralElement &lelement = std::get<0>(it); + auto &pelement = std::get<1>(it); + + body << " if ((*this)->hasAttrOfType<UnitAttr>(\"" + << lelement.getLiteral() << "\")) {\n"; + genLiteralPrinter(lelement.getLiteral(), body, shouldEmitSpace, + lastWasPunctuation); + for (auto &element : pelement) { + genElementPrinter(&element, body, op, shouldEmitSpace, + lastWasPunctuation); + } + body << " }\n"; + } + return; + } + // Emit the attribute dictionary. if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) { genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword()); @@ -2245,6 +2383,8 @@ /// Verify the state of operation successors within the format. LogicalResult verifySuccessors(llvm::SMLoc loc); + LogicalResult verifyOIListElements(llvm::SMLoc loc); + /// Given the values of an `AllTypesMatch` trait, check for inferable type /// resolution. void handleAllTypesMatchConstraint( @@ -2309,6 +2449,9 @@ ParserContext context); LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, FormatToken tok, ParserContext context); + LogicalResult parseOIListDirective(std::unique_ptr<Element> &element, + FormatToken tok, ParserContext context); + LogicalResult verifyOIListParsingElement(Element *element, llvm::SMLoc loc); LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element, bool isRefChild = false); @@ -2407,7 +2550,8 @@ if (failed(verifyAttributes(loc)) || failed(verifyResults(loc, variableTyResolver)) || failed(verifyOperands(loc, variableTyResolver)) || - failed(verifyRegions(loc)) || failed(verifySuccessors(loc))) + failed(verifyRegions(loc)) || failed(verifySuccessors(loc)) || + failed(verifyOIListElements(loc))) return ::mlir::failure(); // Collect the set of used attributes in the format. @@ -2629,6 +2773,42 @@ return ::mlir::success(); } +LogicalResult FormatParser::verifyOIListElements(llvm::SMLoc loc) { + // Check that all of the successors are within the format. + SmallVector<StringRef> prohibitedLiterals; + for (auto &it : fmt.elements) { + if (auto *oilist = dyn_cast<OIListElement>(it.get())) { + if (!prohibitedLiterals.empty()) { + // We just saw an oilist element in last iteration. Literals should not + // match. + for (auto literal : oilist->getLiteralElements()) { + if (llvm::find(prohibitedLiterals, literal.getLiteral()) != + prohibitedLiterals.end()) { + return emitError( + loc, "format ambiguity because " + literal.getLiteral() + + " is used in two adjacent oilist elements."); + } + } + } + for (auto &literal : oilist->getLiteralElements()) { + prohibitedLiterals.push_back(literal.getLiteral()); + } + } else if (auto *literal = dyn_cast<LiteralElement>(it.get())) { + if (llvm::find(prohibitedLiterals, literal->getLiteral()) != + prohibitedLiterals.end()) { + return emitError( + loc, + "format ambiguity because " + literal->getLiteral() + + " is used both in oilist element and the adjacent literal."); + } + prohibitedLiterals.clear(); + } else { + prohibitedLiterals.clear(); + } + } + return ::mlir::success(); +} + void FormatParser::handleAllTypesMatchConstraint( ArrayRef<StringRef> values, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) { @@ -2819,6 +2999,8 @@ return parseReferenceDirective(element, dirTok.getLoc(), context); case FormatToken::kw_type: return parseTypeDirective(element, dirTok, context); + case FormatToken::kw_oilist: + return parseOIListDirective(element, dirTok, context); default: llvm_unreachable("unknown directive token"); @@ -2835,7 +3017,13 @@ } consumeToken(); - StringRef value = literalTok.getSpelling().drop_front().drop_back(); + StringRef value = literalTok.getSpelling(); + // Prevents things like `$arg0` or empty literals (when a literal is expected + // but not found) from getting segmentation faults. + if (value.size() < 2 || value[0] != '`' || value[value.size() - 1] != '`') + return emitError(literalTok.getLoc(), + "expected literal, but got '" + value + "'"); + value = value.drop_front().drop_back(); // The parsed literal is a space element (`` or ` `). if (value.empty() || (value.size() == 1 && value.front() == ' ')) { @@ -3203,6 +3391,92 @@ return ::mlir::success(); } +LogicalResult +FormatParser::parseOIListDirective(std::unique_ptr<Element> &element, + FormatToken tok, ParserContext context) { + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before oilist argument list"))) + return failure(); + std::vector<std::unique_ptr<Element>> literalElements; + std::vector<std::vector<std::unique_ptr<Element>>> parsingElements; + do { + literalElements.emplace_back(); + parsingElements.emplace_back(); + if (failed(parseLiteral(literalElements.back(), context))) + return failure(); + auto &currParsingElements = parsingElements.back(); + while (curToken.getKind() != FormatToken::pipe && + curToken.getKind() != FormatToken::r_paren) { + currParsingElements.emplace_back(); + if (failed(parseElement(currParsingElements.back(), context)) || + failed(verifyOIListParsingElement(currParsingElements.back().get(), + curToken.getLoc()))) + return failure(); + } + if (curToken.getKind() == FormatToken::pipe) { + consumeToken(); + continue; + } + if (curToken.getKind() == FormatToken::r_paren) { + consumeToken(); + break; + } + } while (true); + + element = std::make_unique<OIListElement>(std::move(literalElements), + std::move(parsingElements)); + return success(); +} + +LogicalResult FormatParser::verifyOIListParsingElement(Element *element, + llvm::SMLoc loc) { + return TypeSwitch<Element *, LogicalResult>(element) + // Only optional attributes can be within an oilist parsing group. + .Case([&](AttributeVariable *attrEle) { + if (!attrEle->getVar()->attr.isOptional()) + return emitError(loc, "only optional attributes can be used to " + "in an oilist parsing group"); + return ::mlir::success(); + }) + // Only optional-like(i.e. variadic) operands can be within an oilist + // parsing group. + .Case([&](OperandVariable *ele) { + if (!ele->getVar()->isVariableLength()) + return emitError(loc, "only variable length operands can be " + "used within an oilist parsing group"); + return ::mlir::success(); + }) + // Only optional-like(i.e. variadic) results can be within an oilist + // parsing group. + .Case([&](ResultVariable *ele) { + if (!ele->getVar()->isVariableLength()) + return emitError(loc, "only variable length results can be " + "used within an oilist parsing group"); + return ::mlir::success(); + }) + .Case([&](RegionVariable *) { + // TODO: When ODS has proper support for marking "optional" regions, add + // a check here. + return ::mlir::success(); + }) + .Case([&](TypeDirective *ele) { + return verifyOIListParsingElement(ele->getOperand(), loc); + }) + .Case([&](FunctionalTypeDirective *ele) { + if (failed(verifyOIListParsingElement(ele->getInputs(), loc))) + return failure(); + return verifyOIListParsingElement(ele->getResults(), loc); + }) + // Literals, whitespace, and custom directives may be used. + .Case<LiteralElement, WhitespaceElement, CustomDirective, + FunctionalTypeDirective, OptionalElement>( + [&](Element *) { return ::mlir::success(); }) + .Default([&](Element *) { + return emitError(loc, "only literals, types, and variables can be " + "used within an oilist group"); + }); +} + LogicalResult FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, FormatToken tok, ParserContext context) {