diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -344,6 +344,44 @@ attr-dict }]>; +//===----------------------------------------------------------------------===// +// oilist + +// CHECK: error: expected literal, but got ']' +def OIListErrorExpectedLiteral : TestFormat_Op<[{ + oilist[ `keyword` | ] attr-dict +}]>; +// CHECK: error: expected literal, but got ']' +def OIListErrorExpectedEmpty : TestFormat_Op<[{ + oilist[] attr-dict +}]>; +// CHECK: error: expected literal, but got '$arg0' +def OIListErrorNoLiteral : TestFormat_Op<[{ + oilist[ $arg0 `:` type($arg0) | $arg1 `:` type($arg1) ] attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Optional:$arg0, Optional:$arg1)>; + +// CHECK-NOT: error +def OIListTrivial : TestFormat_Op<[{ + oilist[`keyword` `(` `)` | `otherkeyword` `(` `)`] attr-dict +}]>; +def OIListSimple : TestFormat_Op<[{ + oilist[ `keyword` $arg0 `:` type($arg0) + | `otherkeyword` $arg1 `:` type($arg1) + | `thirdkeyword` $arg2 `:` type($arg2) ] + attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Optional:$arg0, Optional:$arg1, Optional:$arg2)>; +def OIListVariadic : TestFormat_Op<[{ + oilist[ `keyword` `(` $args0 `:` type($args0) `)` + | `otherkeyword` `(` $args1 `:` type($args1) `)` + | `thirdkeyword` `(` $args2 `:` type($args2) `)`] + attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Variadic:$args0, Variadic:$args1, Variadic:$args2)>; +def OIListCustom : TestFormat_Op<[{ + oilist[ `private` `(` $arg0 `:` type($arg0) `)` + | `nowait` + | `reduction` custom($arg1, type($arg1))] attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Optional:$arg0, Optional:$arg1)>; + //===----------------------------------------------------------------------===// // Optional Groups //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -42,6 +42,8 @@ // Tokens with no info. l_paren, r_paren, + l_bracket, + r_bracket, caret, colon, comma, @@ -50,6 +52,7 @@ greater, question, star, + pipe, // Keywords. keyword_start, @@ -65,6 +68,7 @@ kw_struct, kw_successors, kw_type, + kw_oilist, keyword_end, // String valued tokens. diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -113,8 +113,14 @@ return formToken(FormatToken::l_paren, tokStart); case ')': return formToken(FormatToken::r_paren, tokStart); + case '[': + return formToken(FormatToken::l_bracket, tokStart); + case ']': + return formToken(FormatToken::r_bracket, tokStart); case '*': return formToken(FormatToken::star, tokStart); + case '|': + return formToken(FormatToken::pipe, tokStart); // Ignore whitespace characters. case 0: @@ -172,6 +178,7 @@ .Case("struct", FormatToken::kw_struct) .Case("successors", FormatToken::kw_successors) .Case("type", FormatToken::kw_type) + .Case("oilist", FormatToken::kw_oilist) .Default(FormatToken::identifier); return FormatToken(kind, str); } diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -67,6 +67,7 @@ /// This element is an optional element. Optional, + OIList, }; Element(Kind kind) : kind(kind) {} virtual ~Element() = default; @@ -356,6 +357,29 @@ }; } // end anonymous namespace +//===----------------------------------------------------------------------===// +// OIListElement + +namespace { +class OIListElement : public Element { +public: + OIListElement( + std::vector> &&literalElements, + std::vector>> &&parsingElements) + : Element{Kind::OIList}, literalElements(std::move(literalElements)), + parsingElements(std::move(parsingElements)) {} + static bool classof(const Element *element) { + return element->getKind() == Kind::OIList; + } + + int numElements() const { return literalElements.size(); } + +public: + std::vector> literalElements; + std::vector>> parsingElements; +}; +} // end anonymous namespace + //===----------------------------------------------------------------------===// // OperationFormat //===----------------------------------------------------------------------===// @@ -766,6 +790,18 @@ return ::mlir::failure(); )"; +/// The code snippet used to generate a parser for OIList +/// +/// {0}: literal +const char *oilistParserCode = R"( + if(done[{0}Clause]) + return parser.emitError(parser.getCurrentLocation()) + << "at most one {0} element can appear on the " + << result.name.getStringRef() << " operation"; + done[{0}Clause] = true; + result.addAttribute("{0}", UnitAttr::get(parser.getContext())); +)"; + namespace { /// The type of length for a given parse argument. enum class ArgumentLengthKind { @@ -854,7 +890,12 @@ genElementParserStorage(&childElement, op, body); for (auto &childElement : optional->getElseElements()) genElementParserStorage(&childElement, op, body); - + } else if (auto *oilist = dyn_cast(element)) { + for (auto &clause : oilist->parsingElements) { + for (auto &element : clause) { + genElementParserStorage(element.get(), op, body); + } + } } else if (auto *custom = dyn_cast(element)) { for (auto ¶mElement : custom->getArguments()) genElementParserStorage(¶mElement, op, body); @@ -1238,6 +1279,37 @@ } body << "\n"; + } else if (OIListElement *oilist = dyn_cast(element)) { + + // Store done vector and positions for clauses + body << " llvm::BitVector done(" << oilist->numElements() << ", false);\n"; + for (int idx = 0; idx < oilist->numElements(); idx++) { + auto literalElement = + dyn_cast(oilist->literalElements[idx].get()); + body << " int " << literalElement->getLiteral() << "Clause = " << idx + << ";\n"; + } + + // Generate the parsing loop + body << " while(true) {\n"; + for (int idx = 0; idx < oilist->numElements(); idx++) { + body << (idx == 0 ? " " : " else "); + auto literalElement = + dyn_cast(oilist->literalElements[idx].get()); + body << "if(succeeded(parser.parseOptionalKeyword(\"" + << literalElement->getLiteral() << "\"))) {\n"; + StringRef attrname = literalElement->getLiteral(); + body << formatv(oilistParserCode, attrname); + inferredAttributes.insert(attrname); + for (auto &parsingElement : oilist->parsingElements[idx]) + genElementParser(parsingElement.get(), body, attrTypeCtx); + body << " }\n"; + } + body << " else {\n"; + body << " break;\n"; + body << " }\n"; + body << " }\n"; + /// Literals. } else if (LiteralElement *literal = dyn_cast(element)) { body << " if (parser.parse"; @@ -1946,6 +2018,25 @@ return; } + if (auto *oilist = dyn_cast(element)) { + body << " _odsPrinter << \" \";\n"; + for (int i = 0; i < oilist->numElements(); i++) { + auto lelement = + dyn_cast(oilist->literalElements[i].get()); + auto &pelements = oilist->parsingElements[i]; + body << " if((*this)->getAttr(\"" << lelement->getLiteral() + << "\").template dyn_cast_or_null<::mlir::UnitAttr>()) {\n"; + body << " _odsPrinter << \"" << lelement->getLiteral() << "\";\n"; + shouldEmitSpace = false; + for (auto &element : pelements) { + genElementPrinter(element.get(), body, op, shouldEmitSpace, + lastWasPunctuation); + } + body << " _odsPrinter << \" \";\n"; + body << " }\n"; + } + return; + } // Emit the attribute dictionary. if (auto *attrDict = dyn_cast(element)) { genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword()); @@ -2193,6 +2284,8 @@ ParserContext context); LogicalResult parseTypeDirective(std::unique_ptr &element, FormatToken tok, ParserContext context); + LogicalResult parseOIListDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context); LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element, bool isRefChild = false); @@ -2694,6 +2787,8 @@ return parseReferenceDirective(element, dirTok.getLoc(), context); case FormatToken::kw_type: return parseTypeDirective(element, dirTok, context); + case FormatToken::kw_oilist: + return parseOIListDirective(element, dirTok, context); default: llvm_unreachable("unknown directive token"); @@ -2710,7 +2805,12 @@ } consumeToken(); - StringRef value = literalTok.getSpelling().drop_front().drop_back(); + StringRef value = literalTok.getSpelling(); + if (value.size() >= 2 && value[0] == '`' && value[value.size() - 1] == '`') + value = value.drop_front().drop_back(); + else + return emitError(literalTok.getLoc(), + "expected literal, but got '" + value + "'"); // The parsed literal is a space element (`` or ` `). if (value.empty() || (value.size() == 1 && value.front() == ' ')) { @@ -3078,6 +3178,47 @@ return ::mlir::success(); } +LogicalResult +FormatParser::parseOIListDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context) { + if (failed(parseToken(FormatToken::l_bracket, + "expected '[' before argument list"))) + return failure(); + std::vector> literalElements; + std::vector>> parsingElements; + do { + literalElements.push_back({}); + parsingElements.push_back({}); + if (failed(parseLiteral(literalElements.back(), context))) + return failure(); + printf("Parsed Literal:%s\n", + dyn_cast(literalElements.back().get()) + ->getLiteral() + .str() + .c_str()); + auto &currParsingElements = parsingElements.back(); + while (curToken.getKind() != FormatToken::pipe && + curToken.getKind() != FormatToken::r_bracket) { + currParsingElements.push_back({}); + if (failed(parseElement(currParsingElements.back(), context))) + return failure(); + } + if (curToken.getKind() == FormatToken::pipe) { + consumeToken(); + continue; + } + if (curToken.getKind() == FormatToken::r_bracket) + break; + } while (true); + + if (failed(parseToken(FormatToken::r_bracket, + "expected ']' after argument list"))) + return failure(); + element = std::make_unique(std::move(literalElements), + std::move(parsingElements)); + return success(); +} + LogicalResult FormatParser::parseTypeDirective(std::unique_ptr &element, FormatToken tok, ParserContext context) {