diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -211,6 +211,13 @@ TypeRange regionResultTypes, Optional inlineLoc = llvm::None, bool shouldCloneInlinedRegion = true); +LogicalResult inlineRegion(InlinerInterface &interface, Region *src, + Block *inlineBlock, Block::iterator inlinePoint, + BlockAndValueMapping &mapper, + ValueRange resultsToReplace, + TypeRange regionResultTypes, + Optional inlineLoc = llvm::None, + bool shouldCloneInlinedRegion = true); /// This function is an overload of the above 'inlineRegion' that allows for /// providing the set of operands ('inlinedOperands') that should be used @@ -220,6 +227,12 @@ ValueRange resultsToReplace, Optional inlineLoc = llvm::None, bool shouldCloneInlinedRegion = true); +LogicalResult inlineRegion(InlinerInterface &interface, Region *src, + Block *inlineBlock, Block::iterator inlinePoint, + ValueRange inlinedOperands, + ValueRange resultsToReplace, + Optional inlineLoc = llvm::None, + bool shouldCloneInlinedRegion = true); /// This function inlines a given region, 'src', of a callable operation, /// 'callable', into the location defined by the given call operation. This diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -145,11 +145,11 @@ //===----------------------------------------------------------------------===// static LogicalResult -inlineRegionImpl(InlinerInterface &interface, Region *src, - Operation *inlinePoint, BlockAndValueMapping &mapper, +inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, + Block::iterator inlinePoint, BlockAndValueMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, Optional inlineLoc, bool shouldCloneInlinedRegion, - Operation *call) { + Operation *call = nullptr) { assert(resultsToReplace.size() == regionResultTypes.size()); // We expect the region to have at least one block. if (src->empty()) @@ -161,26 +161,18 @@ [&](BlockArgument arg) { return !mapper.contains(arg); })) return failure(); - // The insertion point must be within a block. - Block *insertBlock = inlinePoint->getBlock(); - if (!insertBlock) - return failure(); - Region *insertRegion = insertBlock->getParent(); - // Check that the operations within the source region are valid to inline. + Region *insertRegion = inlineBlock->getParent(); if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion, mapper) || !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion, mapper)) return failure(); - // Split the insertion block. - Block *postInsertBlock = - insertBlock->splitBlock(++inlinePoint->getIterator()); - // Check to see if the region is being cloned, or moved inline. In either // case, move the new blocks after the 'insertBlock' to improve IR // readability. + Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint); if (shouldCloneInlinedRegion) src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); else @@ -189,7 +181,7 @@ src->end()); // Get the range of newly inserted blocks. - auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()), + auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()), postInsertBlock->getIterator()); Block *firstNewBlock = &*newBlocks.begin(); @@ -234,17 +226,17 @@ } // Splice the instructions of the inlined entry block into the insert block. - insertBlock->getOperations().splice(insertBlock->end(), + inlineBlock->getOperations().splice(inlineBlock->end(), firstNewBlock->getOperations()); firstNewBlock->erase(); return success(); } static LogicalResult -inlineRegionImpl(InlinerInterface &interface, Region *src, - Operation *inlinePoint, ValueRange inlinedOperands, +inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, + Block::iterator inlinePoint, ValueRange inlinedOperands, ValueRange resultsToReplace, Optional inlineLoc, - bool shouldCloneInlinedRegion, Operation *call) { + bool shouldCloneInlinedRegion, Operation *call = nullptr) { // We expect the region to have at least one block. if (src->empty()) return failure(); @@ -265,9 +257,9 @@ } // Call into the main region inliner function. - return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace, - resultsToReplace.getTypes(), inlineLoc, - shouldCloneInlinedRegion, call); + return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, + resultsToReplace, resultsToReplace.getTypes(), + inlineLoc, shouldCloneInlinedRegion, call); } LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, @@ -277,10 +269,19 @@ TypeRange regionResultTypes, Optional inlineLoc, bool shouldCloneInlinedRegion) { - return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace, - regionResultTypes, inlineLoc, - shouldCloneInlinedRegion, - /*call=*/nullptr); + return inlineRegion(interface, src, inlinePoint->getBlock(), + ++inlinePoint->getIterator(), mapper, resultsToReplace, + regionResultTypes, inlineLoc, shouldCloneInlinedRegion); +} +LogicalResult +mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock, + Block::iterator inlinePoint, BlockAndValueMapping &mapper, + ValueRange resultsToReplace, TypeRange regionResultTypes, + Optional inlineLoc, + bool shouldCloneInlinedRegion) { + return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, + resultsToReplace, regionResultTypes, inlineLoc, + shouldCloneInlinedRegion); } LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, @@ -289,9 +290,18 @@ ValueRange resultsToReplace, Optional inlineLoc, bool shouldCloneInlinedRegion) { - return inlineRegionImpl(interface, src, inlinePoint, inlinedOperands, - resultsToReplace, inlineLoc, shouldCloneInlinedRegion, - /*call=*/nullptr); + return inlineRegion(interface, src, inlinePoint->getBlock(), + ++inlinePoint->getIterator(), inlinedOperands, + resultsToReplace, inlineLoc, shouldCloneInlinedRegion); +} +LogicalResult +mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock, + Block::iterator inlinePoint, ValueRange inlinedOperands, + ValueRange resultsToReplace, Optional inlineLoc, + bool shouldCloneInlinedRegion) { + return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, + inlinedOperands, resultsToReplace, inlineLoc, + shouldCloneInlinedRegion); } /// Utility function used to generate a cast operation from the given interface, @@ -399,7 +409,8 @@ return cleanupState(); // Attempt to inline the call. - if (failed(inlineRegionImpl(interface, src, call, mapper, callResults, + if (failed(inlineRegionImpl(interface, src, call->getBlock(), + ++call->getIterator(), mapper, callResults, callableResultTypes, call.getLoc(), shouldCloneInlinedRegion, call))) return cleanupState();