diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -918,11 +918,47 @@ return success(); } }; + +struct ConvertTrivialIfToSelect : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp op, + PatternRewriter &rewriter) const override { + if (op->getNumResults() == 0) + return failure(); + + if (!llvm::hasSingleElement(op.thenRegion().front()) || + !llvm::hasSingleElement(op.elseRegion().front())) + return failure(); + + auto cond = op.condition(); + auto thenYieldArgs = + mlir::cast(op.thenRegion().front().getTerminator()) + .getOperands(); + auto elseYieldArgs = + mlir::cast(op.elseRegion().front().getTerminator()) + .getOperands(); + SmallVector results(op->getNumResults()); + for (auto it : llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) { + auto trueVal = std::get<0>(it.value()); + auto falseVal = std::get<1>(it.value()); + if (trueVal == falseVal) + results[it.index()] = trueVal; + else + results[it.index()] = + rewriter.create(op.getLoc(), cond, trueVal, falseVal); + } + + rewriter.replaceOp(op, results); + return success(); + } +}; } // namespace void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -35,10 +35,12 @@ // ----- +func private @side_effect() func @one_unused(%cond: i1) -> (index) { %c0 = constant 0 : index %c1 = constant 1 : index %0, %1 = scf.if %cond -> (index, index) { + call @side_effect() : () -> () scf.yield %c0, %c1 : index, index } else { scf.yield %c0, %c1 : index, index @@ -49,6 +51,7 @@ // CHECK-LABEL: func @one_unused // CHECK: [[C0:%.*]] = constant 1 : index // CHECK: [[V0:%.*]] = scf.if %{{.*}} -> (index) { +// CHECK: call @side_effect() : () -> () // CHECK: scf.yield [[C0]] : index // CHECK: } else // CHECK: scf.yield [[C0]] : index @@ -57,11 +60,13 @@ // ----- +func private @side_effect() func @nested_unused(%cond1: i1, %cond2: i1) -> (index) { %c0 = constant 0 : index %c1 = constant 1 : index %0, %1 = scf.if %cond1 -> (index, index) { %2, %3 = scf.if %cond2 -> (index, index) { + call @side_effect() : () -> () scf.yield %c0, %c1 : index, index } else { scf.yield %c0, %c1 : index, index @@ -77,6 +82,7 @@ // CHECK: [[C0:%.*]] = constant 1 : index // CHECK: [[V0:%.*]] = scf.if {{.*}} -> (index) { // CHECK: [[V1:%.*]] = scf.if {{.*}} -> (index) { +// CHECK: call @side_effect() : () -> () // CHECK: scf.yield [[C0]] : index // CHECK: } else // CHECK: scf.yield [[C0]] : index @@ -113,6 +119,25 @@ // ----- +func @to_select(%cond: i1) -> (index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0, %1 = scf.if %cond -> (index, index) { + scf.yield %c0, %c1 : index, index + } else { + scf.yield %c1, %c1 : index, index + } + return %0, %1 : index, index +} + +// CHECK-LABEL: func @to_select +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[V0:%.*]] = select {{.*}}, [[C0]], [[C1]] +// CHECK: return [[V0]], [[C1]] : index + +// ----- + func private @make_i32() -> i32 func @for_yields_2(%lb : index, %ub : index, %step : index) -> i32 {