diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Traits.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -17,6 +17,7 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" @@ -284,7 +285,9 @@ //===----------------------------------------------------------------------===// // Operator: clamp //===----------------------------------------------------------------------===// -def Tosa_ClampOp : Tosa_Op<"clamp", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_ClampOp : Tosa_Op<"clamp", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes clamp(features, min, max)."; let description = [{ @@ -309,7 +312,9 @@ //===----------------------------------------------------------------------===// // Operator: reluN //===----------------------------------------------------------------------===// -def Tosa_ReluNOp : Tosa_Op<"reluN", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_ReluNOp : Tosa_Op<"reluN", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes rectified linear: `max(features, N)`."; let description = [{ @@ -330,8 +335,9 @@ //===----------------------------------------------------------------------===// // Operator: sigmoid //===----------------------------------------------------------------------===// -def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [NoSideEffect, - SameOperandsAndResultType]> { +def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [ + DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultType]> { let summary = "Computes elementwise sigmoid of input."; let description = [{ @@ -354,7 +360,9 @@ //===----------------------------------------------------------------------===// // Operator: tanh //===----------------------------------------------------------------------===// -def Tosa_TanhOp : Tosa_Op<"tanh", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_TanhOp : Tosa_Op<"tanh", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes elementwise hyperbolic tangent of input"; let description = [{ @@ -382,8 +390,9 @@ //===----------------------------------------------------------------------===// // Operator: add //===----------------------------------------------------------------------===// -def Tosa_AddOp : Tosa_Op<"add", [ResultsBroadcastableShape, NoSideEffect, - Commutative]> { +def Tosa_AddOp : Tosa_Op<"add", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect, Commutative]> { let summary = "Elementwise addition operator"; let description = [{ @@ -404,9 +413,9 @@ //===----------------------------------------------------------------------===// // Operator: arithmetic_right_shift //===----------------------------------------------------------------------===// -def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", - [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect]> { let summary = "Elementwise Arithmetic Right Shift"; let description = [{ @@ -429,8 +438,9 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_and //===----------------------------------------------------------------------===// -def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ResultsBroadcastableShape, - NoSideEffect, Commutative]> { +def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect, Commutative]> { let summary = "Bitwise AND operator"; let description = [{ @@ -451,8 +461,9 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_or //===----------------------------------------------------------------------===// -def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ResultsBroadcastableShape, - NoSideEffect, Commutative]> { +def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect, Commutative]> { let summary = "Bitwise OR operator"; let description = [{ @@ -473,8 +484,9 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_xor //===----------------------------------------------------------------------===// -def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ResultsBroadcastableShape, - NoSideEffect, Commutative]> { +def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect, Commutative]> { let summary = "Bitwise XOR operator"; let description = [{ @@ -495,8 +507,9 @@ //===----------------------------------------------------------------------===// // Operator: div //===----------------------------------------------------------------------===// -def Tosa_DivOp : Tosa_Op<"div", [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_DivOp : Tosa_Op<"div", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect]> { let summary = "Integer divide operator"; let description = [{ @@ -517,8 +530,9 @@ //===----------------------------------------------------------------------===// // Operator: logical_and //===----------------------------------------------------------------------===// -def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ResultsBroadcastableShape, - Commutative, NoSideEffect]> { +def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + Commutative, NoSideEffect]> { let summary = "Returns the truth value of x AND y element-wise."; let description = [{ @@ -539,9 +553,9 @@ //===----------------------------------------------------------------------===// // Operator: logical_left_shift //===----------------------------------------------------------------------===// -def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", - [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect]> { let summary = "Elementwise Logical Left Shift"; let description = [{ @@ -562,9 +576,9 @@ //===----------------------------------------------------------------------===// // Operator: logical_right_shift //===----------------------------------------------------------------------===// -def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", - [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect]> { let summary = "Elementwise Logical Right Shift"; let description = [{ @@ -586,8 +600,9 @@ //===----------------------------------------------------------------------===// // Operator: logical_or //===----------------------------------------------------------------------===// -def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ResultsBroadcastableShape, - Commutative, NoSideEffect]> { +def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + Commutative, NoSideEffect]> { let summary = "Returns the truth value of x OR y element-wise."; let description = [{ @@ -608,8 +623,9 @@ //===----------------------------------------------------------------------===// // Operator: logical_xor //===----------------------------------------------------------------------===// -def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ResultsBroadcastableShape, - Commutative, NoSideEffect]> { +def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + Commutative, NoSideEffect]> { let summary = "Returns the truth value of x XOR y element-wise."; let description = [{ @@ -630,8 +646,9 @@ //===----------------------------------------------------------------------===// // Operator: maximum //===----------------------------------------------------------------------===// -def Tosa_MaximumOp : Tosa_Op<"maximum", [ResultsBroadcastableShape, - NoSideEffect, Commutative]> { +def Tosa_MaximumOp : Tosa_Op<"maximum", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect, Commutative]> { let summary = "Elementwise Maximum"; let description = [{ @@ -652,8 +669,9 @@ //===----------------------------------------------------------------------===// // Operator: minimum //===----------------------------------------------------------------------===// -def Tosa_MinimumOp : Tosa_Op<"minimum", [ResultsBroadcastableShape, - NoSideEffect, Commutative]> { +def Tosa_MinimumOp : Tosa_Op<"minimum", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect, Commutative]> { let summary = "Elementwise Minimum"; let description = [{ @@ -674,8 +692,9 @@ //===----------------------------------------------------------------------===// // Operator: mul //===----------------------------------------------------------------------===// -def Tosa_MulOp : Tosa_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, - Commutative]> { +def Tosa_MulOp : Tosa_Op<"mul", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect, Commutative]> { let summary = "Multiplication operator"; let description = [{ @@ -698,7 +717,9 @@ //===----------------------------------------------------------------------===// // Operator: pow //===----------------------------------------------------------------------===// -def Tosa_PowOp : Tosa_Op<"pow", [ResultsBroadcastableShape, NoSideEffect]> { +def Tosa_PowOp : Tosa_Op<"pow", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect]> { let summary = "Computes the power of one value to another."; let description = [{ @@ -720,7 +741,9 @@ //===----------------------------------------------------------------------===// // Operator: sub //===----------------------------------------------------------------------===// -def Tosa_SubOp : Tosa_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> { +def Tosa_SubOp : Tosa_Op<"sub", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect]> { let summary = "Elementwise subtraction operator"; let description = [{ @@ -781,7 +804,9 @@ //===----------------------------------------------------------------------===// // Operator: abs //===----------------------------------------------------------------------===// -def Tosa_AbsOp : Tosa_Op<"abs", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_AbsOp : Tosa_Op<"abs", [ + DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultType]> { let summary = "Elementwise abs op"; let description = [{ @@ -800,8 +825,9 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_not //===----------------------------------------------------------------------===// -def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect]> { let summary = "Bitwise NOT operator"; let description = [{ @@ -820,7 +846,9 @@ //===----------------------------------------------------------------------===// // Operator: ceil //===----------------------------------------------------------------------===// -def Tosa_CeilOp : Tosa_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_CeilOp : Tosa_Op<"ceil", [ + DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultType]> { let summary = "Elementwise ceil op"; let description = [{ @@ -839,7 +867,9 @@ //===----------------------------------------------------------------------===// // Operator: clz //===----------------------------------------------------------------------===// -def Tosa_ClzOp : Tosa_Op<"clz", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_ClzOp : Tosa_Op<"clz", [ + DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultType]> { let summary = "Elementwise count leading zero op"; let description = [{ @@ -858,7 +888,9 @@ //===----------------------------------------------------------------------===// // Operator: exp //===----------------------------------------------------------------------===// -def Tosa_ExpOp : Tosa_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_ExpOp : Tosa_Op<"exp", [ + DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultType]> { let summary = "Elementwise exp op"; let description = [{ @@ -877,7 +909,9 @@ //===----------------------------------------------------------------------===// // Operator: floor //===----------------------------------------------------------------------===// -def Tosa_FloorOp : Tosa_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_FloorOp : Tosa_Op<"floor", [ + DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultType]> { let summary = "Elementwise floor op"; let description = [{ @@ -896,7 +930,9 @@ //===----------------------------------------------------------------------===// // Operator: log //===----------------------------------------------------------------------===// -def Tosa_LogOp : Tosa_Op<"log", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_LogOp : Tosa_Op<"log", [ + DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultType]> { let summary = "Elementwise log op"; let description = [{ @@ -915,8 +951,9 @@ //===----------------------------------------------------------------------===// // Operator: logical_not //===----------------------------------------------------------------------===// -def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [NoSideEffect, - SameOperandsAndResultType]> { +def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [ + DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultType]> { let summary = "Returns the truth value of NOT x element-wise."; let description = [{ @@ -935,8 +972,9 @@ //===----------------------------------------------------------------------===// // Operator: negate //===----------------------------------------------------------------------===// -def Tosa_NegateOp : Tosa_Op<"negate", [NoSideEffect, - SameOperandsAndResultType]> { +def Tosa_NegateOp : Tosa_Op<"negate", [ + DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultType]> { let summary = "Elementwise negate op"; let description = [{ @@ -958,8 +996,9 @@ //===----------------------------------------------------------------------===// // Operator: reciprocal //===----------------------------------------------------------------------===// -def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [NoSideEffect, - SameOperandsAndResultType]> { +def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [ + DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultType]> { let summary = "Elementwise reciprocal op"; let description = [{ @@ -979,7 +1018,9 @@ //===----------------------------------------------------------------------===// // Operator: rsqrt //===----------------------------------------------------------------------===// -def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [ + DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultType]> { let summary = "Elementwise 1/sqrt op"; let description = [{ @@ -1005,7 +1046,8 @@ //===----------------------------------------------------------------------===// // Operator: select //===----------------------------------------------------------------------===// -def Tosa_SelectOp : Tosa_Op<"select", [NoSideEffect]> { +def Tosa_SelectOp : Tosa_Op<"select", [ + DeclareOpInterfaceMethods, NoSideEffect]> { let summary = "Elementwise select operator"; let description = [{ @@ -1031,8 +1073,9 @@ //===----------------------------------------------------------------------===// // Operator: equal //===----------------------------------------------------------------------===// -def Tosa_EqualOp : Tosa_Op<"equal", [ResultsBroadcastableShape, Commutative, - NoSideEffect]> { +def Tosa_EqualOp : Tosa_Op<"equal", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + Commutative, NoSideEffect]> { let summary = "Returns the truth value of (x == y) element-wise."; let description = [{ @@ -1052,8 +1095,9 @@ //===----------------------------------------------------------------------===// // Operator: greater //===----------------------------------------------------------------------===// -def Tosa_GreaterOp : Tosa_Op<"greater", [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_GreaterOp : Tosa_Op<"greater", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect]> { let summary = "Returns the truth value of (x > y) element-wise."; let description = [{ @@ -1073,8 +1117,9 @@ //===----------------------------------------------------------------------===// // Operator: greater_equal //===----------------------------------------------------------------------===// -def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect]> { let summary = "Returns the truth value of (x >= y) element-wise."; let description = [{ @@ -1269,7 +1314,7 @@ // Operator: reshape //===----------------------------------------------------------------------===// def Tosa_ReshapeOp: Tosa_Op<"reshape", [ - NoSideEffect]> { + DeclareOpInterfaceMethods, NoSideEffect]> { let summary = "Reshape operator"; let description = [{ @@ -1291,7 +1336,8 @@ //===----------------------------------------------------------------------===// // Operator: reverse //===----------------------------------------------------------------------===// -def Tosa_ReverseOp: Tosa_Op<"reverse", [NoSideEffect]> { +def Tosa_ReverseOp: Tosa_Op<"reverse", [ + DeclareOpInterfaceMethods, NoSideEffect]> { let summary = "Reverse operator"; let description = [{ diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -288,6 +288,159 @@ result.types.push_back(outputType); } +//===----------------------------------------------------------------------===// +// TOSA Operator Return Type Inference. +//===----------------------------------------------------------------------===// + +static void getI64Values(ArrayAttr arrayAttr, SmallVector &values) { + for (auto it : arrayAttr) { + values.push_back(it.cast().getValue().getSExtValue()); + } +} + +LogicalResult tosa::ReshapeOp::inferReturnTypes( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + auto newShape = attributes.get("new_shape").cast(); + auto eTy = operands.front().getType().cast().getElementType(); + llvm::SmallVector newShapeValue; + getI64Values(newShape, newShapeValue); + + inferredReturnTypes.push_back(RankedTensorType::get(newShapeValue, eTy)); + return success(); +} + +static LogicalResult ResolveBroadcastShape(ValueRange operands, + SmallVector &outShape) { + int64_t outRank = 0; + for (auto operand : operands) { + auto type = operand.getType().cast(); + outRank = std::max(outRank, type.getRank()); + } + + outShape.resize(outRank, 1); + + for (auto operand : operands) { + auto type = operand.getType().cast(); + auto shape = type.getShape(); + auto rankDiff = outShape.size() - shape.size(); + + for (size_t i = 0; i < shape.size(); i++) { + auto dim1 = outShape[i + rankDiff]; + auto dim2 = shape[i]; + auto resolvedDim = dim1; + + if (dim1 == 1) { + resolvedDim = dim2; + } else if (dim2 == 1) { + resolvedDim = dim1; + } else if (dim1 != dim2) { + return failure(); + } + outShape[i + rankDiff] = resolvedDim; + } + } + + return success(); +} + +static LogicalResult +NAryInferReturnTypes(ValueRange operands, + SmallVectorImpl &inferredReturnTypes) { + llvm::SmallVector outShape; + if (ResolveBroadcastShape(operands, outShape).failed()) + return failure(); + inferredReturnTypes.push_back(RankedTensorType::get( + outShape, + operands.front().getType().cast().getElementType())); + return success(); +} + +#define NARY_SHAPE_INFER(OP) \ + LogicalResult OP::inferReturnTypes( \ + MLIRContext *context, ::llvm::Optional location, \ + ValueRange operands, DictionaryAttr attributes, RegionRange regions, \ + SmallVectorImpl &inferredReturnTypes) { \ + return NAryInferReturnTypes(operands, inferredReturnTypes); \ + } + +NARY_SHAPE_INFER(tosa::AbsOp) +NARY_SHAPE_INFER(tosa::AddOp) +NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp) +NARY_SHAPE_INFER(tosa::BitwiseAndOp) +NARY_SHAPE_INFER(tosa::BitwiseOrOp) +NARY_SHAPE_INFER(tosa::BitwiseXorOp) +NARY_SHAPE_INFER(tosa::BitwiseNotOp) +NARY_SHAPE_INFER(tosa::CeilOp) +NARY_SHAPE_INFER(tosa::ClampOp) +NARY_SHAPE_INFER(tosa::ClzOp) +NARY_SHAPE_INFER(tosa::DivOp) +NARY_SHAPE_INFER(tosa::ExpOp) +NARY_SHAPE_INFER(tosa::FloorOp) +NARY_SHAPE_INFER(tosa::LogOp) +NARY_SHAPE_INFER(tosa::LogicalAndOp) +NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp) +NARY_SHAPE_INFER(tosa::LogicalNotOp) +NARY_SHAPE_INFER(tosa::LogicalOrOp) +NARY_SHAPE_INFER(tosa::LogicalRightShiftOp) +NARY_SHAPE_INFER(tosa::LogicalXorOp) +NARY_SHAPE_INFER(tosa::MaximumOp) +NARY_SHAPE_INFER(tosa::MinimumOp) +NARY_SHAPE_INFER(tosa::NegateOp) +NARY_SHAPE_INFER(tosa::PowOp) +NARY_SHAPE_INFER(tosa::ReciprocalOp) +NARY_SHAPE_INFER(tosa::ReluNOp) +NARY_SHAPE_INFER(tosa::ReverseOp) +NARY_SHAPE_INFER(tosa::RsqrtOp) +NARY_SHAPE_INFER(tosa::SubOp) +NARY_SHAPE_INFER(tosa::TanhOp) +NARY_SHAPE_INFER(tosa::SigmoidOp) +#undef UNARY_SHAPE_INFER + +#define PRED_SHAPE_INFER(OP) \ + LogicalResult OP::inferReturnTypes( \ + MLIRContext *context, ::llvm::Optional location, \ + ValueRange operands, DictionaryAttr attributes, RegionRange regions, \ + SmallVectorImpl &inferredReturnTypes) { \ + llvm::SmallVector outShape; \ + if (ResolveBroadcastShape(operands, outShape).failed()) \ + return failure(); \ + inferredReturnTypes.push_back( \ + RankedTensorType::get(outShape, IntegerType::get(context, 1))); \ + return success(); \ + } +PRED_SHAPE_INFER(tosa::EqualOp) +PRED_SHAPE_INFER(tosa::GreaterOp) +PRED_SHAPE_INFER(tosa::GreaterEqualOp) +#undef PRED_SHAPE_INFER + +LogicalResult tosa::MulOp::inferReturnTypes( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + llvm::SmallVector outShape; + if (ResolveBroadcastShape(operands, outShape).failed()) + return failure(); + auto eTy = operands[1].getType().cast().getElementType(); + if (eTy.isa()) + eTy = IntegerType::get(context, 32); + inferredReturnTypes.push_back(RankedTensorType::get(outShape, eTy)); + return success(); +} + +LogicalResult tosa::SelectOp::inferReturnTypes( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + llvm::SmallVector outShape; + if (ResolveBroadcastShape(operands, outShape).failed()) + return failure(); + inferredReturnTypes.push_back(RankedTensorType::get( + outShape, operands[1].getType().cast().getElementType())); + return success(); +} + //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -500,7 +500,7 @@ // CHECK-LABEL: @test_reshape_downrank_6D func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> { // CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]] - %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> + %0 = "tosa.reshape"(%arg0) {new_shape = [6, 5, 77]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> return %0 : tensor<6x5x77xf32> }