diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1332,11 +1332,38 @@ } } + // Move constant to the right side. + if (operands[0] && !operands[1]) { + // Do not use invertPredicate, as it will change eq to ne and vice versa. + using Pred = CmpIPredicate; + const std::pair invPreds[] = { + {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge}, + {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult}, + {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq}, + {Pred::ne, Pred::ne}, + }; + Pred origPred = getPredicate(); + for (auto pred : invPreds) { + if (origPred == pred.first) { + setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second)); + Value lhs = getLhs(); + Value rhs = getRhs(); + getLhsMutable().assign(rhs); + getRhsMutable().assign(lhs); + return getResult(); + } + } + llvm_unreachable("unknown cmpi predicate kind"); + } + auto lhs = operands.front().dyn_cast_or_null(); - auto rhs = operands.back().dyn_cast_or_null(); - if (!lhs || !rhs) + if (!lhs) return {}; + // We are moving constants to the right side so if lhs is constant rhs is + // guaranteed to be a constant. + auto rhs = operands.back().cast(); + auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); return BoolAttr::get(getContext(), val); } diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -127,6 +127,41 @@ // ----- +// Test case: Move constant to the right side. +// CHECK-LABEL: @cmpi_const_right( +// CHECK-SAME: %[[ARG:.*]]: +// CHECK: %[[C:.*]] = arith.constant 1 : i64 +// CHECK: %[[R0:.*]] = arith.cmpi eq, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R1:.*]] = arith.cmpi sge, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R2:.*]] = arith.cmpi sle, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R3:.*]] = arith.cmpi uge, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R4:.*]] = arith.cmpi ule, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R5:.*]] = arith.cmpi ne, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R6:.*]] = arith.cmpi sgt, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R7:.*]] = arith.cmpi slt, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R8:.*]] = arith.cmpi ugt, %[[ARG]], %[[C]] : i64 +// CHECK: %[[R9:.*]] = arith.cmpi ult, %[[ARG]], %[[C]] : i64 +// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], +// CHECK-SAME: %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] +func.func @cmpi_const_right(%arg0: i64) + -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) { + %c1 = arith.constant 1 : i64 + %0 = arith.cmpi eq, %c1, %arg0 : i64 + %1 = arith.cmpi sle, %c1, %arg0 : i64 + %2 = arith.cmpi sge, %c1, %arg0 : i64 + %3 = arith.cmpi ule, %c1, %arg0 : i64 + %4 = arith.cmpi uge, %c1, %arg0 : i64 + %5 = arith.cmpi ne, %c1, %arg0 : i64 + %6 = arith.cmpi slt, %c1, %arg0 : i64 + %7 = arith.cmpi sgt, %c1, %arg0 : i64 + %8 = arith.cmpi ult, %c1, %arg0 : i64 + %9 = arith.cmpi ugt, %c1, %arg0 : i64 + return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 + : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 +} + +// ----- + // CHECK-LABEL: @cmpOfExtSI // CHECK-NEXT: return %arg0 func.func @cmpOfExtSI(%arg0: i1) -> i1 { diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -819,10 +819,10 @@ // CHECK: %[[c0:.*]] = arith.constant 0 : index // CHECK: %[[c1:.*]] = arith.constant 1 : index // CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1> -// CHECK: %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index +// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index // CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1> -// CHECK: %[[T4:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index +// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index // CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1> // CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1> // CHECK: return %[[T6]] : vector<2x3xi1> @@ -842,13 +842,13 @@ // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index // CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1> -// CHECK: %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[B]] : index +// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[B]], %[[c0]] : index // CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1> // CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1> -// CHECK: %[[T4:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index +// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index // CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1> // CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1> -// CHECK: %[[T7:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index +// CHECK: %[[T7:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index // CHECK: %[[T8:.*]] = arith.select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1> // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1> // CHECK: return %[[T9]] : vector<2x1x7xi1> diff --git a/mlir/test/Transforms/sccp-structured.mlir b/mlir/test/Transforms/sccp-structured.mlir --- a/mlir/test/Transforms/sccp-structured.mlir +++ b/mlir/test/Transforms/sccp-structured.mlir @@ -141,7 +141,7 @@ %c2_i32 = arith.constant 2 : i32 %0 = scf.while (%arg2 = %c2_i32) : (i32) -> (i32) { - %1 = arith.cmpi slt, %arg2, %arg1 : i32 + %1 = arith.cmpi sgt, %arg1, %arg2 : i32 scf.condition(%1) %arg2 : i32 } do { ^bb0(%arg2: i32):