diff --git a/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h b/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h --- a/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h +++ b/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h @@ -51,12 +51,19 @@ /// /// Users must supply a callback `shouldMoveIntoRegion` that determines whether /// the given operation that only has users in the given operation should be -/// moved into that region. +/// moved into that region. If this returns true, `moveIntoRegion` is called on +/// the same operation and region. +/// +/// `moveIntoRegion` must move the operation into the region such that dominance +/// of the operation is preserved; for example, by moving the operation to the +/// start of the entry block. This ensures the preservation of SSA dominance of +/// the operation's results. /// /// Returns the number of operations sunk. size_t controlFlowSink(ArrayRef regions, DominanceInfo &domInfo, - function_ref shouldMoveIntoRegion); + function_ref shouldMoveIntoRegion, + function_ref moveIntoRegion); /// Populates `regions` with regions of the provided region branch op that are /// executed at most once at that are reachable given the current operands of diff --git a/mlir/lib/Transforms/ControlFlowSink.cpp b/mlir/lib/Transforms/ControlFlowSink.cpp --- a/mlir/lib/Transforms/ControlFlowSink.cpp +++ b/mlir/lib/Transforms/ControlFlowSink.cpp @@ -60,9 +60,14 @@ // Get the regions are that known to be executed at most once. getSinglyExecutedRegionsToSink(branch, regionsToSink); // Sink side-effect free operations. - numSunk = - controlFlowSink(regionsToSink, domInfo, [](Operation *op, Region *) { - return isSideEffectFree(op); + numSunk = controlFlowSink( + regionsToSink, domInfo, + [](Operation *op, Region *) { return isSideEffectFree(op); }, + [](Operation *op, Region *region) { + // Move the operation to the beginning of the region's entry block. + // This guarantees the preservation of SSA dominance of all of the + // operation's uses are in the region. + op->moveBefore(®ion->front(), region->front().begin()); }); }); } diff --git a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp --- a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp +++ b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp @@ -34,8 +34,10 @@ public: /// Create an operation sinker with given dominance info. Sinker(function_ref shouldMoveIntoRegion, + function_ref moveIntoRegion, DominanceInfo &domInfo) - : shouldMoveIntoRegion(shouldMoveIntoRegion), domInfo(domInfo) {} + : shouldMoveIntoRegion(shouldMoveIntoRegion), + moveIntoRegion(moveIntoRegion), domInfo(domInfo), numSunk(0) {} /// Given a list of regions, find operations to sink and sink them. Return the /// number of operations sunk. @@ -61,6 +63,8 @@ /// The callback to determine whether an op should be moved in to a region. function_ref shouldMoveIntoRegion; + /// The calback to move an operation into the region. + function_ref moveIntoRegion; /// Dominance info to determine op user dominance with respect to regions. DominanceInfo &domInfo; /// The number of operations sunk. @@ -90,12 +94,7 @@ // If the op's users are all in the region and it can be moved, then do so. if (allUsersDominatedBy(op, region) && shouldMoveIntoRegion(op, region)) { - // Move the op into the region's entry block. If the op is part of a - // subgraph, dependee ops would have been moved first, so inserting before - // the start of the block will ensure SSA dominance is preserved locally - // in the subgraph. Ops can only be safely moved into the entry block as - // the region's other blocks may for a loop. - op->moveBefore(®ion->front(), region->front().begin()); + moveIntoRegion(op, region); ++numSunk; // Add the op to the work queue. stack.push_back(op); @@ -127,8 +126,10 @@ size_t mlir::controlFlowSink( ArrayRef regions, DominanceInfo &domInfo, - function_ref shouldMoveIntoRegion) { - return Sinker(shouldMoveIntoRegion, domInfo).sinkRegions(regions); + function_ref shouldMoveIntoRegion, + function_ref moveIntoRegion) { + return Sinker(shouldMoveIntoRegion, moveIntoRegion, domInfo) + .sinkRegions(regions); } void mlir::getSinglyExecutedRegionsToSink(RegionBranchOpInterface branch,