diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -844,6 +844,30 @@ OpFoldResult CmpIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "cmpi takes two arguments"); + auto lhsOperand = lhs(); + auto rhsOperand = rhs(); + if (!lhsOperand || !rhsOperand) + return {}; + + // Operands of the compare are equal. + if (lhsOperand == rhsOperand) { + // Cmp Ops which are false when operands are the same. + if (getPredicate() == CmpIPredicate::sgt || + getPredicate() == CmpIPredicate::slt || + getPredicate() == CmpIPredicate::ugt || + getPredicate() == CmpIPredicate::ult || + getPredicate() == CmpIPredicate::ne) + return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, 0)); + + // Cmp Ops which are true when operands are the same. + if (getPredicate() == CmpIPredicate::eq || + getPredicate() == CmpIPredicate::sle || + getPredicate() == CmpIPredicate::ule || + getPredicate() == CmpIPredicate::sge || + getPredicate() == CmpIPredicate::uge) + return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, 1)); + } + auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); if (!lhs || !rhs) diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s +func @test_simplify_1(%arg0: memref) -> memref { + %c0 = constant 0 : index + %0 = dim %arg0, %c0 : memref + %c1 = constant 1 : index + %1 = dim %arg0, %c1 : memref + %3 = cmpi "sgt", %0, %0 : index + %4 = select %3, %0, %1 : index + %5 = alloc(%4) : memref + return %5 : memref + + // CHECK-LABEL: test_simplify_1 + // CHECK: [[CONST1:%.+]] = constant 1 : index + // CHECK: [[DIM2:%.+]] = dim %arg0, [[CONST1]] : memref + // CHECK-NOT: cmpi "sgt", %0, %0 : index + // CHECK: [[RES:%.+]] = alloc([[DIM2]]) : memref +} + +func @test_simplify_2(%arg0: memref) -> memref { + %c0 = constant 0 : index + %0 = dim %arg0, %c0 : memref + %c1 = constant 1 : index + %1 = dim %arg0, %c1 : memref + %3 = cmpi "eq", %0, %0 : index + %4 = select %3, %0, %1 : index + %5 = alloc(%4) : memref + return %5 : memref + + // CHECK-LABEL: test_simplify_2 + // CHECK: [[CONST0:%.+]] = constant 0 : index + // CHECK: [[DIM1:%.+]] = dim %arg0, [[CONST0]] : memref + // CHECK-NOT: cmpi "sgt", %0, %0 : index + // CHECK: [[RES:%.+]] = alloc([[DIM1]]) : memref +}