diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h @@ -14,6 +14,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td @@ -542,6 +542,8 @@ let results = (outs I1:$result); let assemblyFormat = "`` $pred `(` $lhs `,` $rhs `)` attr-dict"; let hasFolder = 1; + + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -549,6 +549,38 @@ return {}; } +LogicalResult CmpOp::canonicalize(CmpOp op, ::mlir::PatternRewriter &rewriter) { + /// Canonicalize + /// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`. + /// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`. + IntegerAttr cmpRhs; + IntegerAttr cmpLhs; + if (mlir::matchPattern(op.getRhs(), mlir::m_Constant(&cmpRhs)) && + cmpRhs.getValue().isZero()) { + auto subOp = op.getLhs().getDefiningOp(); + if (!subOp) + return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not subtraction"); + + auto newCmp = rewriter.create( + op.getLoc(), op.getPred(), subOp.getLhs(), subOp.getRhs()); + rewriter.replaceOp(op, newCmp); + return success(); + } else if (mlir::matchPattern(op.getLhs(), mlir::m_Constant(&cmpLhs)) && + cmpLhs.getValue().isZero()) { + + auto subOp = op.getRhs().getDefiningOp(); + if (!subOp) + return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not subtraction"); + + auto newCmp = rewriter.create(op.getLoc(), op.getPred(), + subOp.getRhs(), subOp.getLhs()); + + rewriter.replaceOp(op, newCmp); + return success(); + } + return rewriter.notifyMatchFailure(op.getLoc(), "cmp is not with 0."); +} + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir --- a/mlir/test/Dialect/Index/index-canonicalize.mlir +++ b/mlir/test/Dialect/Index/index-canonicalize.mlir @@ -473,7 +473,7 @@ } // CHECK-LABEL: @cmp -func.func @cmp() -> (i1, i1, i1, i1) { +func.func @cmp(%arg0: index) -> (i1, i1, i1, i1, i1, i1) { %a = index.constant 0 %b = index.constant -1 %c = index.constant -2 @@ -484,10 +484,19 @@ %2 = index.cmp ne(%d, %a) %3 = index.cmp sgt(%b, %a) + %4 = index.sub %a, %arg0 + %5 = index.cmp sgt(%4, %a) + + %6 = index.sub %a, %arg0 + %7 = index.cmp sgt(%a, %6) + // CHECK-DAG: %[[TRUE:.*]] = index.bool.constant true // CHECK-DAG: %[[FALSE:.*]] = index.bool.constant false + // CHECK-DAG: [[IDX0:%.*]] = index.constant 0 + // CHECK-DAG: [[V4:%.*]] = index.cmp sgt([[IDX0]], %arg0) + // CHECK-DAG: [[V5:%.*]] = index.cmp sgt(%arg0, [[IDX0]]) // CHECK: return %[[FALSE]], %[[TRUE]], %[[TRUE]], %[[FALSE]] - return %0, %1, %2, %3 : i1, i1, i1, i1 + return %0, %1, %2, %3, %5, %7 : i1, i1, i1, i1, i1, i1 } // CHECK-LABEL: @cmp_nofold