diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -444,6 +444,7 @@ ); let hasCanonicalizer = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -26,6 +26,8 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" +#include + using namespace mlir; using namespace mlir::tosa; @@ -437,6 +439,88 @@ // Operator Folders. //===----------------------------------------------------------------------===// +template +DenseElementsAttr BinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, + RankedTensorType ty) { + if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { + if (ty.getElementType().isa()) { + APInt l = lhs.getSplatValue(); + APInt r = rhs.getSplatValue(); + APInt result = IntFolder()(l, r); + return DenseElementsAttr::get(ty, result); + } + + if (ty.getElementType().isa()) { + APFloat l = lhs.getSplatValue(); + APFloat r = rhs.getSplatValue(); + APFloat result = FloatFolder()(l, r); + return DenseElementsAttr::get(ty, result); + } + } + + static constexpr int64_t kFoldLimit = 16; + if (ty.getNumElements() < kFoldLimit) { + if (ty.getElementType().isa()) { + llvm::SmallVector values; + for (auto it : + llvm::zip(lhs.getValues(), rhs.getValues())) { + values.push_back(FloatFolder()(std::get<0>(it), std::get<1>(it))); + } + return DenseElementsAttr::get(ty, values); + } + + if (ty.getElementType().isa()) { + llvm::SmallVector values; + for (auto it : + llvm::zip(lhs.getValues(), rhs.getValues())) { + values.push_back(IntFolder()(std::get<0>(it), std::get<1>(it))); + } + return DenseElementsAttr::get(ty, values); + } + } + return {}; +} + +OpFoldResult AddOp::fold(ArrayRef operands) { + auto lhsTy = getInput1().getType().dyn_cast(); + auto rhsTy = getInput2().getType().dyn_cast(); + auto resultTy = getType().dyn_cast(); + if (!lhsTy || !rhsTy || !resultTy) + return {}; + if (lhsTy != rhsTy) + return {}; + + auto resultETy = resultTy.getElementType(); + auto lhsAttr = operands[0].dyn_cast_or_null(); + auto rhsAttr = operands[1].dyn_cast_or_null(); + + if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { + if (lhsAttr.getSplatValue().isZero()) + return getInput2(); + } + + if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { + if (rhsAttr.getSplatValue().isZero()) + return getInput1(); + } + + if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { + if (lhsAttr.getSplatValue().isZero()) + return getInput2(); + } + + if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { + if (rhsAttr.getSplatValue().isZero()) + return getInput1(); + } + + if (!lhsAttr || !rhsAttr) + return {}; + + return BinaryFolder, std::plus>(lhsAttr, rhsAttr, + lhsTy); +} + OpFoldResult CastOp::fold(ArrayRef operands) { if (getInput().getType() == getType()) return getInput(); diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -97,3 +97,89 @@ %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> return %0: tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> } + +// ----- + +// CHECK-LABEL: @fold_add_zero_rhs_f32 +func.func @fold_add_zero_rhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor + %add = "tosa.add"(%arg0, %zero) : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %add : tensor +} + +// ----- + +// CHECK-LABEL: @fold_add_zero_lhs_f32 +func.func @fold_add_zero_lhs_f32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor + %add = "tosa.add"(%zero, %arg0) : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %add : tensor +} + +// ----- + +// CHECK-LABEL: @fold_add_zero_rhs_i32 +func.func @fold_add_zero_rhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + %add = "tosa.add"(%arg0, %zero) : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %add : tensor +} + +// ----- + +// CHECK-LABEL: @fold_add_zero_lhs_i32 +func.func @fold_add_zero_lhs_i32(%arg0: tensor) -> tensor { + %zero = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + %add = "tosa.add"(%zero, %arg0) : (tensor, tensor) -> tensor + // CHECK: return %arg0 + return %add : tensor +} + +// ----- + +// CHECK-LABEL: @fold_add_splat_i32 +func.func @fold_add_splat_i32() -> tensor<10xi32> { + %one = "tosa.const"() {value = dense<1> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {value = dense<2> : tensor<10xi32>} : () -> tensor<10xi32> + %add = "tosa.add"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> + // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<3> : tensor<10xi32>} + // CHECK: return %[[THREE]] + return %add : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @fold_add_splat_f32 +func.func @fold_add_splat_f32() -> tensor<10xf32> { + %one = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %add = "tosa.add"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<3.000000e+00> : tensor<10xf32>} + // CHECK: return %[[THREE]] + return %add : tensor<10xf32> +} + +// ----- + +// CHECK-LABEL: @fold_add_i32 +func.func @fold_add_i32() -> tensor<4xi32> { + %one = "tosa.const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<4xi32> + %two = "tosa.const"() {value = dense<[5, 6, 7, 8]> : tensor<4xi32>} : () -> tensor<4xi32> + %add = "tosa.add"(%one, %two) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: %[[FOLD:.+]] = "tosa.const"() {value = dense<[6, 8, 10, 12]> : tensor<4xi32>} + // CHECK: return %[[FOLD]] + return %add : tensor<4xi32> +} + +// CHECK-LABEL: @fold_add_f32 +func.func @fold_add_f32() -> tensor<2xf32> { + %one = "tosa.const"() {value = dense<[1.0, 2.0]> : tensor<2xf32>} : () -> tensor<2xf32> + %two = "tosa.const"() {value = dense<[5.0, 6.0]> : tensor<2xf32>} : () -> tensor<2xf32> + %add = "tosa.add"(%one, %two) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + // CHECK: %[[FOLD:.+]] = "tosa.const"() {value = dense<[6.000000e+00, 8.000000e+00]> : tensor<2xf32>} + // CHECK: return %[[FOLD]] + return %add : tensor<2xf32> +} \ No newline at end of file