diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -116,4 +116,20 @@ ]; } +// Convenience class grouping together type and shaped type op interfaces for +// ops that have tensor return types. +class InferTensorType overridenMethods = []> { + list traits = [ + // Op implements infer type op interface. + InferTypeOpInterface, + // The op will have methods implementing the ShapedType type inference + // interface. + DeclareOpInterfaceMethods, + // The op produces tensors and will use the ShapedType type infer interface + // along with knowledge that it is producing Tensors to infer shape. + NativeOpTrait<"InferTensorType"> + ]; +} +defvar InferTensorTypeWithReify = InferTensorType<["reifyReturnTypeShapes"]>; + #endif // MLIR_INFERTYPEOPINTERFACE diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -503,24 +503,10 @@ let results = (outs AnyTensor); } -def InferTensorType : NativeOpTrait<"InferTensorType">; def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_type_if", - [ - // Op implements infer type op interface. - InferTypeOpInterface, - // The op will have methods implementing the ShapedType type infer interface. - DeclareOpInterfaceMethods, - // The op produces tensors and will use the ShapedType type infer interface - // along with knowledge that it is producing Tensors to infer shape. - InferTensorType - ]> { + InferTensorTypeWithReify.traits> { let arguments = (ins AnyTensor, AnyTensor); let results = (outs AnyTensor); - - let extraClassDeclaration = [{ - LogicalResult reifyReturnTypeShapes(OpBuilder &builder, - SmallVectorImpl &shapes); - }]; } def IsNotScalar : Constraint>;