diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -844,6 +844,30 @@ OpFoldResult CmpIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "cmpi takes two arguments"); + auto lhsOperand = lhs(); + auto rhsOperand = rhs(); + if (!lhsOperand || !rhsOperand) + return {}; + + // Operands of the compare are equal. + if (lhsOperand == rhsOperand) { + // Cmp Ops which are false when operands are the same. + if (getPredicate() == CmpIPredicate::sgt || + getPredicate() == CmpIPredicate::slt || + getPredicate() == CmpIPredicate::ugt || + getPredicate() == CmpIPredicate::ult || + getPredicate() == CmpIPredicate::ne) + return BoolAttr::get(false, getContext()); + + // Cmp Ops which are true when operands are the same. + if (getPredicate() == CmpIPredicate::eq || + getPredicate() == CmpIPredicate::sle || + getPredicate() == CmpIPredicate::ule || + getPredicate() == CmpIPredicate::sge || + getPredicate() == CmpIPredicate::uge) + return BoolAttr::get(true, getContext()); + } + auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); if (!lhs || !rhs) diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -0,0 +1,100 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s +func @test_simplify_sgt(%arg0: index, %arg1 : index, %arg2 : index) -> index { + %0 = cmpi "sgt", %arg0, %arg0 : index + %1 = select %0, %arg1, %arg2 : index + return %1 : index + + // CHECK-LABEL: test_simplify_sgt + // CHECK-NOT: cmpi "sgt", %arg0, %arg0 : index + // CHECK: return %arg2 : index +} + +func @test_simplify_slt(%arg0: index, %arg1 : index, %arg2 : index) -> index { + %0 = cmpi "slt", %arg0, %arg0 : index + %1 = select %0, %arg1, %arg2 : index + return %1 : index + + // CHECK-LABEL: test_simplify_slt + // CHECK-NOT: cmpi "slt", %arg0, %arg0 : index + // CHECK: return %arg2 : index +} + +func @test_simplify_ugt(%arg0: index, %arg1 : index, %arg2 : index) -> index { + %0 = cmpi "ugt", %arg0, %arg0 : index + %1 = select %0, %arg1, %arg2 : index + return %1 : index + + // CHECK-LABEL: test_simplify_ugt + // CHECK-NOT: cmpi "ugt", %arg0, %arg0 : index + // CHECK: return %arg2 : index +} + +func @test_simplify_ult(%arg0: index, %arg1 : index, %arg2 : index) -> index { + %0 = cmpi "ult", %arg0, %arg0 : index + %1 = select %0, %arg1, %arg2 : index + return %1 : index + + // CHECK-LABEL: test_simplify_ult + // CHECK-NOT: cmpi "ult", %arg0, %arg0 : index + // CHECK: return %arg2 : index +} + +func @test_simplify_ne(%arg0: index, %arg1 : index, %arg2 : index) -> index { + %0 = cmpi "ne", %arg0, %arg0 : index + %1 = select %0, %arg1, %arg2 : index + return %1 : index + + // CHECK-LABEL: test_simplify_ne + // CHECK-NOT: cmpi "ne", %arg0, %arg0 : index + // CHECK: return %arg2 : index +} + +func @test_simplify_eq(%arg0: index, %arg1 : index, %arg2 : index) -> index { + %0 = cmpi "eq", %arg0, %arg0 : index + %1 = select %0, %arg1, %arg2 : index + return %1 : index + + // CHECK-LABEL: test_simplify_eq + // CHECK-NOT: cmpi "eq", %arg0, %arg0 : index + // CHECK: return %arg1 : index +} + +func @test_simplify_sle(%arg0: index, %arg1 : index, %arg2 : index) -> index { + %0 = cmpi "sle", %arg0, %arg0 : index + %1 = select %0, %arg1, %arg2 : index + return %1 : index + + // CHECK-LABEL: test_simplify_sle + // CHECK-NOT: cmpi "sle", %arg0, %arg0 : index + // CHECK: return %arg1 : index +} + +func @test_simplify_ule(%arg0: index, %arg1 : index, %arg2 : index) -> index { + %0 = cmpi "ule", %arg0, %arg0 : index + %1 = select %0, %arg1, %arg2 : index + return %1 : index + + // CHECK-LABEL: test_simplify_ule + // CHECK-NOT: cmpi "ule", %arg0, %arg0 : index + // CHECK: return %arg1 : index +} + +func @test_simplify_sge(%arg0: index, %arg1 : index, %arg2 : index) -> index { + %0 = cmpi "sge", %arg0, %arg0 : index + %1 = select %0, %arg1, %arg2 : index + return %1 : index + + // CHECK-LABEL: test_simplify_sge + // CHECK-NOT: cmpi "sge", %arg0, %arg0 : index + // CHECK: return %arg1 : index +} + +func @test_simplify_uge(%arg0: index, %arg1 : index, %arg2 : index) -> index { + %0 = cmpi "uge", %arg0, %arg0 : index + %1 = select %0, %arg1, %arg2 : index + return %1 : index + + // CHECK-LABEL: test_simplify_uge + // CHECK-NOT: cmpi "uge", %arg0, %arg0 : index + // CHECK: return %arg1 : index +}