diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -862,11 +862,93 @@ return failure(); } }; -} // end anonymous namespace. + +/// cond_br %cond, ^bb1, ^bb2 +/// ^bb1 +/// br ^bbN(...) +/// ^bb2 +/// br ^bbK(...) +/// +/// cond_br %cond, ^bbN(...), ^bbK(...) +/// +struct SimplifyPassThroughCondBranch : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { + Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest(); + ValueRange trueDestOperands = condbr.getTrueOperands(); + ValueRange falseDestOperands = condbr.getFalseOperands(); + SmallVector trueDestOperandStorage, falseDestOperandStorage; + + // Try to collapse one of the current successors. + LogicalResult collapsedTrue = + collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage); + LogicalResult collapsedFalse = + collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage); + if (failed(collapsedTrue) && failed(collapsedFalse)) + return failure(); + + // Create a new branch with the collapsed successors. + rewriter.replaceOpWithNewOp(condbr, condbr.getCondition(), + trueDest, trueDestOperands, + falseDest, falseDestOperands); + return success(); + } + + /// Given a successor, try to collapse it to a new destination if it only + /// contains a passthrough unconditional branch. If the successor is + /// collapsable, `successor` and `successorOperands` are updated to reference + /// the new destination and values. `argStorage` is an optional storage to use + /// if operands to the collapsed successor need to be remapped. + LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, + SmallVectorImpl &argStorage) const { + // Check that the successor only contains a unconditional branch. + if (std::next(successor->begin()) != successor->end()) + return failure(); + // Check that the terminator is an unconditional branch. + BranchOp successorBranch = dyn_cast(successor->getTerminator()); + if (!successorBranch) + return failure(); + // Check that the arguments are only used within the terminator. + for (BlockArgument arg : successor->getArguments()) { + for (Operation *user : arg.getUsers()) + if (user != successorBranch) + return failure(); + } + // Don't try to collapse branches to infinite loops. + Block *successorDest = successorBranch.getDest(); + if (successorDest == successor) + return failure(); + + // Update the operands to the successor. If the branch parent has no + // arguments, we can use the branch operands directly. + OperandRange operands = successorBranch.getOperands(); + if (successor->args_empty()) { + successor = successorBranch.getDest(); + successorOperands = operands; + return success(); + } + + // Otherwise, we need to remap any argument operands. + for (Value operand : operands) { + BlockArgument argOperand = operand.dyn_cast(); + if (argOperand && argOperand.getOwner() == successor) + argStorage.push_back(successorOperands[argOperand.getArgNumber()]); + else + argStorage.push_back(operand); + } + successor = successorBranch.getDest(); + successorOperands = argStorage; + return success(); + } +}; +} // end anonymous namespace void CondBranchOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert( + context); } Optional CondBranchOp::getSuccessorOperands(unsigned index) { diff --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s + +// Test the folding of BranchOp. + +// CHECK-LABEL: func @br_folding( +func @br_folding() -> i32 { + // CHECK-NEXT: %[[CST:.*]] = constant 0 : i32 + // CHECK-NEXT: return %[[CST]] : i32 + %c0_i32 = constant 0 : i32 + br ^bb1(%c0_i32 : i32) +^bb1(%x : i32): + return %x : i32 +} + +// Test the folding of CondBranchOp with a constant condition. + +// CHECK-LABEL: func @cond_br_folding( +func @cond_br_folding(%cond : i1, %a : i32) { + // CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb1 + + %false_cond = constant 0 : i1 + %true_cond = constant 1 : i1 + cond_br %cond, ^bb1, ^bb2(%a : i32) + +^bb1: + cond_br %true_cond, ^bb3, ^bb2(%a : i32) + +^bb2(%x : i32): + cond_br %false_cond, ^bb2(%x : i32), ^bb3 + +^bb3: + // CHECK: ^bb1: + // CHECK-NEXT: return + + return +} + +// Test the compound folding of BranchOp and CondBranchOp. + +// CHECK-LABEL: func @cond_br_and_br_folding( +func @cond_br_and_br_folding(%a : i32) { + // CHECK-NEXT: return + + %false_cond = constant 0 : i1 + %true_cond = constant 1 : i1 + cond_br %true_cond, ^bb2, ^bb1(%a : i32) + +^bb1(%x : i32): + cond_br %false_cond, ^bb1(%x : i32), ^bb2 + +^bb2: + return +} + +/// Test that pass-through successors of CondBranchOp get folded. + +// CHECK-LABEL: func @cond_br_pass_through( +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32 +func @cond_br_pass_through(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) { + // CHECK: cond_br %{{.*}}, ^bb1(%[[ARG0]], %[[ARG1]] : i32, i32), ^bb1(%[[ARG2]], %[[ARG2]] : i32, i32) + + cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32) + +^bb1(%arg3: i32): + br ^bb2(%arg3, %arg1 : i32, i32) + +^bb2(%arg4: i32, %arg5: i32): + // CHECK: ^bb1(%[[RET0:.*]]: i32, %[[RET1:.*]]: i32): + // CHECK-NEXT: return %[[RET0]], %[[RET1]] + + return %arg4, %arg5 : i32, i32 +} + +/// Test the failure modes of collapsing CondBranchOp pass-throughs successors. + +// CHECK-LABEL: func @cond_br_pass_through_fail( +func @cond_br_pass_through_fail(%cond : i1) { + // CHECK: cond_br %{{.*}}, ^bb1, ^bb2 + + cond_br %cond, ^bb1, ^bb2 + +^bb1: + // CHECK: ^bb1: + // CHECK: "foo.op" + // CHECK: br ^bb2 + + // Successors can't be collapsed if they contain other operations. + "foo.op"() : () -> () + br ^bb2 + +^bb2: + return +} diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -506,52 +506,6 @@ return %Av : memref } -// CHECK-LABEL: func @br_folding -func @br_folding() -> i32 { - // CHECK-NEXT: %[[CST:.*]] = constant 0 : i32 - // CHECK-NEXT: return %[[CST]] : i32 - %c0_i32 = constant 0 : i32 - br ^bb1(%c0_i32 : i32) -^bb1(%x : i32): - return %x : i32 -} - -// CHECK-LABEL: func @cond_br_folding -func @cond_br_folding(%cond : i1, %a : i32) { - %false_cond = constant 0 : i1 - %true_cond = constant 1 : i1 - cond_br %cond, ^bb1, ^bb2(%a : i32) - -^bb1: - // CHECK: ^bb1: - // CHECK-NEXT: br ^bb3 - cond_br %true_cond, ^bb3, ^bb2(%a : i32) - -^bb2(%x : i32): - // CHECK: ^bb2 - // CHECK: br ^bb3 - cond_br %false_cond, ^bb2(%x : i32), ^bb3 - -^bb3: - return -} - -// CHECK-LABEL: func @cond_br_and_br_folding -func @cond_br_and_br_folding(%a : i32) { - // Test the compound folding of conditional and unconditional branches. - // CHECK-NEXT: return - - %false_cond = constant 0 : i1 - %true_cond = constant 1 : i1 - cond_br %true_cond, ^bb2, ^bb1(%a : i32) - -^bb1(%x : i32): - cond_br %false_cond, ^bb1(%x : i32), ^bb2 - -^bb2: - return -} - // CHECK-LABEL: func @indirect_call_folding func @indirect_target() { return