diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -699,6 +699,14 @@ - `input` must be either an operand or result [variable](#variables), the `operands` directive, or the `results` directive. +* `type_ref` ( input ) + + - Represents a reference to the type of the given input that must have + already been resolved. + - `input` must be either an operand or result [variable](#variables), the + `operands` directive, or the `results` directive. + - Used to pass previously parsed types to custom directives. + #### Literals A literal is either a keyword or punctuation surrounded by \`\`. @@ -762,6 +770,10 @@ - Single: `Type &` - Optional: `Type &` - Variadic: `SmallVectorImpl &` +* TypeRef Directives + - Single: `Type` + - Optional: `Type` + - Variadic: `const SmallVectorImpl &` When a variable is optional, the value should only be specified if the variable is present. Otherwise, the value should remain `None` or null. @@ -788,6 +800,10 @@ - Single: `Type` - Optional: `Type` - Variadic: `TypeRange` +* TypeRef Directives + - Single: `Type` + - Optional: `Type` + - Variadic: `TypeRange` When a variable is optional, the provided value may be null. diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -308,6 +308,25 @@ return failure(); return success(); } +static ParseResult +parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, + Type optOperandType, + const SmallVectorImpl &varOperandTypes) { + if (parser.parseKeyword("type_refs_capture")) + return failure(); + + Type operandType2, optOperandType2; + SmallVector varOperandTypes2; + if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, + varOperandTypes2)) + return failure(); + + if (operandType != operandType2 || optOperandType != optOperandType2 || + varOperandTypes != varOperandTypes2) + return failure(); + + return success(); +} static ParseResult parseCustomDirectiveOperandsAndTypes( OpAsmParser &parser, OpAsmParser::OperandType &operand, Optional &optOperand, @@ -365,6 +384,14 @@ printer << ", " << optOperandType; printer << " -> (" << varOperandTypes << ")"; } +static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, + Type operandType, + Type optOperandType, + TypeRange varOperandTypes) { + printer << " type_refs_capture "; + printCustomDirectiveResults(printer, operandType, optOperandType, + varOperandTypes); +} static void printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand, Value optOperand, OperandRange varOperands, diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1518,6 +1518,22 @@ }]; } +def FormatCustomDirectiveResultsWithTypeRefs + : TEST_Op<"format_custom_directive_results_with_type_refs", + [AttrSizedResultSegments]> { + let results = (outs AnyType:$result, Optional:$optResult, + Variadic:$varResults); + let assemblyFormat = [{ + custom( + type($result), type($optResult), type($varResults) + ) + custom( + type_ref($result), type_ref($optResult), type_ref($varResults) + ) + attr-dict + }]; +} + def FormatCustomDirectiveSuccessors : TEST_Op<"format_custom_directive_successors", [Terminator]> { let successors = (successor AnySuccessor:$successor, diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -230,8 +230,66 @@ def DirectiveTypeZOperandInvalidI : TestFormat_Op<"type_operand_invalid_i", [{ type($result) type($result) }]>, Results<(outs I64:$result)>; + +//===----------------------------------------------------------------------===// +// type_ref + +// CHECK: error: 'type_ref' of 'operand' is not bound by a prior 'type' directive +def DirectiveTypeZZTypeRefOperandInvalidC : TestFormat_Op<"type_ref_operand_invalid_c", [{ + type_ref($operand) type(operands) +}]>, Arguments<(ins I64:$operand)>; +// CHECK: error: 'operands' 'type_ref' is not bound by a prior 'type' directive +def DirectiveTypeZZTypeRefOperandInvalidD : TestFormat_Op<"type_ref_operand_invalid_d", [{ + type_ref(operands) type($operand) +}]>, Arguments<(ins I64:$operand)>; +// CHECK: error: 'type_ref' of 'operand' is not bound by a prior 'type' directive +def DirectiveTypeZZTypeRefOperandInvalidE : TestFormat_Op<"type_ref_operand_invalid_e", [{ + type_ref($operand) type($operand) +}]>, Arguments<(ins I64:$operand)>; +// CHECK: error: 'type_ref' of 'result' is not bound by a prior 'type' directive +def DirectiveTypeZZTypeRefOperandInvalidG : TestFormat_Op<"type_ref_operand_invalid_g", [{ + type_ref($result) type(results) +}]>, Results<(outs I64:$result)>; +// CHECK: error: 'results' 'type_ref' is not bound by a prior 'type' directive +def DirectiveTypeZZTypeRefOperandInvalidH : TestFormat_Op<"type_ref_operand_invalid_h", [{ + type_ref(results) type($result) +}]>, Results<(outs I64:$result)>; +// CHECK: error: 'type_ref' of 'result' is not bound by a prior 'type' directive +def DirectiveTypeZZTypeRefOperandInvalidI : TestFormat_Op<"type_ref_operand_invalid_i", [{ + type_ref($result) type($result) +}]>, Results<(outs I64:$result)>; + +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandB : TestFormat_Op<"type_ref_operand_valid_b", [{ + type_ref(operands) attr-dict +}]>; +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandD : TestFormat_Op<"type_ref_operand_valid_d", [{ + type(operands) type_ref($operand) attr-dict +}]>, Arguments<(ins I64:$operand)>; +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandE : TestFormat_Op<"type_ref_operand_valid_e", [{ + type($operand) type_ref($operand) attr-dict +}]>, Arguments<(ins I64:$operand)>; +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandF : TestFormat_Op<"type_ref_operand_valid_f", [{ + type(results) type_ref(results) attr-dict +}]>; +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandG : TestFormat_Op<"type_ref_operand_valid_g", [{ + type($result) type_ref(results) attr-dict +}]>, Results<(outs I64:$result)>; +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandH : TestFormat_Op<"type_ref_operand_valid_h", [{ + type(results) type_ref($result) attr-dict +}]>, Results<(outs I64:$result)>; +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandI : TestFormat_Op<"type_ref_operand_valid_i", [{ + type($result) type_ref($result) attr-dict +}]>, Results<(outs I64:$result)>; + // CHECK-NOT: error: -def DirectiveTypeZOperandValid : TestFormat_Op<"type_operand_valid", [{ +def DirectiveTypeZZZOperandValid : TestFormat_Op<"type_operand_valid", [{ type(operands) type(results) attr-dict }]>; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -237,6 +237,12 @@ // CHECK: test.format_custom_directive_results : i64 -> (i64) test.format_custom_directive_results : i64 -> (i64) +// CHECK: test.format_custom_directive_results_with_type_refs : i64, i64 -> (i64) type_refs_capture : i64, i64 -> (i64) +test.format_custom_directive_results_with_type_refs : i64, i64 -> (i64) type_refs_capture : i64, i64 -> (i64) + +// CHECK: test.format_custom_directive_results_with_type_refs : i64 -> (i64) type_refs_capture : i64 -> (i64) +test.format_custom_directive_results_with_type_refs : i64 -> (i64) type_refs_capture : i64 -> (i64) + func @foo() { // CHECK: test.format_custom_directive_successors ^bb1, ^bb2 test.format_custom_directive_successors ^bb1, ^bb2 diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -53,6 +53,7 @@ ResultsDirective, SuccessorsDirective, TypeDirective, + TypeRefDirective, /// This element is a literal. Literal, @@ -230,7 +231,19 @@ /// The operand that is used to format the directive. std::unique_ptr operand; }; -} // end anonymous namespace + +/// This class represents the `type_ref` directive. +class TypeRefDirective + : public DirectiveElement { +public: + TypeRefDirective(std::unique_ptr arg) : operand(std::move(arg)) {} + Element *getOperand() const { return operand.get(); } + +private: + /// The operand that is used to format the directive. + std::unique_ptr operand; +}; +} // namespace //===----------------------------------------------------------------------===// // LiteralElement @@ -805,6 +818,19 @@ << llvm::formatv( " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n", name); + } else if (auto *dir = dyn_cast(element)) { + ArgumentLengthKind lengthKind; + StringRef name = getTypeListName(dir->getOperand(), lengthKind); + // Refer to the previously encountered TypeDirective for name. + // Take a `const ::mlir::SmallVector<::mlir::Type, 1> &` in the declaration + // to properly track the types that will be parsed and pushed later on. + if (lengthKind != ArgumentLengthKind::Single) + body << " const ::mlir::SmallVector<::mlir::Type, 1> &" << name + << "TypesRef(" << name << "Types);\n"; + else + body << llvm::formatv( + " ::llvm::ArrayRef<::mlir::Type> {0}RawTypesRef({0}RawTypes);\n", + name); } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind ignored; body << " ::llvm::ArrayRef<::mlir::Type> " @@ -844,6 +870,15 @@ else body << llvm::formatv("{0}Successor", name); + } else if (auto *dir = dyn_cast(¶m)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Variadic) + body << llvm::formatv("{0}TypesRef", listName); + else if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv("{0}TypeRef", listName); + else + body << formatv("{0}RawTypesRef[0]", listName); } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -876,6 +911,16 @@ "{0}Operand;\n", operand->getVar()->name); } + } else if (auto *dir = dyn_cast(¶m)) { + // Reference to an optional which may or may not have been set. + // Retrieve from vector if not empty. + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv( + " ::mlir::Type {0}TypeRef = {0}TypesRef.empty() " + "? Type() : {0}TypesRef[0];\n", + listName); } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -907,6 +952,9 @@ body << llvm::formatv(" if ({0}Operand.hasValue())\n" " {0}Operands.push_back(*{0}Operand);\n", var->name); + } else if (auto *dir = dyn_cast(¶m)) { + // In the `type_ref` case, do not parse a new Type that needs to be added. + // Just do nothing here. } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -1098,6 +1146,15 @@ } else if (isa(element)) { body << llvm::formatv(successorListParserCode, "full"); + } else if (auto *dir = dyn_cast(element)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Variadic) + body << llvm::formatv(variadicTypeParserCode, listName); + else if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv(optionalTypeParserCode, listName); + else + body << formatv(typeParserCode, listName); } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -1428,6 +1485,17 @@ } else if (auto *successor = dyn_cast(¶m)) { body << successor->getVar()->name << "()"; + } else if (auto *dir = dyn_cast(¶m)) { + auto *typeOperand = dir->getOperand(); + auto *operand = dyn_cast(typeOperand); + auto *var = operand ? operand->getVar() + : cast(typeOperand)->getVar(); + if (var->isVariadic()) + body << var->name << "().getTypes()"; + else if (var->isOptional()) + body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name); + else + body << var->name << "().getType()"; } else if (auto *dir = dyn_cast(¶m)) { auto *typeOperand = dir->getOperand(); auto *operand = dyn_cast(typeOperand); @@ -1601,6 +1669,9 @@ } else if (auto *dir = dyn_cast(element)) { body << " p << "; genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; + } else if (auto *dir = dyn_cast(element)) { + body << " p << "; + genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; } else if (auto *dir = dyn_cast(element)) { body << " p.printFunctionalType("; genTypeOperandPrinter(dir->getInputs(), body) << ", "; @@ -1666,6 +1737,7 @@ kw_results, kw_successors, kw_type, + kw_type_ref, keyword_end, // String valued tokens. @@ -1870,6 +1942,7 @@ .Case("results", Token::kw_results) .Case("successors", Token::kw_successors) .Case("type", Token::kw_type) + .Case("type_ref", Token::kw_type_ref) .Default(Token::identifier); return Token(kind, str); } @@ -1990,8 +2063,9 @@ LogicalResult parseSuccessorsDirective(std::unique_ptr &element, llvm::SMLoc loc, bool isTopLevel); LogicalResult parseTypeDirective(std::unique_ptr &element, Token tok, - bool isTopLevel); - LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element); + bool isTopLevel, bool isTypeRef = false); + LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element, + bool isTypeRef = false); //===--------------------------------------------------------------------===// // Lexer Utilities @@ -2436,6 +2510,8 @@ return parseResultsDirective(element, dirTok.getLoc(), isTopLevel); case Token::kw_successors: return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel); + case Token::kw_type_ref: + return parseTypeDirective(element, dirTok, isTopLevel, /*isTypeRef=*/true); case Token::kw_type: return parseTypeDirective(element, dirTok, isTopLevel); @@ -2501,7 +2577,10 @@ return ::mlir::success(); }; for (auto &ele : elements) { - if (auto *typeEle = dyn_cast(ele.get())) { + if (auto *typeEle = dyn_cast(ele.get())) { + if (failed(checkTypeOperand(typeEle->getOperand()))) + return failure(); + } else if (auto *typeEle = dyn_cast(ele.get())) { if (failed(checkTypeOperand(typeEle->getOperand()))) return ::mlir::failure(); } else if (auto *typeEle = dyn_cast(ele.get())) { @@ -2561,7 +2640,7 @@ // Literals, custom directives, and type directives may be used, // but they can't anchor the group. .Case([&](Element *) { + OptionalElement, TypeRefDirective, TypeDirective>([&](Element *) { if (isAnchor) return emitError(childLoc, "only variables can be used to anchor " "an optional group"); @@ -2624,6 +2703,13 @@ // After parsing all of the elements, ensure that all type directives refer // only to variables. for (auto &ele : elements) { + if (auto *typeEle = dyn_cast(ele.get())) { + if (!isa(typeEle->getOperand())) { + return emitError(curLoc, + "type_ref directives within a custom directive " + "may only refer to variables"); + } + } if (auto *typeEle = dyn_cast(ele.get())) { if (!isa(typeEle->getOperand())) { return emitError(curLoc, "type directives within a custom directive " @@ -2645,8 +2731,8 @@ return ::mlir::failure(); // Verify that the element can be placed within a custom directive. - if (!isa(parameters.back().get())) { + if (!isa(parameters.back().get())) { return emitError(childLoc, "only variables and types may be used as " "parameters to a custom directive"); } @@ -2723,22 +2809,26 @@ LogicalResult FormatParser::parseTypeDirective(std::unique_ptr &element, Token tok, - bool isTopLevel) { + bool isTopLevel, bool isTypeRef) { llvm::SMLoc loc = tok.getLoc(); if (!isTopLevel) return emitError(loc, "'type' is only valid as a top-level directive"); std::unique_ptr operand; if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || - failed(parseTypeDirectiveOperand(operand)) || + failed(parseTypeDirectiveOperand(operand, isTypeRef)) || failed(parseToken(Token::r_paren, "expected ')' after argument list"))) return ::mlir::failure(); - element = std::make_unique(std::move(operand)); + if (isTypeRef) + element = std::make_unique(std::move(operand)); + else + element = std::make_unique(std::move(operand)); return ::mlir::success(); } LogicalResult -FormatParser::parseTypeDirectiveOperand(std::unique_ptr &element) { +FormatParser::parseTypeDirectiveOperand(std::unique_ptr &element, + bool isTypeRef) { llvm::SMLoc loc = curToken.getLoc(); if (failed(parseElement(element, /*isTopLevel=*/false))) return ::mlir::failure(); @@ -2748,23 +2838,36 @@ if (auto *var = dyn_cast(element.get())) { unsigned opIdx = var->getVar() - op.operand_begin(); - if (fmt.allOperandTypes || seenOperandTypes.test(opIdx)) + if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.test(opIdx))) return emitError(loc, "'type' of '" + var->getVar()->name + "' is already bound"); + if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx))) + return emitError(loc, "'type_ref' of '" + var->getVar()->name + + "' is not bound by a prior 'type' directive"); seenOperandTypes.set(opIdx); } else if (auto *var = dyn_cast(element.get())) { unsigned resIdx = var->getVar() - op.result_begin(); - if (fmt.allResultTypes || seenResultTypes.test(resIdx)) + if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.test(resIdx))) return emitError(loc, "'type' of '" + var->getVar()->name + "' is already bound"); + if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.test(resIdx))) + return emitError(loc, "'type_ref' of '" + var->getVar()->name + + "' is not bound by a prior 'type' directive"); seenResultTypes.set(resIdx); } else if (isa(&*element)) { - if (fmt.allOperandTypes || seenOperandTypes.any()) + if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.any())) return emitError(loc, "'operands' 'type' is already bound"); + if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.all())) + return emitError( + loc, + "'operands' 'type_ref' is not bound by a prior 'type' directive"); fmt.allOperandTypes = true; } else if (isa(&*element)) { - if (fmt.allResultTypes || seenResultTypes.any()) + if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.any())) return emitError(loc, "'results' 'type' is already bound"); + if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.all())) + return emitError( + loc, "'results' 'type_ref' is not bound by a prior 'type' directive"); fmt.allResultTypes = true; } else { return emitError(loc, "invalid argument to 'type' directive");