diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -15,7 +15,9 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Traits.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -35,6 +37,43 @@ #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" } // namespace tosa + +namespace OpTrait { +namespace tosa { + +// This trait verifies the element type amoung operands and result +// of multiplication match tosa specification. +template +class MulCompatibleOperandsAndResultElementType + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + auto elementType = getElementTypeOrSelf(op->getResult(0)); + + // In case of floating point type, it requires the same element + // type for all operands and result. + if (llvm::isa(elementType)) + return impl::verifySameOperandsAndResultElementType(op); + + if (llvm::isa(elementType)) { + if (!elementType.isInteger(32)) + return op->emitOpError("requires the element type of result to be i32"); + + if (getElementTypeOrSelf(op->getOperand(0)) != + getElementTypeOrSelf(op->getOperand(1))) + return op->emitOpError( + "requires the same element type for all operands"); + return success(); + } + + return failure(); + } +}; + +} // namespace tosa +} // namespace OpTrait + } // namespace mlir #define GET_ATTRDEF_CLASSES diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -756,10 +756,19 @@ ); } +def MulCompatibleOperandsAndResultElementType : + NativeOpTrait<"MulCompatibleOperandsAndResultElementType"> { + let cppNamespace = "mlir::OpTrait::tosa"; +} + //===----------------------------------------------------------------------===// // Operator: mul //===----------------------------------------------------------------------===// -def Tosa_MulOp : Tosa_ElemWiseBinaryOp<"mul", [Commutative]> { +def Tosa_MulOp : Tosa_Op<"mul", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, Pure, Commutative, + MulCompatibleOperandsAndResultElementType]> { let summary = "Multiplication operator"; let description = [{ diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -294,8 +294,10 @@ // CHECK-LABEL: @test_simple_i16 func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () { // CHECK: linalg.generic + // CHECK: arith.extsi + // CHECK: arith.extsi // CHECK: arith.muli - %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi16> + %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32> return } diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -294,13 +294,13 @@ // ----- // CHECK-LABEL: @fold_mul_splat_i8 -func.func @fold_mul_splat_i8() -> tensor<10xi8> { +func.func @fold_mul_splat_i8() -> tensor<10xi32> { %one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8> %two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8> - %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi8> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi8>} + %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi32>} // CHECK: return %[[THREE]] - return %mul : tensor<10xi8> + return %mul : tensor<10xi32> } // -----