diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -844,6 +844,25 @@ OpFoldResult CmpIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "cmpi takes two arguments"); + // Operands of the compare are equal. + if (lhs() == rhs()) { + // Cmp Ops which are false when operands are the same. + if (getPredicate() == CmpIPredicate::sgt || + getPredicate() == CmpIPredicate::slt || + getPredicate() == CmpIPredicate::ugt || + getPredicate() == CmpIPredicate::ult || + getPredicate() == CmpIPredicate::ne) + return BoolAttr::get(false, getContext()); + + // Cmp Ops which are true when operands are the same. + if (getPredicate() == CmpIPredicate::eq || + getPredicate() == CmpIPredicate::sle || + getPredicate() == CmpIPredicate::ule || + getPredicate() == CmpIPredicate::sge || + getPredicate() == CmpIPredicate::uge) + return BoolAttr::get(true, getContext()); + } + auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); if (!lhs || !rhs) diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -0,0 +1,101 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s + +// CHECK-LABEL: test_simplify_sgt +func @test_simplify_sgt(%arg0: index) -> i1 { + %0 = cmpi "sgt", %arg0, %arg0 : index + return %0 : i1 + + // CHECK-NOT: cmpi + // CHECK: [[FALSE:%.+]] = constant false + // CHECK: return [[FALSE]] : i1 +} + +// CHECK-LABEL: test_simplify_slt +func @test_simplify_slt(%arg0: index) -> i1 { + %0 = cmpi "slt", %arg0, %arg0 : index + return %0 : i1 + + // CHECK-NOT: cmpi + // CHECK: [[FALSE:%.+]] = constant false + // CHECK: return [[FALSE]] : i1 +} + +// CHECK-LABEL: test_simplify_ugt +func @test_simplify_ugt(%arg0: index) -> i1 { + %0 = cmpi "ugt", %arg0, %arg0 : index + return %0 : i1 + + // CHECK-NOT: cmpi + // CHECK: [[FALSE:%.+]] = constant false + // CHECK: return [[FALSE]] : i1 +} + +// CHECK-LABEL: test_simplify_ult +func @test_simplify_ult(%arg0: index) -> i1 { + %0 = cmpi "ult", %arg0, %arg0 : index + return %0 : i1 + + // CHECK-NOT: cmpi + // CHECK: [[FALSE:%.+]] = constant false + // CHECK: return [[FALSE]] : i1 +} + +// CHECK-LABEL: test_simplify_ne +func @test_simplify_ne(%arg0: index) -> i1 { + %0 = cmpi "ne", %arg0, %arg0 : index + return %0 : i1 + + // CHECK-NOT: cmpi + // CHECK: [[FALSE:%.+]] = constant false + // CHECK: return [[FALSE]] : i1 +} + +// CHECK-LABEL: test_simplify_eq +func @test_simplify_eq(%arg0: index) -> i1 { + %0 = cmpi "eq", %arg0, %arg0 : index + return %0 : i1 + + // CHECK-NOT: cmpi + // CHECK: [[TRUE:%.+]] = constant true + // CHECK: return [[TRUE]] : i1 +} + +// CHECK-LABEL: test_simplify_sle +func @test_simplify_sle(%arg0: index) -> i1 { + %0 = cmpi "sle", %arg0, %arg0 : index + return %0 : i1 + + // CHECK-NOT: cmpi + // CHECK: [[TRUE:%.+]] = constant true + // CHECK: return [[TRUE]] : i1 +} + +// CHECK-LABEL: test_simplify_ule +func @test_simplify_ule(%arg0: index) -> i1 { + %0 = cmpi "ule", %arg0, %arg0 : index + return %0 : i1 + + // CHECK-NOT: cmpi + // CHECK: [[TRUE:%.+]] = constant true + // CHECK: return [[TRUE]] : i1 +} + +// CHECK-LABEL: test_simplify_sge +func @test_simplify_sge(%arg0: index) -> i1 { + %0 = cmpi "sge", %arg0, %arg0 : index + return %0 : i1 + + // CHECK-NOT: cmpi + // CHECK: [[TRUE:%.+]] = constant true + // CHECK: return [[TRUE]] : i1 +} + +// CHECK-LABEL: test_simplify_uge +func @test_simplify_uge(%arg0: index) -> i1 { + %0 = cmpi "uge", %arg0, %arg0 : index + return %0 : i1 + + // CHECK-NOT: cmpi + // CHECK: [[TRUE:%.+]] = constant true + // CHECK: return [[TRUE]] : i1 +}