diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -32,10 +32,7 @@ //===----------------------------------------------------------------------===// // Operator: argmax //===----------------------------------------------------------------------===// -def Tosa_ArgMaxOp : Tosa_Op<"argmax", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ArgMaxOp : Tosa_Op<"argmax", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Perform argmax on the input."; let description = [{ @@ -62,10 +59,7 @@ //===----------------------------------------------------------------------===// // Operator: avg_pool2d //===----------------------------------------------------------------------===// -def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Performs max pooling on the input."; let description = [{ @@ -95,10 +89,7 @@ //===----------------------------------------------------------------------===// // Operator: conv2d //===----------------------------------------------------------------------===// -def Tosa_Conv2DOp : Tosa_Op<"conv2d", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_Conv2DOp : Tosa_Op<"conv2d", [InferShapedTypeOpAdaptor, Pure]> { let summary = "2D Convolution Operator"; let description = [{ @@ -128,10 +119,7 @@ //===----------------------------------------------------------------------===// // Operator: conv3d //===----------------------------------------------------------------------===// -def Tosa_Conv3DOp : Tosa_Op<"conv3d", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_Conv3DOp : Tosa_Op<"conv3d", [InferShapedTypeOpAdaptor, Pure]> { let summary = "3D Convolution operator"; let description = [{ @@ -160,10 +148,8 @@ //===----------------------------------------------------------------------===// // Operator: depthwise_conv2d //===----------------------------------------------------------------------===// -def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", + [InferShapedTypeOpAdaptor, Pure]> { let summary = "Depthwise 2D Convolution operator"; let description = [{ @@ -193,10 +179,7 @@ //===----------------------------------------------------------------------===// // Operator: fft2d //===----------------------------------------------------------------------===// -def Tosa_FFT2dOp : Tosa_Op<"fft2d", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_FFT2dOp : Tosa_Op<"fft2d", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Performs FFT2D operation on the input."; let description = [{ @@ -224,9 +207,7 @@ // Operator: fully_connected //===----------------------------------------------------------------------===// def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [ - DeclareOpInterfaceMethods, - Pure]> { + InferShapedTypeOpAdaptor, Pure]> { let summary = "Fully Connected operator"; let description = [{ @@ -251,10 +232,7 @@ //===----------------------------------------------------------------------===// // Operator: matmul //===----------------------------------------------------------------------===// -def Tosa_MatMulOp : Tosa_Op<"matmul", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_MatMulOp : Tosa_Op<"matmul", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Matrix multiplication with bias"; let description = [{ @@ -279,10 +257,7 @@ //===----------------------------------------------------------------------===// // Operator: max_pool2d //===----------------------------------------------------------------------===// -def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Performs max pooling on the input."; let description = [{ @@ -310,10 +285,7 @@ //===----------------------------------------------------------------------===// // Operator: rfft2d //===----------------------------------------------------------------------===// -def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Performs RFFT2D operation on the input."; let description = [{ @@ -338,10 +310,8 @@ //===----------------------------------------------------------------------===// // Operator: transpose_conv2d //===----------------------------------------------------------------------===// -def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", + [InferShapedTypeOpAdaptor, Pure]> { let summary = "Transpose 2D Convolution operator."; let description = [{ @@ -828,10 +798,7 @@ //===----------------------------------------------------------------------===// // Operator: table //===----------------------------------------------------------------------===// -def Tosa_TableOp : Tosa_Op<"table", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_TableOp : Tosa_Op<"table", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Table lookup op"; let description = [{ @@ -1214,7 +1181,7 @@ // Operator: reduce_all //===----------------------------------------------------------------------===// def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [ - InferTensorType, Pure]> { + InferTensorTypeAdaptor, Pure]> { let summary = "Reduce All operator"; let description = [{ @@ -1243,7 +1210,7 @@ // Operator: reduce_any //===----------------------------------------------------------------------===// def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [ - InferTensorType, Pure]> { + InferTensorTypeAdaptor, Pure]> { let summary = "Reduce Any operator"; let description = [{ @@ -1272,7 +1239,7 @@ // Operator: reduce_max //===----------------------------------------------------------------------===// def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [ - InferTensorType, Pure]> { + InferTensorTypeAdaptor, Pure]> { let summary = "Reduce Max operator"; let description = [{ @@ -1301,7 +1268,7 @@ // Operator: reduce_min //===----------------------------------------------------------------------===// def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [ - InferTensorType, Pure]> { + InferTensorTypeAdaptor, Pure]> { let summary = "Reduce Min operator"; let description = [{ @@ -1330,7 +1297,7 @@ // Operator: reduce_prod //===----------------------------------------------------------------------===// def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [ - InferTensorType, Pure]> { + InferTensorTypeAdaptor, Pure]> { let summary = "Reduce Prod operator"; let description = [{ @@ -1359,7 +1326,7 @@ // Operator: reduce_sum //===----------------------------------------------------------------------===// def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [ - InferTensorType, Pure]> { + InferTensorTypeAdaptor, Pure]> { let summary = "Reduce Sum operator"; let description = [{ @@ -1393,7 +1360,7 @@ // Operator: concat //===----------------------------------------------------------------------===// def Tosa_ConcatOp : Tosa_Op<"concat", [ - InferTensorType, Pure]> { + InferTensorTypeAdaptor, Pure]> { let summary = "Concatenates tensors along one dimension."; let description = [{ @@ -1423,10 +1390,7 @@ //===----------------------------------------------------------------------===// // Operator: pad //===----------------------------------------------------------------------===// -def Tosa_PadOp : Tosa_Op<"pad", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_PadOp : Tosa_Op<"pad", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Pads a tensor with value specified."; let description = [{ @@ -1471,7 +1435,7 @@ // Operator: reshape //===----------------------------------------------------------------------===// def Tosa_ReshapeOp: Tosa_Op<"reshape", [ - InferTensorType, Pure]> { + InferTensorTypeAdaptor, Pure]> { let summary = "Reshape operator"; let description = [{ @@ -1529,9 +1493,7 @@ //===----------------------------------------------------------------------===// // Operator: slice //===----------------------------------------------------------------------===// -def Tosa_SliceOp: Tosa_Op<"slice", [ - DeclareOpInterfaceMethods, Pure]> { +def Tosa_SliceOp: Tosa_Op<"slice", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Slice operator"; let description = [{ @@ -1557,10 +1519,7 @@ //===----------------------------------------------------------------------===// // Operator: tile //===----------------------------------------------------------------------===// -def Tosa_TileOp: Tosa_Op<"tile", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_TileOp: Tosa_Op<"tile", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Tile operator"; let description = [{ @@ -1581,10 +1540,7 @@ //===----------------------------------------------------------------------===// // Operator: transpose //===----------------------------------------------------------------------===// -def Tosa_TransposeOp : Tosa_Op<"transpose", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_TransposeOp : Tosa_Op<"transpose", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Transpose operator"; let description = [{ @@ -1616,10 +1572,7 @@ //===----------------------------------------------------------------------===// // Operator: gather //===----------------------------------------------------------------------===// -def Tosa_GatherOp : Tosa_Op<"gather", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_GatherOp : Tosa_Op<"gather", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Gather operation,"; let description = [{ @@ -1640,10 +1593,7 @@ //===----------------------------------------------------------------------===// // Operator: scatter //===----------------------------------------------------------------------===// -def Tosa_ScatterOp : Tosa_Op<"scatter", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ScatterOp : Tosa_Op<"scatter", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Scatter operation,"; let description = [{ @@ -1670,10 +1620,7 @@ //===----------------------------------------------------------------------===// // Operator: resize //===----------------------------------------------------------------------===// -def Tosa_ResizeOp : Tosa_Op<"resize", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ResizeOp : Tosa_Op<"resize", [InferShapedTypeOpAdaptor, Pure]> { let summary = "Resize operation, supports various resize/upsample modes"; @@ -1899,9 +1846,8 @@ //===----------------------------------------------------------------------===// // Further described in docs/Rationale/RationaleTOSADialect.md . //===----------------------------------------------------------------------===// -def Tosa_IfOp : Tosa_Op<"cond_if", [ - DeclareOpInterfaceMethods, +def Tosa_IfOp : Tosa_Op<"cond_if", + [InferShapedTypeOpAdaptor, SingleBlockImplicitTerminator<"YieldOp">, RecursiveMemoryEffects]> { let summary = "Conditional if operator"; @@ -1934,8 +1880,7 @@ //===----------------------------------------------------------------------===// def Tosa_WhileOp : Tosa_Op<"while_loop", [ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + InferShapedTypeOpAdaptor, SingleBlockImplicitTerminator<"YieldOp">, RecursiveMemoryEffects]> { let summary = "output = input; While (Cond(output)) {output = Body(output)}"; diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -262,6 +262,10 @@ class InferTypeOpInterfaceAdaptor : public TraitBase {}; +template +class InferShapedTypeOpAdaptor + : public TraitBase {}; + /// Tensor type inference trait that constructs a tensor from the inferred /// shape and elemental types. /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface. diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -215,6 +215,42 @@ > ]>; +// Convenient trait to define a wrapper to inferReturnTypeComponents that passes +// in the Op Adaptor directly +class InferShapedTypeOpAdaptorBase overridenMethods = []> : TraitList< + [ + // Op implements infer type op interface. + DeclareOpInterfaceMethods, + NativeOpTrait< + /*name=*/"InferShapedTypeOpAdaptor", + /*traits=*/[], + /*extraOpDeclaration=*/[{ + static ::mlir::LogicalResult + inferReturnTypeComponents(::mlir::MLIRContext *context, + std::optional<::mlir::Location> location, + Adaptor adaptor, + ::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes); + }], + /*extraOpDefinition=*/[{ + ::mlir::LogicalResult + $cppClass::inferReturnTypeComponents(::mlir::MLIRContext *context, + std::optional<::mlir::Location> location, + ::mlir::ValueShapeRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents> &inferredReturnShapes) { + $cppClass::Adaptor adaptor(operands, attributes, properties, regions); + return $cppClass::inferReturnTypeComponents(context, + location, adaptor, inferredReturnShapes); + } + }] + > + ]>; + +def InferShapedTypeOpAdaptor : InferShapedTypeOpAdaptorBase<[ + "inferReturnTypeComponents"]>; +def InferShapedTypeOpAdaptorWithReify : InferShapedTypeOpAdaptorBase<[ + "inferReturnTypeComponents", "reifyReturnTypeShapes"]>; + // Convenience class grouping together type and shaped type op interfaces for // ops that have tensor return types. class InferTensorTypeBase overridenMethods = []> : TraitList< @@ -232,11 +268,11 @@ /*extraOpDeclaration=*/[{}], /*extraOpDefinition=*/[{ LogicalResult - $cppClass::inferReturnTypes(MLIRContext *context, - std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { + $cppClass::inferReturnTypes(::mlir::MLIRContext *context, + std::optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { SmallVector retComponents; if (failed($cppClass::inferReturnTypeComponents(context, location, operands, attributes, properties, regions, @@ -253,6 +289,44 @@ def InferTensorTypeWithReify: InferTensorTypeBase<[ "inferReturnTypeComponents", "reifyReturnTypeShapes"]>; +// Convenience class grouping together type and shaped type op interfaces for +// ops that have tensor return types. +class InferTensorTypeAdaptorBase overridenMethods = []> : TraitList< + [ + // Op implements infer type op interface. + DeclareOpInterfaceMethods, + // The op will have methods implementing the ShapedType type inference + // interface. + InferShapedTypeOpAdaptorBase, + // The op produces tensors and will use the ShapedType type infer interface + // along with knowledge that it is producing Tensors to infer the type. + NativeOpTrait< + /*name=*/"InferTensorType", + /*traits=*/[], + /*extraOpDeclaration=*/[{}], + /*extraOpDefinition=*/[{ + LogicalResult + $cppClass::inferReturnTypes(::mlir::MLIRContext *context, + std::optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + SmallVector retComponents; + if (failed($cppClass::inferReturnTypeComponents(context, location, + operands, attributes, properties, regions, + retComponents))) + return failure(); + return ::mlir::detail::inferReturnTensorTypes(retComponents, + inferredReturnTypes); + } + }] + > + ]>; + +def InferTensorTypeAdaptor : InferTensorTypeAdaptorBase<["inferReturnTypeComponents"]>; +def InferTensorTypeAdaptorWithReify: InferTensorTypeAdaptorBase<[ + "inferReturnTypeComponents", "reifyReturnTypeShapes"]>; + def ReifyRankedShapedTypeOpInterface : OpInterface<"ReifyRankedShapedTypeOpInterface"> { let description = [{ diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" @@ -404,12 +405,10 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + ArgMaxOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - ShapeAdaptor inputShape = operands.getShape(0); - auto *prop = properties.as(); - IntegerAttr axis = prop->axis; + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput().getType()); + IntegerAttr axis = adaptor.getProperties().axis; int32_t axisVal = axis.getValue().getSExtValue(); if (!inputShape.hasRank()) { @@ -431,10 +430,9 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + RFFT2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - ShapeAdaptor inputShape = operands.getShape(0); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput().getType()); if (!inputShape.hasRank()) return failure(); @@ -458,26 +456,26 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + FFT2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(0))); - inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(1))); + inferredReturnShapes.push_back(ShapedTypeComponents( + dyn_cast(adaptor.getInputReal().getType()))); + inferredReturnShapes.push_back(ShapedTypeComponents( + dyn_cast(adaptor.getInputImag().getType()))); return success(); } LogicalResult tosa::ConcatOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + ConcatOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { // Infer all dimension sizes by reducing based on inputs. - auto *prop = properties.as(); - int32_t axis = prop->axis.getValue().getSExtValue(); + const Properties &prop = adaptor.getProperties(); + int32_t axis = prop.axis.getValue().getSExtValue(); llvm::SmallVector outputShape; bool hasRankedInput = false; - for (auto operand : operands) { - ShapeAdaptor operandShape = operands.getShape(operand); + for (auto operand : adaptor.getOperands()) { + ShapeAdaptor operandShape = dyn_cast(operand.getType()); if (!operandShape.hasRank()) continue; @@ -501,7 +499,7 @@ hasRankedInput = true; } Type inputType = - llvm::cast(operands.getType()[0]).getElementType(); + llvm::cast(adaptor.getInput1().getType()[0]).getElementType(); if (!hasRankedInput) { inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); @@ -509,8 +507,8 @@ // Determine the dimension size along the concatenation axis. int64_t concatDimSize = 0; - for (auto operand : operands) { - ShapeAdaptor operandShape = operands.getShape(operand); + for (auto operand : adaptor.getOperands()) { + ShapeAdaptor operandShape = dyn_cast(operand.getType()); // We need to know the length of the concatenation axis of all inputs to // determine the dimension size of the output shape. @@ -553,12 +551,12 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + FullyConnectedOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - ShapeAdaptor inputShape = operands.getShape(0); - ShapeAdaptor weightShape = operands.getShape(1); - ShapeAdaptor biasShape = operands.getShape(2); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput().getType()); + ShapeAdaptor weightShape = + dyn_cast(adaptor.getWeight().getType()); + ShapeAdaptor biasShape = dyn_cast(adaptor.getBias().getType()); // All shapes are dynamic. SmallVector outShape; @@ -585,11 +583,10 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + MatMulOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - ShapeAdaptor lhsShape = operands.getShape(0); - ShapeAdaptor rhsShape = operands.getShape(1); + ShapeAdaptor lhsShape = dyn_cast(adaptor.getA().getType()); + ShapeAdaptor rhsShape = dyn_cast(adaptor.getB().getType()); // All shapes are dynamic. SmallVector outShape; @@ -612,11 +609,11 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + PadOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - ShapeAdaptor inputShape = operands.getShape(0); - ShapeAdaptor paddingShape = operands.getShape(1); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput1().getType()); + ShapeAdaptor paddingShape = + dyn_cast(adaptor.getPadding().getType()); SmallVector outputShape; // If both inputs have unknown shape, we cannot determine the shape of the @@ -641,7 +638,7 @@ DenseIntElementsAttr paddings; // If the paddings value is not a constant, all dimensions must be dynamic. - if (!matchPattern(operands[1], m_Constant(&paddings))) { + if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) { outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); @@ -675,22 +672,18 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + SliceOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - inferredReturnShapes.push_back(ShapedTypeComponents( - convertToMlirShape(SliceOpAdaptor(operands, attributes, - *properties.as(), regions) - .getSize()))); + inferredReturnShapes.push_back( + ShapedTypeComponents(convertToMlirShape(adaptor.getSize()))); return success(); } LogicalResult tosa::TableOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + TableOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - ShapeAdaptor inputShape = operands.getShape(0); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput().getType()); if (!inputShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents()); @@ -704,13 +697,10 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + TileOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - TileOpAdaptor adaptor(operands, attributes, *properties.as(), - regions); ArrayRef multiples = adaptor.getMultiples(); - ShapeAdaptor inputShape = operands.getShape(0); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput1().getType()); SmallVector outputShape; if (!inputShape.hasRank()) { outputShape.resize(multiples.size(), ShapedType::kDynamic); @@ -739,13 +729,10 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + ReshapeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - ReshapeOpAdaptor adaptor(operands, attributes, *properties.as(), - regions); - ShapeAdaptor inputShape = operands.getShape(0); - Type inputType = getElementTypeOrSelf(operands.getType()[0]); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput1().getType()); + Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); llvm::SmallVector newShapeValue = convertToMlirShape(adaptor.getNewShape()); @@ -814,11 +801,10 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + TransposeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - ShapeAdaptor inputShape = operands.getShape(0); - ShapeAdaptor permsShape = operands.getShape(1); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput1().getType()); + ShapeAdaptor permsShape = dyn_cast(adaptor.getPerms().getType()); // If input rank and permutation length is unknown, the output rank is // unknown. @@ -869,7 +855,10 @@ outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); // If the permuations are a constant we can directly determine the output // shape. - if (ShapeAdaptor permShape = operands.getValueAsShape(1)) { + DenseIntElementsAttr attr; + if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) && + attr.getType().getRank() == 1) { + ShapeAdaptor permShape = attr; outputShape.reserve(inputShape.getRank()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i)); @@ -882,19 +871,20 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + GatherOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamic); - ShapeAdaptor valuesShape = operands.getShape(0); + ShapeAdaptor valuesShape = + dyn_cast(adaptor.getValues().getType()); if (valuesShape.hasRank()) { outputShape[0] = valuesShape.getDimSize(0); outputShape[2] = valuesShape.getDimSize(2); } - ShapeAdaptor indicesShape = operands.getShape(1); + ShapeAdaptor indicesShape = + dyn_cast(adaptor.getIndices().getType()); if (indicesShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamic) outputShape[0] = indicesShape.getDimSize(0); @@ -908,15 +898,12 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + ResizeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - ResizeOpAdaptor adaptor(operands, attributes, *properties.as(), - regions); llvm::SmallVector outputShape; outputShape.resize(4, ShapedType::kDynamic); - ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput().getType()); if (!inputShape.hasRank()) return failure(); @@ -950,26 +937,27 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + ScatterOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamic); - ShapeAdaptor valuesInShape = operands.getShape(0); + ShapeAdaptor valuesInShape = + dyn_cast(adaptor.getValuesIn().getType()); if (valuesInShape.hasRank()) { outputShape[0] = valuesInShape.getDimSize(0); outputShape[1] = valuesInShape.getDimSize(1); outputShape[2] = valuesInShape.getDimSize(2); } - ShapeAdaptor indicesShape = operands.getShape(1); + ShapeAdaptor indicesShape = + dyn_cast(adaptor.getIndices().getType()); if (indicesShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamic) outputShape[0] = indicesShape.getDimSize(0); } - ShapeAdaptor inputShape = operands.getShape(2); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput().getType()); if (inputShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamic) outputShape[0] = inputShape.getDimSize(0); @@ -1009,13 +997,14 @@ #define REDUCE_SHAPE_INFER(OP) \ LogicalResult OP::inferReturnTypeComponents( \ MLIRContext *context, ::std::optional location, \ - ValueShapeRange operands, DictionaryAttr attributes, \ - OpaqueProperties properties, RegionRange regions, \ + OP::Adaptor adaptor, \ SmallVectorImpl &inferredReturnShapes) { \ Type inputType = \ - llvm::cast(operands.getType()[0]).getElementType(); \ - return ReduceInferReturnTypes(operands.getShape(0), inputType, \ - properties.as()->axis, \ + llvm::cast(adaptor.getInput().getType()).getElementType(); \ + ShapeAdaptor inputShape = \ + dyn_cast(adaptor.getInput().getType()); \ + const Properties &prop = adaptor.getProperties(); \ + return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \ inferredReturnShapes); \ } \ COMPATIBLE_RETURN_TYPES(OP) @@ -1092,10 +1081,9 @@ #undef PRED_SHAPE_INFER static LogicalResult poolingInferReturnTypes( - const ValueShapeRange &operands, DictionaryAttr attributes, - ArrayRef kernel, ArrayRef stride, ArrayRef pad, + ShapeAdaptor inputShape, ArrayRef kernel, ArrayRef stride, + ArrayRef pad, SmallVectorImpl &inferredReturnShapes) { - ShapeAdaptor inputShape = operands.getShape(0); llvm::SmallVector outputShape; outputShape.resize(4, ShapedType::kDynamic); @@ -1128,12 +1116,9 @@ LogicalResult Conv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + Conv2DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamic); - Conv2DOp::Adaptor adaptor(operands, attributes, - *properties.as(), regions); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; @@ -1142,7 +1127,7 @@ // Input shape describes input width/height and batch. - ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = inputShape.getDimSize(0); inputHeight = inputShape.getDimSize(1); @@ -1150,7 +1135,8 @@ } // Weight shapes describes the filter width/height and the output channels. - ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight()); + ShapeAdaptor weightShape = + dyn_cast(adaptor.getWeight().getType()); if (weightShape.hasRank()) { outputShape[3] = weightShape.getDimSize(0); weightHeight = weightShape.getDimSize(1); @@ -1158,7 +1144,7 @@ } // Bias shape can describe the output channels. - ShapeAdaptor biasShape = operands.getShape(adaptor.getBias()); + ShapeAdaptor biasShape = dyn_cast(adaptor.getBias().getType()); if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) @@ -1193,12 +1179,9 @@ LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + Conv3DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(5, ShapedType::kDynamic); - Conv3DOp::Adaptor adaptor(operands, attributes, - *properties.as(), regions); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; @@ -1209,7 +1192,7 @@ int64_t weightDepth = ShapedType::kDynamic; // Input shape describes input width/height and batch. - ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = inputShape.getDimSize(0); inputDepth = inputShape.getDimSize(1); @@ -1218,7 +1201,8 @@ } // Weight shapes describes the filter width/height and the output channels. - ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight()); + ShapeAdaptor weightShape = + dyn_cast(adaptor.getWeight().getType()); if (weightShape.hasRank()) { outputShape[4] = weightShape.getDimSize(0); weightDepth = weightShape.getDimSize(1); @@ -1227,7 +1211,7 @@ } // Bias shape can describe the output channels. - ShapeAdaptor biasShape = operands.getShape(adaptor.getBias()); + ShapeAdaptor biasShape = dyn_cast(adaptor.getBias().getType()); if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) { outputShape[4] = biasShape.getDimSize(0); } @@ -1268,32 +1252,29 @@ LogicalResult AvgPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + AvgPool2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - Properties &prop = *properties.as(); - return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride, - prop.pad, inferredReturnShapes); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput().getType()); + const Properties &prop = adaptor.getProperties(); + return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, + inferredReturnShapes); } LogicalResult MaxPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + MaxPool2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - Properties &prop = *properties.as(); - return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride, - prop.pad, inferredReturnShapes); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput().getType()); + const Properties &prop = adaptor.getProperties(); + return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, + inferredReturnShapes); } LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + DepthwiseConv2DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamic); - DepthwiseConv2DOp::Adaptor adaptor(operands, attributes, - *properties.as(), regions); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; @@ -1304,7 +1285,7 @@ int64_t depthChannels = ShapedType::kDynamic; // Input shape describes input width/height and batch. - ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = inputShape.getDimSize(0); inputHeight = inputShape.getDimSize(1); @@ -1313,7 +1294,8 @@ } // Weight shapes describes the filter width/height and the output channels. - ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight()); + ShapeAdaptor weightShape = + dyn_cast(adaptor.getWeight().getType()); if (weightShape.hasRank()) { weightHeight = weightShape.getDimSize(0); weightWidth = weightShape.getDimSize(1); @@ -1331,7 +1313,7 @@ } // Bias shape can describe the output channels. - ShapeAdaptor biasShape = operands.getShape(adaptor.getBias()); + ShapeAdaptor biasShape = dyn_cast(adaptor.getBias().getType()); if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) @@ -1366,11 +1348,8 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + TransposeConv2DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - TransposeConv2DOp::Adaptor adaptor(operands, attributes, - *properties.as(), regions); // outputShape is mutable. llvm::SmallVector outputShape = convertToMlirShape(adaptor.getOutShape()); @@ -1381,7 +1360,7 @@ int64_t weightHeight = ShapedType::kDynamic; // Input shape describes input width/height and batch. - ShapeAdaptor inputShape = operands.getShape(adaptor.getInput()); + ShapeAdaptor inputShape = dyn_cast(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = ShapedType::isDynamic(outputShape[0]) ? inputShape.getDimSize(0) @@ -1391,7 +1370,8 @@ } // Weight shapes describes the filter width/height and the output channels. - ShapeAdaptor weightShape = operands.getShape(adaptor.getFilter()); + ShapeAdaptor weightShape = + dyn_cast(adaptor.getFilter().getType()); if (weightShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? weightShape.getDimSize(0) @@ -1401,7 +1381,7 @@ } // Bias shape can describe the output channels. - ShapeAdaptor biasShape = operands.getShape(adaptor.getInput()); + ShapeAdaptor biasShape = dyn_cast(adaptor.getInput().getType()); if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) @@ -1433,11 +1413,10 @@ LogicalResult IfOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + IfOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector yieldOps; - for (Region *region : regions) { + for (Region *region : adaptor.getRegions()) { for (auto &block : *region) if (auto returnOp = dyn_cast(block.getTerminator())) yieldOps.push_back(returnOp); @@ -1478,11 +1457,10 @@ LogicalResult WhileOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, + WhileOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector yieldOps; - for (auto &block : *regions[1]) + for (auto &block : adaptor.getBody()) if (auto returnOp = dyn_cast(block.getTerminator())) yieldOps.push_back(returnOp); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1445,6 +1445,36 @@ return success(); } +LogicalResult +OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents( + MLIRContext *context, std::optional location, + OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + // Create return type consisting of the last element of the first operand. + auto operandType = adaptor.getOperand1().getType(); + auto sval = dyn_cast(operandType); + if (!sval) { + return emitOptionalError(location, "only shaped type operands allowed"); + } + int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; + auto type = IntegerType::get(context, 17); + + Attribute encoding; + if (auto rankedTy = dyn_cast(sval)) + encoding = rankedTy.getEncoding(); + inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); + return success(); +} + +LogicalResult +OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes( + OpBuilder &builder, ValueRange operands, + llvm::SmallVectorImpl &shapes) { + shapes = SmallVector{ + builder.createOrFold(getLoc(), operands.front(), 0)}; + return success(); +} + LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( OpBuilder &builder, ValueRange operands, llvm::SmallVectorImpl &shapes) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -774,6 +774,13 @@ let results = (outs AnyTensor); } +def OpWithShapedTypeInferTypeAdaptorInterfaceOp : + TEST_Op<"op_with_shaped_type_infer_type_adaptor_if", + [InferTensorTypeAdaptorWithReify]> { + let arguments = (ins AnyTensor:$operand1, AnyTensor:$operand2); + let results = (outs AnyTensor:$result); +} + def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface", [DeclareOpInterfaceMethods]> {