diff --git a/mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h b/mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h --- a/mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h +++ b/mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h @@ -32,10 +32,9 @@ MutableArrayRef regions) override; /// Creates an `scf.yield` op returning the given results. - LogicalResult - createStructuredBranchRegionTerminatorOp(Location loc, OpBuilder &builder, - Operation *branchRegionOp, - ValueRange results) override; + LogicalResult createStructuredBranchRegionTerminatorOp( + Location loc, OpBuilder &builder, Operation *branchRegionOp, + Operation *replacedControlFlowOp, ValueRange results) override; /// Creates an `scf.while` op. The loop body is made the before-region of the /// while op and terminated with an `scf.condition` op. The after-region does diff --git a/mlir/include/mlir/Transforms/CFGToSCF.h b/mlir/include/mlir/Transforms/CFGToSCF.h --- a/mlir/include/mlir/Transforms/CFGToSCF.h +++ b/mlir/include/mlir/Transforms/CFGToSCF.h @@ -42,12 +42,14 @@ /// Creates a return-like terminator for a branch region of the op returned /// by `createStructuredBranchRegionOp`. `branchRegionOp` is the operation - /// returned by `createStructuredBranchRegionOp` while `results` are the - /// values that should be returned by the branch region. - virtual LogicalResult - createStructuredBranchRegionTerminatorOp(Location loc, OpBuilder &builder, - Operation *branchRegionOp, - ValueRange results) = 0; + /// returned by `createStructuredBranchRegionOp`. + /// `replacedControlFlowOp` is the control flow op being replaced by the + /// terminator or nullptr if the terminator is not replacing any existing + /// control flow op. `results` are the values that should be returned by the + /// branch region. + virtual LogicalResult createStructuredBranchRegionTerminatorOp( + Location loc, OpBuilder &builder, Operation *branchRegionOp, + Operation *replacedControlFlowOp, ValueRange results) = 0; /// Creates a structured control flow operation representing a do-while loop. /// The do-while loop is expected to have the exact same result types as the @@ -77,8 +79,10 @@ /// `caseDestinations` or `defaultDest`. This is used by the transformation /// for intermediate transformations before lifting to structured control /// flow. The switch op branches based on `flag` which is guaranteed to be of - /// the same type as values returned by `getCFGSwitchValue`. Note: - /// `caseValues` and other related ranges may be empty to represent an + /// the same type as values returned by `getCFGSwitchValue`. The insertion + /// block of the builder is guaranteed to have its predecessors already set + /// to create an equivalent CFG after this operation. + /// Note: `caseValues` and other related ranges may be empty to represent an /// unconditional branch. virtual void createCFGSwitchOp(Location loc, OpBuilder &builder, Value flag, ArrayRef caseValues, diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp --- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp +++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp @@ -76,7 +76,7 @@ LogicalResult ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp( Location loc, OpBuilder &builder, Operation *branchRegionOp, - ValueRange results) { + Operation *replacedControlFlowOp, ValueRange results) { builder.create(loc, results); return success(); } diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp --- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp +++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp @@ -306,7 +306,8 @@ /// Creates a switch op using `builder` which dispatches to the original /// successors of the edges passed to `create` minus the ones in `excluded`. /// The builder's insertion point has to be in a block dominated by the - /// multiplexer block. + /// multiplexer block. All edges to the multiplexer block must have already + /// been redirected using `redirectEdge`. void createSwitch( Location loc, OpBuilder &builder, CFGToSCFInterface &interface, const SmallPtrSetImpl &excluded = SmallPtrSet{}) { @@ -337,6 +338,8 @@ Block *defaultDest = caseDestinations.pop_back_val(); ValueRange defaultArgs = caseArguments.pop_back_val(); + assert(!builder.getInsertionBlock()->hasNoPredecessors() && + "Edges need to be redirected prior to creating switch."); interface.createCFGSwitchOp(loc, builder, realDiscriminator, caseValues, caseDestinations, caseArguments, defaultDest, defaultArgs); @@ -507,12 +510,14 @@ loc, llvm::map_to_vector(entryEdges, std::mem_fn(&Edge::getSuccessor)), getSwitchValue, getUndefValue); - auto builder = OpBuilder::atBlockBegin(result.getMultiplexerBlock()); - result.createSwitch(loc, builder, interface); - + // Redirect the edges prior to creating the switch op. + // We guarantee that predecessors are up to date. for (Edge edge : entryEdges) result.redirectEdge(edge); + auto builder = OpBuilder::atBlockBegin(result.getMultiplexerBlock()); + result.createSwitch(loc, builder, interface); + return result; } @@ -565,6 +570,17 @@ // Since this is a loop, all back edges point to the same loop header. Block *loopHeader = backEdges.front().getSuccessor(); + // Redirect the edges prior to creating the switch op. + // We guarantee that predecessors are up to date. + + // Redirecting back edges with `shouldRepeat` as 1. + for (Edge backEdge : backEdges) + multiplexer.redirectEdge(backEdge, /*extraArgs=*/getSwitchValue(1)); + + // Redirecting exits edges with `shouldRepeat` as 0. + for (Edge exitEdge : exitEdges) + multiplexer.redirectEdge(exitEdge, /*extraArgs=*/getSwitchValue(0)); + // Create the new only back edge to the loop header. Branch to the // exit block otherwise. Value shouldRepeat = latchBlock->getArguments().back(); @@ -603,14 +619,6 @@ } } - // Redirecting back edges with `shouldRepeat` as 1. - for (Edge backEdge : backEdges) - multiplexer.redirectEdge(backEdge, /*extraArgs=*/getSwitchValue(1)); - - // Redirecting exits edges with `shouldRepeat` as 0. - for (Edge exitEdge : exitEdges) - multiplexer.redirectEdge(exitEdge, /*extraArgs=*/getSwitchValue(0)); - return StructuredLoopProperties{latchBlock, /*condition=*/shouldRepeat, exitBlock}; } @@ -794,13 +802,14 @@ // First turn the cycle into a loop by creating a single entry block if // needed. if (edges.entryEdges.size() > 1) { + SmallVector edgesToEntryBlocks; + llvm::append_range(edgesToEntryBlocks, edges.entryEdges); + llvm::append_range(edgesToEntryBlocks, edges.backEdges); + EdgeMultiplexer multiplexer = createSingleEntryBlock( - loopHeader->getTerminator()->getLoc(), edges.entryEdges, + loopHeader->getTerminator()->getLoc(), edgesToEntryBlocks, getSwitchValue, getUndefValue, interface); - for (Edge edge : edges.backEdges) - multiplexer.redirectEdge(edge); - loopHeader = multiplexer.getMultiplexerBlock(); } cycleBlockSet.insert(loopHeader); @@ -1140,7 +1149,8 @@ for (auto &&[block, valueRange] : createdEmptyBlocks) { auto builder = OpBuilder::atBlockEnd(block); LogicalResult result = interface.createStructuredBranchRegionTerminatorOp( - structuredCondOp->getLoc(), builder, structuredCondOp, valueRange); + structuredCondOp->getLoc(), builder, structuredCondOp, nullptr, + valueRange); if (failed(result)) return failure(); } @@ -1153,7 +1163,7 @@ assert(user->getNumSuccessors() == 1); auto builder = OpBuilder::atBlockTerminator(user->getBlock()); LogicalResult result = interface.createStructuredBranchRegionTerminatorOp( - user->getLoc(), builder, structuredCondOp, + user->getLoc(), builder, structuredCondOp, user, static_cast( getMutableSuccessorOperands(user->getBlock(), 0))); if (failed(result))