diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -51,6 +51,24 @@ void getUsedValuesDefinedAbove(MutableArrayRef regions, SetVector &values); +/// Make a region isolated from above +/// - Capture the values that are defined above the region and used within it. +/// - Append to the entry block arguments that represent the captured values +/// (one per captured value). +/// - Replace all uses within the region of the captured values with the +/// newly added arguments. +/// - `cloneOperationIntoRegion` is a callback that allows caller to specify +/// if the operation defining an `OpOperand` needs to be cloned into the +/// region. Then the operands of this operation become part of the captured +/// values set (unless the operations that define the operands themeselves +/// are to be cloned). The cloned operations are added to the entry block +/// of the region. +/// Return the set of captured values for the operation. +SmallVector makeRegionIsolatedFromAbove( + RewriterBase &rewriter, Region ®ion, + llvm::function_ref cloneOperationIntoRegion = + [](Operation *) { return false; }); + /// Run a set of structural simplifications over the given regions. This /// includes transformations like unreachable block elimination, dead argument /// elimination, as well as some other DCE. This function returns success if any diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -8,17 +8,21 @@ #include "mlir/Transforms/RegionUtils.h" #include "mlir/IR/Block.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/TopologicalSortUtils.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallSet.h" +#include + using namespace mlir; void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, @@ -69,6 +73,91 @@ getUsedValuesDefinedAbove(region, region, values); } +//===----------------------------------------------------------------------===// +// Make block isolated from above. +//===----------------------------------------------------------------------===// + +SmallVector mlir::makeRegionIsolatedFromAbove( + RewriterBase &rewriter, Region ®ion, + llvm::function_ref cloneOperationIntoRegion) { + + llvm::SetVector initialCapturedValues; + mlir::getUsedValuesDefinedAbove(region, initialCapturedValues); + + std::deque worklist(initialCapturedValues.begin(), + initialCapturedValues.end()); + llvm::DenseSet visited; + llvm::DenseSet visitedOps; + + llvm::SetVector finalCapturedValues; + SmallVector clonedOperations; + while (!worklist.empty()) { + Value currValue = worklist.front(); + worklist.pop_front(); + if (visited.count(currValue)) + continue; + visited.insert(currValue); + + Operation *definingOp = currValue.getDefiningOp(); + if (!definingOp || visitedOps.count(definingOp)) + continue; + visitedOps.insert(definingOp); + + if (!cloneOperationIntoRegion(definingOp)) { + finalCapturedValues.insert(currValue); + continue; + } + + for (Value operand : definingOp->getOperands()) { + if (visited.count(operand)) + continue; + worklist.push_back(operand); + } + clonedOperations.push_back(definingOp); + } + + mlir::computeTopologicalSorting(clonedOperations); + + OpBuilder::InsertionGuard g(rewriter); + // Collect types of existing block + Block *entryBlock = ®ion.front(); + SmallVector newArgTypes = + llvm::to_vector(entryBlock->getArgumentTypes()); + SmallVector newArgLocs = llvm::to_vector(llvm::map_range( + entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); })); + + // Append the types of the captured values. + for (auto value : finalCapturedValues) { + newArgTypes.push_back(value.getType()); + newArgLocs.push_back(value.getLoc()); + } + + // Create a new entry block. + Block *newEntryBlock = + rewriter.createBlock(®ion, region.begin(), newArgTypes, newArgLocs); + + // Create a mapping between the captured values and the new arguments added. + IRMapping map; + auto replaceIfFn = [&](OpOperand &use) { + return use.getOwner()->getBlock()->getParent() == ®ion; + }; + for (auto [arg, capturedVal] : llvm::zip( + newEntryBlock->getArguments().take_back(finalCapturedValues.size()), + finalCapturedValues)) { + map.map(capturedVal, arg); + rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn); + } + rewriter.setInsertionPointToStart(newEntryBlock); + for (auto clonedOp : clonedOperations) { + Operation *newOp = rewriter.clone(*clonedOp, map); + rewriter.replaceOpWithIf(clonedOp, newOp->getResults(), replaceIfFn); + } + rewriter.mergeBlocks( + entryBlock, newEntryBlock, + newEntryBlock->getArguments().take_front(entryBlock->getNumArguments())); + return llvm::to_vector(finalCapturedValues); +} + //===----------------------------------------------------------------------===// // Unreachable Block Elimination //===----------------------------------------------------------------------===//