diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1360,15 +1360,38 @@ //===----------------------------------------------------------------------===// OpFoldResult SelectOp::fold(ArrayRef operands) { + auto trueVal = getTrueValue(); + auto falseVal = getFalseValue(); + if (trueVal == falseVal) + return trueVal; + auto condition = getCondition(); // select true, %0, %1 => %0 if (matchPattern(condition, m_One())) - return getTrueValue(); + return trueVal; // select false, %0, %1 => %1 if (matchPattern(condition, m_Zero())) - return getFalseValue(); + return falseVal; + + if (auto cmp = dyn_cast_or_null(condition.getDefiningOp())) { + auto pred = cmp.predicate(); + if (pred == mlir::CmpIPredicate::eq || pred == mlir::CmpIPredicate::ne) { + auto cmpLhs = cmp.lhs(); + auto cmpRhs = cmp.rhs(); + + // %0 = cmpi eq, %arg0, %arg1 + // %1 = select %0, %arg0, %arg1 => %arg1 + + // %0 = cmpi ne, %arg0, %arg1 + // %1 = select %0, %arg0, %arg1 => %arg0 + + if ((cmpLhs == trueVal && cmpRhs == falseVal) || + (cmpRhs == trueVal && cmpLhs == falseVal)) + return pred == mlir::CmpIPredicate::ne ? trueVal : falseVal; + } + } return nullptr; } diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -339,3 +339,32 @@ // CHECK: %[[GENERATE:.+]] = tensor.generate // CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]] into %[[GENERATE]] // CHECK: return %[[RESULT]] + +// ----- + +// CHECK-LABEL: @select_same_val +// CHECK: return %arg1 +func @select_same_val(%arg0: i1, %arg1: i64) -> i64 { + %0 = select %arg0, %arg1, %arg1 : i64 + return %0 : i64 +} + +// ----- + +// CHECK-LABEL: @select_cmp_eq_select +// CHECK: return %arg1 +func @select_cmp_eq_select(%arg0: i64, %arg1: i64) -> i64 { + %0 = cmpi eq, %arg0, %arg1 : i64 + %1 = select %0, %arg0, %arg1 : i64 + return %1 : i64 +} + +// ----- + +// CHECK-LABEL: @select_cmp_ne_select +// CHECK: return %arg0 +func @select_cmp_ne_select(%arg0: i64, %arg1: i64) -> i64 { + %0 = cmpi ne, %arg0, %arg1 : i64 + %1 = select %0, %arg0, %arg1 : i64 + return %1 : i64 +}