diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -699,6 +699,14 @@ - `input` must be either an operand or result [variable](#variables), the `operands` directive, or the `results` directive. +* `type_ref` ( input ) + + - Represents a reference to the type of the given input that must have + already been resolved. + - `input` must be either an operand or result [variable](#variables), the + `operands` directive, or the `results` directive. + - Used to pass previously parsed types to custom directives. + #### Literals A literal is either a keyword or punctuation surrounded by \`\`. diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -53,6 +53,7 @@ ResultsDirective, SuccessorsDirective, TypeDirective, + TypeRefDirective, /// This element is a literal. Literal, @@ -226,6 +227,18 @@ TypeDirective(std::unique_ptr arg) : operand(std::move(arg)) {} Element *getOperand() const { return operand.get(); } +private: + /// The operand that is used to format the directive. + std::unique_ptr operand; +}; + +/// This class represents the `type_ref` directive. +class TypeRefDirective + : public DirectiveElement { +public: + TypeRefDirective(std::unique_ptr arg) : operand(std::move(arg)) {} + Element *getOperand() const { return operand.get(); } + private: /// The operand that is used to format the directive. std::unique_ptr operand; @@ -805,6 +818,18 @@ << llvm::formatv( " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n", name); + } else if (auto *dir = dyn_cast(element)) { + ArgumentLengthKind lengthKind; + StringRef name = getTypeListName(dir->getOperand(), lengthKind); + // Refer to the previously encountered TypeDirective for name. + // Take a `::mlir::SmallVector<::mlir::Type, 1> &` in the declaration to + // properly track the types that will be parsed and pushed later on. + if (lengthKind != ArgumentLengthKind::Single) + body << " ::mlir::SmallVector<::mlir::Type, 1> &" << name << "TypesRef(" + << name << "Types);\n"; + else + body << llvm::formatv( + " ::llvm::ArrayRef<::mlir::Type> {0}TypesRef({0}RawTypes);\n", name); } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind ignored; body << " ::llvm::ArrayRef<::mlir::Type> " @@ -844,6 +869,15 @@ else body << llvm::formatv("{0}Successor", name); + } else if (auto *dir = dyn_cast(¶m)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Variadic) + body << llvm::formatv("{0}TypesRef", listName); + else if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv("{0}Type", listName); + else + body << formatv("{0}RawTypes[0]", listName); } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -876,6 +910,8 @@ "{0}Operand;\n", operand->getVar()->name); } + } else if (auto *dir = dyn_cast(¶m)) { + // TODO: is noop good enough? } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -907,6 +943,8 @@ body << llvm::formatv(" if ({0}Operand.hasValue())\n" " {0}Operands.push_back(*{0}Operand);\n", var->name); + } else if (auto *dir = dyn_cast(¶m)) { + // TODO: is noop good enough? } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -1098,6 +1136,8 @@ } else if (isa(element)) { body << llvm::formatv(successorListParserCode, "full"); + } else if (auto *dir = dyn_cast(element)) { + llvm_unreachable("Cannot use `type_ref` directive inside ElementParser"); } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -1428,6 +1468,17 @@ } else if (auto *successor = dyn_cast(¶m)) { body << successor->getVar()->name << "()"; + } else if (auto *dir = dyn_cast(¶m)) { + auto *typeOperand = dir->getOperand(); + auto *operand = dyn_cast(typeOperand); + auto *var = operand ? operand->getVar() + : cast(typeOperand)->getVar(); + if (var->isVariadic()) + body << var->name << "().getTypes()"; + else if (var->isOptional()) + body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name); + else + body << var->name << "().getType()"; } else if (auto *dir = dyn_cast(¶m)) { auto *typeOperand = dir->getOperand(); auto *operand = dyn_cast(typeOperand); @@ -1605,6 +1656,8 @@ body << " p.printFunctionalType("; genTypeOperandPrinter(dir->getInputs(), body) << ", "; genTypeOperandPrinter(dir->getResults(), body) << ");\n"; + } else if (auto *dir = dyn_cast(element)) { + llvm_unreachable("type_ref does not print"); } else { llvm_unreachable("unknown format element"); } @@ -1666,6 +1719,7 @@ kw_results, kw_successors, kw_type, + kw_type_ref, keyword_end, // String valued tokens. @@ -1870,6 +1924,7 @@ .Case("results", Token::kw_results) .Case("successors", Token::kw_successors) .Case("type", Token::kw_type) + .Case("type_ref", Token::kw_type_ref) .Default(Token::identifier); return Token(kind, str); } @@ -1990,8 +2045,9 @@ LogicalResult parseSuccessorsDirective(std::unique_ptr &element, llvm::SMLoc loc, bool isTopLevel); LogicalResult parseTypeDirective(std::unique_ptr &element, Token tok, - bool isTopLevel); - LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element); + bool isTopLevel, bool isTypeRef = false); + LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element, + bool isTypeRef = false); //===--------------------------------------------------------------------===// // Lexer Utilities @@ -2436,6 +2492,8 @@ return parseResultsDirective(element, dirTok.getLoc(), isTopLevel); case Token::kw_successors: return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel); + case Token::kw_type_ref: + return parseTypeDirective(element, dirTok, isTopLevel, /*isTypeRef=*/true); case Token::kw_type: return parseTypeDirective(element, dirTok, isTopLevel); @@ -2501,7 +2559,10 @@ return ::mlir::success(); }; for (auto &ele : elements) { - if (auto *typeEle = dyn_cast(ele.get())) { + if (auto *typeEle = dyn_cast(ele.get())) { + if (failed(checkTypeOperand(typeEle->getOperand()))) + return failure(); + } else if (auto *typeEle = dyn_cast(ele.get())) { if (failed(checkTypeOperand(typeEle->getOperand()))) return ::mlir::failure(); } else if (auto *typeEle = dyn_cast(ele.get())) { @@ -2561,7 +2622,7 @@ // Literals, custom directives, and type directives may be used, // but they can't anchor the group. .Case([&](Element *) { + OptionalElement, TypeRefDirective, TypeDirective>([&](Element *) { if (isAnchor) return emitError(childLoc, "only variables can be used to anchor " "an optional group"); @@ -2624,6 +2685,13 @@ // After parsing all of the elements, ensure that all type directives refer // only to variables. for (auto &ele : elements) { + if (auto *typeEle = dyn_cast(ele.get())) { + if (!isa(typeEle->getOperand())) { + return emitError(curLoc, + "type_ref directives within a custom directive " + "may only refer to variables"); + } + } if (auto *typeEle = dyn_cast(ele.get())) { if (!isa(typeEle->getOperand())) { return emitError(curLoc, "type directives within a custom directive " @@ -2645,8 +2713,8 @@ return ::mlir::failure(); // Verify that the element can be placed within a custom directive. - if (!isa(parameters.back().get())) { + if (!isa(parameters.back().get())) { return emitError(childLoc, "only variables and types may be used as " "parameters to a custom directive"); } @@ -2723,22 +2791,26 @@ LogicalResult FormatParser::parseTypeDirective(std::unique_ptr &element, Token tok, - bool isTopLevel) { + bool isTopLevel, bool isTypeRef) { llvm::SMLoc loc = tok.getLoc(); if (!isTopLevel) return emitError(loc, "'type' is only valid as a top-level directive"); std::unique_ptr operand; if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || - failed(parseTypeDirectiveOperand(operand)) || + failed(parseTypeDirectiveOperand(operand, isTypeRef)) || failed(parseToken(Token::r_paren, "expected ')' after argument list"))) return ::mlir::failure(); - element = std::make_unique(std::move(operand)); + if (isTypeRef) + element = std::make_unique(std::move(operand)); + else + element = std::make_unique(std::move(operand)); return ::mlir::success(); } LogicalResult -FormatParser::parseTypeDirectiveOperand(std::unique_ptr &element) { +FormatParser::parseTypeDirectiveOperand(std::unique_ptr &element, + bool isTypeRef) { llvm::SMLoc loc = curToken.getLoc(); if (failed(parseElement(element, /*isTopLevel=*/false))) return ::mlir::failure(); @@ -2749,22 +2821,26 @@ if (auto *var = dyn_cast(element.get())) { unsigned opIdx = var->getVar() - op.operand_begin(); if (fmt.allOperandTypes || seenOperandTypes.test(opIdx)) - return emitError(loc, "'type' of '" + var->getVar()->name + - "' is already bound"); + if (!isTypeRef) + return emitError(loc, "'type' of '" + var->getVar()->name + + "' is already bound"); seenOperandTypes.set(opIdx); } else if (auto *var = dyn_cast(element.get())) { unsigned resIdx = var->getVar() - op.result_begin(); if (fmt.allResultTypes || seenResultTypes.test(resIdx)) - return emitError(loc, "'type' of '" + var->getVar()->name + - "' is already bound"); + if (!isTypeRef) + return emitError(loc, "'type' of '" + var->getVar()->name + + "' is already bound"); seenResultTypes.set(resIdx); } else if (isa(&*element)) { if (fmt.allOperandTypes || seenOperandTypes.any()) - return emitError(loc, "'operands' 'type' is already bound"); + if (!isTypeRef) + return emitError(loc, "'operands' 'type' is already bound"); fmt.allOperandTypes = true; } else if (isa(&*element)) { if (fmt.allResultTypes || seenResultTypes.any()) - return emitError(loc, "'results' 'type' is already bound"); + if (!isTypeRef) + return emitError(loc, "'results' 'type' is already bound"); fmt.allResultTypes = true; } else { return emitError(loc, "invalid argument to 'type' directive");