diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1161,6 +1161,8 @@ let results = (outs I1Tensor:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -441,20 +441,25 @@ template DenseElementsAttr BinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, - RankedTensorType ty) { + RankedTensorType returnTy) { if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { - if (ty.getElementType().isa()) { + auto lETy = lhs.getType().cast().getElementType(); + auto rETy = rhs.getType().cast().getElementType(); + if (lETy != rETy) + return {}; + + if (lETy.isa()) { APInt l = lhs.getSplatValue(); APInt r = rhs.getSplatValue(); - APInt result = IntFolder()(l, r); - return DenseElementsAttr::get(ty, result); + auto result = IntFolder()(l, r); + return DenseElementsAttr::get(returnTy, result); } - if (ty.getElementType().isa()) { + if (lETy.isa()) { APFloat l = lhs.getSplatValue(); APFloat r = rhs.getSplatValue(); - APFloat result = FloatFolder()(l, r); - return DenseElementsAttr::get(ty, result); + auto result = FloatFolder()(l, r); + return DenseElementsAttr::get(returnTy, result); } } @@ -501,6 +506,37 @@ lhsTy); } +namespace { +template +struct ComparisonFold { + ComparisonFold() {} + APInt operator()(const APInt &l, const APInt &r) { + return APInt(1, Cmp()(l, r)); + } + + APInt operator()(const APFloat &l, const APFloat &r) { + return APInt(1, Cmp()(l, r)); + } +}; + +struct APIntFoldGreater { + APIntFoldGreater() {} + APInt operator()(APInt l, APInt r) { return APInt(1, l.sgt(r)); } +}; +} // namespace + +OpFoldResult GreaterOp::fold(ArrayRef operands) { + auto resultTy = getType().dyn_cast(); + auto lhsAttr = operands[0].dyn_cast_or_null(); + auto rhsAttr = operands[1].dyn_cast_or_null(); + + if (!lhsAttr || !rhsAttr) + return {}; + + return BinaryFolder>>( + lhsAttr, rhsAttr, resultTy); +} + OpFoldResult CastOp::fold(ArrayRef operands) { if (getInput().getType() == getType()) return getInput(); diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -164,6 +164,54 @@ // ----- +// CHECK-LABEL: @fold_greater_splat_f32_true +func.func @fold_greater_splat_f32_true() -> tensor<10xi1> { + %one = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK: return %[[BOOL]] + return %add : tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_splat_f32_false +func.func @fold_greater_splat_f32_false() -> tensor<10xi1> { + %one = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK: return %[[BOOL]] + return %add : tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_splat_i32_false +func.func @fold_greater_splat_i32_false() -> tensor<10xi1> { + %one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> + %add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK: return %[[BOOL]] + return %add : tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_splat_i32_true +func.func @fold_greater_splat_i32_true() -> tensor<10xi1> { + %one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %two = "tosa.const"() {value = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32> + %add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK: return %[[BOOL]] + return %add : tensor<10xi1> +} + +// ----- + // CHECK-LABEL: @slice_splat func.func @slice_splat() -> tensor<1x1x1xi32> { // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}