diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -488,6 +488,75 @@ // ----- +// CHECK-LABEL: @succeededOilistTrivial +func @succeededOilistTrivial() { + // CHECK: test.oilist_with_keywords_only keyword + test.oilist_with_keywords_only keyword + // CHECK: test.oilist_with_keywords_only otherKeyword + test.oilist_with_keywords_only otherKeyword + // CHECK: test.oilist_with_keywords_only keyword otherKeyword + test.oilist_with_keywords_only keyword otherKeyword + // CHECK: test.oilist_with_keywords_only keyword otherKeyword + test.oilist_with_keywords_only otherKeyword keyword + return +} + +// ----- + +// CHECK-LABEL: @succeededOilistSimple +func @succeededOilistSimple(%arg0 : i32, %arg1 : i32, %arg2 : i32) { + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 + test.oilist_with_simple_args keyword %arg0 : i32 + // CHECK: test.oilist_with_simple_args otherKeyword %{{.*}} : i32 + test.oilist_with_simple_args otherKeyword %arg0 : i32 + // CHECK: test.oilist_with_simple_args thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args thirdKeyword %arg0 : i32 + + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 + test.oilist_with_simple_args keyword %arg0 : i32 otherKeyword %arg1 : i32 + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args keyword %arg0 : i32 thirdKeyword %arg1 : i32 + // CHECK: test.oilist_with_simple_args otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args thirdKeyword %arg0 : i32 otherKeyword %arg1 : i32 + + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args keyword %arg0 : i32 otherKeyword %arg1 : i32 thirdKeyword %arg2 : i32 + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args otherKeyword %arg0 : i32 keyword %arg1 : i32 thirdKeyword %arg2 : i32 + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args otherKeyword %arg0 : i32 thirdKeyword %arg1 : i32 keyword %arg2 : i32 + return +} + +// ----- + +// CHECK-LABEL: @succeededOilistVariadic +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) +func @succeededOilistVariadic(%arg0: i32, %arg1: i32, %arg2: i32) { + // CHECK: test.oilist_variadic_with_parens keyword(%[[ARG0]], %[[ARG1]] : i32, i32) + test.oilist_variadic_with_parens keyword (%arg0, %arg1 : i32, i32) + // CHECK: test.oilist_variadic_with_parens keyword(%[[ARG0]], %[[ARG1]] : i32, i32) otherKeyword(%[[ARG2]], %[[ARG1]] : i32, i32) + test.oilist_variadic_with_parens otherKeyword (%arg2, %arg1 : i32, i32) keyword (%arg0, %arg1 : i32, i32) + // CHECK: test.oilist_variadic_with_parens keyword(%[[ARG0]], %[[ARG1]] : i32, i32) otherKeyword(%[[ARG0]], %[[ARG1]] : i32, i32) thirdKeyword(%[[ARG2]], %[[ARG0]], %[[ARG1]] : i32, i32, i32) + test.oilist_variadic_with_parens thirdKeyword (%arg2, %arg0, %arg1 : i32, i32, i32) keyword (%arg0, %arg1 : i32, i32) otherKeyword (%arg0, %arg1 : i32, i32) + return +} + +// ----- +// CHECK-LABEL: succeededOilistCustom +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) +func @succeededOilistCustom(%arg0: i32, %arg1: i32, %arg2: i32) { + // CHECK: test.oilist_custom private(%[[ARG0]], %[[ARG1]] : i32, i32) + test.oilist_custom private (%arg0, %arg1 : i32, i32) + // CHECK: test.oilist_custom private(%[[ARG0]], %[[ARG1]] : i32, i32) nowait + test.oilist_custom private (%arg0, %arg1 : i32, i32) nowait + // CHECK: test.oilist_custom private(%arg0, %arg1 : i32, i32) nowait reduction (%arg1) + test.oilist_custom nowait reduction (%arg1) private (%arg0, %arg1 : i32, i32) + return +} + +// ----- + func @failedHasDominanceScopeOutsideDominanceFreeScope() -> () { "test.ssacfg_region"() ({ test.graph_region { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -388,6 +388,16 @@ //===----------------------------------------------------------------------===// // Parsing +static ParseResult +parseCustomOptionalOperand(OpAsmParser &parser, + Optional &optOperand) { + if (succeeded(parser.parseOptionalLParen())) { + optOperand.emplace(); + if (parser.parseOperand(*optOperand) || parser.parseRParen()) + return failure(); + } + return success(); +} static ParseResult parseCustomDirectiveOperands( OpAsmParser &parser, OpAsmParser::OperandType &operand, Optional &optOperand, @@ -506,6 +516,12 @@ //===----------------------------------------------------------------------===// // Printing +static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, + Value optOperand) { + if (optOperand) { + printer << "(" << optOperand << ") "; + } +} static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, Value operand, Value optOperand, OperandRange varOperands) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -597,6 +597,48 @@ ); } +// Ops related to OIList primitive +def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> { + let assemblyFormat = [{ + oilist[ `keyword` + | `otherKeyword`] attr-dict + }]; +} + +def OIListSimple : TEST_Op<"oilist_with_simple_args", [AttrSizedOperandSegments]> { + let arguments = (ins Optional:$arg0, + Optional:$arg1, + Optional:$arg2); + let assemblyFormat = [{ + oilist[ `keyword` $arg0 `:` type($arg0) + | `otherKeyword` $arg1 `:` type($arg1) + | `thirdKeyword` $arg2 `:` type($arg2) ] attr-dict + }]; +} + +def OIListVariadic : TEST_Op<"oilist_variadic_with_parens", [AttrSizedOperandSegments]> { + let arguments = (ins Variadic:$arg0, + Variadic:$arg1, + Variadic:$arg2); + let assemblyFormat = [{ + oilist[ `keyword` `(` $arg0 `:` type($arg0) `)` + | `otherKeyword` `(` $arg1 `:` type($arg1) `)` + | `thirdKeyword` `(` $arg2 `:` type($arg2) `)`] attr-dict + }]; +} + +def OIListCustom : TEST_Op<"oilist_custom", [AttrSizedOperandSegments]> { + let arguments = (ins Variadic:$arg0, + Optional:$optOperand, + UnitAttr:$nowait); + let assemblyFormat = [{ + oilist[ `private` `(` $arg0 `:` type($arg0) `)` + | `nowait` + | `reduction` custom($optOperand) + ] attr-dict + }]; +} + // This is used to test encoding of a string attribute into an SSA name of a // pretty printed value name. def StringAttrPrettyNameOp diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -344,6 +344,44 @@ attr-dict }]>; +//===----------------------------------------------------------------------===// +// oilist + +// CHECK: error: expected literal, but got ']' +def OIListErrorExpectedLiteral : TestFormat_Op<[{ + oilist[ `keyword` | ] attr-dict +}]>; +// CHECK: error: expected literal, but got ']' +def OIListErrorExpectedEmpty : TestFormat_Op<[{ + oilist[] attr-dict +}]>; +// CHECK: error: expected literal, but got '$arg0' +def OIListErrorNoLiteral : TestFormat_Op<[{ + oilist[ $arg0 `:` type($arg0) | $arg1 `:` type($arg1) ] attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Optional:$arg0, Optional:$arg1)>; + +// CHECK-NOT: error +def OIListTrivial : TestFormat_Op<[{ + oilist[`keyword` `(` `)` | `otherkeyword` `(` `)`] attr-dict +}]>; +def OIListSimple : TestFormat_Op<[{ + oilist[ `keyword` $arg0 `:` type($arg0) + | `otherkeyword` $arg1 `:` type($arg1) + | `thirdkeyword` $arg2 `:` type($arg2) ] + attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Optional:$arg0, Optional:$arg1, Optional:$arg2)>; +def OIListVariadic : TestFormat_Op<[{ + oilist[ `keyword` `(` $args0 `:` type($args0) `)` + | `otherkeyword` `(` $args1 `:` type($args1) `)` + | `thirdkeyword` `(` $args2 `:` type($args2) `)`] + attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Variadic:$args0, Variadic:$args1, Variadic:$args2)>; +def OIListCustom : TestFormat_Op<[{ + oilist[ `private` `(` $arg0 `:` type($arg0) `)` + | `nowait` + | `reduction` custom($arg1, type($arg1))] attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Optional:$arg0, Optional:$arg1)>; + //===----------------------------------------------------------------------===// // Optional Groups //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -42,6 +42,8 @@ // Tokens with no info. l_paren, r_paren, + l_square, + r_square, caret, colon, comma, @@ -50,6 +52,7 @@ greater, question, star, + pipe, // Keywords. keyword_start, @@ -65,6 +68,7 @@ kw_struct, kw_successors, kw_type, + kw_oilist, keyword_end, // String valued tokens. diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -113,8 +113,14 @@ return formToken(FormatToken::l_paren, tokStart); case ')': return formToken(FormatToken::r_paren, tokStart); + case '[': + return formToken(FormatToken::l_square, tokStart); + case ']': + return formToken(FormatToken::r_square, tokStart); case '*': return formToken(FormatToken::star, tokStart); + case '|': + return formToken(FormatToken::pipe, tokStart); // Ignore whitespace characters. case 0: @@ -172,6 +178,7 @@ .Case("struct", FormatToken::kw_struct) .Case("successors", FormatToken::kw_successors) .Case("type", FormatToken::kw_type) + .Case("oilist", FormatToken::kw_oilist) .Default(FormatToken::identifier); return FormatToken(kind, str); } diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -50,6 +50,7 @@ ResultsDirective, SuccessorsDirective, TypeDirective, + OIListDirective, /// This element is a literal. Literal, @@ -356,6 +357,49 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// OIListElement + +namespace { +class OIListElement : public Element { +public: + OIListElement( + std::vector> &&literalElements, + std::vector>> &&parsingElements) + : Element{Kind::OIListDirective}, + literalElements(std::move(literalElements)), + parsingElements(std::move(parsingElements)) {} + + static bool classof(const Element *element) { + return element->getKind() == Kind::OIListDirective; + } + + int numElements() const { return literalElements.size(); } + +public: + // A vector of `LiteralElement` objects. Each element stores *the keyword* for + // one case of oilist element. For example, an oilist element along with the + // literalElements vector: + // ``` + // oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] + // literalElements = { `keyword`, `otherKeyword` } + // ``` + std::vector> literalElements; + + // A vector of valid declarative assembly format vectors. Each object in + // parsing elements is a vector of elements in assembly format syntax. + // For example, an oilist element along with the parsingElements vector: + // ``` + // oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] + // parsingElements = { + // { `=`, `(`, $arg0, `)` }, + // { `<`, $arg1, `>` } + // } + // ``` + std::vector>> parsingElements; +}; +} // end anonymous namespace + //===----------------------------------------------------------------------===// // OperationFormat //===----------------------------------------------------------------------===// @@ -782,6 +826,19 @@ return ::mlir::failure(); )"; +/// The code snippet used to generate a parser for OIList +/// +/// {0}: literal keyword coressponding to a case for oilist +const char *oilistParserCode = R"( + if ({0}Clause) { + return parser.emitError(parser.getCurrentLocation()) + << "at most one {0}oilist element can appear on the " + << result.name.getStringRef() << " operation"; + } + {0}Clause = true; + result.addAttribute("{0}", UnitAttr::get(parser.getContext())); +)"; + namespace { /// The type of length for a given parse argument. enum class ArgumentLengthKind { @@ -871,6 +928,11 @@ for (auto &childElement : optional->getElseElements()) genElementParserStorage(&childElement, op, body); + } else if (auto *oilist = dyn_cast(element)) { + for (auto &clause : oilist->parsingElements) + for (auto &element : clause) + genElementParserStorage(element.get(), op, body); + } else if (auto *custom = dyn_cast(element)) { for (auto ¶mElement : custom->getArguments()) genElementParserStorage(¶mElement, op, body); @@ -1254,6 +1316,37 @@ } body << "\n"; + } else if (OIListElement *oilist = dyn_cast(element)) { + + // Store done vector and positions for clauses + for (int idx = 0; idx < oilist->numElements(); idx++) { + auto literalElement = + cast(oilist->literalElements[idx].get()); + body << " bool " << literalElement->getLiteral() << "Clause = false" + << ";\n"; + } + + // Generate the parsing loop + body << " while(true) {\n"; + for (int idx = 0; idx < oilist->numElements(); idx++) { + body << (idx == 0 ? " " : " else "); + auto literalElement = + cast(oilist->literalElements[idx].get()); + body << "if (succeeded(parser.parseOptional"; + genLiteralParser(literalElement->getLiteral(), body); + body << ")) {\n"; + StringRef attrname = literalElement->getLiteral(); + body << formatv(oilistParserCode, attrname); + inferredAttributes.insert(attrname); + for (auto &parsingElement : oilist->parsingElements[idx]) + genElementParser(parsingElement.get(), body, attrTypeCtx); + body << " }\n"; + } + body << " else {\n"; + body << " break;\n"; + body << " }\n"; + body << " }\n"; + /// Literals. } else if (LiteralElement *literal = dyn_cast(element)) { body << " if (parser.parse"; @@ -1981,6 +2074,25 @@ return; } + if (auto *oilist = dyn_cast(element)) { + genLiteralPrinter(" ", body, shouldEmitSpace, lastWasPunctuation); + for (int i = 0; i < oilist->numElements(); i++) { + auto lelement = + dyn_cast(oilist->literalElements[i].get()); + auto &pelements = oilist->parsingElements[i]; + body << " if ((*this)->hasAttrOfType(\"" + << lelement->getLiteral() << "\")) {\n"; + genLiteralPrinter(lelement->getLiteral(), body, shouldEmitSpace, + lastWasPunctuation); + for (auto &element : pelements) { + genElementPrinter(element.get(), body, op, shouldEmitSpace, + lastWasPunctuation); + } + body << " }\n"; + } + return; + } + // Emit the attribute dictionary. if (auto *attrDict = dyn_cast(element)) { genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword()); @@ -2255,6 +2367,8 @@ ParserContext context); LogicalResult parseTypeDirective(std::unique_ptr &element, FormatToken tok, ParserContext context); + LogicalResult parseOIListDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context); LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element, bool isRefChild = false); @@ -2756,6 +2870,8 @@ return parseReferenceDirective(element, dirTok.getLoc(), context); case FormatToken::kw_type: return parseTypeDirective(element, dirTok, context); + case FormatToken::kw_oilist: + return parseOIListDirective(element, dirTok, context); default: llvm_unreachable("unknown directive token"); @@ -2772,7 +2888,14 @@ } consumeToken(); - StringRef value = literalTok.getSpelling().drop_front().drop_back(); + StringRef value = literalTok.getSpelling(); + // Prevents things like `$arg0` or empty literals (when a literal is expected + // but not found) from getting segmentation faults. + if (value.size() >= 2 && value[0] == '`' && value[value.size() - 1] == '`') + value = value.drop_front().drop_back(); + else + return emitError(literalTok.getLoc(), + "expected literal, but got '" + value + "'"); // The parsed literal is a space element (`` or ` `). if (value.empty() || (value.size() == 1 && value.front() == ' ')) { @@ -3140,6 +3263,41 @@ return ::mlir::success(); } +LogicalResult +FormatParser::parseOIListDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context) { + if (failed(parseToken(FormatToken::l_square, + "expected '[' before argument list"))) + return failure(); + std::vector> literalElements; + std::vector>> parsingElements; + do { + literalElements.emplace_back(); + parsingElements.emplace_back(); + if (failed(parseLiteral(literalElements.back(), context))) + return failure(); + auto &currParsingElements = parsingElements.back(); + while (curToken.getKind() != FormatToken::pipe && + curToken.getKind() != FormatToken::r_square) { + currParsingElements.emplace_back(); + if (failed(parseElement(currParsingElements.back(), context))) + return failure(); + } + if (curToken.getKind() == FormatToken::pipe) { + consumeToken(); + continue; + } + if (curToken.getKind() == FormatToken::r_square) { + consumeToken(); + break; + } + } while (true); + + element = std::make_unique(std::move(literalElements), + std::move(parsingElements)); + return success(); +} + LogicalResult FormatParser::parseTypeDirective(std::unique_ptr &element, FormatToken tok, ParserContext context) {