diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -15,7 +15,9 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Traits.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -35,6 +37,49 @@ #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" } // namespace tosa + +namespace OpTrait { +namespace tosa { + +// This trait verifies if the element type amoung operands and result +// of multiplication match tosa specification. +template +class MulOperandsAndResultElementType + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + auto resElemType = getElementTypeOrSelf(op->getResult(0)); + + // In cases of floating point type, op requires the same element + // type for all operands and result. + if (llvm::isa(resElemType)) + return impl::verifySameOperandsAndResultElementType(op); + + if (auto resIntType = resElemType.dyn_cast()) { + IntegerType lhsIntType = + getElementTypeOrSelf(op->getOperand(0)).cast(); + IntegerType rhsIntType = + getElementTypeOrSelf(op->getOperand(1)).cast(); + if (lhsIntType != rhsIntType) + return op->emitOpError( + "requires the same element type for all operands"); + + // Though the spec requires the element type of result to be i32, a more + // relaxed way is provided at dialect level for easier cooperating with + // other dialects. + if (lhsIntType.getWidth() > resIntType.getWidth()) + return op->emitOpError("invalid data type size for operands or result"); + + return success(); + } + + return failure(); + } +}; + +} // namespace tosa +} // namespace OpTrait + } // namespace mlir #define GET_ATTRDEF_CLASSES diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -756,10 +756,19 @@ ); } +def MulOperandsAndResultElementType : + NativeOpTrait<"MulOperandsAndResultElementType"> { + let cppNamespace = "mlir::OpTrait::tosa"; +} + //===----------------------------------------------------------------------===// // Operator: mul //===----------------------------------------------------------------------===// -def Tosa_MulOp : Tosa_ElemWiseBinaryOp<"mul", [Commutative]> { +def Tosa_MulOp : Tosa_Op<"mul", [ + DeclareOpInterfaceMethods, + ResultsBroadcastableShape, Pure, Commutative, + MulOperandsAndResultElementType]> { let summary = "Multiplication operator"; let description = [{ diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -294,8 +294,10 @@ // CHECK-LABEL: @test_simple_i16 func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () { // CHECK: linalg.generic + // CHECK: arith.extsi + // CHECK: arith.extsi // CHECK: arith.muli - %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi16> + %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32> return } diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -294,13 +294,13 @@ // ----- // CHECK-LABEL: @fold_mul_splat_i8 -func.func @fold_mul_splat_i8() -> tensor<10xi8> { +func.func @fold_mul_splat_i8() -> tensor<10xi32> { %one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8> %two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8> - %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi8> - // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi8>} + %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi32> + // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi32>} // CHECK: return %[[THREE]] - return %mul : tensor<10xi8> + return %mul : tensor<10xi32> } // ----- diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -229,6 +229,13 @@ return %0 : tensor<13x21x3xf32> } +// ----- +// CHECK-LABEL: mul +func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x1x3xi16>) -> tensor<13x21x3xi16> { + %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 } : (tensor<13x21x3xi16>, tensor<13x1x3xi16>) -> tensor<13x21x3xi16> + return %0 : tensor<13x21x3xi16> +} + // ----- // CHECK-LABEL: pow func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {