diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1143,6 +1143,8 @@ /// InferTypeOpInterface. static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); }]; + + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1191,6 +1193,8 @@ let results = (outs I1Tensor:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -675,6 +675,11 @@ APIntFoldGreater() {} APInt operator()(APInt l, APInt r) { return APInt(1, l.sgt(r)); } }; + +struct APIntFoldGreaterEqual { + APIntFoldGreaterEqual() {} + APInt operator()(APInt l, APInt r) { return APInt(1, l.sge(r)); } +}; } // namespace OpFoldResult GreaterOp::fold(ArrayRef operands) { @@ -689,6 +694,42 @@ lhsAttr, rhsAttr, resultTy); } +OpFoldResult GreaterEqualOp::fold(ArrayRef operands) { + auto resultTy = getType().dyn_cast(); + auto lhsAttr = operands[0].dyn_cast_or_null(); + auto rhsAttr = operands[1].dyn_cast_or_null(); + + if (!lhsAttr || !rhsAttr) + return {}; + + return BinaryFolder>>( + lhsAttr, rhsAttr, resultTy); +} + +OpFoldResult EqualOp::fold(ArrayRef operands) { + auto resultTy = getType().dyn_cast(); + auto lhsAttr = operands[0].dyn_cast_or_null(); + auto rhsAttr = operands[1].dyn_cast_or_null(); + Value lhs = getInput1(); + Value rhs = getInput2(); + auto lhsTy = lhs.getType().cast(); + + // If we are comparing an integer value to itself it is always true. We can + // not do this with float due to float values. + if (lhsTy.getElementType().isa() && resultTy.hasStaticShape() && + lhs == rhs) { + return DenseElementsAttr::get(resultTy, true); + } + + if (!lhsAttr || !rhsAttr) + return {}; + + return BinaryFolder>, + ComparisonFold>>(lhsAttr, rhsAttr, + resultTy); +} + OpFoldResult CastOp::fold(ArrayRef operands) { if (getInput().getType() == getType()) return getInput(); diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -350,50 +350,108 @@ // ----- -// CHECK-LABEL: @fold_greater_splat_f32_true -func.func @fold_greater_splat_f32_true() -> tensor<10xi1> { - %one = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> - %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} - // CHECK: return %[[BOOL]] - return %add : tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_greater_splat_f32_false -func.func @fold_greater_splat_f32_false() -> tensor<10xi1> { - %one = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> - %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> - // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} - // CHECK: return %[[BOOL]] - return %add : tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_greater_splat_i32_false -func.func @fold_greater_splat_i32_false() -> tensor<10xi1> { - %one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %two = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> - %add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} - // CHECK: return %[[BOOL]] - return %add : tensor<10xi1> -} - -// ----- - -// CHECK-LABEL: @fold_greater_splat_i32_true -func.func @fold_greater_splat_i32_true() -> tensor<10xi1> { - %one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> - %two = "tosa.const"() {value = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32> - %add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> - // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} - // CHECK: return %[[BOOL]] - return %add : tensor<10xi1> +// CHECK-LABEL: @fold_greater_splat_f32 +func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %1 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %true = "tosa.greater"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + %false = "tosa.greater"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_splat_i32 +func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> + %2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %3 = "tosa.const"() {value = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32> + %false = "tosa.greater"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + %true = "tosa.greater"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK: return %[[FALSE]], %[[TRUE]] + return %false, %true : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_eq_splat_f32 +func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %1 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %true = "tosa.greater_equal"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + %false = "tosa.greater_equal"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_greater_eq_splat_i32 +func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> + %2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %3 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %true = "tosa.greater_equal"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + %false = "tosa.greater_equal"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_eq_splat_f32 +func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %1 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32> + %2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32> + %3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> + %true = "tosa.equal"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + %false = "tosa.equal"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_eq_splat_i32 +func.func @fold_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) { + %0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32> + %2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %3 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32> + %true = "tosa.equal"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + %false = "tosa.equal"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + // CHECK: return %[[TRUE]], %[[FALSE]] + return %true, %false : tensor<10xi1>, tensor<10xi1> +} + +// ----- + +// CHECK-LABEL: @fold_eq_i32 +func.func @fold_eq_i32(%arg0 : tensor<10xi32>) -> (tensor<10xi1>) { + // CHECK: %[[TRUE:.+]] = "tosa.const"() {value = dense : tensor<10xi1>} + %0 = "tosa.equal"(%arg0, %arg0) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1> + // CHECK: return %[[TRUE]] + return %0 : tensor<10xi1> } // -----