diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2292,6 +2292,7 @@ ``` }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3011,6 +3011,80 @@ [](APInt a, APInt b) { return a ^ b; }); } +namespace { +/// Replace a not of a comparison operation, for example: not(cmp eq A, B) => +/// cmp ne A, B. Note that a logical not is implemented as xor 1, val +struct NotICmp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(XOrOp op, + PatternRewriter &rewriter) const override { + + APInt constValue; + if (!matchPattern(op.getOperand(1), m_ConstantInt(&constValue))) + return failure(); + + if (constValue != 1) + return failure(); + + auto prev = op.getOperand(0).getDefiningOp(); + if (!prev) + return failure(); + + switch (prev.predicate()) { + case CmpIPredicate::eq: + rewriter.replaceOpWithNewOp(op, CmpIPredicate::ne, prev.lhs(), + prev.rhs()); + return success(); + case CmpIPredicate::ne: + rewriter.replaceOpWithNewOp(op, CmpIPredicate::eq, prev.lhs(), + prev.rhs()); + return success(); + + case CmpIPredicate::slt: + rewriter.replaceOpWithNewOp(op, CmpIPredicate::sge, prev.lhs(), + prev.rhs()); + return success(); + case CmpIPredicate::sle: + rewriter.replaceOpWithNewOp(op, CmpIPredicate::sgt, prev.lhs(), + prev.rhs()); + return success(); + case CmpIPredicate::sgt: + rewriter.replaceOpWithNewOp(op, CmpIPredicate::sle, prev.lhs(), + prev.rhs()); + return success(); + case CmpIPredicate::sge: + rewriter.replaceOpWithNewOp(op, CmpIPredicate::slt, prev.lhs(), + prev.rhs()); + return success(); + + case CmpIPredicate::ult: + rewriter.replaceOpWithNewOp(op, CmpIPredicate::uge, prev.lhs(), + prev.rhs()); + return success(); + case CmpIPredicate::ule: + rewriter.replaceOpWithNewOp(op, CmpIPredicate::ugt, prev.lhs(), + prev.rhs()); + return success(); + case CmpIPredicate::ugt: + rewriter.replaceOpWithNewOp(op, CmpIPredicate::ule, prev.lhs(), + prev.rhs()); + return success(); + case CmpIPredicate::uge: + rewriter.replaceOpWithNewOp(op, CmpIPredicate::ult, prev.lhs(), + prev.rhs()); + return success(); + } + return failure(); + } +}; +} // namespace + +void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -538,3 +538,113 @@ %add2 = subi %add1, %c42 : index return %add2 : index } + +// CHECK-LABEL: @notCmpEQ +// CHECK: %[[cres:.+]] = cmpi ne, %arg0, %arg1 : i8 +// CHECK: return %[[cres]] +func @notCmpEQ(%arg0: i8, %arg1: i8) -> i1 { + %true = constant true + %cmp = cmpi "eq", %arg0, %arg1 : i8 + %ncmp = xor %cmp, %true : i1 + return %ncmp : i1 +} + +// CHECK-LABEL: @notCmpEQ2 +// CHECK: %[[cres:.+]] = cmpi ne, %arg0, %arg1 : i8 +// CHECK: return %[[cres]] +func @notCmpEQ2(%arg0: i8, %arg1: i8) -> i1 { + %true = constant true + %cmp = cmpi "eq", %arg0, %arg1 : i8 + %ncmp = xor %true, %cmp : i1 + return %ncmp : i1 +} + +// CHECK-LABEL: @notCmpNE +// CHECK: %[[cres:.+]] = cmpi eq, %arg0, %arg1 : i8 +// CHECK: return %[[cres]] +func @notCmpNE(%arg0: i8, %arg1: i8) -> i1 { + %true = constant true + %cmp = cmpi "ne", %arg0, %arg1 : i8 + %ncmp = xor %cmp, %true : i1 + return %ncmp : i1 +} + +// CHECK-LABEL: @notCmpSLT +// CHECK: %[[cres:.+]] = cmpi sge, %arg0, %arg1 : i8 +// CHECK: return %[[cres]] +func @notCmpSLT(%arg0: i8, %arg1: i8) -> i1 { + %true = constant true + %cmp = cmpi "slt", %arg0, %arg1 : i8 + %ncmp = xor %cmp, %true : i1 + return %ncmp : i1 +} + +// CHECK-LABEL: @notCmpSLE +// CHECK: %[[cres:.+]] = cmpi sgt, %arg0, %arg1 : i8 +// CHECK: return %[[cres]] +func @notCmpSLE(%arg0: i8, %arg1: i8) -> i1 { + %true = constant true + %cmp = cmpi "sle", %arg0, %arg1 : i8 + %ncmp = xor %cmp, %true : i1 + return %ncmp : i1 +} + +// CHECK-LABEL: @notCmpSGT +// CHECK: %[[cres:.+]] = cmpi sle, %arg0, %arg1 : i8 +// CHECK: return %[[cres]] +func @notCmpSGT(%arg0: i8, %arg1: i8) -> i1 { + %true = constant true + %cmp = cmpi "sgt", %arg0, %arg1 : i8 + %ncmp = xor %cmp, %true : i1 + return %ncmp : i1 +} + +// CHECK-LABEL: @notCmpSGE +// CHECK: %[[cres:.+]] = cmpi slt, %arg0, %arg1 : i8 +// CHECK: return %[[cres]] +func @notCmpSGE(%arg0: i8, %arg1: i8) -> i1 { + %true = constant true + %cmp = cmpi "sge", %arg0, %arg1 : i8 + %ncmp = xor %cmp, %true : i1 + return %ncmp : i1 +} + +// CHECK-LABEL: @notCmpULT +// CHECK: %[[cres:.+]] = cmpi uge, %arg0, %arg1 : i8 +// CHECK: return %[[cres]] +func @notCmpULT(%arg0: i8, %arg1: i8) -> i1 { + %true = constant true + %cmp = cmpi "ult", %arg0, %arg1 : i8 + %ncmp = xor %cmp, %true : i1 + return %ncmp : i1 +} + +// CHECK-LABEL: @notCmpULE +// CHECK: %[[cres:.+]] = cmpi ugt, %arg0, %arg1 : i8 +// CHECK: return %[[cres]] +func @notCmpULE(%arg0: i8, %arg1: i8) -> i1 { + %true = constant true + %cmp = cmpi "ule", %arg0, %arg1 : i8 + %ncmp = xor %cmp, %true : i1 + return %ncmp : i1 +} + +// CHECK-LABEL: @notCmpUGT +// CHECK: %[[cres:.+]] = cmpi ule, %arg0, %arg1 : i8 +// CHECK: return %[[cres]] +func @notCmpUGT(%arg0: i8, %arg1: i8) -> i1 { + %true = constant true + %cmp = cmpi "ugt", %arg0, %arg1 : i8 + %ncmp = xor %cmp, %true : i1 + return %ncmp : i1 +} + +// CHECK-LABEL: @notCmpUGE +// CHECK: %[[cres:.+]] = cmpi ult, %arg0, %arg1 : i8 +// CHECK: return %[[cres]] +func @notCmpUGE(%arg0: i8, %arg1: i8) -> i1 { + %true = constant true + %cmp = cmpi "uge", %arg0, %arg1 : i8 + %ncmp = xor %cmp, %true : i1 + return %ncmp : i1 +}