diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -968,9 +968,43 @@ } }; +// select %arg, %c1, %c0 => extui %arg +struct SelectToExtUI : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SelectOp op, + PatternRewriter &rewriter) const override { + // Cannot extui i1 to i1, or i1 to f32 + if (!op.getType().isa() || op.getType().isInteger(1)) + return failure(); + + // select %x, c1, %c0 => extui %arg + if (matchPattern(op.getTrueValue(), m_One())) + if (matchPattern(op.getFalseValue(), m_Zero())) { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getCondition()); + return success(); + } + + // select %x, c0, %c1 => extui (xor %arg, true) + if (matchPattern(op.getTrueValue(), m_Zero())) + if (matchPattern(op.getFalseValue(), m_One())) { + rewriter.replaceOpWithNewOp( + op, op.getType(), + rewriter.create( + op.getLoc(), op.getCondition(), + rewriter.create( + op.getLoc(), 1, op.getCondition().getType()))); + return success(); + } + + return failure(); + } +}; + void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } OpFoldResult SelectOp::fold(ArrayRef operands) { @@ -989,6 +1023,12 @@ if (matchPattern(condition, m_Zero())) return falseVal; + // select %x, true, false => %x + if (getType().isInteger(1)) + if (matchPattern(getTrueValue(), m_One())) + if (matchPattern(getFalseValue(), m_Zero())) + return condition; + if (auto cmp = dyn_cast_or_null(condition.getDefiningOp())) { auto pred = cmp.getPredicate(); if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -29,6 +29,41 @@ // ----- +// CHECK-LABEL: @select_extui +// CHECK: %[[res:.+]] = arith.extui %arg0 : i1 to i64 +// CHECK: return %[[res]] +func @select_extui(%arg0: i1) -> i64 { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %res = select %arg0, %c1_i64, %c0_i64 : i64 + return %res : i64 +} + +// CHECK-LABEL: @select_extui2 +// CHECK-DAG: %true = arith.constant true +// CHECK-DAG: %[[xor:.+]] = arith.xori %arg0, %true : i1 +// CHECK-DAG: %[[res:.+]] = arith.extui %[[xor]] : i1 to i64 +// CHECK: return %[[res]] +func @select_extui2(%arg0: i1) -> i64 { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %res = select %arg0, %c0_i64, %c1_i64 : i64 + return %res : i64 +} + +// ----- + +// CHECK-LABEL: @select_extui_i1 +// CHECK-NEXT: return %arg0 +func @select_extui_i1(%arg0: i1) -> i1 { + %c0_i1 = arith.constant false + %c1_i1 = arith.constant true + %res = select %arg0, %c1_i1, %c0_i1 : i1 + return %res : i1 +} + +// ----- + // CHECK-LABEL: @branchCondProp // CHECK: %[[trueval:.+]] = arith.constant true // CHECK: %[[falseval:.+]] = arith.constant false