diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -161,6 +161,19 @@ Optional deallocationFn; Optional memCpyFn; + /// Create a memref allocation with the given type and dynamic extents. + FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape) const; + + /// Creates a memref deallocation. The given memref buffer must have been + /// allocated using `createAlloc`. + LogicalResult createDealloc(OpBuilder &b, Location loc, + Value allocatedBuffer) const; + + /// Creates a memcpy between two given buffers. + LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, + Value to) const; + /// Specifies whether not bufferizable ops are allowed in the input. If so, /// bufferization.to_memref and bufferization.to_tensor ops are inserted at /// the boundaries. @@ -514,15 +527,6 @@ MemRefLayoutAttrInterface layout = {}, Attribute memorySpace = {}); -/// Creates a memref deallocation. The given memref buffer must have been -/// allocated using `createAlloc`. -LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, - const BufferizationOptions &options); - -/// Creates a memcpy between two given buffers. -LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to, - const BufferizationOptions &options); - /// Try to hoist all new buffer allocations until the next hoisting barrier. LogicalResult hoistBufferAllocations(Operation *op, const BufferizationOptions &options); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -327,8 +327,7 @@ // The copy happens right before the op that is bufferized. rewriter.setInsertionPoint(op); } - if (failed( - createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options))) + if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer))) return failure(); return resultBuffer; @@ -418,26 +417,24 @@ //===----------------------------------------------------------------------===// /// Create a memref allocation with the given type and dynamic extents. -static FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, - ValueRange dynShape, - const BufferizationOptions &options) { - if (options.allocationFn) - return (*options.allocationFn)(b, loc, type, dynShape, - options.bufferAlignment); +FailureOr BufferizationOptions::createAlloc(OpBuilder &b, Location loc, + MemRefType type, + ValueRange dynShape) const { + if (allocationFn) + return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); // Default bufferallocation via AllocOp. Value allocated = b.create( - loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment)); + loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment)); return allocated; } /// Creates a memref deallocation. The given memref buffer must have been /// allocated using `createAlloc`. -LogicalResult -bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, - const BufferizationOptions &options) { - if (options.deallocationFn) - return (*options.deallocationFn)(b, loc, allocatedBuffer); +LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc, + Value allocatedBuffer) const { + if (deallocationFn) + return (*deallocationFn)(b, loc, allocatedBuffer); // Default buffer deallocation via DeallocOp. b.create(loc, allocatedBuffer); @@ -523,11 +520,10 @@ } /// Create a memory copy between two memref buffers. -LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc, - Value from, Value to, - const BufferizationOptions &options) { - if (options.memCpyFn) - return (*options.memCpyFn)(b, loc, from, to); +LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, + Value from, Value to) const { + if (memCpyFn) + return (*memCpyFn)(b, loc, from, to); b.create(loc, from, to); return success(); @@ -557,8 +553,8 @@ Block *block = allocaOp->getBlock(); rewriter.setInsertionPoint(allocaOp); FailureOr alloc = - createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(), - allocaOp.dynamicSizes(), options); + options.createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(), + allocaOp.dynamicSizes()); if (failed(alloc)) return WalkResult::interrupt(); rewriter.replaceOp(allocaOp, *alloc); @@ -571,7 +567,7 @@ // Create dealloc. rewriter.setInsertionPoint(block->getTerminator()); - if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options))) + if (failed(options.createDealloc(rewriter, alloc->getLoc(), *alloc))) return WalkResult::interrupt(); return WalkResult::advance(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -495,7 +495,7 @@ // Note: This copy will fold away. It must be inserted here to ensure // that `returnVal` still has at least one use and does not fold away. if (failed( - createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options))) + options.createMemCpy(rewriter, loc, toMemrefOp, equivBbArg))) return funcOp->emitError("could not generate copy for bbArg"); continue; } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -363,8 +363,8 @@ // TODO: We should rollback, but for now just assume that this always // succeeds. assert(yieldedAlloc.hasValue() && "could not create alloc"); - LogicalResult copyStatus = bufferization::createMemCpy( - rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc, state.getOptions()); + LogicalResult copyStatus = state.getOptions().createMemCpy( + rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc); (void)copyStatus; assert(succeeded(copyStatus) && "could not create memcpy"); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -309,8 +309,8 @@ if (!inplace) { // Do not copy if the copied data is never read. if (state.getAnalysisState().isValueRead(extractSliceOp.result())) - if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView, - alloc, state.getOptions()))) + if (failed(state.getOptions().createMemCpy( + rewriter, extractSliceOp.getLoc(), subView, alloc))) return failure(); subView = alloc; } @@ -705,8 +705,8 @@ // tensor.extract_slice, the copy operation will eventually fold away. Value srcMemref = *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); - if (failed(createMemCpy(rewriter, loc, srcMemref, subView, - state.getOptions()))) + if (failed( + state.getOptions().createMemCpy(rewriter, loc, srcMemref, subView))) return failure(); replaceOpWithBufferizedValues(rewriter, op, *dstMemref);