diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h @@ -23,10 +23,15 @@ // Forward Declarations //===----------------------------------------------------------------------===// -namespace mlir::index { +namespace mlir { +class PatternRewriter; + +namespace index { enum class IndexCmpPredicate : uint32_t; class IndexCmpPredicateAttr; -} // namespace mlir::index +} // namespace index + +} // namespace mlir //===----------------------------------------------------------------------===// // ODS-Generated Declarations diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td @@ -542,6 +542,7 @@ let results = (outs I1:$result); let assemblyFormat = "`` $pred `(` $lhs `,` $rhs `)` attr-dict"; let hasFolder = 1; + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/TypeSwitch.h" @@ -549,6 +550,37 @@ return {}; } +/// Canonicalize +/// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`. +/// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`. +LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) { + IntegerAttr cmpRhs; + IntegerAttr cmpLhs; + + bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) && + cmpRhs.getValue().isZero(); + bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) && + cmpLhs.getValue().isZero(); + if (!rhsIsZero && !lhsIsZero) + return rewriter.notifyMatchFailure(op.getLoc(), + "cmp is not comparing something with 0"); + SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp() + : op.getRhs().getDefiningOp(); + if (!subOp) + return rewriter.notifyMatchFailure( + op.getLoc(), "non-zero operand is not a result of subtraction"); + + index::CmpOp newCmp; + if (rhsIsZero) + newCmp = rewriter.create(op.getLoc(), op.getPred(), + subOp.getLhs(), subOp.getRhs()); + else + newCmp = rewriter.create(op.getLoc(), op.getPred(), + subOp.getRhs(), subOp.getLhs()); + rewriter.replaceOp(op, newCmp); + return success(); +} + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir --- a/mlir/test/Dialect/Index/index-canonicalize.mlir +++ b/mlir/test/Dialect/Index/index-canonicalize.mlir @@ -473,7 +473,7 @@ } // CHECK-LABEL: @cmp -func.func @cmp() -> (i1, i1, i1, i1) { +func.func @cmp(%arg0: index) -> (i1, i1, i1, i1, i1, i1) { %a = index.constant 0 %b = index.constant -1 %c = index.constant -2 @@ -484,10 +484,19 @@ %2 = index.cmp ne(%d, %a) %3 = index.cmp sgt(%b, %a) + %4 = index.sub %a, %arg0 + %5 = index.cmp sgt(%4, %a) + + %6 = index.sub %a, %arg0 + %7 = index.cmp sgt(%a, %6) + // CHECK-DAG: %[[TRUE:.*]] = index.bool.constant true // CHECK-DAG: %[[FALSE:.*]] = index.bool.constant false + // CHECK-DAG: [[IDX0:%.*]] = index.constant 0 + // CHECK-DAG: [[V4:%.*]] = index.cmp sgt([[IDX0]], %arg0) + // CHECK-DAG: [[V5:%.*]] = index.cmp sgt(%arg0, [[IDX0]]) // CHECK: return %[[FALSE]], %[[TRUE]], %[[TRUE]], %[[FALSE]] - return %0, %1, %2, %3 : i1, i1, i1, i1 + return %0, %1, %2, %3, %5, %7 : i1, i1, i1, i1, i1, i1 } // CHECK-LABEL: @cmp_nofold