diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1064,12 +1064,55 @@ return success(); } }; + +/// cond_br %cond, ^bb1(...), ^bb2(...) +/// ... +/// ^bb1: +/// cond_br %cond, ^bb3(...), ^bb4(...) +/// +/// -> +/// +/// cond_br %cond, ^bb1(...), ^bb2(...) +/// ... +/// ^bb1: +/// br ^bb3(...) +/// +struct SimplifyCondBranchFromCondBranchOnSameCondition + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + // Check that we have a single distinct predecessor. + Block *currentBlock = condbr.getOperation()->getBlock(); + Block *predecessor = currentBlock->getSinglePredecessor(); + if (!predecessor) + return failure(); + + // Check that the predecessor terminates with a conditional branch to this + // block and that it branches on the same condition. + auto predBranch = dyn_cast(predecessor->getTerminator()); + if (!predBranch || condbr.getCondition() != predBranch.getCondition()) + return failure(); + + // Fold this branch to a unconditional branch. + if (currentBlock == predBranch.trueDest()) { + rewriter.replaceOpWithNewOp(condbr, condbr.trueDest(), + condbr.trueDestOperands()); + } else { + rewriter.replaceOpWithNewOp(condbr, condbr.falseDest(), + condbr.falseDestOperands()); + } + return success(); + } +}; } // end anonymous namespace void CondBranchOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); + SimplifyCondBranchIdenticalSuccessors, + SimplifyCondBranchFromCondBranchOnSameCondition>(context); } Optional diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir --- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir +++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir @@ -139,6 +139,27 @@ return } +/// Test folding conditional branches that are successors of conditional +/// branches with the same condition. + +// CHECK-LABEL: func @cond_br_from_cond_br_with_same_condition +func @cond_br_from_cond_br_with_same_condition(%cond : i1) { + // CHECK: cond_br %{{.*}}, ^bb1, ^bb2 + // CHECK: ^bb1: + // CHECK: return + + cond_br %cond, ^bb1, ^bb2 + +^bb1: + cond_br %cond, ^bb3, ^bb2 + +^bb2: + "foo.terminator"() : () -> () + +^bb3: + return +} + // ----- // Erase assertion if condition is known to be true at compile time. diff --git a/mlir/test/Transforms/canonicalize-block-merge.mlir b/mlir/test/Transforms/canonicalize-block-merge.mlir --- a/mlir/test/Transforms/canonicalize-block-merge.mlir +++ b/mlir/test/Transforms/canonicalize-block-merge.mlir @@ -178,23 +178,23 @@ // block is used in another. // CHECK-LABEL: func @mismatch_loop( -// CHECK-SAME: %[[ARG:.*]]: i1 -func @mismatch_loop(%cond : i1) { - // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG]] : i1), ^bb2 +// CHECK-SAME: %[[ARG:.*]]: i1, %[[ARG2:.*]]: i1 +func @mismatch_loop(%cond : i1, %cond2 : i1) { + // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG2]] : i1), ^bb2 cond_br %cond, ^bb2, ^bb3 ^bb1: - // CHECK: ^bb1(%[[ARG2:.*]]: i1): + // CHECK: ^bb1(%[[ARG3:.*]]: i1): // CHECK-NEXT: %[[LOOP_CARRY:.*]] = "foo.op" - // CHECK-NEXT: cond_br %[[ARG2]], ^bb1(%[[LOOP_CARRY]] : i1), ^bb2 + // CHECK-NEXT: cond_br %[[ARG3]], ^bb1(%[[LOOP_CARRY]] : i1), ^bb2 %ignored = "foo.op"() : () -> (i1) - cond_br %cond2, ^bb1, ^bb3 + cond_br %cond3, ^bb1, ^bb3 ^bb2: - %cond2 = "foo.op"() : () -> (i1) - cond_br %cond, ^bb1, ^bb3 + %cond3 = "foo.op"() : () -> (i1) + cond_br %cond2, ^bb1, ^bb3 ^bb3: // CHECK: ^bb2: