diff --git a/mlir/docs/PDLL.md b/mlir/docs/PDLL.md --- a/mlir/docs/PDLL.md +++ b/mlir/docs/PDLL.md @@ -577,9 +577,10 @@ #### Operands The operands section corresponds to the operands of the operation. This section -of an operation expression may be elided, in which case the operands are not -constrained in any way. When present, the operands of an operation expression -are interpreted in the following ways: +of an operation expression may be elided, which within a `match` section means +that the operands are not constrained in any way. If elided within a `rewrite` +section, the operation is treated as having no operands. When present, the +operands of an operation expression are interpreted in the following ways: 1) A single instance of type `ValueRange`: @@ -612,10 +613,11 @@ #### Results -The results section corresponds to the result types of the operation. This -section of an operation expression may be elided, in which case the result types -are not constrained in any way. When present, the result types of an operation -expression are interpreted in the following ways: +The results section corresponds to the result types of the operation. This section +of an operation expression may be elided, which within a `match` section means +that the result types are not constrained in any way. If elided within a `rewrite` +section, the results of the operation are [inferred](#inferred-results). When present, +the result types of an operation expression are interpreted in the following ways: 1) A single instance of type `TypeRange`: @@ -646,6 +648,87 @@ let root = op -> (result: Type, otherResults: TypeRange); ``` +#### Inferred Results + +Within the `rewrite` section of a pattern, the result types of an +operation are inferred if they are elided or otherwise not +previously bound. The ["variable binding"](#variable-binding) section above +discusses the concept of "binding" in more detail. Below are various examples +that build upon this to help showcase how a result type may be "bound": + +* Binding to a [constant](#type-expression): + +```pdll +op -> (type<"i32">); +``` + +* Binding to types within the `match` section: + +```pdll +Pattern { + replace op -> (resultTypes: TypeRange) + with op -> (resultTypes); +} +``` + +* Binding to previously inferred types: + +```pdll +Pattern { + rewrite root: Op with { + // `resultTypes` here is *not* yet bound, and will be inferred when + // creating `dialect.op`. Any uses of `resultTypes` after this expression, + // will use the types inferred when creating this operation. + op -> (resultTypes: TypeRange); + + // `resultTypes` here is bound to the types inferred when creating `dialect.op`. + op -> (resultTypes); + }; +} +``` + +* Binding to a [`Native Rewrite`](#native-rewriters) method result: + +```pdll +Rewrite BuildTypes() -> TypeRange; + +Pattern { + rewrite root: Op with { + op -> (BuildTypes()); + }; +} +``` + +Below are the set of contexts in which result type inferrence is supported: + +##### Inferred Results of Replacement Operation + +Replacements have the invariant that the types of the replacement values must +match the result types of the input operation. This means that when replacing +one operation with another, the result types of the replacement operation may +be inferred from the result types of the operation being replaced. For example, +consider the following pattern: + +```pdll +Pattern => replace op with op; +``` + +This pattern could be written in a more explicit way as: + +```pdll +Pattern { + replace op -> (resultTypes: TypeRange) + with op -> (resultTypes); +} +``` + +##### Inferred Results with InferTypeOpInterface + +`InferTypeOpInterface` is an interface that enables operations to infer its result +types from its input attributes, operands, regions, etc. When the result types of +an operation cannot be inferred from any other context, this interface is invoked +to infer the result types of the operation. + #### Attributes The attributes section of the operation expression corresponds to the attribute diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Context.h b/mlir/include/mlir/Tools/PDLL/ODS/Context.h --- a/mlir/include/mlir/Tools/PDLL/ODS/Context.h +++ b/mlir/include/mlir/Tools/PDLL/ODS/Context.h @@ -62,7 +62,8 @@ /// and a boolean indicating if the operation newly inserted (false if the /// operation already existed). std::pair - insertOperation(StringRef name, StringRef summary, StringRef desc, SMLoc loc); + insertOperation(StringRef name, StringRef summary, StringRef desc, + bool supportsResultTypeInferrence, SMLoc loc); /// Lookup an operation registered with the given name, or null if no /// operation with that name is registered. diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h --- a/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h +++ b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h @@ -34,7 +34,8 @@ /// and a boolean indicating if the operation newly inserted (false if the /// operation already existed). std::pair - insertOperation(StringRef name, StringRef summary, StringRef desc, SMLoc loc); + insertOperation(StringRef name, StringRef summary, StringRef desc, + bool supportsResultTypeInferrence, SMLoc loc); /// Lookup an operation registered with the given name, or null if no /// operation with that name is registered. diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Operation.h b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h --- a/mlir/include/mlir/Tools/PDLL/ODS/Operation.h +++ b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h @@ -75,6 +75,12 @@ /// Return the name of this value. StringRef getName() const { return name; } + /// Returns true if this value is variable length, i.e. if it is Variadic or + /// Optional. + bool isVariableLength() const { + return variableLengthKind != VariableLengthKind::Single; + } + /// Returns true if this value is variadic (Note this is false if the value is /// Optional). bool isVariadic() const { @@ -157,8 +163,12 @@ /// Returns the results of this operation. ArrayRef getResults() const { return results; } + /// Return if the operation is known to support result type inferrence. + bool hasResultTypeInferrence() const { return supportsTypeInferrence; } + private: - Operation(StringRef name, StringRef summary, StringRef desc, SMLoc loc); + Operation(StringRef name, StringRef summary, StringRef desc, + bool supportsTypeInferrence, SMLoc loc); /// The name of the operation. std::string name; @@ -167,6 +177,9 @@ std::string summary; std::string description; + /// Flag indicating if the operation is known to support type inferrence. + bool supportsTypeInferrence; + /// The source location of this operation. SMRange location; diff --git a/mlir/lib/Tools/PDLL/ODS/Context.cpp b/mlir/lib/Tools/PDLL/ODS/Context.cpp --- a/mlir/lib/Tools/PDLL/ODS/Context.cpp +++ b/mlir/lib/Tools/PDLL/ODS/Context.cpp @@ -59,13 +59,12 @@ return it == dialects.end() ? nullptr : &*it->second; } -std::pair Context::insertOperation(StringRef name, - StringRef summary, - StringRef desc, - SMLoc loc) { +std::pair +Context::insertOperation(StringRef name, StringRef summary, StringRef desc, + bool supportsResultTypeInferrence, SMLoc loc) { std::pair dialectAndName = name.split('.'); return insertDialect(dialectAndName.first) - .insertOperation(name, summary, desc, loc); + .insertOperation(name, summary, desc, supportsResultTypeInferrence, loc); } const Operation *Context::lookupOperation(StringRef name) const { diff --git a/mlir/lib/Tools/PDLL/ODS/Dialect.cpp b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp --- a/mlir/lib/Tools/PDLL/ODS/Dialect.cpp +++ b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp @@ -21,15 +21,15 @@ Dialect::Dialect(StringRef name) : name(name.str()) {} Dialect::~Dialect() = default; -std::pair Dialect::insertOperation(StringRef name, - StringRef summary, - StringRef desc, - llvm::SMLoc loc) { +std::pair +Dialect::insertOperation(StringRef name, StringRef summary, StringRef desc, + bool supportsResultTypeInferrence, llvm::SMLoc loc) { std::unique_ptr &operation = operations[name]; if (operation) return std::make_pair(&*operation, /*wasInserted*/ false); - operation.reset(new Operation(name, summary, desc, loc)); + operation.reset( + new Operation(name, summary, desc, supportsResultTypeInferrence, loc)); return std::make_pair(&*operation, /*wasInserted*/ true); } diff --git a/mlir/lib/Tools/PDLL/ODS/Operation.cpp b/mlir/lib/Tools/PDLL/ODS/Operation.cpp --- a/mlir/lib/Tools/PDLL/ODS/Operation.cpp +++ b/mlir/lib/Tools/PDLL/ODS/Operation.cpp @@ -18,8 +18,9 @@ //===----------------------------------------------------------------------===// Operation::Operation(StringRef name, StringRef summary, StringRef desc, - llvm::SMLoc loc) + bool supportsTypeInferrence, llvm::SMLoc loc) : name(name.str()), summary(summary.str()), + supportsTypeInferrence(supportsTypeInferrence), location(loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)) { llvm::raw_string_ostream descOS(description); raw_indented_ostream(descOS).printReindented(desc.rtrim(" \t")); diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -76,6 +76,19 @@ Rewrite, }; + /// The current specification context of an operations result type. This + /// indicates how the result types of an operation may be inferred. + enum class OpResultTypeContext { + /// The result types of the operation are not known to be inferred. + Explicit, + /// The result types of the operation are inferred from the root input of a + /// `replace` statement. + Replacement, + /// The result types of the operation are inferred by using the + /// `InferTypeOpInterface` interface provided by the operation. + Interface, + }; + //===--------------------------------------------------------------------===// // Parsing //===--------------------------------------------------------------------===// @@ -280,7 +293,9 @@ FailureOr parseMemberAccessExpr(ast::Expr *parentExpr); FailureOr parseOperationName(bool allowEmptyName = false); FailureOr parseWrappedOperationName(bool allowEmptyName); - FailureOr parseOperationExpr(); + FailureOr + parseOperationExpr(OpResultTypeContext inputResultTypeContext = + OpResultTypeContext::Explicit); FailureOr parseTupleExpr(); FailureOr parseTypeExpr(); FailureOr parseUnderscoreExpr(); @@ -378,6 +393,7 @@ StringRef name, SMRange loc); FailureOr createOperationExpr(SMRange loc, const ast::OpNameDecl *name, + OpResultTypeContext resultTypeContext, MutableArrayRef operands, MutableArrayRef attributes, MutableArrayRef results); @@ -388,6 +404,8 @@ LogicalResult validateOperationResults(SMRange loc, Optional name, const ods::Operation *odsOp, MutableArrayRef results); + void checkOperationResultTypeInferrence(SMRange loc, StringRef name, + const ods::Operation *odsOp); LogicalResult validateOperationOperandsOrResults( StringRef groupName, SMRange loc, Optional odsOpLoc, Optional name, MutableArrayRef values, @@ -795,11 +813,15 @@ for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { tblgen::Operator op(def); + // Check to see if this operation is known to support type inferrence. + bool supportsResultTypeInferrence = + op.getTrait("::mlir::InferTypeOpInterface::Trait"); + bool inserted = false; ods::Operation *odsOp = nullptr; - std::tie(odsOp, inserted) = - odsContext.insertOperation(op.getOperationName(), op.getSummary(), - op.getDescription(), op.getLoc().front()); + std::tie(odsOp, inserted) = odsContext.insertOperation( + op.getOperationName(), op.getSummary(), op.getDescription(), + supportsResultTypeInferrence, op.getLoc().front()); // Ignore operations that have already been added. if (!inserted) @@ -1917,7 +1939,8 @@ return opNameDecl; } -FailureOr Parser::parseOperationExpr() { +FailureOr +Parser::parseOperationExpr(OpResultTypeContext inputResultTypeContext) { SMRange loc = curToken.getLoc(); consumeToken(Token::kw_op); @@ -1994,13 +2017,23 @@ return failure(); } - // Check for the optional list of result types. + // Handle the result types of the operation. SmallVector resultTypes; + OpResultTypeContext resultTypeContext = inputResultTypeContext; + + // Check for an explicit list of result types. if (consumeIf(Token::arrow)) { if (failed(parseToken(Token::l_paren, "expected `(` before operation result type list"))) return failure(); + // If result types are provided, initially assume that the operation does + // not rely on type inferrence. We don't assert that it isn't, because we + // may be inferring the value of some type/type range variables, but given + // that these variables may be defined in calls we can't always discern when + // this is the case. + resultTypeContext = OpResultTypeContext::Explicit; + // Handle the case of an empty result list. if (!consumeIf(Token::r_paren)) { do { @@ -2027,10 +2060,14 @@ // "unconstrained results". resultTypes.push_back(createImplicitRangeVar( ast::TypeRangeConstraintDecl::create(ctx, loc), typeRangeTy)); + } else if (resultTypeContext == OpResultTypeContext::Explicit) { + // If the result list isn't specified and we are in a rewrite, try to infer + // them at runtime instead. + resultTypeContext = OpResultTypeContext::Interface; } - return createOperationExpr(loc, *opNameDecl, operands, attributes, - resultTypes); + return createOperationExpr(loc, *opNameDecl, resultTypeContext, operands, + attributes, resultTypes); } FailureOr Parser::parseTupleExpr() { @@ -2294,7 +2331,13 @@ "expected `)` after replacement values"))) return failure(); } else { - FailureOr replExpr = parseExpr(); + // Handle replacement with an operation uniquely, as the replacement + // operation supports type inferrence from the root operation. + FailureOr replExpr; + if (curToken.is(Token::kw_op)) + replExpr = parseOperationExpr(OpResultTypeContext::Replacement); + else + replExpr = parseExpr(); if (failed(replExpr)) return failure(); replValues.emplace_back(*replExpr); @@ -2707,6 +2750,7 @@ FailureOr Parser::createOperationExpr( SMRange loc, const ast::OpNameDecl *name, + OpResultTypeContext resultTypeContext, MutableArrayRef operands, MutableArrayRef attributes, MutableArrayRef results) { @@ -2728,9 +2772,22 @@ } } - // Verify the result types. - if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) - return failure(); + assert( + (resultTypeContext == OpResultTypeContext::Explicit || results.empty()) && + "unexpected inferrence when results were explicitly specified"); + + // If we aren't relying on type inferrence, or explicit results were provided, + // validate them. + if (resultTypeContext == OpResultTypeContext::Explicit) { + if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) + return failure(); + + // Validate the use of interface based type inferrence for this operation. + } else if (resultTypeContext == OpResultTypeContext::Interface) { + assert(opNameRef && + "expected valid operation name when inferring operation results"); + checkOperationResultTypeInferrence(loc, *opNameRef, odsOp); + } return ast::OperationExpr::create(ctx, loc, name, operands, results, attributes); @@ -2755,6 +2812,48 @@ results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); } +void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName, + const ods::Operation *odsOp) { + // If the operation might not have inferrence support, emit a warning to the + // user. We don't emit an error because the interface might be added to the + // operation at runtime. It's rare, but it could still happen. We emit a + // warning here instead. + + // Handle inferrence warnings for unknown operations. + if (!odsOp) { + ctx.getDiagEngine().emitWarning( + loc, llvm::formatv( + "operation result types are marked to be inferred, but " + "`{0}` is unknown. Ensure that `{0}` supports zero " + "results or implements `InferTypeOpInterface`. Include " + "the ODS definition of this operation to remove this warning.", + opName)); + return; + } + + // Handle inferrence warnings for known operations that expected at least one + // result, but don't have inference support. An elided results list can mean + // "zero-results", and we don't want to warn when that is the expected + // behavior. + bool requiresInferrence = + llvm::any_of(odsOp->getResults(), [](const ods::OperandOrResult &result) { + return !result.isVariableLength(); + }); + if (requiresInferrence && !odsOp->hasResultTypeInferrence()) { + ast::InFlightDiagnostic diag = ctx.getDiagEngine().emitWarning( + loc, + llvm::formatv("operation result types are marked to be inferred, but " + "`{0}` does not provide an implementation of " + "`InferTypeOpInterface`. Ensure that `{0}` attaches " + "`InferTypeOpInterface` at runtime, or add support to " + "the ODS definition to remove this warning.", + opName)); + diag->attachNote(llvm::formatv("see the definition of `{0}` here", opName), + odsOp->getLoc()); + return; + } +} + LogicalResult Parser::validateOperationOperandsOrResults( StringRef groupName, SMRange loc, Optional odsOpLoc, Optional name, MutableArrayRef values, diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -296,6 +296,36 @@ // ----- +Pattern { + // CHECK: warning: operation result types are marked to be inferred, but + // CHECK-SAME: `test.unknown_inferred_result_op` is unknown. + // CHECK-SAME: Ensure that `test.unknown_inferred_result_op` supports zero + // CHECK-SAME: results or implements `InferTypeOpInterface`. + // CHECK-SAME: Include the ODS definition of this operation to remove this + // CHECK-SAME: warning. + rewrite _: Op with { + op; + }; +} + +// ----- + +#include "include/ops.td" + +Pattern { + // CHECK: warning: operation result types are marked to be inferred, but + // CHECK-SAME: `test.multiple_single_result` does not provide an implementation + // CHECK-SAME: of `InferTypeOpInterface`. Ensure that `test.multiple_single_result` + // CHECK-SAME: attaches `InferTypeOpInterface` at runtime, or add support + // CHECK-SAME: to the ODS definition to remove this warning. + // CHECK: see the definition of `test.multiple_single_result` here + rewrite _: Op with { + op; + }; +} + +// ----- + //===----------------------------------------------------------------------===// // `type` Expr //===----------------------------------------------------------------------===//