diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1027,6 +1027,8 @@ let hasFolder = 1; + let hasCanonicalizer = 1; + let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -853,6 +853,52 @@ return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); } +namespace { +/// cmpi %op %op +/// -> {0/1} +struct SimplifyCmpIOpWithIdenticalOperands : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CmpIOp cmpIOp, + PatternRewriter &rewriter) const override { + auto loc = cmpIOp.getLoc(); + Value replacement = nullptr; + + // Operands of the compare are equal. + if (cmpIOp.getOperands()[0] == cmpIOp.getOperands()[1]) { + // Cmp Ops which are false when operands are the same. + if (cmpIOp.getPredicate() == CmpIPredicate::sgt || + cmpIOp.getPredicate() == CmpIPredicate::slt || + cmpIOp.getPredicate() == CmpIPredicate::ugt || + cmpIOp.getPredicate() == CmpIPredicate::ult || + cmpIOp.getPredicate() == CmpIPredicate::ne) + replacement = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + + // Cmp Ops which are true when operands are the same. + if (cmpIOp.getPredicate() == CmpIPredicate::eq || + cmpIOp.getPredicate() == CmpIPredicate::sle || + cmpIOp.getPredicate() == CmpIPredicate::ule || + cmpIOp.getPredicate() == CmpIPredicate::sge || + cmpIOp.getPredicate() == CmpIPredicate::uge) + replacement = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + } + + if (!replacement) + return failure(); + + rewriter.replaceOp(cmpIOp, replacement); + return success(); + } +}; +} // end namespace + +void CmpIOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // CmpFOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s +func @test_simplify_1(%arg0: memref) -> memref { + %c0 = constant 0 : index + %0 = dim %arg0, %c0 : memref + %c1 = constant 1 : index + %1 = dim %arg0, %c1 : memref + %3 = cmpi "sgt", %0, %0 : index + %4 = select %3, %0, %1 : index + %5 = alloc(%4) : memref + return %5 : memref + + // CHECK-LABEL: test_simplify_1 + // CHECK: [[DIM2:%.+]] = dim %arg0, %c1 : memref + // CHECK-NOT: cmpi "sgt", %0, %0 : index + // CHECK: [[RES:%.+]] = alloc([[DIM2]]) : memref +} + +func @test_simplify_2(%arg0: memref) -> memref { + %c0 = constant 0 : index + %0 = dim %arg0, %c0 : memref + %c1 = constant 1 : index + %1 = dim %arg0, %c1 : memref + %3 = cmpi "eq", %0, %0 : index + %4 = select %3, %0, %1 : index + %5 = alloc(%4) : memref + return %5 : memref + + // CHECK-LABEL: test_simplify_2 + // CHECK: [[DIM1:%.+]] = dim %arg0, %c0 : memref + // CHECK-NOT: cmpi "sgt", %0, %0 : index + // CHECK: [[RES:%.+]] = alloc([[DIM1]]) : memref +}