diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2288,6 +2288,7 @@ ``` }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2747,6 +2747,75 @@ [](APInt a, APInt b) { return a ^ b; }); } +namespace { +/// not(cmp eq A, B) => cmp ne A, B +struct NotICmp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(XOrOp op, + PatternRewriter &rewriter) const override { + auto c1 = op.rhs().getDefiningOp(); + if (!c1) + return failure(); + if (c1.getValue().cast().getValue() != 1) + return failure(); + auto prev = op.lhs().getDefiningOp(); + if (!prev) + return failure(); + switch (prev.predicate()) { + case mlir::CmpIPredicate::eq: + rewriter.replaceOpWithNewOp(op, mlir::CmpIPredicate::ne, + prev.lhs(), prev.rhs()); + return success(); + case mlir::CmpIPredicate::ne: + rewriter.replaceOpWithNewOp(op, mlir::CmpIPredicate::eq, + prev.lhs(), prev.rhs()); + return success(); + + case mlir::CmpIPredicate::slt: + rewriter.replaceOpWithNewOp(op, mlir::CmpIPredicate::sge, + prev.lhs(), prev.rhs()); + return success(); + case mlir::CmpIPredicate::sle: + rewriter.replaceOpWithNewOp(op, mlir::CmpIPredicate::sgt, + prev.lhs(), prev.rhs()); + return success(); + case mlir::CmpIPredicate::sgt: + rewriter.replaceOpWithNewOp(op, mlir::CmpIPredicate::sle, + prev.lhs(), prev.rhs()); + return success(); + case mlir::CmpIPredicate::sge: + rewriter.replaceOpWithNewOp(op, mlir::CmpIPredicate::slt, + prev.lhs(), prev.rhs()); + return success(); + + case mlir::CmpIPredicate::ult: + rewriter.replaceOpWithNewOp(op, mlir::CmpIPredicate::uge, + prev.lhs(), prev.rhs()); + return success(); + case mlir::CmpIPredicate::ule: + rewriter.replaceOpWithNewOp(op, mlir::CmpIPredicate::ugt, + prev.lhs(), prev.rhs()); + return success(); + case mlir::CmpIPredicate::ugt: + rewriter.replaceOpWithNewOp(op, mlir::CmpIPredicate::ule, + prev.lhs(), prev.rhs()); + return success(); + case mlir::CmpIPredicate::uge: + rewriter.replaceOpWithNewOp(op, mlir::CmpIPredicate::ult, + prev.lhs(), prev.rhs()); + return success(); + } + return failure(); + } +}; +} // namespace + +void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -306,3 +306,15 @@ %1 = select %0, %arg0, %arg1 : i64 return %1 : i64 } + +// ----- + +// CHECK-LABEL: @notCmp +// CHECK: %[[cres:.+]] = cmpi ne, %arg0, %arg1 : i8 +// CHECK: return %[[cres]] +func @notCmp(%arg0: i8, %arg1: i8) -> i1 { + %true = constant true + %cmp = cmpi "eq", %arg0, %arg1 : i8 + %ncmp = xor %cmp, %true : i1 + return %ncmp : i1 +} \ No newline at end of file