diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -737,6 +737,7 @@ ); let hasCanonicalizer = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -501,6 +501,87 @@ lhsTy); } +DenseElementsAttr MulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, + RankedTensorType ty, int32_t shift) { + if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { + if (ty.getElementType().isa()) { + APInt l = lhs.getSplatValue(); + APInt r = rhs.getSplatValue(); + + if (shift == 0) { + return DenseElementsAttr::get(ty, l * r); + } + + auto bitwidth = ty.getElementType().getIntOrFloatBitWidth(); + l = l.sext(bitwidth * 2); + r = r.sext(bitwidth * 2); + auto result = l * r; + result.lshrInPlace(shift); + result = result.trunc(bitwidth); + return DenseElementsAttr::get(ty, result); + } + + if (ty.getElementType().isa()) { + APFloat l = lhs.getSplatValue(); + APFloat r = rhs.getSplatValue(); + APFloat result = l * r; + return DenseElementsAttr::get(ty, result); + } + } + + return {}; +} + +OpFoldResult MulOp::fold(ArrayRef operands) { + auto lhs = getInput1(); + auto rhs = getInput2(); + auto lhsTy = lhs.getType().dyn_cast(); + auto rhsTy = rhs.getType().dyn_cast(); + auto resultTy = getType().dyn_cast(); + if (!lhsTy || !rhsTy || !resultTy) + return {}; + if (lhsTy != rhsTy) + return {}; + + auto resultETy = resultTy.getElementType(); + auto lhsAttr = operands[0].dyn_cast_or_null(); + auto rhsAttr = operands[1].dyn_cast_or_null(); + + if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { + auto val = lhsAttr.getSplatValue(); + if (val.isZero()) + return lhsAttr; + if (val.isExactlyValue(1.0)) + return rhs; + } + + if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { + auto val = rhsAttr.getSplatValue(); + if (val.isZero()) + return rhsAttr; + if (val.isExactlyValue(1.0)) + return lhs; + } + + if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { + auto val = lhsAttr.getSplatValue(); + if (val.isZero()) + return lhsAttr; + if (val.getSExtValue() == (1 << getShift())) + return rhs; + } + + if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { + auto val = rhsAttr.getSplatValue(); + if (val.isZero()) + return rhsAttr; + if (val.getSExtValue() == (1 << getShift())) + return lhs; + } + + return MulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift()); +} + OpFoldResult CastOp::fold(ArrayRef operands) { if (getInput().getType() == getType()) return getInput(); diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -164,6 +164,115 @@ // ----- + +// CHECK-LABEL: @fold_mul_zero_rhs_f32 +func.func @fold_mul_zero_rhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor + // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor} + %mul = "tosa.mul"(%arg0, %zero) {shift = 0 : i32} : (tensor, tensor) -> tensor + // CHECK: return %[[ZERO]] + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_zero_lhs_f32 +func.func @fold_mul_zero_lhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor + // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor} + %mul = "tosa.mul"(%zero, %arg0) {shift = 0 : i32} : (tensor, tensor) -> tensor + // CHECK: return %[[ZERO]] + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_zero_rhs_i32 +func.func @fold_mul_zero_rhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor} + %mul = "tosa.mul"(%arg0, %zero) {shift = 0 : i32} : (tensor, tensor) -> tensor + // CHECK: return %[[ZERO]] + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_zero_lhs_i32 +func.func @fold_mul_zero_lhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor} + %mul = "tosa.mul"(%zero, %arg0) {shift = 0 : i32} : (tensor, tensor) -> tensor + // CHECK: return %[[ZERO]] + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_one_rhs_f32 +func.func @fold_mul_one_rhs_f32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {value = dense<1.0> : tensor} : () -> tensor + %mul = "tosa.mul"(%arg0, %one) {shift = 0 : i32} : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_one_lhs_f32 +func.func @fold_mul_one_lhs_f32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {value = dense<1.0> : tensor} : () -> tensor + %mul = "tosa.mul"(%one, %arg0) {shift = 0 : i32} : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_one_rhs_i32 +func.func @fold_mul_one_rhs_i32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {value = dense<64> : tensor} : () -> tensor + %mul = "tosa.mul"(%arg0, %one) {shift = 6 : i32} : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_one_lhs_i32 +func.func @fold_mul_one_lhs_i32(%arg0: tensor) -> tensor { + %one = "tosa.const"() {value = dense<64> : tensor} : () -> tensor + %mul = "tosa.mul"(%one, %arg0) {shift = 6 : i32} : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %mul : tensor +} + +// ----- + +// CHECK-LABEL: @fold_mul_splat_i8 +func.func @fold_mul_splat_i8() -> tensor<10xi8> { + %one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8> + %two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8> + %mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi8> + // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<68> : tensor<10xi8>} + // CHECK: return %[[THREE]] + return %mul : tensor<10xi8> +} + +// ----- + +// CHECK-LABEL: @fold_mul_splat_f32 +func.func @fold_mul_splat_f32() -> tensor<10xf32> { + %one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32> + %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %mul = "tosa.mul"(%one, %two) {shift = 0 : i32} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<6.000000e+00> : tensor<10xf32>} + // CHECK: return %[[THREE]] + return %mul : tensor<10xf32> +} + +// ----- + // CHECK-LABEL: @slice_splat func.func @slice_splat() -> tensor<1x1x1xi32> { // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}