diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -198,12 +198,14 @@ /// provided, will be used to update the inlined operations' location /// information. 'shouldCloneInlinedRegion' corresponds to whether the source /// region should be cloned into the 'inlinePoint' or spliced directly. +/// 'clonedAttrs' contains all the attributes to be copied to copied operations. LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, BlockAndValueMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, Optional inlineLoc = llvm::None, - bool shouldCloneInlinedRegion = true); + bool shouldCloneInlinedRegion = true, + ArrayRef clonedAttrs = {}); /// This function is an overload of the above 'inlineRegion' that allows for /// providing the set of operands ('inlinedOperands') that should be used @@ -220,10 +222,12 @@ /// function returns failure if inlining is not possible, success otherwise. On /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' /// corresponds to whether the source region should be cloned into the 'call' or -/// spliced directly. +/// spliced directly. 'clonedAttrs' contains all the attributes to be copied to +/// the copied operations. LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, - bool shouldCloneInlinedRegion = true); + bool shouldCloneInlinedRegion = true, + ArrayRef clonedAttrs = {}); } // end namespace mlir diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -42,6 +42,17 @@ block.walk(remapOpLoc); } +/// Copy the given attributes to the operations. +static void cloneNamedAttributes(iterator_range inlinedBlocks, + ArrayRef clonedAttrs) { + auto setAttr = [&](Operation *op) { + for (auto attr : clonedAttrs) + op->setAttr(attr.first, attr.second); + }; + for (auto &block : inlinedBlocks) + block.walk(setAttr); +} + static void remapInlinedOperands(iterator_range inlinedBlocks, BlockAndValueMapping &mapper) { auto remapOperands = [&](Operation *op) { @@ -137,13 +148,12 @@ // Inline Methods //===----------------------------------------------------------------------===// -LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, - Operation *inlinePoint, - BlockAndValueMapping &mapper, - ValueRange resultsToReplace, - TypeRange regionResultTypes, - Optional inlineLoc, - bool shouldCloneInlinedRegion) { +LogicalResult +mlir::inlineRegion(InlinerInterface &interface, Region *src, + Operation *inlinePoint, BlockAndValueMapping &mapper, + ValueRange resultsToReplace, TypeRange regionResultTypes, + Optional inlineLoc, bool shouldCloneInlinedRegion, + ArrayRef clonedAttrs) { assert(resultsToReplace.size() == regionResultTypes.size()); // We expect the region to have at least one block. if (src->empty()) @@ -197,6 +207,11 @@ if (!shouldCloneInlinedRegion) remapInlinedOperands(newBlocks, mapper); + // Clone the attributes. + if (!clonedAttrs.empty()) { + cloneNamedAttributes(newBlocks, clonedAttrs); + } + // Process the newly inlined blocks. interface.processInlinedBlocks(newBlocks); @@ -297,7 +312,8 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface, CallOpInterface call, CallableOpInterface callable, Region *src, - bool shouldCloneInlinedRegion) { + bool shouldCloneInlinedRegion, + ArrayRef clonedAttrs) { // We expect the region to have at least one block. if (src->empty()) return failure(); @@ -373,7 +389,7 @@ // Attempt to inline the call. if (failed(inlineRegion(interface, src, call, mapper, callResults, callableResultTypes, call.getLoc(), - shouldCloneInlinedRegion))) + shouldCloneInlinedRegion, clonedAttrs))) return cleanupState(); return success(); }