diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -369,6 +369,12 @@ /// Returns true if the operation type referenced supports result type /// inference. bool hasTypeInference(); + + /// Returns true if the operation type referenced might support result type + /// inference, i.e. it supports type reference or is currently not + /// registered in the context. Returns false if the root operation name + /// has not been set. + bool mightHaveTypeInference(); }]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -167,6 +167,17 @@ return impl->interfaceMap.contains(interfaceID); } + /// Returns true if the operation *might* have the provided interface. This + /// means that either the operation is unregistered, or it was registered with + /// the provide interface. + template + bool mightHaveInterface() const { + return mightHaveInterface(TypeID::get()); + } + bool mightHaveInterface(TypeID interfaceID) const { + return !isRegistered() || hasInterface(interfaceID); + } + /// Return the dialect this operation is registered to if the dialect is /// loaded in the context, or nullptr if the dialect isn't loaded. Dialect *getDialect() const { diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -198,6 +198,36 @@ if (llvm::any_of(op.op().getUses(), canInferTypeFromUse)) return success(); + // Handle the case where the operation has no explicit result types. + if (resultTypes.empty()) { + // If we don't know the concrete operation, assume the user actually meant + // zero-results. + Optional rawOpName = op.name(); + if (!rawOpName) + return success(); + Optional opName = + RegisteredOperationName::lookup(*rawOpName, op.getContext()); + if (!opName) + return success(); + + // If no explicit result types were provided, check to see if the operation + // expected at least one result. This doesn't cover all cases, but this + // should cover many cases in which the user intended to infer the results + // of an operation, but it isn't actually possible. + bool expectedAtLeastOneResult = + !opName->hasTrait() && + !opName->hasTrait(); + if (expectedAtLeastOneResult) { + return op + .emitOpError("must have inferable or constrained result types when " + "nested within `pdl.rewrite`") + .attachNote() + .append("operation is created in a non-inferrable context, but '", + *opName, "' does not implement InferTypeOpInterface"); + } + return success(); + } + // Otherwise, make sure each of the types can be inferred. for (const auto &it : llvm::enumerate(resultTypes)) { Operation *resultTypeOp = it.value().getDefiningOp(); @@ -248,7 +278,7 @@ // If the operation is within a rewrite body and doesn't have type inference, // ensure that the result types can be resolved. - if (isWithinRewrite && !hasTypeInference()) { + if (isWithinRewrite && !mightHaveTypeInference()) { if (failed(verifyResultTypesAreInferrable(*this, types()))) return failure(); } @@ -257,12 +287,18 @@ } bool OperationOp::hasTypeInference() { - Optional opName = name(); - if (!opName) - return false; + if (Optional rawOpName = name()) { + OperationName opName(*rawOpName, getContext()); + return opName.hasInterface(); + } + return false; +} - if (auto rInfo = RegisteredOperationName::lookup(*opName, getContext())) - return rInfo->hasInterface(); +bool OperationOp::mightHaveTypeInference() { + if (Optional rawOpName = name()) { + OperationName opName(*rawOpName, getContext()); + return opName.mightHaveInterface(); + } return false; } diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir --- a/mlir/test/Dialect/PDL/invalid.mlir +++ b/mlir/test/Dialect/PDL/invalid.mlir @@ -136,7 +136,21 @@ // expected-error@below {{op must have inferable or constrained result types when nested within `pdl.rewrite`}} // expected-note@below {{result type #0 was not constrained}} - %newOp = operation "foo.op" -> (%type : !pdl.type) + %newOp = operation "builtin.unrealized_conversion_cast" -> (%type : !pdl.type) + } +} + +// ----- + +// Unused operation only necessary to ensure the func dialect is loaded. +func.func private @unusedOpToLoadFuncDialect() + +pdl.pattern : benefit(1) { + %op = operation "foo.op" + rewrite %op { + // expected-error@below {{op must have inferable or constrained result types when nested within `pdl.rewrite`}} + // expected-note@below {{operation is created in a non-inferrable context, but 'func.constant' does not implement InferTypeOpInterface}} + %newOp = operation "func.constant" } }