diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -215,4 +215,18 @@ Op { } +class Tosa_ElemWiseUnaryOp traits = []> : + Tosa_Op, + Pure, SameOperandsAndResultElementType])> { +} + +class Tosa_ElemWiseBinaryOp traits = []> : + Tosa_Op, + ResultsBroadcastableShape, Pure, SameOperandsAndResultElementType])> { +} + #endif // TOSA_OP_BASE diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -457,10 +457,7 @@ //===----------------------------------------------------------------------===// // Operator: add //===----------------------------------------------------------------------===// -def Tosa_AddOp : Tosa_Op<"add", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_AddOp : Tosa_ElemWiseBinaryOp<"add", [Commutative]> { let summary = "Elementwise addition operator"; let description = [{ @@ -483,10 +480,7 @@ //===----------------------------------------------------------------------===// // Operator: arithmetic_right_shift //===----------------------------------------------------------------------===// -def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_ArithmeticRightShiftOp : Tosa_ElemWiseBinaryOp<"arithmetic_right_shift"> { let summary = "Elementwise Arithmetic Right Shift"; let description = [{ @@ -508,10 +502,7 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_and //===----------------------------------------------------------------------===// -def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_BitwiseAndOp : Tosa_ElemWiseBinaryOp<"bitwise_and", [Commutative]> { let summary = "Bitwise AND operator"; let description = [{ @@ -532,10 +523,7 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_or //===----------------------------------------------------------------------===// -def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_BitwiseOrOp : Tosa_ElemWiseBinaryOp<"bitwise_or", [Commutative]> { let summary = "Bitwise OR operator"; let description = [{ @@ -556,10 +544,7 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_xor //===----------------------------------------------------------------------===// -def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_BitwiseXorOp : Tosa_ElemWiseBinaryOp<"bitwise_xor", [Commutative]> { let summary = "Bitwise XOR operator"; let description = [{ @@ -580,10 +565,7 @@ //===----------------------------------------------------------------------===// // Operator: div //===----------------------------------------------------------------------===// -def Tosa_DivOp : Tosa_Op<"div", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_DivOp : Tosa_ElemWiseBinaryOp<"div"> { let summary = "Integer divide operator"; let description = [{ @@ -606,10 +588,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_and //===----------------------------------------------------------------------===// -def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Commutative, Pure]> { +def Tosa_LogicalAndOp : Tosa_ElemWiseBinaryOp<"logical_and", [Commutative]> { let summary = "Returns the truth value of x AND y element-wise."; let description = [{ @@ -630,10 +609,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_left_shift //===----------------------------------------------------------------------===// -def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_LogicalLeftShiftOp : Tosa_ElemWiseBinaryOp<"logical_left_shift"> { let summary = "Elementwise Logical Left Shift"; let description = [{ @@ -654,10 +630,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_right_shift //===----------------------------------------------------------------------===// -def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_LogicalRightShiftOp : Tosa_ElemWiseBinaryOp<"logical_right_shift"> { let summary = "Elementwise Logical Right Shift"; let description = [{ @@ -678,10 +651,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_or //===----------------------------------------------------------------------===// -def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Commutative, Pure]> { +def Tosa_LogicalOrOp : Tosa_ElemWiseBinaryOp<"logical_or", [Commutative]> { let summary = "Returns the truth value of x OR y element-wise."; let description = [{ @@ -702,10 +672,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_xor //===----------------------------------------------------------------------===// -def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Commutative, Pure]> { +def Tosa_LogicalXorOp : Tosa_ElemWiseBinaryOp<"logical_xor", [Commutative]> { let summary = "Returns the truth value of x XOR y element-wise."; let description = [{ @@ -726,10 +693,7 @@ //===----------------------------------------------------------------------===// // Operator: maximum //===----------------------------------------------------------------------===// -def Tosa_MaximumOp : Tosa_Op<"maximum", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_MaximumOp : Tosa_ElemWiseBinaryOp<"maximum", [Commutative]> { let summary = "Elementwise Maximum"; let description = [{ @@ -750,10 +714,7 @@ //===----------------------------------------------------------------------===// // Operator: minimum //===----------------------------------------------------------------------===// -def Tosa_MinimumOp : Tosa_Op<"minimum", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { +def Tosa_MinimumOp : Tosa_ElemWiseBinaryOp<"minimum", [Commutative]> { let summary = "Elementwise Minimum"; let description = [{ @@ -777,12 +738,13 @@ def Tosa_MulOp : Tosa_Op<"mul", [ DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure, Commutative]> { + ResultsBroadcastableShape, Pure, Commutative, SameOperandsElementType]> { let summary = "Multiplication operator"; let description = [{ Elementwise multiplication (Hadamard product) of input1 and input2. Axis of size 1 will be broadcast, as necessary. + i8/i16 input type can be promoted to i32 result type. }]; let arguments = (ins @@ -801,10 +763,7 @@ //===----------------------------------------------------------------------===// // Operator: pow //===----------------------------------------------------------------------===// -def Tosa_PowOp : Tosa_Op<"pow", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_PowOp : Tosa_ElemWiseBinaryOp<"pow"> { let summary = "Computes the power of one value to another."; let description = [{ @@ -825,10 +784,7 @@ //===----------------------------------------------------------------------===// // Operator: sub //===----------------------------------------------------------------------===// -def Tosa_SubOp : Tosa_Op<"sub", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_SubOp : Tosa_ElemWiseBinaryOp<"sub"> { let summary = "Elementwise subtraction operator"; let description = [{ @@ -894,10 +850,7 @@ //===----------------------------------------------------------------------===// // Operator: abs //===----------------------------------------------------------------------===// -def Tosa_AbsOp : Tosa_Op<"abs", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_AbsOp : Tosa_ElemWiseUnaryOp<"abs"> { let summary = "Elementwise abs op"; let description = [{ @@ -916,10 +869,7 @@ //===----------------------------------------------------------------------===// // Operator: bitwise_not //===----------------------------------------------------------------------===// -def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ - DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { +def Tosa_BitwiseNotOp : Tosa_ElemWiseUnaryOp<"bitwise_not"> { let summary = "Bitwise NOT operator"; let description = [{ @@ -938,10 +888,7 @@ //===----------------------------------------------------------------------===// // Operator: ceil //===----------------------------------------------------------------------===// -def Tosa_CeilOp : Tosa_Op<"ceil", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_CeilOp : Tosa_ElemWiseUnaryOp<"ceil"> { let summary = "Elementwise ceil op"; let description = [{ @@ -960,10 +907,7 @@ //===----------------------------------------------------------------------===// // Operator: clz //===----------------------------------------------------------------------===// -def Tosa_ClzOp : Tosa_Op<"clz", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ClzOp : Tosa_ElemWiseUnaryOp<"clz"> { let summary = "Elementwise count leading zero op"; let description = [{ @@ -982,10 +926,7 @@ //===----------------------------------------------------------------------===// // Operator: exp //===----------------------------------------------------------------------===// -def Tosa_ExpOp : Tosa_Op<"exp", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ExpOp : Tosa_ElemWiseUnaryOp<"exp"> { let summary = "Elementwise exp op"; let description = [{ @@ -1006,10 +947,7 @@ //===----------------------------------------------------------------------===// // Operator: floor //===----------------------------------------------------------------------===// -def Tosa_FloorOp : Tosa_Op<"floor", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_FloorOp : Tosa_ElemWiseUnaryOp<"floor"> { let summary = "Elementwise floor op"; let description = [{ @@ -1028,10 +966,7 @@ //===----------------------------------------------------------------------===// // Operator: log //===----------------------------------------------------------------------===// -def Tosa_LogOp : Tosa_Op<"log", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_LogOp : Tosa_ElemWiseUnaryOp<"log"> { let summary = "Elementwise log op"; let description = [{ @@ -1052,10 +987,7 @@ //===----------------------------------------------------------------------===// // Operator: logical_not //===----------------------------------------------------------------------===// -def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [ - DeclareOpInterfaceMethods, - Pure, SameOperandsAndResultType]> { +def Tosa_LogicalNotOp : Tosa_ElemWiseUnaryOp<"logical_not"> { let summary = "Returns the truth value of NOT x element-wise."; let description = [{ @@ -1074,10 +1006,7 @@ //===----------------------------------------------------------------------===// // Operator: negate //===----------------------------------------------------------------------===// -def Tosa_NegateOp : Tosa_Op<"negate", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_NegateOp : Tosa_ElemWiseUnaryOp<"negate"> { let summary = "Elementwise negate op"; let description = [{ @@ -1099,10 +1028,7 @@ //===----------------------------------------------------------------------===// // Operator: reciprocal //===----------------------------------------------------------------------===// -def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_ReciprocalOp : Tosa_ElemWiseUnaryOp<"reciprocal"> { let summary = "Elementwise reciprocal op"; let description = [{ @@ -1122,10 +1048,7 @@ //===----------------------------------------------------------------------===// // Operator: rsqrt //===----------------------------------------------------------------------===// -def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [ - DeclareOpInterfaceMethods, - Pure]> { +def Tosa_RsqrtOp : Tosa_ElemWiseUnaryOp<"rsqrt"> { let summary = "Elementwise 1/sqrt op"; let description = [{ @@ -1182,7 +1105,7 @@ // Operator: equal //===----------------------------------------------------------------------===// def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape, - Commutative, Pure]> { + Commutative, Pure, SameOperandsElementType]> { let summary = "Returns the truth value of (x == y) element-wise."; let description = [{ @@ -1213,7 +1136,7 @@ def Tosa_GreaterOp : Tosa_Op<"greater", [ DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { + ResultsBroadcastableShape, Pure, SameOperandsElementType]> { let summary = "Returns the truth value of (x > y) element-wise."; let description = [{ @@ -1238,7 +1161,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ DeclareOpInterfaceMethods, - ResultsBroadcastableShape, Pure]> { + ResultsBroadcastableShape, Pure, SameOperandsElementType]> { let summary = "Returns the truth value of (x >= y) element-wise."; let description = [{