diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -15,6 +15,7 @@ #ifndef MLIR_TRANSFORMS_LOOP_UTILS_H #define MLIR_TRANSFORMS_LOOP_UTILS_H +#include "mlir/Analysis/Utils.h" #include "mlir/IR/Block.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -185,6 +186,25 @@ Optional filterMemRef, DenseSet ©Nests); +/// generateDataCopyAroundOp is similar to affineDataCopyGenerate, but with some +/// simplifications: +/// * The logic of "find relavant memrefs and their uses" is de-coupled and push +/// back to the users. It focuses on generating fast buffers and associated +/// loops/DMAs. +/// * It handles a single memref per call. +/// * The prologue and epilogue always surround `op`, not in potentially +/// arbitrary places. +struct CopyGenerateResult { + uint64_t sizeInBytes; + Operation *alloc; + Operation *copyNest; +}; + +LogicalResult generateDataCopyAroundOp(const MemRefRegion &memrefRegion, + Operation *where, + const AffineCopyOptions ©Options, + CopyGenerateResult &result); + /// Tile a nest of standard for loops rooted at `rootForOp` by finding such /// parametric tile sizes that the outer loops have a fixed number of iterations /// as defined in `sizes`. diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1797,6 +1797,33 @@ filterMemRef, copyNests); } +LogicalResult mlir::generateDataCopyAroundOp( + const MemRefRegion &memrefRegion, Operation *where, + const AffineCopyOptions ©Options, CopyGenerateResult &result) { + Block *block = op->getBlock(); + auto begin = op->getIterator(); + auto end = std::next(begin); + DenseMap fastBufferMap; + DenseSet copyNests; + + auto err = generateCopy(memrefRegion, block, begin, end, block, begin, end, + copyOptions, fastBufferMap, copyNests, + &result.sizeInBytes, &begin, &end); + if (failed(err)) { + return err; + } + result.alloc = + fastBufferMap.find(memrefRegion.memref)->second.getDefiningOp(); + if (copyNests.empty()) { + result.copyNest = nullptr; + } else { + assert(copyNests.size() == 1 && + "Multiple copy nests generated appear for a single memref."); + result.copyNest = *copyNests.begin(); + } + return success(); +} + /// Gathers all AffineForOps in 'block' at 'currLoopDepth' in 'depthToLoops'. static void gatherLoopsInBlock(Block *block, unsigned currLoopDepth,