diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -566,6 +566,54 @@ // BranchOp //===----------------------------------------------------------------------===// +/// Given a successor, try to collapse it to a new destination if it only +/// contains a passthrough unconditional branch. If the successor is +/// collapsable, `successor` and `successorOperands` are updated to reference +/// the new destination and values. `argStorage` is an optional storage to use +/// if operands to the collapsed successor need to be remapped. +static LogicalResult collapseBranch(Block *&successor, + ValueRange &successorOperands, + SmallVectorImpl &argStorage) { + // Check that the successor only contains a unconditional branch. + if (std::next(successor->begin()) != successor->end()) + return failure(); + // Check that the terminator is an unconditional branch. + BranchOp successorBranch = dyn_cast(successor->getTerminator()); + if (!successorBranch) + return failure(); + // Check that the arguments are only used within the terminator. + for (BlockArgument arg : successor->getArguments()) { + for (Operation *user : arg.getUsers()) + if (user != successorBranch) + return failure(); + } + // Don't try to collapse branches to infinite loops. + Block *successorDest = successorBranch.getDest(); + if (successorDest == successor) + return failure(); + + // Update the operands to the successor. If the branch parent has no + // arguments, we can use the branch operands directly. + OperandRange operands = successorBranch.getOperands(); + if (successor->args_empty()) { + successor = successorDest; + successorOperands = operands; + return success(); + } + + // Otherwise, we need to remap any argument operands. + for (Value operand : operands) { + BlockArgument argOperand = operand.dyn_cast(); + if (argOperand && argOperand.getOwner() == successor) + argStorage.push_back(successorOperands[argOperand.getArgNumber()]); + else + argStorage.push_back(operand); + } + successor = successorDest; + successorOperands = argStorage; + return success(); +} + namespace { /// Simplify a branch to a block that has a single predecessor. This effectively /// merges the two blocks. @@ -586,6 +634,33 @@ return success(); } }; + +/// br ^bb1 +/// ^bb1 +/// br ^bbN(...) +/// +/// -> br ^bbN(...) +/// +struct SimplifyPassThroughBr : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BranchOp op, + PatternRewriter &rewriter) const override { + Block *dest = op.getDest(); + ValueRange destOperands = op.getOperands(); + SmallVector destOperandStorage; + + // Try to collapse the successor if it points somewhere other than this + // block. + if (dest == op.getOperation()->getBlock() || + failed(collapseBranch(dest, destOperands, destOperandStorage))) + return failure(); + + // Create a new branch with the collapsed successor. + rewriter.replaceOpWithNewOp(op, dest, destOperands); + return success(); + } +}; } // end anonymous namespace. Block *BranchOp::getDest() { return getSuccessor(); } @@ -598,7 +673,8 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert( + context); } Optional BranchOp::getSuccessorOperands(unsigned index) { @@ -889,53 +965,6 @@ falseDest, falseDestOperands); return success(); } - - /// Given a successor, try to collapse it to a new destination if it only - /// contains a passthrough unconditional branch. If the successor is - /// collapsable, `successor` and `successorOperands` are updated to reference - /// the new destination and values. `argStorage` is an optional storage to use - /// if operands to the collapsed successor need to be remapped. - LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, - SmallVectorImpl &argStorage) const { - // Check that the successor only contains a unconditional branch. - if (std::next(successor->begin()) != successor->end()) - return failure(); - // Check that the terminator is an unconditional branch. - BranchOp successorBranch = dyn_cast(successor->getTerminator()); - if (!successorBranch) - return failure(); - // Check that the arguments are only used within the terminator. - for (BlockArgument arg : successor->getArguments()) { - for (Operation *user : arg.getUsers()) - if (user != successorBranch) - return failure(); - } - // Don't try to collapse branches to infinite loops. - Block *successorDest = successorBranch.getDest(); - if (successorDest == successor) - return failure(); - - // Update the operands to the successor. If the branch parent has no - // arguments, we can use the branch operands directly. - OperandRange operands = successorBranch.getOperands(); - if (successor->args_empty()) { - successor = successorBranch.getDest(); - successorOperands = operands; - return success(); - } - - // Otherwise, we need to remap any argument operands. - for (Value operand : operands) { - BlockArgument argOperand = operand.dyn_cast(); - if (argOperand && argOperand.getOwner() == successor) - argStorage.push_back(successorOperands[argOperand.getArgNumber()]); - else - argStorage.push_back(operand); - } - successor = successorBranch.getDest(); - successorOperands = argStorage; - return success(); - } }; /// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir --- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir +++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir @@ -12,6 +12,26 @@ return %x : i32 } +/// Test that pass-through successors of BranchOp get folded. + +// CHECK-LABEL: func @br_passthrough( +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +func @br_passthrough(%arg0 : i32, %arg1 : i32) -> (i32, i32) { + "foo.switch"() [^bb1, ^bb2, ^bb3] : () -> () + +^bb1: + // CHECK: ^bb1: + // CHECK-NEXT: br ^bb3(%[[ARG0]], %[[ARG1]] : i32, i32) + + br ^bb2(%arg0 : i32) + +^bb2(%arg2 : i32): + br ^bb3(%arg2, %arg1 : i32, i32) + +^bb3(%arg4 : i32, %arg5 : i32): + return %arg4, %arg5 : i32, i32 +} + /// Test the folding of CondBranchOp with a constant condition. // CHECK-LABEL: func @cond_br_folding( @@ -103,9 +123,9 @@ /// Test that pass-through successors of CondBranchOp get folded. -// CHECK-LABEL: func @cond_br_pass_through( +// CHECK-LABEL: func @cond_br_passthrough( // CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1 -func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) { +func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) { // CHECK: %[[RES:.*]] = select %[[COND]], %[[ARG0]], %[[ARG2]] // CHECK: %[[RES2:.*]] = select %[[COND]], %[[ARG1]], %[[ARG2]] // CHECK: return %[[RES]], %[[RES2]]