diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -215,4 +215,18 @@ Op { } +class Tosa_ElemWiseUnaryOp traits = []> : + Tosa_Op, + Pure, SameOperandsAndResultElementType])> { +} + +class Tosa_ElemWiseBinaryOp traits = []> : + Tosa_Op, + ResultsBroadcastableShape, Pure, SameOperandsAndResultElementType])> { +} + #endif // TOSA_OP_BASE diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -368,10 +368,7 @@ //===----------------------------------------------------------------------===// // Operator: clamp //===----------------------------------------------------------------------===// -def Tosa_ClampOp : Tosa_Op<"clamp", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ClampOp : Tosa_ElemWiseBinaryOp<"clamp"> { let summary = "Computes clamp(features, min, max)."; let description = [{ @@ -400,10 +397,7 @@ //===----------------------------------------------------------------------===// // Operator: sigmoid //===----------------------------------------------------------------------===// -def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_SigmoidOp : Tosa_ElemWiseBinaryOp<"sigmoid"> { let summary = "Computes elementwise sigmoid of input."; let description = [{ @@ -426,10 +420,7 @@ //===----------------------------------------------------------------------===// // Operator: tanh //===----------------------------------------------------------------------===// -def Tosa_TanhOp : Tosa_Op<"tanh", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_TanhOp : Tosa_ElemWiseBinaryOp<"tanh"> { let summary = "Computes elementwise hyperbolic tangent of input"; let description = [{ @@ -457,10 +448,7 @@ //===----------------------------------------------------------------------===// // Operator: add //===----------------------------------------------------------------------===// -def Tosa_AddOp : Tosa_Op<"add", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_AddOp : Tosa_ElemWiseBinaryOp<"add", [Commutative]> { let summary = "Elementwise addition operator"; let description = [{ @@ -483,10 +471,7 @@ //===----------------------------------------------------------------------===// // Operator: arithmetic_right_shift //===----------------------------------------------------------------------===// -def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_ArithmeticRightShiftOp : Tosa_ElemWiseBinaryOp<"arithmetic_right_shift"> { let summary = "Elementwise Arithmetic Right Shift"; let description = [{ @@ -508,10 +493,7 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_and //===----------------------------------------------------------------------===// -def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_BitwiseAndOp : Tosa_ElemWiseBinaryOp<"bitwise_and", [Commutative]> { let summary = "Bitwise AND operator"; let description = [{ @@ -532,10 +514,7 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_or //===----------------------------------------------------------------------===// -def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_BitwiseOrOp : Tosa_ElemWiseBinaryOp<"bitwise_or", [Commutative]> { let summary = "Bitwise OR operator"; let description = [{ @@ -556,10 +535,7 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_xor //===----------------------------------------------------------------------===// -def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_BitwiseXorOp : Tosa_ElemWiseBinaryOp<"bitwise_xor", [Commutative]> { let summary = "Bitwise XOR operator"; let description = [{ @@ -580,10 +556,7 @@ //===----------------------------------------------------------------------===// // Operator: div //===----------------------------------------------------------------------===// -def Tosa_DivOp : Tosa_Op<"div", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_DivOp : Tosa_ElemWiseBinaryOp<"div"> { let summary = "Integer divide operator"; let description = [{ @@ -606,10 +579,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_and //===----------------------------------------------------------------------===// -def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Commutative, Pure]> { +def Tosa_LogicalAndOp : Tosa_ElemWiseBinaryOp<"logical_and", [Commutative]> { let summary = "Returns the truth value of x AND y element-wise."; let description = [{ @@ -630,10 +600,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_left_shift //===----------------------------------------------------------------------===// -def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_LogicalLeftShiftOp : Tosa_ElemWiseBinaryOp<"logical_left_shift"> { let summary = "Elementwise Logical Left Shift"; let description = [{ @@ -654,10 +621,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_right_shift //===----------------------------------------------------------------------===// -def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_LogicalRightShiftOp : Tosa_ElemWiseBinaryOp<"logical_right_shift"> { let summary = "Elementwise Logical Right Shift"; let description = [{ @@ -678,10 +642,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_or //===----------------------------------------------------------------------===// -def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Commutative, Pure]> { +def Tosa_LogicalOrOp : Tosa_ElemWiseBinaryOp<"logical_or", [Commutative]> { let summary = "Returns the truth value of x OR y element-wise."; let description = [{ @@ -702,10 +663,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_xor //===----------------------------------------------------------------------===// -def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Commutative, Pure]> { +def Tosa_LogicalXorOp : Tosa_ElemWiseBinaryOp<"logical_xor", [Commutative]> { let summary = "Returns the truth value of x XOR y element-wise."; let description = [{ @@ -726,10 +684,7 @@ //===----------------------------------------------------------------------===// // Operator: maximum //===----------------------------------------------------------------------===// -def Tosa_MaximumOp : Tosa_Op<"maximum", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_MaximumOp : Tosa_ElemWiseBinaryOp<"maximum", [Commutative]> { let summary = "Elementwise Maximum"; let description = [{ @@ -750,10 +705,7 @@ //===----------------------------------------------------------------------===// // Operator: minimum //===----------------------------------------------------------------------===// -def Tosa_MinimumOp : Tosa_Op<"minimum", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_MinimumOp : Tosa_ElemWiseBinaryOp<"minimum", [Commutative]> { let summary = "Elementwise Minimum"; let description = [{ @@ -777,12 +729,13 @@ def Tosa_MulOp : Tosa_Op<"mul", [ DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { + ResultsBroadcastableShape, Pure, Commutative, SameOperandsElementType]> { let summary = "Multiplication operator"; let description = [{ Elementwise multiplication (Hadamard product) of input1 and input2. Axis of size 1 will be broadcast, as necessary. + i8/i16 input type can be promoted to i32 result type. }]; let arguments = (ins @@ -801,10 +754,7 @@ //===----------------------------------------------------------------------===// // Operator: pow //===----------------------------------------------------------------------===// -def Tosa_PowOp : Tosa_Op<"pow", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_PowOp : Tosa_ElemWiseBinaryOp<"pow"> { let summary = "Computes the power of one value to another."; let description = [{ @@ -825,10 +775,7 @@ //===----------------------------------------------------------------------===// // Operator: sub //===----------------------------------------------------------------------===// -def Tosa_SubOp : Tosa_Op<"sub", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_SubOp : Tosa_ElemWiseBinaryOp<"sub"> { let summary = "Elementwise subtraction operator"; let description = [{ @@ -894,10 +841,7 @@ //===----------------------------------------------------------------------===// // Operator: abs //===----------------------------------------------------------------------===// -def Tosa_AbsOp : Tosa_Op<"abs", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_AbsOp : Tosa_ElemWiseUnaryOp<"abs"> { let summary = "Elementwise abs op"; let description = [{ @@ -916,10 +860,7 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_not //===----------------------------------------------------------------------===// -def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_BitwiseNotOp : Tosa_ElemWiseUnaryOp<"bitwise_not"> { let summary = "Bitwise NOT operator"; let description = [{ @@ -938,10 +879,7 @@ //===----------------------------------------------------------------------===// // Operator: ceil //===----------------------------------------------------------------------===// -def Tosa_CeilOp : Tosa_Op<"ceil", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_CeilOp : Tosa_ElemWiseUnaryOp<"ceil"> { let summary = "Elementwise ceil op"; let description = [{ @@ -960,10 +898,7 @@ //===----------------------------------------------------------------------===// // Operator: clz //===----------------------------------------------------------------------===// -def Tosa_ClzOp : Tosa_Op<"clz", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ClzOp : Tosa_ElemWiseUnaryOp<"clz"> { let summary = "Elementwise count leading zero op"; let description = [{ @@ -982,10 +917,7 @@ //===----------------------------------------------------------------------===// // Operator: exp //===----------------------------------------------------------------------===// -def Tosa_ExpOp : Tosa_Op<"exp", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ExpOp : Tosa_ElemWiseUnaryOp<"exp"> { let summary = "Elementwise exp op"; let description = [{ @@ -1006,10 +938,7 @@ //===----------------------------------------------------------------------===// // Operator: floor //===----------------------------------------------------------------------===// -def Tosa_FloorOp : Tosa_Op<"floor", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_FloorOp : Tosa_ElemWiseUnaryOp<"floor"> { let summary = "Elementwise floor op"; let description = [{ @@ -1028,10 +957,7 @@ //===----------------------------------------------------------------------===// // Operator: log //===----------------------------------------------------------------------===// -def Tosa_LogOp : Tosa_Op<"log", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_LogOp : Tosa_ElemWiseUnaryOp<"log"> { let summary = "Elementwise log op"; let description = [{ @@ -1052,10 +978,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_not //===----------------------------------------------------------------------===// -def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [ - DeclareOpInterfaceMethods, - Pure, SameOperandsAndResultType]> { +def Tosa_LogicalNotOp : Tosa_ElemWiseUnaryOp<"logical_not"> { let summary = "Returns the truth value of NOT x element-wise."; let description = [{ @@ -1074,10 +997,7 @@ //===----------------------------------------------------------------------===// // Operator: negate //===----------------------------------------------------------------------===// -def Tosa_NegateOp : Tosa_Op<"negate", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_NegateOp : Tosa_ElemWiseUnaryOp<"negate"> { let summary = "Elementwise negate op"; let description = [{ @@ -1099,10 +1019,7 @@ //===----------------------------------------------------------------------===// // Operator: reciprocal //===----------------------------------------------------------------------===// -def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ReciprocalOp : Tosa_ElemWiseUnaryOp<"reciprocal"> { let summary = "Elementwise reciprocal op"; let description = [{ @@ -1122,10 +1039,7 @@ //===----------------------------------------------------------------------===// // Operator: rsqrt //===----------------------------------------------------------------------===// -def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_RsqrtOp : Tosa_ElemWiseUnaryOp<"rsqrt"> { let summary = "Elementwise 1/sqrt op"; let description = [{ @@ -1182,7 +1096,7 @@ // Operator: equal //===----------------------------------------------------------------------===// def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape, - Commutative, Pure]> { + Commutative, Pure, SameOperandsElementType]> { let summary = "Returns the truth value of (x == y) element-wise."; let description = [{ @@ -1213,7 +1127,7 @@ def Tosa_GreaterOp : Tosa_Op<"greater", [ DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { + ResultsBroadcastableShape, Pure, SameOperandsElementType]> { let summary = "Returns the truth value of (x > y) element-wise."; let description = [{ @@ -1238,7 +1152,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { + ResultsBroadcastableShape, Pure, SameOperandsElementType]> { let summary = "Returns the truth value of (x >= y) element-wise."; let description = [{