diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -36,7 +36,7 @@ which an Operation would be created (e.g., as used in Operation::create) and the regions of the op. Be aware that this method is supposed to be called with valid arguments, e.g., operands are verified, or it may result - in an undefined behavior. + in an undefined behaviour. }], /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"inferReturnTypes", @@ -47,6 +47,46 @@ "::mlir::RegionRange":$regions, "::llvm::SmallVectorImpl<::mlir::Type>&":$inferredReturnTypes) >, + StaticInterfaceMethod< + /*desc=*/[{Refine the return types that an op would generate. + + The operands and attributes correspond to those with which an Operation + would be created (e.g., as used in Operation::create) and the regions of + the op. Except the method takes an optional location which, if set, will + be used to report errors on. + + The result types, if provided, are verified to be compatible with + existing return type. A null result type provided should be treated as + unconstrained. + + Be aware that this method is supposed to be called with valid arguments, + e.g., operands are verified, or it may result in an undefined behaviour. + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"refineReturnTypes", + /*args=*/(ins "::mlir::MLIRContext *":$context, + "::llvm::Optional<::mlir::Location>":$location, + "::mlir::ValueRange":$operands, + "::mlir::DictionaryAttr":$attributes, + "::mlir::RegionRange":$regions, + "::llvm::SmallVectorImpl<::mlir::Type>&":$returnTypes), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + llvm::SmallVector inferredReturnTypes; + if (failed(ConcreteOp::inferReturnTypes(context, location, operands, + attributes, regions, + inferredReturnTypes))) + return failure(); + if (!ConcreteOp::isCompatibleReturnTypes(inferredReturnTypes, + returnTypes)) + return emitOptionalError( + location, "'", ConcreteOp::getOperationName(), + "' op inferred type(s) ", inferredReturnTypes, + " are incompatible with return type(s) of operation ", + returnTypes); + return success(); + }] + >, StaticInterfaceMethod< /*desc=*/"Returns whether two array of types are compatible result types" " for an op.", diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1128,6 +1128,35 @@ return success(); } +// TODO: We should be able to have some way to only define one of these. +LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &returnTypes) { + returnTypes.clear(); + return OpWithRefineTypeInterfaceOp::refineReturnTypes( + context, location, operands, attributes, regions, returnTypes); +} + +LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes( + MLIRContext *, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &returnTypes) { + if (operands[0].getType() != operands[1].getType()) { + return emitOptionalError(location, "operand type mismatch ", + operands[0].getType(), " vs ", + operands[1].getType()); + } + // TODO: Add helper to make this more concise to write. + if (returnTypes.empty()) + returnTypes.resize(1, nullptr); + if (returnTypes[0] && returnTypes[0] != operands[0].getType()) + return emitOptionalError(location, + "required first operand and result to match"); + returnTypes[0] = operands[0].getType(); + return success(); +} + LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( MLIRContext *context, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -645,8 +645,14 @@ } def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [ + DeclareOpInterfaceMethods]> { + let arguments = (ins AnyTensor, AnyTensor); + let results = (outs AnyTensor); +} + +def OpWithRefineTypeInterfaceOp : TEST_Op<"op_with_refine_type_if", [ DeclareOpInterfaceMethods]> { + ["refineReturnTypes"]>]> { let arguments = (ins AnyTensor, AnyTensor); let results = (outs AnyTensor); } diff --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir --- a/mlir/test/mlir-tblgen/return-types.mlir +++ b/mlir/test/mlir-tblgen/return-types.mlir @@ -39,6 +39,13 @@ // ----- +func.func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) { + // expected-error@+1 {{required first operand and result to match}} + %bad = "test.op_with_refine_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32> + return +} + +// ----- // CHECK-LABEL: testReifyFunctions func.func @testReifyFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) { // expected-remark@+1 {{arith.constant 10}}