diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -813,29 +813,31 @@ // SelectOp //===----------------------------------------------------------------------===// -// Transforms a select to a not, where relevant. +// Transforms a select of a boolean to arithmetic operations // -// select %arg, %false, %true +// select %arg, %x, %y : i1 // // becomes // -// xor %arg, %true -struct SelectToNot : public OpRewritePattern { +// and(%arg, %x) or and(!%arg, %y) +struct SelectI1Simplify : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SelectOp op, PatternRewriter &rewriter) const override { - if (!matchPattern(op.getTrueValue(), m_Zero())) - return failure(); - - if (!matchPattern(op.getFalseValue(), m_One())) - return failure(); - if (!op.getType().isInteger(1)) return failure(); - rewriter.replaceOpWithNewOp(op, op.getCondition(), - op.getFalseValue()); + Value falseConstant = + rewriter.create(op.getLoc(), true, 1); + Value notCondition = rewriter.create( + op.getLoc(), op.getCondition(), falseConstant); + + Value trueVal = rewriter.create( + op.getLoc(), op.getCondition(), op.getTrueValue()); + Value falseVal = rewriter.create(op.getLoc(), notCondition, + op.getFalseValue()); + rewriter.replaceOpWithNewOp(op, trueVal, falseVal); return success(); } }; @@ -876,7 +878,7 @@ void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } OpFoldResult SelectOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -88,10 +88,23 @@ // CHECK-LABEL: @selToNot // CHECK: %[[trueval:.+]] = arith.constant true -// CHECK: %{{.+}} = arith.xori %arg0, %[[trueval]] : i1 +// CHECK: %[[res:.+]] = arith.xori %arg0, %[[trueval]] : i1 +// CHECK: return %[[res]] func @selToNot(%arg0: i1) -> i1 { %true = arith.constant true %false = arith.constant false %res = select %arg0, %false, %true : i1 return %res : i1 } + +// CHECK-LABEL: @selToArith +// CHECK-NEXT: %[[trueval:.+]] = arith.constant true +// CHECK-NEXT: %[[notcmp:.+]] = arith.xori %arg0, %[[trueval]] : i1 +// CHECK-NEXT: %[[condtrue:.+]] = arith.andi %arg0, %arg1 : i1 +// CHECK-NEXT: %[[condfalse:.+]] = arith.andi %[[notcmp]], %arg2 : i1 +// CHECK-NEXT: %[[res:.+]] = arith.ori %[[condtrue]], %[[condfalse]] : i1 +// CHECK: return %[[res]] +func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 { + %res = select %arg0, %arg1, %arg2 : i1 + return %res : i1 +}