diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -916,17 +916,41 @@ llvm_unreachable("unknown comparison predicate"); } +// Returns true if the predicate is true for two equal operands. +static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) { + switch (predicate) { + case CmpIPredicate::eq: + case CmpIPredicate::sle: + case CmpIPredicate::sge: + case CmpIPredicate::ule: + case CmpIPredicate::uge: + return true; + case CmpIPredicate::ne: + case CmpIPredicate::slt: + case CmpIPredicate::sgt: + case CmpIPredicate::ult: + case CmpIPredicate::ugt: + return false; + } + llvm_unreachable("unknown comparison predicate"); +} + // Constant folding hook for comparisons. OpFoldResult CmpIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "cmpi takes two arguments"); + if (lhs() == rhs()) { + auto val = applyCmpPredicateToEqualOperands(getPredicate()); + return BoolAttr::get(val, getContext()); + } + auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); if (!lhs || !rhs) return {}; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); - return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); + return BoolAttr::get(val, getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -59,3 +59,25 @@ %1 = dim %0, %c3 : tensor<2x?x4x?x5xindex> return %1 : index } + +// Test case: Folding of comparisons with equal operands. +// CHECK-LABEL: @cmpi_equal_operands +// CHECK-DAG: %[[T:.*]] = constant true +// CHECK-DAG: %[[F:.*]] = constant false +// CHECK: return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]], +// CHECK-SAME: %[[F]], %[[F]], %[[F]], %[[F]], %[[F]] +func @cmpi_equal_operands(%arg0: i64) + -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) { + %0 = cmpi "eq", %arg0, %arg0 : i64 + %1 = cmpi "sle", %arg0, %arg0 : i64 + %2 = cmpi "sge", %arg0, %arg0 : i64 + %3 = cmpi "ule", %arg0, %arg0 : i64 + %4 = cmpi "uge", %arg0, %arg0 : i64 + %5 = cmpi "ne", %arg0, %arg0 : i64 + %6 = cmpi "slt", %arg0, %arg0 : i64 + %7 = cmpi "sgt", %arg0, %arg0 : i64 + %8 = cmpi "ult", %arg0, %arg0 : i64 + %9 = cmpi "ugt", %arg0, %arg0 : i64 + return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 + : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 +}