diff --git a/flang/test/Fir/boxproc.fir b/flang/test/Fir/boxproc.fir --- a/flang/test/Fir/boxproc.fir +++ b/flang/test/Fir/boxproc.fir @@ -93,7 +93,7 @@ // CHECK: %[[VAL_4:.*]] = load { ptr, i64 }, ptr %[[VAL_3]], align 8 // CHECK: %[[VAL_5:.*]] = extractvalue { ptr, i64 } %[[VAL_4]], 0 // CHECK: %[[VAL_6:.*]] = extractvalue { ptr, i64 } %[[VAL_4]], 1 -// CHECK: %[[VAL_8:.*]] = icmp slt i64 10, %[[VAL_6]] +// CHECK: %[[VAL_8:.*]] = icmp sgt i64 %[[VAL_6]], 10 // CHECK: %[[VAL_9:.*]] = select i1 %[[VAL_8]], i64 10, i64 %[[VAL_6]] // CHECK: call void @llvm.memmove.p0.p0.i64(ptr %[[VAL_0]], ptr %[[VAL_5]], i64 %[[VAL_9]], i1 false) // CHECK: %[[VAL_10:.*]] = sub i64 10, %[[VAL_9]] @@ -129,7 +129,7 @@ // CHECK: %[[VAL_27:.*]] = load [1 x i8], ptr %[[VAL_26]], align 1 // CHECK: %[[VAL_29:.*]] = getelementptr [1 x i8], ptr %[[VAL_14]], i64 %[[VAL_18]] // CHECK: store [1 x i8] %[[VAL_27]], ptr %[[VAL_29]], align 1 -// CHECK: %[[VAL_30:.*]] = icmp slt i64 40, %[[VAL_13]] +// CHECK: %[[VAL_30:.*]] = icmp sgt i64 %[[VAL_13]], 40 // CHECK: %[[VAL_31:.*]] = select i1 %[[VAL_30]], i64 40, i64 %[[VAL_13]] // CHECK: call void @llvm.memmove.p0.p0.i64(ptr %[[VAL_0]], ptr %[[VAL_14]], i64 %[[VAL_31]], i1 false) // CHECK: %[[VAL_32:.*]] = sub i64 40, %[[VAL_31]] diff --git a/flang/test/Lower/array-character.f90 b/flang/test/Lower/array-character.f90 --- a/flang/test/Lower/array-character.f90 +++ b/flang/test/Lower/array-character.f90 @@ -22,7 +22,7 @@ ! CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_7]] : index ! CHECK: %[[VAL_17:.*]] = fir.array_coor %[[VAL_11]](%[[VAL_12]]) %[[VAL_16]] typeparams %[[VAL_10]]#1 : (!fir.ref>>, !fir.shape<1>, index, index) -> !fir.ref> ! CHECK: %[[VAL_18:.*]] = fir.array_coor %[[VAL_9]](%[[VAL_12]]) %[[VAL_16]] : (!fir.ref>>, !fir.shape<1>, index) -> !fir.ref> - ! CHECK: %[[VAL_19:.*]] = arith.cmpi slt, %[[VAL_5]], %[[VAL_10]]#1 : index + ! CHECK: %[[VAL_19:.*]] = arith.cmpi sgt, %[[VAL_10]]#1, %[[VAL_5]] : index ! CHECK: %[[VAL_20:.*]] = arith.select %[[VAL_19]], %[[VAL_5]], %[[VAL_10]]#1 : index ! CHECK: %[[VAL_21:.*]] = fir.convert %[[VAL_20]] : (index) -> i64 ! CHECK: %[[VAL_22:.*]] = fir.convert %[[VAL_18]] : (!fir.ref>) -> !fir.ref diff --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90 --- a/flang/test/Lower/host-associated.f90 +++ b/flang/test/Lower/host-associated.f90 @@ -540,7 +540,7 @@ ! CHECK: %[[VAL_10:.*]] = fir.load %[[VAL_9]] : !fir.ref> ! CHECK: %[[VAL_11:.*]]:2 = fir.unboxchar %[[VAL_10]] : (!fir.boxchar<1>) -> (!fir.ref>, index) ! CHECK: %[[VAL_12:.*]] = fir.convert %[[VAL_0]] : (!fir.ref>) -> !fir.ref> -! CHECK: %[[VAL_13:.*]] = arith.cmpi slt, %[[VAL_4]], %[[VAL_11]]#1 : index +! CHECK: %[[VAL_13:.*]] = arith.cmpi sgt, %[[VAL_11]]#1, %[[VAL_4]] : index ! CHECK: %[[VAL_14:.*]] = arith.select %[[VAL_13]], %[[VAL_4]], %[[VAL_11]]#1 : index ! CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (index) -> i64 ! CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_12]] : (!fir.ref>) -> !fir.ref @@ -607,7 +607,7 @@ ! CHECK: %[[VAL_34:.*]] = arith.subi %[[VAL_25]], %[[VAL_6]] : index ! CHECK: br ^bb1(%[[VAL_33]], %[[VAL_34]] : index, index) ! CHECK: ^bb3: -! CHECK: %[[VAL_35:.*]] = arith.cmpi slt, %[[VAL_3]], %[[VAL_19]] : index +! CHECK: %[[VAL_35:.*]] = arith.cmpi sgt, %[[VAL_19]], %[[VAL_3]] : index ! CHECK: %[[VAL_36:.*]] = arith.select %[[VAL_35]], %[[VAL_3]], %[[VAL_19]] : index ! CHECK: %[[VAL_37:.*]] = fir.convert %[[VAL_36]] : (index) -> i64 ! CHECK: %[[VAL_38:.*]] = fir.convert %[[VAL_9]] : (!fir.ref>) -> !fir.ref diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1332,11 +1332,38 @@ } } + // Move constant to the right side. + if (operands[0] && !operands[1]) { + // Do not use invertPredicate, as it will change eq to ne and vice versa. + using Pred = CmpIPredicate; + const std::pair invPreds[] = { + {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge}, + {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult}, + {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq}, + {Pred::ne, Pred::ne}, + }; + Pred origPred = getPredicate(); + for (auto pred : invPreds) { + if (origPred == pred.first) { + setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second)); + Value lhs = getLhs(); + Value rhs = getRhs(); + getLhsMutable().assign(rhs); + getRhsMutable().assign(lhs); + return getResult(); + } + } + llvm_unreachable("unknown cmpi predicate kind"); + } + auto lhs = operands.front().dyn_cast_or_null(); - auto rhs = operands.back().dyn_cast_or_null(); - if (!lhs || !rhs) + if (!lhs) return {}; + // We are moving constants to the right side; So if lhs is constant rhs is + // guaranteed to be a constant. + auto rhs = operands.back().cast(); + auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); return BoolAttr::get(getContext(), val); } diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -127,6 +127,41 @@ // ----- +// Test case: Move constant to the right side. +// CHECK-LABEL: @cmpi_const_right( +// CHECK-SAME: %[[ARG:.*]]: +// CHECK: %[[C:.*]] = arith.constant 1 : i64 +// CHECK: %[[R0:.*]] = arith.cmpi eq, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R1:.*]] = arith.cmpi sge, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R2:.*]] = arith.cmpi sle, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R3:.*]] = arith.cmpi uge, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R4:.*]] = arith.cmpi ule, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R5:.*]] = arith.cmpi ne, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R6:.*]] = arith.cmpi sgt, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R7:.*]] = arith.cmpi slt, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R8:.*]] = arith.cmpi ugt, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R9:.*]] = arith.cmpi ult, %[[ARG]], %[[C]] : i64 +// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], +// CHECK-SAME: %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] +func.func @cmpi_const_right(%arg0: i64) + -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) { + %c1 = arith.constant 1 : i64 + %0 = arith.cmpi eq, %c1, %arg0 : i64 + %1 = arith.cmpi sle, %c1, %arg0 : i64 + %2 = arith.cmpi sge, %c1, %arg0 : i64 + %3 = arith.cmpi ule, %c1, %arg0 : i64 + %4 = arith.cmpi uge, %c1, %arg0 : i64 + %5 = arith.cmpi ne, %c1, %arg0 : i64 + %6 = arith.cmpi slt, %c1, %arg0 : i64 + %7 = arith.cmpi sgt, %c1, %arg0 : i64 + %8 = arith.cmpi ult, %c1, %arg0 : i64 + %9 = arith.cmpi ugt, %c1, %arg0 : i64 + return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 + : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 +} + +// ----- + // CHECK-LABEL: @cmpOfExtSI // CHECK-NEXT: return %arg0 func.func @cmpOfExtSI(%arg0: i1) -> i1 { diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -819,10 +819,10 @@ // CHECK: %[[c0:.*]] = arith.constant 0 : index // CHECK: %[[c1:.*]] = arith.constant 1 : index // CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1> -// CHECK: %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index +// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index // CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1> -// CHECK: %[[T4:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index +// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index // CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1> // CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1> // CHECK: return %[[T6]] : vector<2x3xi1> @@ -842,13 +842,13 @@ // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index // CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1> -// CHECK: %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[B]] : index +// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[B]], %[[c0]] : index // CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1> -// CHECK: %[[T4:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index +// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index // CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1> // CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1> -// CHECK: %[[T7:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index +// CHECK: %[[T7:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index // CHECK: %[[T8:.*]] = arith.select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1> // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1> // CHECK: return %[[T9]] : vector<2x1x7xi1> diff --git a/mlir/test/Transforms/sccp-structured.mlir b/mlir/test/Transforms/sccp-structured.mlir --- a/mlir/test/Transforms/sccp-structured.mlir +++ b/mlir/test/Transforms/sccp-structured.mlir @@ -141,7 +141,7 @@ %c2_i32 = arith.constant 2 : i32 %0 = scf.while (%arg2 = %c2_i32) : (i32) -> (i32) { - %1 = arith.cmpi slt, %arg2, %arg1 : i32 + %1 = arith.cmpi sgt, %arg1, %arg2 : i32 scf.condition(%1) %arg2 : i32 } do { ^bb0(%arg2: i32):