diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -172,16 +172,28 @@ /// `defaultAllocationFn`. void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer); +/// Default memory copy function that is used by the comprehensive bufferization +/// pass. Creates a `linalg.copy` op. +void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to); + /// Callback functions that are used by the comprehensive bufferization pass to /// allocate/deallocate memory. These default to use the /// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the /// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned /// by the `allocationFn`. struct AllocationCallbacks { - std::function(OpBuilder &b, Location loc, Value shapedValue)> - allocationFn = defaultAllocationFn; - std::function deallocationFn = - defaultDeallocationFn; + using AllocationFn = + std::function(OpBuilder &, Location, Value)>; + using DeallocationFn = std::function; + using MemCpyFn = std::function; + + AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn, + MemCpyFn copyFn) + : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {} + + AllocationFn allocationFn; + DeallocationFn deallocationFn; + MemCpyFn memCpyFn; }; /// Bufferize one particular op. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1274,7 +1274,7 @@ if (!skipCopy) { // Set insertion point now that potential alloc/dealloc are introduced. b.setInsertionPoint(op); - b.create(loc, operandBuffer, resultBuffer); + allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer); } return resultBuffer; } @@ -1669,6 +1669,11 @@ b.create(loc, allocatedBuffer); } +void mlir::linalg::defaultMemCpyFn(OpBuilder &b, Location loc, Value from, + Value to) { + b.create(loc, from, to); +} + LogicalResult mlir::linalg::bufferizeOp( Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, AllocationCallbacks allocationFns, @@ -2258,11 +2263,13 @@ // command line option. So this is set up at the start of the pass. if (useAlloca) { AllocationCallbacks allocaAllocationFns = { - allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {}}; + allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {}, + defaultMemCpyFn}; allocationFns = std::make_unique(std::move(allocaAllocationFns)); } else { - allocationFns = std::make_unique(); + allocationFns = std::make_unique( + defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn); } } ModuleOp moduleOp = getOperation(); @@ -3222,7 +3229,7 @@ if (alloc) { // Do not copy if the copied data is never read. if (isValueRead(extractSliceOp.result())) - b.create(extractSliceOp.getLoc(), subView, alloc); + allocationFn.memCpyFn(b, extractSliceOp.getLoc(), subView, alloc); subView = alloc; } @@ -3344,7 +3351,7 @@ insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); // Insert new alias. aliasInfo.insertNewBufferAlias(subView, dstMemref); - b.create(insertSliceOp.getLoc(), srcMemref, subView); + allocationFn.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, subView); } map(bvm, insertSliceOp.result(), dstMemref);