diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Traits.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -17,6 +17,7 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" @@ -284,7 +285,10 @@ //===----------------------------------------------------------------------===// // Operator: clamp //===----------------------------------------------------------------------===// -def Tosa_ClampOp : Tosa_Op<"clamp", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_ClampOp : Tosa_Op<"clamp", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes clamp(features, min, max)."; let description = [{ @@ -309,7 +313,10 @@ //===----------------------------------------------------------------------===// // Operator: reluN //===----------------------------------------------------------------------===// -def Tosa_ReluNOp : Tosa_Op<"reluN", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_ReluNOp : Tosa_Op<"reluN", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes rectified linear: `max(features, N)`."; let description = [{ @@ -330,8 +337,10 @@ //===----------------------------------------------------------------------===// // Operator: sigmoid //===----------------------------------------------------------------------===// -def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [NoSideEffect, - SameOperandsAndResultType]> { +def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes elementwise sigmoid of input."; let description = [{ @@ -354,7 +363,10 @@ //===----------------------------------------------------------------------===// // Operator: tanh //===----------------------------------------------------------------------===// -def Tosa_TanhOp : Tosa_Op<"tanh", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_TanhOp : Tosa_Op<"tanh", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes elementwise hyperbolic tangent of input"; let description = [{ @@ -382,8 +394,10 @@ //===----------------------------------------------------------------------===// // Operator: add //===----------------------------------------------------------------------===// -def Tosa_AddOp : Tosa_Op<"add", [ResultsBroadcastableShape, NoSideEffect, - Commutative]> { +def Tosa_AddOp : Tosa_Op<"add", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect, Commutative]> { let summary = "Elementwise addition operator"; let description = [{ @@ -404,9 +418,10 @@ //===----------------------------------------------------------------------===// // Operator: arithmetic_right_shift //===----------------------------------------------------------------------===// -def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", - [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect]> { let summary = "Elementwise Arithmetic Right Shift"; let description = [{ @@ -429,8 +444,10 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_and //===----------------------------------------------------------------------===// -def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ResultsBroadcastableShape, - NoSideEffect, Commutative]> { +def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect, Commutative]> { let summary = "Bitwise AND operator"; let description = [{ @@ -451,8 +468,10 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_or //===----------------------------------------------------------------------===// -def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ResultsBroadcastableShape, - NoSideEffect, Commutative]> { +def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect, Commutative]> { let summary = "Bitwise OR operator"; let description = [{ @@ -473,8 +492,10 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_xor //===----------------------------------------------------------------------===// -def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ResultsBroadcastableShape, - NoSideEffect, Commutative]> { +def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect, Commutative]> { let summary = "Bitwise XOR operator"; let description = [{ @@ -495,8 +516,10 @@ //===----------------------------------------------------------------------===// // Operator: div //===----------------------------------------------------------------------===// -def Tosa_DivOp : Tosa_Op<"div", [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_DivOp : Tosa_Op<"div", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect]> { let summary = "Integer divide operator"; let description = [{ @@ -517,8 +540,10 @@ //===----------------------------------------------------------------------===// // Operator: logical_and //===----------------------------------------------------------------------===// -def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ResultsBroadcastableShape, - Commutative, NoSideEffect]> { +def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, Commutative, NoSideEffect]> { let summary = "Returns the truth value of x AND y element-wise."; let description = [{ @@ -539,9 +564,10 @@ //===----------------------------------------------------------------------===// // Operator: logical_left_shift //===----------------------------------------------------------------------===// -def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", - [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect]> { let summary = "Elementwise Logical Left Shift"; let description = [{ @@ -562,9 +588,9 @@ //===----------------------------------------------------------------------===// // Operator: logical_right_shift //===----------------------------------------------------------------------===// -def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", - [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", [ + DeclareOpInterfaceMethods, ResultsBroadcastableShape, + NoSideEffect]> { let summary = "Elementwise Logical Right Shift"; let description = [{ @@ -586,8 +612,10 @@ //===----------------------------------------------------------------------===// // Operator: logical_or //===----------------------------------------------------------------------===// -def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ResultsBroadcastableShape, - Commutative, NoSideEffect]> { +def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, Commutative, NoSideEffect]> { let summary = "Returns the truth value of x OR y element-wise."; let description = [{ @@ -608,8 +636,10 @@ //===----------------------------------------------------------------------===// // Operator: logical_xor //===----------------------------------------------------------------------===// -def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ResultsBroadcastableShape, - Commutative, NoSideEffect]> { +def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, Commutative, NoSideEffect]> { let summary = "Returns the truth value of x XOR y element-wise."; let description = [{ @@ -630,8 +660,10 @@ //===----------------------------------------------------------------------===// // Operator: maximum //===----------------------------------------------------------------------===// -def Tosa_MaximumOp : Tosa_Op<"maximum", [ResultsBroadcastableShape, - NoSideEffect, Commutative]> { +def Tosa_MaximumOp : Tosa_Op<"maximum", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect, Commutative]> { let summary = "Elementwise Maximum"; let description = [{ @@ -652,8 +684,10 @@ //===----------------------------------------------------------------------===// // Operator: minimum //===----------------------------------------------------------------------===// -def Tosa_MinimumOp : Tosa_Op<"minimum", [ResultsBroadcastableShape, - NoSideEffect, Commutative]> { +def Tosa_MinimumOp : Tosa_Op<"minimum", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect, Commutative]> { let summary = "Elementwise Minimum"; let description = [{ @@ -674,8 +708,10 @@ //===----------------------------------------------------------------------===// // Operator: mul //===----------------------------------------------------------------------===// -def Tosa_MulOp : Tosa_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, - Commutative]> { +def Tosa_MulOp : Tosa_Op<"mul", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect, Commutative]> { let summary = "Multiplication operator"; let description = [{ @@ -698,7 +734,10 @@ //===----------------------------------------------------------------------===// // Operator: pow //===----------------------------------------------------------------------===// -def Tosa_PowOp : Tosa_Op<"pow", [ResultsBroadcastableShape, NoSideEffect]> { +def Tosa_PowOp : Tosa_Op<"pow", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect]> { let summary = "Computes the power of one value to another."; let description = [{ @@ -720,7 +759,10 @@ //===----------------------------------------------------------------------===// // Operator: sub //===----------------------------------------------------------------------===// -def Tosa_SubOp : Tosa_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> { +def Tosa_SubOp : Tosa_Op<"sub", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect]> { let summary = "Elementwise subtraction operator"; let description = [{ @@ -781,7 +823,10 @@ //===----------------------------------------------------------------------===// // Operator: abs //===----------------------------------------------------------------------===// -def Tosa_AbsOp : Tosa_Op<"abs", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_AbsOp : Tosa_Op<"abs", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Elementwise abs op"; let description = [{ @@ -800,8 +845,10 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_not //===----------------------------------------------------------------------===// -def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect]> { let summary = "Bitwise NOT operator"; let description = [{ @@ -820,7 +867,10 @@ //===----------------------------------------------------------------------===// // Operator: ceil //===----------------------------------------------------------------------===// -def Tosa_CeilOp : Tosa_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_CeilOp : Tosa_Op<"ceil", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Elementwise ceil op"; let description = [{ @@ -839,7 +889,10 @@ //===----------------------------------------------------------------------===// // Operator: clz //===----------------------------------------------------------------------===// -def Tosa_ClzOp : Tosa_Op<"clz", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_ClzOp : Tosa_Op<"clz", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Elementwise count leading zero op"; let description = [{ @@ -858,7 +911,10 @@ //===----------------------------------------------------------------------===// // Operator: exp //===----------------------------------------------------------------------===// -def Tosa_ExpOp : Tosa_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_ExpOp : Tosa_Op<"exp", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Elementwise exp op"; let description = [{ @@ -877,7 +933,10 @@ //===----------------------------------------------------------------------===// // Operator: floor //===----------------------------------------------------------------------===// -def Tosa_FloorOp : Tosa_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_FloorOp : Tosa_Op<"floor", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Elementwise floor op"; let description = [{ @@ -896,7 +955,10 @@ //===----------------------------------------------------------------------===// // Operator: log //===----------------------------------------------------------------------===// -def Tosa_LogOp : Tosa_Op<"log", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_LogOp : Tosa_Op<"log", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Elementwise log op"; let description = [{ @@ -915,8 +977,10 @@ //===----------------------------------------------------------------------===// // Operator: logical_not //===----------------------------------------------------------------------===// -def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [NoSideEffect, - SameOperandsAndResultType]> { +def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns the truth value of NOT x element-wise."; let description = [{ @@ -935,8 +999,10 @@ //===----------------------------------------------------------------------===// // Operator: negate //===----------------------------------------------------------------------===// -def Tosa_NegateOp : Tosa_Op<"negate", [NoSideEffect, - SameOperandsAndResultType]> { +def Tosa_NegateOp : Tosa_Op<"negate", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Elementwise negate op"; let description = [{ @@ -958,8 +1024,10 @@ //===----------------------------------------------------------------------===// // Operator: reciprocal //===----------------------------------------------------------------------===// -def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [NoSideEffect, - SameOperandsAndResultType]> { +def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Elementwise reciprocal op"; let description = [{ @@ -979,7 +1047,10 @@ //===----------------------------------------------------------------------===// // Operator: rsqrt //===----------------------------------------------------------------------===// -def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> { +def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [ + DeclareOpInterfaceMethods, + NoSideEffect, SameOperandsAndResultType]> { let summary = "Elementwise 1/sqrt op"; let description = [{ @@ -1005,7 +1076,9 @@ //===----------------------------------------------------------------------===// // Operator: select //===----------------------------------------------------------------------===// -def Tosa_SelectOp : Tosa_Op<"select", [NoSideEffect]> { +def Tosa_SelectOp : Tosa_Op<"select", [ + DeclareOpInterfaceMethods, NoSideEffect]> { let summary = "Elementwise select operator"; let description = [{ @@ -1031,8 +1104,10 @@ //===----------------------------------------------------------------------===// // Operator: equal //===----------------------------------------------------------------------===// -def Tosa_EqualOp : Tosa_Op<"equal", [ResultsBroadcastableShape, Commutative, - NoSideEffect]> { +def Tosa_EqualOp : Tosa_Op<"equal", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, Commutative, NoSideEffect]> { let summary = "Returns the truth value of (x == y) element-wise."; let description = [{ @@ -1052,8 +1127,10 @@ //===----------------------------------------------------------------------===// // Operator: greater //===----------------------------------------------------------------------===// -def Tosa_GreaterOp : Tosa_Op<"greater", [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_GreaterOp : Tosa_Op<"greater", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect]> { let summary = "Returns the truth value of (x > y) element-wise."; let description = [{ @@ -1073,8 +1150,10 @@ //===----------------------------------------------------------------------===// // Operator: greater_equal //===----------------------------------------------------------------------===// -def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ResultsBroadcastableShape, - NoSideEffect]> { +def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, NoSideEffect]> { let summary = "Returns the truth value of (x >= y) element-wise."; let description = [{ @@ -1269,7 +1348,9 @@ // Operator: reshape //===----------------------------------------------------------------------===// def Tosa_ReshapeOp: Tosa_Op<"reshape", [ - NoSideEffect]> { + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Reshape operator"; let description = [{ @@ -1291,7 +1372,9 @@ //===----------------------------------------------------------------------===// // Operator: reverse //===----------------------------------------------------------------------===// -def Tosa_ReverseOp: Tosa_Op<"reverse", [NoSideEffect]> { +def Tosa_ReverseOp: Tosa_Op<"reverse", [ + DeclareOpInterfaceMethods, NoSideEffect]> { let summary = "Reverse operator"; let description = [{ diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -13,11 +13,13 @@ #ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" namespace mlir { namespace tosa { +std::unique_ptr createTosaInferShapesPass(); std::unique_ptr createTosaMakeBroadcastablePass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -15,6 +15,21 @@ include "mlir/Pass/PassBase.td" +def TosaInferShapes : FunctionPass<"tosa-infer-shapes"> { + let summary = "Propagate shapes across TOSA operations"; + let description = [{ + Pass that uses operand types and propagates shapes to TOSA operations. + This includes legalizing rankless and dynamic shapes towards static. + }]; + + let constructor = "createTosaInferShapesPass()"; + let dependentDialects = [ + "StandardOpsDialect", + "tensor::TensorDialect", + "tosa::TosaDialect", + ]; +} + def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> { let summary = "TOSA rank Reshape to enable Broadcasting"; let description = [{ diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -291,6 +291,148 @@ result.types.push_back(outputType); } +//===----------------------------------------------------------------------===// +// TOSA Operator Return Type Inference. +//===----------------------------------------------------------------------===// + +static void getI64Values(ArrayAttr arrayAttr, SmallVector &values) { + for (auto it : arrayAttr) { + values.push_back(it.cast().getValue().getSExtValue()); + } +} + +LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + ShapedType type = operands.front().getType().cast(); + + auto newShape = attributes.get("new_shape").cast(); + llvm::SmallVector newShapeValue; + getI64Values(newShape, newShapeValue); + + // We cannot infer from the total number of elements so we must take the + // shape attribute as exact. + if (!type.hasRank() || !type.hasStaticShape()) { + inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue)); + return success(); + } + + // Determine the number of elements covered by the slice of all static + // dimensions. This allows us to infer the length of the remaining dynamic + // dimension. + int64_t numElements = type.getNumElements(); + int64_t staticMul = 1; + for (auto val : newShapeValue) { + if (val != -1) { + staticMul *= val; + } + } + + // Determine the length of the dynamic dimension. + for (auto &val : newShapeValue) { + if (val == -1) + val = numElements / staticMul; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue)); + return success(); +} + +static LogicalResult resolveBroadcastShape(ValueRange operands, + SmallVector &outShape) { + int64_t outRank = 0; + for (auto operand : operands) { + auto type = operand.getType().cast(); + if (!type.hasRank()) + return failure(); + outRank = std::max(outRank, type.getRank()); + } + + outShape.resize(outRank, 1); + + for (auto operand : operands) { + auto type = operand.getType().cast(); + auto shape = type.getShape(); + auto rankDiff = outShape.size() - shape.size(); + + for (size_t i = 0; i < shape.size(); i++) { + auto dim1 = outShape[i + rankDiff]; + auto dim2 = shape[i]; + auto resolvedDim = dim1; + + if (dim1 == 1) { + resolvedDim = dim2; + } else if (dim2 == 1) { + resolvedDim = dim1; + } else if (dim1 != dim2) { + return failure(); + } + outShape[i + rankDiff] = resolvedDim; + } + } + + return success(); +} + +static LogicalResult NAryInferReturnTypes( + ValueRange operands, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector outShape; + if (resolveBroadcastShape(operands, outShape).failed()) { + inferredReturnShapes.push_back(ShapedTypeComponents()); + } else { + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + } + return success(); +} + +#define NARY_SHAPE_INFER(OP) \ + LogicalResult OP::inferReturnTypeComponents( \ + MLIRContext *context, ::llvm::Optional location, \ + ValueRange operands, DictionaryAttr attributes, RegionRange regions, \ + SmallVectorImpl &inferredReturnShapes) { \ + return NAryInferReturnTypes(operands, inferredReturnShapes); \ + } + +NARY_SHAPE_INFER(tosa::AbsOp) +NARY_SHAPE_INFER(tosa::AddOp) +NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp) +NARY_SHAPE_INFER(tosa::BitwiseAndOp) +NARY_SHAPE_INFER(tosa::BitwiseOrOp) +NARY_SHAPE_INFER(tosa::BitwiseXorOp) +NARY_SHAPE_INFER(tosa::BitwiseNotOp) +NARY_SHAPE_INFER(tosa::CeilOp) +NARY_SHAPE_INFER(tosa::ClampOp) +NARY_SHAPE_INFER(tosa::ClzOp) +NARY_SHAPE_INFER(tosa::DivOp) +NARY_SHAPE_INFER(tosa::EqualOp) +NARY_SHAPE_INFER(tosa::ExpOp) +NARY_SHAPE_INFER(tosa::FloorOp) +NARY_SHAPE_INFER(tosa::GreaterEqualOp) +NARY_SHAPE_INFER(tosa::GreaterOp) +NARY_SHAPE_INFER(tosa::LogOp) +NARY_SHAPE_INFER(tosa::LogicalAndOp) +NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp) +NARY_SHAPE_INFER(tosa::LogicalNotOp) +NARY_SHAPE_INFER(tosa::LogicalOrOp) +NARY_SHAPE_INFER(tosa::LogicalRightShiftOp) +NARY_SHAPE_INFER(tosa::LogicalXorOp) +NARY_SHAPE_INFER(tosa::MaximumOp) +NARY_SHAPE_INFER(tosa::MinimumOp) +NARY_SHAPE_INFER(tosa::MulOp) +NARY_SHAPE_INFER(tosa::NegateOp) +NARY_SHAPE_INFER(tosa::PowOp) +NARY_SHAPE_INFER(tosa::ReciprocalOp) +NARY_SHAPE_INFER(tosa::ReluNOp) +NARY_SHAPE_INFER(tosa::ReverseOp) +NARY_SHAPE_INFER(tosa::RsqrtOp) +NARY_SHAPE_INFER(tosa::SelectOp) +NARY_SHAPE_INFER(tosa::SubOp) +NARY_SHAPE_INFER(tosa::TanhOp) +NARY_SHAPE_INFER(tosa::SigmoidOp) +#undef PRED_SHAPE_INFER + //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRTosaTransforms + TosaInferShapes.cpp TosaMakeBroadcastable.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -0,0 +1,247 @@ +//===- TosaInferShapes.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Propogate shapes forward along TOSA operations to resolve dynamic shape +// operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +// ----------------------------------------------------------------------------- +// Analysis. +// ----------------------------------------------------------------------------- + +static Type joinElementTypes(Type lhs, Type rhs) { + return lhs == rhs ? lhs : Type(); +} + +namespace { +// Statically known information for a particular Value. +// +// This struct currently tracks only information relevant for tensor/array-like +// shaped types. It is fine to associate a `ValueKnowledge` with a non-shaped +// type as long as it is in the default "no knowledge" state returned by +// `getPessimisticValueState`. The important invariant is that we cannot +// claim to know something about a value which is false. +// +// This class could also be called "dataflow facts", "lattice value", etc. +struct ValueKnowledge { + ValueKnowledge() = delete; + ValueKnowledge(bool hasSizes, std::vector sizes, Type dtype) + : hasSizes(hasSizes), sizes(sizes), dtype(dtype) { + assert(sizes.size() == 0 || hasSizes); + } + + // Get the static knowledge intrinsic to `type`. + static ValueKnowledge getKnowledgeFromType(Type type) { + ValueKnowledge result = getPessimisticValueState(type.getContext()); + if (auto shapedType = type.dyn_cast()) { + if (shapedType.hasRank()) { + result.hasSizes = true; + result.sizes = shapedType.getShape(); + } + result.dtype = shapedType.getElementType(); + } + return result; + } + + // Return a pessimistic/conservative value state without assuming any knowlege + // about the IR. + static ValueKnowledge getPessimisticValueState(MLIRContext *context) { + return ValueKnowledge(false, {}, Type()); + } + + Type getType() const { + if (hasSizes) { + return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype); + } + return UnrankedTensorType::get(dtype); + } + + bool operator==(const ValueKnowledge &rhs) const { + return std::make_tuple(hasSizes, sizes, dtype) == + std::make_tuple(rhs.hasSizes, rhs.sizes, rhs.dtype); + } + + // Given two pieces of static knowledge, calculate conservatively the + // information we can be sure about. + static ValueKnowledge join(const ValueKnowledge &lhs, + const ValueKnowledge &rhs) { + // Mental model: All conditions are checking how to change from the safe "no + // knowledge" default-initialized state to a state with more knowledge + // consistent with lhs and rhs. + ValueKnowledge result = getPessimisticValueState(nullptr); + + if (lhs.hasSizes && !rhs.hasSizes) { + result.hasSizes = true; + result.sizes = lhs.sizes; + } else if (!lhs.hasSizes && rhs.hasSizes) { + result.hasSizes = true; + result.sizes = rhs.sizes; + } else if (lhs.hasSizes && rhs.hasSizes && + lhs.sizes.size() == rhs.sizes.size()) { + result.hasSizes = true; + result.sizes.resize(lhs.sizes.size(), ShapedType::kDynamicSize); + for (int i = 0, e = result.sizes.size(); i != e; i++) { + int64_t lhsSize = lhs.sizes[i]; + int64_t rhsSize = rhs.sizes[i]; + int64_t &resultSize = result.sizes[i]; + if (lhsSize == ShapedType::kDynamicSize) { + resultSize = rhsSize; + } else if (rhsSize == ShapedType::kDynamicSize) { + resultSize = lhsSize; + } else if (lhsSize == rhsSize) { + resultSize = lhsSize; + } + } + } + + result.dtype = joinElementTypes(lhs.dtype, rhs.dtype); + return result; + } + + // Whether the Value is known to have a list of sizes. + bool hasSizes; + // If `hasSizes`, the sizes along each rank. Unknown sizes are represented as + // `ShapedType::kDynamicSize`. + std::vector sizes; + // The dtype of a tensor. + // This is equal to nullptr if we don't know that it is a specific concrete + // type. + Type dtype; +}; + +} // namespace + +/// Pass that enables broadcast by making all input arrays have the same +/// number of dimensions. Insert RESHAPE operations to lower rank operand +struct TosaInferShapes : public TosaInferShapesBase { +public: + void runOnFunction() override { + FuncOp func = getOperation(); + + IRRewriter rewriter(func.getContext()); + + func.walk([&](Operation *op) { + if (op->getDialect()->getNamespace() != + tosa::TosaDialect::getDialectNamespace()) + return; + InferShapedTypeOpInterface shapeInterface = + dyn_cast(op); + if (!shapeInterface) + return; + + SmallVector returnedShapes; + if (shapeInterface + .inferReturnTypeComponents( + op->getContext(), op->getLoc(), op->getOperands(), + op->getAttrDictionary(), op->getRegions(), returnedShapes) + .succeeded()) { + for (auto it : llvm::zip(op->getResults(), returnedShapes)) { + Value result = std::get<0>(it); + ShapedTypeComponents predictedShape = std::get<1>(it); + + // Check whether this use case is replaceable. We define an op as + // being replaceable if it is used by a ReturnOp or a TosaOp. + bool replaceable = true; + for (auto user : result.getUsers()) { + if (isa(user)) + continue; + if (user->getDialect()->getNamespace() == + tosa::TosaDialect::getDialectNamespace()) + continue; + + replaceable = false; + } + + // Determine the knowledge based on the output type. + Type resultTy = result.getType(); + auto currentKnowledge = + ValueKnowledge::getKnowledgeFromType(resultTy); + + // Compute the knowledge based on the inferred type. + auto inferredKnowledge = + ValueKnowledge::getPessimisticValueState(op->getContext()); + inferredKnowledge.dtype = + resultTy.cast().getElementType(); + inferredKnowledge.hasSizes = predictedShape.hasRank(); + if (predictedShape.hasRank()) { + for (auto dim : predictedShape.getDims()) { + inferredKnowledge.sizes.push_back(dim); + } + } + + if (!replaceable) + continue; + + // Compute the new type based on the joined versoin. + auto newKnowledge = + ValueKnowledge::join(currentKnowledge, inferredKnowledge); + result.setType(newKnowledge.getType()); + } + } + }); + + // Insert UnrealizedConversionCasts to guarantee ReturnOp agress with + // the FuncOp type. + func.walk([&](ReturnOp op) { + FuncOp parent = dyn_cast(op->getParentOp()); + if (!parent) + return; + + rewriter.setInsertionPoint(op); + FunctionType funcTy = func.getType(); + auto resultTys = funcTy.getResults(); + + bool castAdded = false; + SmallVector castedValues; + for (auto it : llvm::zip(op->getOperands(), resultTys)) { + auto operand = std::get<0>(it); + auto currentTy = operand.getType(); + auto castTy = std::get<1>(it); + if (currentTy == castTy) { + castedValues.push_back(operand); + continue; + } + + castedValues.push_back( + rewriter.create(op.getLoc(), castTy, operand) + .getResult()); + + castAdded = true; + } + + if (castAdded) { + rewriter.replaceOpWithNewOp(op, castedValues); + } + }); + } +}; +} // end anonymous namespace + +std::unique_ptr mlir::tosa::createTosaInferShapesPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR//TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -500,7 +500,7 @@ // CHECK-LABEL: @test_reshape_downrank_6D func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> { // CHECK: linalg.tensor_collapse_shape %arg0 {{\[}}[0, 1, 2], [3], [4, 5]] - %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> + %0 = "tosa.reshape"(%arg0) {new_shape = [6, 5, 77]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> return %0 : tensor<6x5x77xf32> } diff --git a/mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir b/mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa_infer_shapes.mlir @@ -0,0 +1,278 @@ +// RUN: mlir-opt --split-input-file --tosa-infer-shapes %s | FileCheck %s + +// CHECK-LABEL: @test_return +func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> { + // CHECK: [[LOG:%.+]] = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: tensor.cast [[LOG]] : tensor<4xf32> to tensor<*xf32> + %0 = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: @test_multiple +func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor) -> tensor<*xf32> { + // CHECK: [[ADD:%.+]] = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: [[LOG:%.+]] = "tosa.log"(%0) : (tensor<4xf32>) -> tensor<4xf32> + %1 = "tosa.log"(%0) : (tensor<*xf32>) -> tensor<*xf32> + + // CHECK: [[SUB:%.+]] = "tosa.sub"(%0, %arg2) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %2 = "tosa.sub"(%0, %arg2) : (tensor<*xf32>, tensor) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: @test_unary_f32 +func @test_unary_f32(%arg0 : tensor<4xf32>) -> () { + // CHECK: "tosa.abs"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %0 = "tosa.abs"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.ceil"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %1 = "tosa.ceil"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.clamp"(%arg0) {{.+}} : (tensor<4xf32>) -> tensor<4xf32> + %2 = "tosa.clamp"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.exp"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %3 = "tosa.exp"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.floor"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %4 = "tosa.floor"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %5 = "tosa.log"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.negate"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %6 = "tosa.negate"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %7 = "tosa.reciprocal"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xf32>) -> tensor<4xf32> + %8 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xf32>) -> tensor<4xf32> + %9 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xf32>) -> tensor + + // CHECK: "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %10 = "tosa.rsqrt"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %11 = "tosa.tanh"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + + // CHECK: "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %12 = "tosa.sigmoid"(%arg0) : (tensor<4xf32>) -> tensor<*xf32> + return +} + +// ----- + +// CHECK-LABEL: @test_unary_i32 +func @test_unary_i32(%arg0 : tensor<4xi32>) -> () { + // CHECK: "tosa.abs"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %0 = "tosa.abs"(%arg0) : (tensor<4xi32>) -> tensor<*xi32> + + // CHECK: "tosa.bitwise_not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %1 = "tosa.bitwise_not"(%arg0) : (tensor<4xi32>) -> tensor<*xi32> + + // CHECK: "tosa.clamp"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi32> + %2 = "tosa.clamp"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32> + + // CHECK: "tosa.clz"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %3 = "tosa.clz"(%arg0) : (tensor<4xi32>) -> tensor<*xi32> + + // CHECK: "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %4 = "tosa.negate"(%arg0) : (tensor<4xi32>) -> tensor<*xi32> + + // CHECK: "tosa.reluN"(%arg0) {{.+}} : (tensor<4xi32>) -> tensor<4xi32> + %5 = "tosa.reluN"(%arg0) { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32> + + // CHECK: "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<4xi32>) -> tensor<4xi32> + %6 = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<4xi32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_unary_i1 +func @test_unary_i1(%arg0 : tensor<4xi1>) -> () { + // CHECK: "tosa.logical_not"(%arg0) : (tensor<4xi1>) -> tensor<4xi1> + %0 = "tosa.logical_not"(%arg0) : (tensor<4xi1>) -> tensor<*xi1> + return +} + +// ----- + +// CHECK-LABEL: @test_binary_scalar_f32 +func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor) -> () { + // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + + // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + + // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + + // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xf32>, tensor) -> tensor<4xf32> + %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xf32>, tensor) -> tensor<*xf32> + + // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + + // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xf32> + %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xf32> + + // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> + %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> + + // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> + %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> + + // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<4xi1> + %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor) -> tensor<*xi1> + + return +} + +// ----- + +// CHECK-LABEL: @test_binary_broadcast_f32 +func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) -> () { + // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %1 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %2 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %3 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 } : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %4 = "tosa.pow"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + %5 = "tosa.sub"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + + // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> + %6 = "tosa.equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + + // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> + %7 = "tosa.greater"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + + // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> + %8 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + + return +} + +// ----- + +// CHECK-LABEL: @test_binary_i32 +func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor) -> () { + // CHECK: "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %0 = "tosa.add"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.bitwise_and"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %1 = "tosa.bitwise_and"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.bitwise_or"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %2 = "tosa.bitwise_or"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.bitwise_xor"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %3 = "tosa.bitwise_xor"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> + %4 = "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> + + // CHECK: "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> + %5 = "tosa.greater"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> + + // CHECK: "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi1> + %6 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> + + // CHECK: "tosa.logical_left_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %7 = "tosa.logical_left_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.logical_right_shift"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %8 = "tosa.logical_right_shift"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %9 = "tosa.maximum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %10 = "tosa.minimum"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.mul"(%arg0, %arg1) {shift = 0 : i32} : (tensor<4xi32>, tensor) -> tensor<4xi32> + %11 = "tosa.mul"(%arg0, %arg1) { shift = 0 : i32 }: (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %12 = "tosa.pow"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + + // CHECK: "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<4xi32> + %13 = "tosa.sub"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi32> + + return +} + +// ----- + +// CHECK-LABEL: @test_binary_i1 +func @test_binary_i1(%arg0 : tensor<4xi1>, %arg1 : tensor) -> () { + // CHECK "tosa.logical_and"(%arg0, %arg1) : (tensor<4xi1>, tensor) -> tensor<4xi1> + %0 = "tosa.logical_and"(%arg0, %arg1): (tensor<4xi1>, tensor) -> tensor<*xi1> + + // CHECK "tosa.logical_or"(%arg0, %arg1) : (tensor<4xi1>, tensor) -> tensor<4xi1> + %1 = "tosa.logical_or"(%arg0, %arg1): (tensor<4xi1>, tensor) -> tensor<*xi1> + + // CHECK "tosa.logical_xor"(%arg0, %arg1) : (tensor<4xi1>, tensor) -> tensor<*4i1> + %2 = "tosa.logical_xor"(%arg0, %arg1): (tensor<4xi1>, tensor) -> tensor<*xi1> + + return +} + +// ----- + +// CHECK-LABEL: @test_select_i32 +func @test_select_i32(%arg0 : tensor<4xi1>, %arg1 : tensor, %arg2 : tensor<4xi32>) -> () { + // CHECK: "tosa.select"(%arg0, %arg1, %arg2) : (tensor<4xi1>, tensor, tensor<4xi32>) -> tensor<4xi32> + %0 = "tosa.select"(%arg0, %arg1, %arg2): (tensor<4xi1>, tensor, tensor<4xi32>) -> tensor<*xi32> + + return +} + +// ----- + +func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () { + // CHECK: "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x4xi32>) -> tensor<16xi32> + %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x4xi32>) -> tensor + + // CHECK: "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x4xi32>) -> tensor<16xi32> + %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x4xi32>) -> tensor + + // CHECK: "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x4xi32>) -> tensor<2x8xi32> + %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x4xi32>) -> tensor + + return +} +// ----- + +func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () { + // CHECK: %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x?xi32>) -> tensor<16xi32> + %0 = "tosa.reshape"(%arg0) {new_shape = [16]} : (tensor<4x?xi32>) -> tensor + + // CHECK: %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x?xi32>) -> tensor + %1 = "tosa.reshape"(%arg0) {new_shape = [-1]} : (tensor<4x?xi32>) -> tensor + + // CHECK: %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x?xi32>) -> tensor<2x?xi32> + %2 = "tosa.reshape"(%arg0) {new_shape = [2, -1]} : (tensor<4x?xi32>) -> tensor + + return +} + diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp --- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp +++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp @@ -11,7 +11,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tosa/IR//TosaOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"