diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -620,6 +620,13 @@ - The constraints on `inputs` and `results` are the same as the `input` of the `type` directive. +* `oilist` ( \`keyword\` elements | \`otherKeyword\` elements ...) + + - Represents an optional order-independent list of clauses. Each clause + has a keyword associated with it and a string of other assembly format + elements. + - Each clause can appear 0 or 1 times (in any order). + * `operands` - Represents all of the operands of an operation. diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -488,6 +488,75 @@ // ----- +// CHECK-LABEL: @succeededOilistTrivial +func @succeededOilistTrivial() { + // CHECK: test.oilist_with_keywords_only keyword + test.oilist_with_keywords_only keyword + // CHECK: test.oilist_with_keywords_only otherKeyword + test.oilist_with_keywords_only otherKeyword + // CHECK: test.oilist_with_keywords_only keyword otherKeyword + test.oilist_with_keywords_only keyword otherKeyword + // CHECK: test.oilist_with_keywords_only keyword otherKeyword + test.oilist_with_keywords_only otherKeyword keyword + return +} + +// ----- + +// CHECK-LABEL: @succeededOilistSimple +func @succeededOilistSimple(%arg0 : i32, %arg1 : i32, %arg2 : i32) { + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 + test.oilist_with_simple_args keyword %arg0 : i32 + // CHECK: test.oilist_with_simple_args otherKeyword %{{.*}} : i32 + test.oilist_with_simple_args otherKeyword %arg0 : i32 + // CHECK: test.oilist_with_simple_args thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args thirdKeyword %arg0 : i32 + + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 + test.oilist_with_simple_args keyword %arg0 : i32 otherKeyword %arg1 : i32 + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args keyword %arg0 : i32 thirdKeyword %arg1 : i32 + // CHECK: test.oilist_with_simple_args otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args thirdKeyword %arg0 : i32 otherKeyword %arg1 : i32 + + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args keyword %arg0 : i32 otherKeyword %arg1 : i32 thirdKeyword %arg2 : i32 + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args otherKeyword %arg0 : i32 keyword %arg1 : i32 thirdKeyword %arg2 : i32 + // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32 + test.oilist_with_simple_args otherKeyword %arg0 : i32 thirdKeyword %arg1 : i32 keyword %arg2 : i32 + return +} + +// ----- + +// CHECK-LABEL: @succeededOilistVariadic +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) +func @succeededOilistVariadic(%arg0: i32, %arg1: i32, %arg2: i32) { + // CHECK: test.oilist_variadic_with_parens keyword(%[[ARG0]], %[[ARG1]] : i32, i32) + test.oilist_variadic_with_parens keyword (%arg0, %arg1 : i32, i32) + // CHECK: test.oilist_variadic_with_parens keyword(%[[ARG0]], %[[ARG1]] : i32, i32) otherKeyword(%[[ARG2]], %[[ARG1]] : i32, i32) + test.oilist_variadic_with_parens otherKeyword (%arg2, %arg1 : i32, i32) keyword (%arg0, %arg1 : i32, i32) + // CHECK: test.oilist_variadic_with_parens keyword(%[[ARG0]], %[[ARG1]] : i32, i32) otherKeyword(%[[ARG0]], %[[ARG1]] : i32, i32) thirdKeyword(%[[ARG2]], %[[ARG0]], %[[ARG1]] : i32, i32, i32) + test.oilist_variadic_with_parens thirdKeyword (%arg2, %arg0, %arg1 : i32, i32, i32) keyword (%arg0, %arg1 : i32, i32) otherKeyword (%arg0, %arg1 : i32, i32) + return +} + +// ----- +// CHECK-LABEL: succeededOilistCustom +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) +func @succeededOilistCustom(%arg0: i32, %arg1: i32, %arg2: i32) { + // CHECK: test.oilist_custom private(%[[ARG0]], %[[ARG1]] : i32, i32) + test.oilist_custom private (%arg0, %arg1 : i32, i32) + // CHECK: test.oilist_custom private(%[[ARG0]], %[[ARG1]] : i32, i32) nowait + test.oilist_custom private (%arg0, %arg1 : i32, i32) nowait + // CHECK: test.oilist_custom private(%arg0, %arg1 : i32, i32) nowait reduction (%arg1) + test.oilist_custom nowait reduction (%arg1) private (%arg0, %arg1 : i32, i32) + return +} + +// ----- + func @failedHasDominanceScopeOutsideDominanceFreeScope() -> () { "test.ssacfg_region"() ({ test.graph_region { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -398,6 +398,16 @@ //===----------------------------------------------------------------------===// // Parsing +static ParseResult +parseCustomOptionalOperand(OpAsmParser &parser, + Optional &optOperand) { + if (succeeded(parser.parseOptionalLParen())) { + optOperand.emplace(); + if (parser.parseOperand(*optOperand) || parser.parseRParen()) + return failure(); + } + return success(); +} static ParseResult parseCustomDirectiveOperands( OpAsmParser &parser, OpAsmParser::OperandType &operand, Optional &optOperand, @@ -516,6 +526,11 @@ //===----------------------------------------------------------------------===// // Printing +static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, + Value optOperand) { + if (optOperand) + printer << "(" << optOperand << ") "; +} static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, Value operand, Value optOperand, OperandRange varOperands) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -601,6 +601,48 @@ // is the dialect parser and printer hooks. def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">; +// Ops related to OIList primitive +def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> { + let assemblyFormat = [{ + oilist( `keyword` + | `otherKeyword`) attr-dict + }]; +} + +def OIListSimple : TEST_Op<"oilist_with_simple_args", [AttrSizedOperandSegments]> { + let arguments = (ins Optional:$arg0, + Optional:$arg1, + Optional:$arg2); + let assemblyFormat = [{ + oilist( `keyword` $arg0 `:` type($arg0) + | `otherKeyword` $arg1 `:` type($arg1) + | `thirdKeyword` $arg2 `:` type($arg2) ) attr-dict + }]; +} + +def OIListVariadic : TEST_Op<"oilist_variadic_with_parens", [AttrSizedOperandSegments]> { + let arguments = (ins Variadic:$arg0, + Variadic:$arg1, + Variadic:$arg2); + let assemblyFormat = [{ + oilist( `keyword` `(` $arg0 `:` type($arg0) `)` + | `otherKeyword` `(` $arg1 `:` type($arg1) `)` + | `thirdKeyword` `(` $arg2 `:` type($arg2) `)`) attr-dict + }]; +} + +def OIListCustom : TEST_Op<"oilist_custom", [AttrSizedOperandSegments]> { + let arguments = (ins Variadic:$arg0, + Optional:$optOperand, + UnitAttr:$nowait); + let assemblyFormat = [{ + oilist( `private` `(` $arg0 `:` type($arg0) `)` + | `nowait` + | `reduction` custom($optOperand) + ) attr-dict + }]; +} + // This is used to test encoding of a string attribute into an SSA name of a // pretty printed value name. def StringAttrPrettyNameOp diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -344,6 +344,48 @@ attr-dict }]>; +//===----------------------------------------------------------------------===// +// oilist + +// CHECK: error: expected literal, but got ')' +def OIListErrorExpectedLiteral : TestFormat_Op<[{ + oilist( `keyword` | ) attr-dict +}]>; +// CHECK: error: expected literal, but got ')' +def OIListErrorExpectedEmpty : TestFormat_Op<[{ + oilist() attr-dict +}]>; +// CHECK: error: expected literal, but got '$arg0' +def OIListErrorNoLiteral : TestFormat_Op<[{ + oilist( $arg0 `:` type($arg0) | $arg1 `:` type($arg1) ) attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Optional:$arg0, Optional:$arg1)>; +// CHECK: error: expected '(' before oilist argument list +def OIListStartingToken : TestFormat_Op<[{ + oilist `wrong` attr-dict +}]>; + +// CHECK-NOT: error +def OIListTrivial : TestFormat_Op<[{ + oilist(`keyword` `(` `)` | `otherkeyword` `(` `)`) attr-dict +}]>; +def OIListSimple : TestFormat_Op<[{ + oilist( `keyword` $arg0 `:` type($arg0) + | `otherkeyword` $arg1 `:` type($arg1) + | `thirdkeyword` $arg2 `:` type($arg2) ) + attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Optional:$arg0, Optional:$arg1, Optional:$arg2)>; +def OIListVariadic : TestFormat_Op<[{ + oilist( `keyword` `(` $args0 `:` type($args0) `)` + | `otherkeyword` `(` $args1 `:` type($args1) `)` + | `thirdkeyword` `(` $args2 `:` type($args2) `)`) + attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Variadic:$args0, Variadic:$args1, Variadic:$args2)>; +def OIListCustom : TestFormat_Op<[{ + oilist( `private` `(` $arg0 `:` type($arg0) `)` + | `nowait` + | `reduction` custom($arg1, type($arg1))) attr-dict +}], [AttrSizedOperandSegments]>, Arguments<(ins Optional:$arg0, Optional:$arg1)>; + //===----------------------------------------------------------------------===// // Optional Groups //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -50,6 +50,7 @@ greater, question, star, + pipe, // Keywords. keyword_start, @@ -57,6 +58,7 @@ kw_attr_dict_w_keyword, kw_custom, kw_functional_type, + kw_oilist, kw_operands, kw_params, kw_ref, diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -115,6 +115,8 @@ return formToken(FormatToken::r_paren, tokStart); case '*': return formToken(FormatToken::star, tokStart); + case '|': + return formToken(FormatToken::pipe, tokStart); // Ignore whitespace characters. case 0: @@ -164,6 +166,7 @@ .Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword) .Case("custom", FormatToken::kw_custom) .Case("functional-type", FormatToken::kw_functional_type) + .Case("oilist", FormatToken::kw_oilist) .Case("operands", FormatToken::kw_operands) .Case("params", FormatToken::kw_params) .Case("ref", FormatToken::kw_ref) diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -44,6 +44,7 @@ AttrDictDirective, CustomDirective, FunctionalTypeDirective, + OIListDirective, OperandsDirective, RefDirective, RegionsDirective, @@ -356,6 +357,73 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// OIListElement + +namespace { +class OIListElement : public Element { + +private: + // A vector of `LiteralElement` objects. Each element stores *the keyword* for + // one case of oilist element. For example, an oilist element along with the + // literalElements vector: + // ``` + // oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] + // literalElements = { `keyword`, `otherKeyword` } + // ``` + std::vector> literalElements; + + // A vector of valid declarative assembly format vectors. Each object in + // parsing elements is a vector of elements in assembly format syntax. + // For example, an oilist element along with the parsingElements vector: + // ``` + // oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`] + // parsingElements = { + // { `=`, `(`, $arg0, `)` }, + // { `<`, $arg1, `>` } + // } + // ``` + std::vector>> parsingElements; + + // A helper function to take a "vector of unique pointers to Element" and + // return an iterator for a "vector to Elements". This function is used in + // `getParsingElements`. + static auto + vectorConverter(const std::vector> &vec) { + std::function &)> converter = + [](auto &ell) { return ell.get(); }; + return llvm::make_pointee_range(llvm::map_range(vec, converter)); + } + +public: + OIListElement( + std::vector> &&literalElements, + std::vector>> &&parsingElements) + : Element(Kind::OIListDirective), + literalElements(std::move(literalElements)), + parsingElements(std::move(parsingElements)) {} + + static bool classof(const Element *element) { + return element->getKind() == Kind::OIListDirective; + } + + // Returns an llvm::iterator_range to iterate over the LiteralElements. + auto getLiteralElements() const { + std::function &)> + literalElementCastConverter = + [](auto &el) { return cast(el.get()); }; + return llvm::make_pointee_range( + llvm::map_range(literalElements, literalElementCastConverter)); + } + + // Returns an llvm::iterator_range to iterate over the parsing elements + // coressponding to the clauses. + auto getParsingElements() const { + return llvm::map_range(parsingElements, OIListElement::vectorConverter); + } +}; +} // end anonymous namespace + //===----------------------------------------------------------------------===// // OperationFormat //===----------------------------------------------------------------------===// @@ -786,6 +854,19 @@ return ::mlir::failure(); )"; +/// The code snippet used to generate a parser for OIList +/// +/// {0}: literal keyword corresponding to a case for oilist +const char *oilistParserCode = R"( + if ({0}Clause) { + return parser.emitError(parser.getNameLoc()) + << "`{0}` clause can appear at most once in the expansion of the " + "oilist directive"; + } + {0}Clause = true; + result.addAttribute("{0}", UnitAttr::get(parser.getContext())); +)"; + namespace { /// The type of length for a given parse argument. enum class ArgumentLengthKind { @@ -875,6 +956,11 @@ for (auto &childElement : optional->getElseElements()) genElementParserStorage(&childElement, op, body); + } else if (auto *oilist = dyn_cast(element)) { + for (auto pelement : oilist->getParsingElements()) + for (auto &element : pelement) + genElementParserStorage(&element, op, body); + } else if (auto *custom = dyn_cast(element)) { for (auto ¶mElement : custom->getArguments()) genElementParserStorage(¶mElement, op, body); @@ -1258,6 +1344,33 @@ } body << "\n"; + // OIList Directive + } else if (OIListElement *oilist = dyn_cast(element)) { + for (LiteralElement &le : oilist->getLiteralElements()) { + body << " bool " << le.getLiteral() << "Clause = false;\n"; + } + + // Generate the parsing loop + body << " while(true) {\n"; + for (auto it : llvm::zip(oilist->getLiteralElements(), + oilist->getParsingElements())) { + LiteralElement &lelement = std::get<0>(it); + auto pelement = std::get<1>(it); + body << "if (succeeded(parser.parseOptional"; + genLiteralParser(lelement.getLiteral(), body); + body << ")) {\n"; + StringRef attrname = lelement.getLiteral(); + body << formatv(oilistParserCode, attrname); + inferredAttributes.insert(attrname); + for (auto &el : pelement) + genElementParser(&el, body, attrTypeCtx); + body << " } else "; + } + body << " {\n"; + body << " break;\n"; + body << " }\n"; + body << "}\n"; + /// Literals. } else if (LiteralElement *literal = dyn_cast(element)) { body << " if (parser.parse"; @@ -1990,6 +2103,26 @@ return; } + if (auto *oilist = dyn_cast(element)) { + genLiteralPrinter(" ", body, shouldEmitSpace, lastWasPunctuation); + for (auto it : llvm::zip(oilist->getLiteralElements(), + oilist->getParsingElements())) { + LiteralElement &lelement = std::get<0>(it); + auto pelement = std::get<1>(it); + + body << " if ((*this)->hasAttrOfType(\"" + << lelement.getLiteral() << "\")) {\n"; + genLiteralPrinter(lelement.getLiteral(), body, shouldEmitSpace, + lastWasPunctuation); + for (auto &element : pelement) { + genElementPrinter(&element, body, op, shouldEmitSpace, + lastWasPunctuation); + } + body << " }\n"; + } + return; + } + // Emit the attribute dictionary. if (auto *attrDict = dyn_cast(element)) { genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword()); @@ -2264,6 +2397,8 @@ ParserContext context); LogicalResult parseTypeDirective(std::unique_ptr &element, FormatToken tok, ParserContext context); + LogicalResult parseOIListDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context); LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element, bool isRefChild = false); @@ -2765,6 +2900,8 @@ return parseReferenceDirective(element, dirTok.getLoc(), context); case FormatToken::kw_type: return parseTypeDirective(element, dirTok, context); + case FormatToken::kw_oilist: + return parseOIListDirective(element, dirTok, context); default: llvm_unreachable("unknown directive token"); @@ -2781,7 +2918,13 @@ } consumeToken(); - StringRef value = literalTok.getSpelling().drop_front().drop_back(); + StringRef value = literalTok.getSpelling(); + // Prevents things like `$arg0` or empty literals (when a literal is expected + // but not found) from getting segmentation faults. + if (value.size() < 2 || value[0] != '`' || value[value.size() - 1] != '`') + return emitError(literalTok.getLoc(), + "expected literal, but got '" + value + "'"); + value = value.drop_front().drop_back(); // The parsed literal is a space element (`` or ` `). if (value.empty() || (value.size() == 1 && value.front() == ' ')) { @@ -3149,6 +3292,41 @@ return ::mlir::success(); } +LogicalResult +FormatParser::parseOIListDirective(std::unique_ptr &element, + FormatToken tok, ParserContext context) { + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before oilist argument list"))) + return failure(); + std::vector> literalElements; + std::vector>> parsingElements; + do { + literalElements.emplace_back(); + parsingElements.emplace_back(); + if (failed(parseLiteral(literalElements.back(), context))) + return failure(); + auto &currParsingElements = parsingElements.back(); + while (curToken.getKind() != FormatToken::pipe && + curToken.getKind() != FormatToken::r_paren) { + currParsingElements.emplace_back(); + if (failed(parseElement(currParsingElements.back(), context))) + return failure(); + } + if (curToken.getKind() == FormatToken::pipe) { + consumeToken(); + continue; + } + if (curToken.getKind() == FormatToken::r_paren) { + consumeToken(); + break; + } + } while (true); + + element = std::make_unique(std::move(literalElements), + std::move(parsingElements)); + return success(); +} + LogicalResult FormatParser::parseTypeDirective(std::unique_ptr &element, FormatToken tok, ParserContext context) {