diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1454,6 +1454,7 @@ Value getFalseValue() { return false_value(); } }]; + let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1528,6 +1528,37 @@ // SelectOp //===----------------------------------------------------------------------===// +// Transforms a select to a not, where relevant. +// +// select %arg, %false, %true +// +// becomes +// +// xor %arg, %true +struct SelectToNot : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SelectOp op, + PatternRewriter &rewriter) const override { + if (!matchPattern(op.getTrueValue(), m_Zero())) + return failure(); + + if (!matchPattern(op.getFalseValue(), m_One())) + return failure(); + + if (!op.getType().isInteger(1)) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.condition(), op.getFalseValue()); + return success(); + } +}; + +void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + OpFoldResult SelectOp::fold(ArrayRef operands) { auto trueVal = getTrueValue(); auto falseVal = getFalseValue(); diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -319,3 +319,15 @@ ^exit: return } + +// ----- + +// CHECK-LABEL: @selToNot +// CHECK: %[[trueval:.+]] = constant true +// CHECK: %{{.+}} = xor %arg0, %[[trueval]] : i1 +func @selToNot(%arg0: i1) -> i1 { + %true = constant true + %false = constant false + %res = select %arg0, %false, %true : i1 + return %res : i1 +}