diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -47,6 +47,54 @@ "::mlir::RegionRange":$regions, "::llvm::SmallVectorImpl<::mlir::Type>&":$inferredReturnTypes) >, + StaticInterfaceMethod< + /*desc=*/[{Refine the return types that an op would generate. + + This method computes the return types as `inferReturnTypes` does but + additionally takes the existing result types as input. The existing + result types can be checked as part of inference to provide more + op-specific error messages as well as part of inference to merge + additional information, attributes, during inference. It is called during + verification for ops implementing this trait with default behavior + reporting mismatch with current and inferred types printed. + + The operands and attributes correspond to those with which an Operation + would be created (e.g., as used in Operation::create) and the regions of + the op. The method takes an optional location which, if set, will be used + to report errors on. + + The return types may be elided or specific elements be null for elements + that should just be returned but not verified. + + Be aware that this method is supposed to be called with valid arguments, + e.g., operands are verified, or it may result in an undefined behavior. + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"refineReturnTypes", + /*args=*/(ins "::mlir::MLIRContext *":$context, + "::llvm::Optional<::mlir::Location>":$location, + "::mlir::ValueRange":$operands, + "::mlir::DictionaryAttr":$attributes, + "::mlir::RegionRange":$regions, + "::llvm::SmallVectorImpl<::mlir::Type>&":$returnTypes), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + llvm::SmallVector inferredReturnTypes; + if (failed(ConcreteOp::inferReturnTypes(context, location, operands, + attributes, regions, + inferredReturnTypes))) + return failure(); + if (!ConcreteOp::isCompatibleReturnTypes(inferredReturnTypes, + returnTypes)) { + return emitOptionalError( + location, "'", ConcreteOp::getOperationName(), + "' op inferred type(s) ", inferredReturnTypes, + " are incompatible with return type(s) of operation ", + returnTypes); + } + return success(); + }] + >, StaticInterfaceMethod< /*desc=*/"Returns whether two array of types are compatible result types" " for an op.", diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -204,17 +204,9 @@ } LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { - SmallVector inferredReturnTypes; + SmallVector inferredReturnTypes(op->getResultTypes()); auto retTypeFn = cast(op); - if (failed(retTypeFn.inferReturnTypes( - op->getContext(), op->getLoc(), op->getOperands(), - op->getAttrDictionary(), op->getRegions(), inferredReturnTypes))) - return failure(); - if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes, - op->getResultTypes())) - return op->emitOpError("inferred type(s) ") - << inferredReturnTypes - << " are incompatible with return type(s) of operation " - << op->getResultTypes(); - return success(); + return retTypeFn.refineReturnTypes(op->getContext(), op->getLoc(), + op->getOperands(), op->getAttrDictionary(), + op->getRegions(), inferredReturnTypes); } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1128,6 +1128,36 @@ return success(); } +// TODO: We should be able to only define either inferReturnType or +// refineReturnType, currently only refineReturnType can be omitted. +LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &returnTypes) { + returnTypes.clear(); + return OpWithRefineTypeInterfaceOp::refineReturnTypes( + context, location, operands, attributes, regions, returnTypes); +} + +LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes( + MLIRContext *, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &returnTypes) { + if (operands[0].getType() != operands[1].getType()) { + return emitOptionalError(location, "operand type mismatch ", + operands[0].getType(), " vs ", + operands[1].getType()); + } + // TODO: Add helper to make this more concise to write. + if (returnTypes.empty()) + returnTypes.resize(1, nullptr); + if (returnTypes[0] && returnTypes[0] != operands[0].getType()) + return emitOptionalError(location, + "required first operand and result to match"); + returnTypes[0] = operands[0].getType(); + return success(); +} + LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( MLIRContext *context, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -645,8 +645,14 @@ } def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [ + DeclareOpInterfaceMethods]> { + let arguments = (ins AnyTensor, AnyTensor); + let results = (outs AnyTensor); +} + +def OpWithRefineTypeInterfaceOp : TEST_Op<"op_with_refine_type_if", [ DeclareOpInterfaceMethods]> { + ["refineReturnTypes"]>]> { let arguments = (ins AnyTensor, AnyTensor); let results = (outs AnyTensor); } diff --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir --- a/mlir/test/mlir-tblgen/return-types.mlir +++ b/mlir/test/mlir-tblgen/return-types.mlir @@ -39,6 +39,13 @@ // ----- +func.func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) { + // expected-error@+1 {{required first operand and result to match}} + %bad = "test.op_with_refine_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32> + return +} + +// ----- // CHECK-LABEL: testReifyFunctions func.func @testReifyFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) { // expected-remark@+1 {{arith.constant 10}}