diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -215,4 +215,18 @@ Op { } +class Tosa_ElemWiseUnaryOp traits = []> : + Tosa_Op, + Pure, SameOperandsAndResultElementType])> { +} + +class Tosa_ElemWiseBinaryOp traits = []> : + Tosa_Op, + ResultsBroadcastableShape, Pure, SameOperandsAndResultElementType])> { +} + #endif // TOSA_OP_BASE diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -375,10 +375,7 @@ //===----------------------------------------------------------------------===// // Operator: clamp //===----------------------------------------------------------------------===// -def Tosa_ClampOp : Tosa_Op<"clamp", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ClampOp : Tosa_ElemWiseUnaryOp<"clamp"> { let summary = "Computes clamp(features, min, max)."; let description = [{ @@ -407,10 +404,7 @@ //===----------------------------------------------------------------------===// // Operator: sigmoid //===----------------------------------------------------------------------===// -def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_SigmoidOp : Tosa_ElemWiseUnaryOp<"sigmoid"> { let summary = "Computes elementwise sigmoid of input."; let description = [{ @@ -433,10 +427,7 @@ //===----------------------------------------------------------------------===// // Operator: tanh //===----------------------------------------------------------------------===// -def Tosa_TanhOp : Tosa_Op<"tanh", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_TanhOp : Tosa_ElemWiseUnaryOp<"tanh"> { let summary = "Computes elementwise hyperbolic tangent of input"; let description = [{ @@ -490,10 +481,7 @@ //===----------------------------------------------------------------------===// // Operator: add //===----------------------------------------------------------------------===// -def Tosa_AddOp : Tosa_Op<"add", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_AddOp : Tosa_ElemWiseBinaryOp<"add", [Commutative]> { let summary = "Elementwise addition operator"; let description = [{ @@ -516,10 +504,7 @@ //===----------------------------------------------------------------------===// // Operator: arithmetic_right_shift //===----------------------------------------------------------------------===// -def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_ArithmeticRightShiftOp : Tosa_ElemWiseBinaryOp<"arithmetic_right_shift"> { let summary = "Elementwise Arithmetic Right Shift"; let description = [{ @@ -541,10 +526,7 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_and //===----------------------------------------------------------------------===// -def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_BitwiseAndOp : Tosa_ElemWiseBinaryOp<"bitwise_and", [Commutative]> { let summary = "Bitwise AND operator"; let description = [{ @@ -565,10 +547,7 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_or //===----------------------------------------------------------------------===// -def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_BitwiseOrOp : Tosa_ElemWiseBinaryOp<"bitwise_or", [Commutative]> { let summary = "Bitwise OR operator"; let description = [{ @@ -589,10 +568,7 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_xor //===----------------------------------------------------------------------===// -def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_BitwiseXorOp : Tosa_ElemWiseBinaryOp<"bitwise_xor", [Commutative]> { let summary = "Bitwise XOR operator"; let description = [{ @@ -613,10 +589,7 @@ //===----------------------------------------------------------------------===// // Operator: div //===----------------------------------------------------------------------===// -def Tosa_DivOp : Tosa_Op<"div", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_DivOp : Tosa_ElemWiseBinaryOp<"div"> { let summary = "Integer divide operator"; let description = [{ @@ -639,10 +612,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_and //===----------------------------------------------------------------------===// -def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Commutative, Pure]> { +def Tosa_LogicalAndOp : Tosa_ElemWiseBinaryOp<"logical_and", [Commutative]> { let summary = "Returns the truth value of x AND y element-wise."; let description = [{ @@ -663,10 +633,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_left_shift //===----------------------------------------------------------------------===// -def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_LogicalLeftShiftOp : Tosa_ElemWiseBinaryOp<"logical_left_shift"> { let summary = "Elementwise Logical Left Shift"; let description = [{ @@ -687,10 +654,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_right_shift //===----------------------------------------------------------------------===// -def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_LogicalRightShiftOp : Tosa_ElemWiseBinaryOp<"logical_right_shift"> { let summary = "Elementwise Logical Right Shift"; let description = [{ @@ -711,10 +675,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_or //===----------------------------------------------------------------------===// -def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Commutative, Pure]> { +def Tosa_LogicalOrOp : Tosa_ElemWiseBinaryOp<"logical_or", [Commutative]> { let summary = "Returns the truth value of x OR y element-wise."; let description = [{ @@ -735,10 +696,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_xor //===----------------------------------------------------------------------===// -def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Commutative, Pure]> { +def Tosa_LogicalXorOp : Tosa_ElemWiseBinaryOp<"logical_xor", [Commutative]> { let summary = "Returns the truth value of x XOR y element-wise."; let description = [{ @@ -759,10 +717,7 @@ //===----------------------------------------------------------------------===// // Operator: maximum //===----------------------------------------------------------------------===// -def Tosa_MaximumOp : Tosa_Op<"maximum", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_MaximumOp : Tosa_ElemWiseBinaryOp<"maximum", [Commutative]> { let summary = "Elementwise Maximum"; let description = [{ @@ -783,10 +738,7 @@ //===----------------------------------------------------------------------===// // Operator: minimum //===----------------------------------------------------------------------===// -def Tosa_MinimumOp : Tosa_Op<"minimum", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_MinimumOp : Tosa_ElemWiseBinaryOp<"minimum", [Commutative]> { let summary = "Elementwise Minimum"; let description = [{ @@ -807,15 +759,13 @@ //===----------------------------------------------------------------------===// // Operator: mul //===----------------------------------------------------------------------===// -def Tosa_MulOp : Tosa_Op<"mul", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_MulOp : Tosa_ElemWiseBinaryOp<"mul", [Commutative]> { let summary = "Multiplication operator"; let description = [{ Elementwise multiplication (Hadamard product) of input1 and input2. Axis of size 1 will be broadcast, as necessary. + i8/i16 input type can be promoted to i32 result type. }]; let arguments = (ins @@ -834,10 +784,7 @@ //===----------------------------------------------------------------------===// // Operator: pow //===----------------------------------------------------------------------===// -def Tosa_PowOp : Tosa_Op<"pow", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_PowOp : Tosa_ElemWiseBinaryOp<"pow"> { let summary = "Computes the power of one value to another."; let description = [{ @@ -858,10 +805,7 @@ //===----------------------------------------------------------------------===// // Operator: sub //===----------------------------------------------------------------------===// -def Tosa_SubOp : Tosa_Op<"sub", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_SubOp : Tosa_ElemWiseBinaryOp<"sub"> { let summary = "Elementwise subtraction operator"; let description = [{ @@ -927,10 +871,7 @@ //===----------------------------------------------------------------------===// // Operator: abs //===----------------------------------------------------------------------===// -def Tosa_AbsOp : Tosa_Op<"abs", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_AbsOp : Tosa_ElemWiseUnaryOp<"abs"> { let summary = "Elementwise abs op"; let description = [{ @@ -951,10 +892,7 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_not //===----------------------------------------------------------------------===// -def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_BitwiseNotOp : Tosa_ElemWiseUnaryOp<"bitwise_not"> { let summary = "Bitwise NOT operator"; let description = [{ @@ -973,10 +911,7 @@ //===----------------------------------------------------------------------===// // Operator: ceil //===----------------------------------------------------------------------===// -def Tosa_CeilOp : Tosa_Op<"ceil", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_CeilOp : Tosa_ElemWiseUnaryOp<"ceil"> { let summary = "Elementwise ceil op"; let description = [{ @@ -995,10 +930,7 @@ //===----------------------------------------------------------------------===// // Operator: clz //===----------------------------------------------------------------------===// -def Tosa_ClzOp : Tosa_Op<"clz", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ClzOp : Tosa_ElemWiseUnaryOp<"clz"> { let summary = "Elementwise count leading zero op"; let description = [{ @@ -1017,10 +949,7 @@ //===----------------------------------------------------------------------===// // Operator: exp //===----------------------------------------------------------------------===// -def Tosa_ExpOp : Tosa_Op<"exp", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ExpOp : Tosa_ElemWiseUnaryOp<"exp"> { let summary = "Elementwise exp op"; let description = [{ @@ -1041,10 +970,7 @@ //===----------------------------------------------------------------------===// // Operator: floor //===----------------------------------------------------------------------===// -def Tosa_FloorOp : Tosa_Op<"floor", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_FloorOp : Tosa_ElemWiseUnaryOp<"floor"> { let summary = "Elementwise floor op"; let description = [{ @@ -1063,10 +989,7 @@ //===----------------------------------------------------------------------===// // Operator: log //===----------------------------------------------------------------------===// -def Tosa_LogOp : Tosa_Op<"log", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_LogOp : Tosa_ElemWiseUnaryOp<"log"> { let summary = "Elementwise log op"; let description = [{ @@ -1087,10 +1010,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_not //===----------------------------------------------------------------------===// -def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [ - DeclareOpInterfaceMethods, - Pure, SameOperandsAndResultType]> { +def Tosa_LogicalNotOp : Tosa_ElemWiseUnaryOp<"logical_not"> { let summary = "Returns the truth value of NOT x element-wise."; let description = [{ @@ -1109,10 +1029,7 @@ //===----------------------------------------------------------------------===// // Operator: negate //===----------------------------------------------------------------------===// -def Tosa_NegateOp : Tosa_Op<"negate", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_NegateOp : Tosa_ElemWiseUnaryOp<"negate"> { let summary = "Elementwise negate op"; let description = [{ @@ -1136,10 +1053,7 @@ //===----------------------------------------------------------------------===// // Operator: reciprocal //===----------------------------------------------------------------------===// -def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ReciprocalOp : Tosa_ElemWiseUnaryOp<"reciprocal"> { let summary = "Elementwise reciprocal op"; let description = [{ @@ -1159,10 +1073,7 @@ //===----------------------------------------------------------------------===// // Operator: rsqrt //===----------------------------------------------------------------------===// -def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_RsqrtOp : Tosa_ElemWiseUnaryOp<"rsqrt"> { let summary = "Elementwise 1/sqrt op"; let description = [{ @@ -1219,7 +1130,7 @@ // Operator: equal //===----------------------------------------------------------------------===// def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape, - Commutative, Pure]> { + Commutative, Pure, SameOperandsElementType]> { let summary = "Returns the truth value of (x == y) element-wise."; let description = [{ @@ -1250,7 +1161,7 @@ def Tosa_GreaterOp : Tosa_Op<"greater", [ DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { + ResultsBroadcastableShape, Pure, SameOperandsElementType]> { let summary = "Returns the truth value of (x > y) element-wise."; let description = [{ @@ -1275,7 +1186,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { + ResultsBroadcastableShape, Pure, SameOperandsElementType]> { let summary = "Returns the truth value of (x >= y) element-wise."; let description = [{ diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -282,10 +282,8 @@ // CHECK-LABEL: @test_simple_i16 func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () { // CHECK: linalg.generic - // CHECK: arith.extsi - // CHECK: arith.extsi // CHECK: arith.muli - %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32> + %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi16> return }