diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -64,14 +64,14 @@ std::unique_ptr defaultAllocationCallbacks(); /// PostAnalysisSteps can be registered with `BufferizationOptions` and are -/// executed after the analysis, but before bufferization. They can be used +/// executed after the analysis, but before bufferization. They can be used to /// implement custom dialect-specific optimizations. struct PostAnalysisStep { virtual ~PostAnalysisStep() {} /// Run the post analysis step. This function may modify the IR, but must keep - /// `aliasInfo` (inside `state`) consistent. Newly created operations and - /// operations that should be re-analyzed must be stored in `newOps`. + /// `aliasInfo` consistent. Newly created operations and operations that + /// should be re-analyzed must be added to `newOps`. virtual LogicalResult run(Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) = 0; @@ -102,7 +102,8 @@ } /// Allow-list the given dialects in the dialect filter. Only ops from - /// allow-listed dialects will be bufferized. + /// allow-listed dialects will be bufferized. If no dialect is added, ops from + /// any dialect will be bufferized. template void addToDialectFilter() { // The following expands a call to addToDialectFilterImpl for each dialect @@ -288,17 +289,7 @@ }; /// BufferizationState provides a variety of helper functions for dealing with -/// tensor values and memref buffers. In particular, -/// `BufferizableOpInterface::bufferize` implementation should utilize the -/// following helper functions. -/// -/// * `createAlloc` / `createDealloc` / `createAllocDeallocPair` creates ops -/// that allocate and/or deallocate memref buffers. -/// * `lookupBuffer` returns the memref buffer of a given tensor value. -/// * `getResultBuffer` returns the memref buffer for a given tensor OpResult. -/// Based on inplace bufferization decisions of the analysis, it may either -/// directly return a mapped buffer or allocate a new brand new buffer. -/// * `replaceOp` replaces an op with new values. +/// tensor values and memref buffers. class BufferizationState { public: BufferizationState(Operation *op, const BufferizationOptions &options); @@ -413,7 +404,7 @@ /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. - Value getResultBuffer(RewriterBase &rewriter, OpResult result); + FailureOr getResultBuffer(RewriterBase &rewriter, OpResult result); /// Return dialect-specific bufferization state. template StateT &getDialectState(StringRef name) { @@ -453,12 +444,9 @@ MemRefLayoutAttrInterface layout = {}, Attribute memorySpace = {}); -/// Return a contiguous MemRefType (i.e. with canonical/empty layout map) -/// with the same shape as `shapedType` and specified `layout` and -/// `addressSpace` or an UnrankedMemRefType otherwise. -Type getContiguousOrUnrankedMemRefType(Type type, - MemRefLayoutAttrInterface layout = {}, - Attribute memorySpace = {}); +/// Return an UnrankedMemRefType with the given element type and memory space. +UnrankedMemRefType getUnrankedMemRefType(Type elementType, + Attribute memorySpace = {}); /// Return a MemRefType to which the `tensorType` can be bufferized in a /// composable fashion. The layout must be the most dynamic possible and @@ -491,7 +479,7 @@ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, BufferizationState &state) const { - return false; + return true; } SmallVector diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h @@ -23,6 +23,12 @@ namespace linalg_ext { struct InitTensorEliminationStep : public PostAnalysisStep { + /// A function that matches anchor OpOperands for InitTensorOp elimination. + using AnchorMatchFn = std::function; + + /// A function that rewrites matched anchors. + using RewriteFn = std::function; + /// Try to eliminate InitTensorOps inside `op`. /// /// * `rewriteFunc` generates the replacement for the InitTensorOp. @@ -33,12 +39,11 @@ /// InitTensorOp. /// * The result of `rewriteFunc` must usually be analyzed for inplacability. /// This analysis can be skipped with `skipAnalysis`. - LogicalResult eliminateInitTensors( - Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, - std::function anchorMatchFunc, - std::function rewriteFunc, - SmallVector &newOps); + LogicalResult eliminateInitTensors(Operation *op, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, + AnchorMatchFn anchorMatchFunc, + RewriteFn rewriteFunc, + SmallVector &newOps); }; /// Try to eliminate InitTensorOps inside `op` that are anchored on an diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp @@ -13,6 +13,8 @@ void mlir::linalg::comprehensive_bufferize::affine_ext:: registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { + // AffineParallelOp bufferization not implemented yet. However, never hoist + // memref allocations across AffineParallelOp boundaries. registry.addOpInterface>(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp @@ -20,23 +20,30 @@ namespace comprehensive_bufferize { namespace arith_ext { +/// Bufferization of arith.constant. Replace with memref.get_global. struct ConstantOpInterface : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto constantOp = cast(op); - assert(constantOp.getType().dyn_cast() && - "not a constant ranked tensor"); + + // Only ranked tensors are supported. + if (!constantOp.getType().isa()) + return failure(); + + // Only constants inside a module are supported. auto moduleOp = constantOp->getParentOfType(); if (!moduleOp) - return constantOp.emitError( - "cannot bufferize constants not within builtin.module op"); + return failure(); + // Create global memory segment and replace tensor with memref pointing to + // that memory segment. GlobalCreator globalCreator(moduleOp); auto globalMemref = globalCreator.getGlobalFor(constantOp); state.replaceOpWithNewBufferizedOp( rewriter, op, globalMemref.type(), globalMemref.getName()); + return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -74,6 +74,21 @@ BufferizationOptions::BufferizationOptions() : allocationFns(defaultAllocationCallbacks()) {} +BufferizableOpInterface mlir::linalg::comprehensive_bufferize:: + BufferizationOptions::dynCastBufferizableOp(Operation *op) const { + if (isOpAllowed(op)) + return dyn_cast(op); + return nullptr; +} + +BufferizableOpInterface mlir::linalg::comprehensive_bufferize:: + BufferizationOptions::dynCastBufferizableOp(Value value) const { + if (auto bufferizableOp = value.getDefiningOp()) + if (isOpAllowed(bufferizableOp.getOperation())) + return bufferizableOp; + return nullptr; +} + //===----------------------------------------------------------------------===// // BufferizationAliasInfo //===----------------------------------------------------------------------===// @@ -180,21 +195,6 @@ } } -BufferizableOpInterface mlir::linalg::comprehensive_bufferize:: - BufferizationOptions::dynCastBufferizableOp(Operation *op) const { - if (isOpAllowed(op)) - return dyn_cast(op); - return nullptr; -} - -BufferizableOpInterface mlir::linalg::comprehensive_bufferize:: - BufferizationOptions::dynCastBufferizableOp(Value value) const { - if (auto bufferizableOp = value.getDefiningOp()) - if (isOpAllowed(bufferizableOp.getOperation())) - return bufferizableOp; - return nullptr; -} - /// Determine which OpOperand* will alias with `result` if the op is bufferized /// in place. Return an empty vector if the op is not bufferizable. SmallVector @@ -359,8 +359,9 @@ /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. -Value mlir::linalg::comprehensive_bufferize::BufferizationState:: - getResultBuffer(RewriterBase &rewriter, OpResult result) { +FailureOr +mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer( + RewriterBase &rewriter, OpResult result) { OpBuilder::InsertionGuard guard(rewriter); Operation *op = result.getOwner(); SmallVector aliasingOperands = getAliasingOpOperand(result); @@ -376,10 +377,8 @@ if (aliasingOperands.size() > 1 && !llvm::all_of(aliasingOperands, [&](OpOperand *o) { return lookupBuffer(rewriter, o->get()) == operandBuffer; - })) { - op->emitError("result buffer is ambiguous"); - return Value(); - } + })) + return FailureOr(op->emitError("result buffer is ambiguous")); // If bufferizing out-of-place, allocate a new buffer. if (!aliasInfo.isInPlace(result)) { @@ -611,10 +610,13 @@ // Insert to_memref op. OpBuilder::InsertionGuard g(rewriter); setInsertionPointAfter(rewriter, tensor); - Type memrefType = - tensor.getType().isa() - ? getDynamicMemRefType(tensor.getType().cast()) - : getContiguousOrUnrankedMemRefType(tensor.getType()); + Type memrefType; + if (auto rankedTensorType = tensor.getType().dyn_cast()) { + memrefType = getDynamicMemRefType(rankedTensorType); + } else { + memrefType = getUnrankedMemRefType( + tensor.getType().cast().getElementType()); + } return rewriter.create(tensor.getLoc(), memrefType, tensor); } @@ -631,13 +633,9 @@ layout, memorySpace); } -Type mlir::linalg::comprehensive_bufferize::getContiguousOrUnrankedMemRefType( - Type type, MemRefLayoutAttrInterface layout, Attribute memorySpace) { - if (type.isa()) - return getContiguousMemRefType(type.cast(), layout, - memorySpace); - assert(!layout && "expected empty layout with UnrankedMemRefType"); - return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace); +UnrankedMemRefType mlir::linalg::comprehensive_bufferize::getUnrankedMemRefType( + Type elementType, Attribute memorySpace) { + return UnrankedMemRefType::get(elementType, memorySpace); } MemRefType mlir::linalg::comprehensive_bufferize::getDynamicMemRefType( diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp @@ -25,6 +25,9 @@ // TODO: These ops should implement BufferizableOpInterface directly when moved // to the Bufferization dialect. +/// Bufferization of bufferization.to_memref. to_memref(to_tensor(x)) is folded +/// to x. Other to_memref ops are ignored during bufferization. +/// /// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory /// location of the incoming tensor once it will be bufferized. In the anlysis, /// the incoming tensor is assumed to bufferize to a memory read and to an @@ -41,7 +44,7 @@ bufferization::ToMemrefOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, BufferizationState &state) const { - // It is unknown whether the resulting MemRef will be read or not. + // It is unknown whether the resulting memref will be read or not. return true; } @@ -58,10 +61,13 @@ if (auto toTensorOp = toMemrefOp.tensor().getDefiningOp()) { Value buffer = toTensorOp.memref(); + + // Insert cast in case to_memref(to_tensor(x))'s type is different from + // x's type. if (toTensorOp.memref().getType() != toMemrefOp.getType()) buffer = rewriter.create(toMemrefOp.getLoc(), buffer, toMemrefOp.getType()); - rewriter.replaceOp(toMemrefOp, buffer); + state.replaceOpWithBufferizedValues(rewriter, toMemrefOp, buffer); return success(); } @@ -69,16 +75,19 @@ } }; -/// ToTensorOp conceptually loads a tensor from a memory location. Such ops do -/// not lower any further, and they should have disappeared by the time the -/// input is fully bufferized. +/// Bufferization of bufferization.to_tensor. Such ops cannot be bufferized. +/// However, other ops that are using to_tensor's result will eventually be +/// bufferized. At that point, they will start using to_tensor's memref operand. +/// Once all users of to_tensor are bufferized, the op will not have any users +/// anymore and DCE away. /// -/// The analysis has no information about the memref that is loaded from by the -/// ToTensorOp. We have to assume that the loaded tensor may after bufferization -/// potentially alias with any other bufferized tensor. Since ToTensorOp and -/// ToMemrefOp have no aliasing OpOperand/OpResult pairs, this cannot be encoded -/// directly in the analysis. However, declaring ToTensorOp results as not -/// writable also enforces a buffer copy and has the same effect. +/// ToTensorOp conceptually loads a tensor from a memory location. The analysis +/// has no information about the memref that is loaded from by ToTensorOp. We +/// have to assume that the loaded tensor may after bufferization potentially +/// alias with any other bufferized tensor. Since ToTensorOp and ToMemrefOp have +/// no aliasing OpOperand/OpResult pairs, this cannot be encoded directly in the +/// analysis. However, declaring ToTensorOp results as not writable enforces a +/// buffer copy and has the same effect. struct ToTensorOpInterface : public BufferizableOpInterface::ExternalModel { @@ -88,7 +97,7 @@ } bool isWritable(Operation *op, Value value, BufferizationState &state) const { - // It is unknown whether the MemRef operand is writable or not. + // It is unknown whether the memref operand is writable or not. return false; } }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -6,98 +6,37 @@ // //===----------------------------------------------------------------------===// // -// Perform inplace bufferization within function boundaries. -// This is a specialized pass that supports inplace analysis for a fixed subset -// of ops that have well-defined inplace semantics. -// This pass caters to high-performance codegen where buffer reuse is deemed -// critical: the pass should fail if the bufferized form of the function needs -// to return any buffer. -// Generic control-flow and branching are unsupported. -// Composability with extensible set of ops is not a first-class concern. -// -// Bufferization occurs by: -// a. performing an inPlace analysis `inPlaceAnalysis` which marks each -// operation within the op with the `kInPlaceResultsAttrName` attribute. -// b. traversing each operation in the op and rewriting it in -// buffer form and keeping a BlockAndValueMapping mapping of the -// rewrites. New allocations are introduced during this step. -// TODO: Allocation + depending op hoisting to outermost enclosing -// sequential scope. -// c. at the end of this bufferization, 3 cases may occur: -// i. inplaceable function arguments may be reused in place after the -// function itself has been bufferized. This is encoded by IR resembling: -// -// ``` -// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// func @foo(%A: tensor {linalg.inplaceable = true}) -// -> tensor { -// %0 = bufferization.to_memref %A : memref -// // ... uses of %0 -// %res = bufferization.to_tensor %0 : memref -// return %res : tensor -// } -// ``` +// Comprehensive Bufferize bufferizes function bodies. Function boundaries +// (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops. +// ModuleBufferization.cpp is an extension of Comprehensive Bufferize for simple +// call graphs. // -// this is the cue for the bufferization of the function foo (and calls -// to it) may bufferize to `func @foo(%A: memref)`. -// To fully achieve bufferization, an additional analysis is needed to -// determine whether function argument/operand pairs bufferize to a -// single inplace buffer argument (i.e. functions may return tensors in -// arbitrary order that may not match argument numbers). +// Comprehensive Bufferize consists of two phases. // -// ii. results that don't map to an inplaceable function argument are -// generally allocated. Since memref semantics wrt ownership of the -// underlying memory region are not well-defined, comprehensive -// bufferization chooses to perform allocations in a scoped fashion: -// returning memrefs is always considered illegal. -// Such scenarios are encoded by IR resembling: +// 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without +// inserting buffer copies. The analysis queries op bufferization semantics +// via `BufferizableOpInterface`. +// 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This +// function does not generate buffer copies for OpResults that were decided +// to bufferize inplace during the analysis phase. // -// ``` -// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// func @foo(%A: tensor {linalg.inplaceable = true}) -// -> tensor { -// %0 = bufferization.to_memref %A : memref -// %1 = memref.dim %0, %c0 : memref -// %2 = memref.alloc(%1) : memref -// %3 = memref.cast %2 : memref to memref -// // ... uses of %3 -// memref.dealloc %2 : memref -// %res = bufferization.to_tensor %3 : memref -// return %res : tensor -// } -// ``` +// Inplace bufferization decisions are passed from the analysis to the +// bufferization phase via `BufferizationState` and `BufferizationAliasInfo`. +// They can be printed for debugging purposes with `testAnalysisOnly`. // -// this is the cue for the bufferization of the function foo (and calls -// to it) that it must bufferize to `func @foo(%A: memref, -// %B: memref)` (i.e. make a cloned -// allocation of the result tensor) -// To fully achieve bufferization, the alloc/dealloc pair must be lifted -// out of the function at each call site. +// Ops that do not implement `BufferizableOpInterface` can be analyzed but are +// treated conservatively. E.g., the analysis has to assume that their +// OpOperands bufferize to memory writes. While such ops can be analyzed, they +// are not bufferized and remain in the IR. to_tensor and to_memref ops are +// inserted at the bufferization boundary. // -// iii. as an optimization over ii., it may be possible to reuse an argument -// and only want to return a slice. -// This may forego allocation by letting *all* callers decide whether to -// pass a new *aliasing* memref function argument (i.e. a subview). -// Without loss of generality, callers may agree to allocate a new buffer -// to avoid this aliasing. Such scenarios are encoded by IR resembling: +// Note: If `allowUnknownOps` is set to false, bufferization fails when an +// unknown op (that does not implement `BufferizableOpInterface`) is found. No +// to_tensor/to_memref ops are inserted. // -// ``` -// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// func @foo(%arg0: tensor {linalg.inplaceable = true}) -// -> tensor<4xf32> { -// %0 = bufferization.to_memref %arg0 : memref -// %1 = memref.subview %0[0] [4] [1] : memref to -// memref<4xf32, #map> -// // ... inplace computes into %1 -// %3 = bufferization.to_tensor %1 : memref<4xf32, #map> -// return %3 : tensor<4xf32> -// } -// ``` -// -// Note: In the future, it may be worthwhile to design special bufferization -// ops to encode the desired semantics at function boundaries for i., ii. and -// iii. +// This pass caters to high-performance codegen where buffer reuse is deemed +// critical: the pass should fail if the bufferized form of the function needs +// to return any buffer, unless `allowReturnMemref` is enabled. // // Lastly, note that layout map chosen to bufferize is the most dynamic // canonical strided layout of the proper rank. This ensures compatibility with diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -38,6 +38,7 @@ if (!op.hasTensorSemantics()) return op->emitError() << "op does not have tensor semantics"; + // New input operands for the cloned op. SmallVector newInputBuffers; newInputBuffers.reserve(op.getNumInputs()); for (OpOperand *opOperand : op.getInputOperands()) { @@ -48,22 +49,23 @@ newInputBuffers.push_back(state.lookupBuffer(rewriter, opOperand->get())); } + // New output operands for the cloned op. SmallVector newOutputBuffers; for (OpOperand *opOperand : op.getOutputOperands()) { OpResult opResult = op.getTiedOpResult(opOperand); assert(opResult && "could not find correspond OpResult"); - Value resultBuffer = state.getResultBuffer(rewriter, opResult); - if (!resultBuffer) - return failure(); - newOutputBuffers.push_back(resultBuffer); + FailureOr resultBuffer = state.getResultBuffer(rewriter, opResult); + newOutputBuffers.push_back(*resultBuffer); } - // Clone the newly bufferized op. + // Merge input/output operands. SmallVector newOperands = newInputBuffers; newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); // Set insertion point now that potential alloc/dealloc are introduced. rewriter.setInsertionPoint(op); + // Clone the op, but use the new operands. Since the new op does not have any + // tensor results, it does not return anything. op.clone(rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands); // Replace the results of the old op with the new output buffers. @@ -135,18 +137,23 @@ return mapping; } +/// Bufferization of linalg.generic. Replace with a new linalg.generic that +/// operates entirely on memrefs. template struct LinalgOpInterface : public BufferizableOpInterface::ExternalModel, OpTy> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, BufferizationState &state) const { + // Operand is read if it is used in the computation. auto genericOp = cast(op); return genericOp.payloadUsesValueFromOperand(&opOperand); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, BufferizationState &state) const { + // Operand is written to if it has an aliasing OpResult. For more details, + // see `computeAliasingPairs`. auto bufferizableOp = cast(op); return static_cast( bufferizableOp.getAliasingOpResult(opOperand, state)); @@ -156,6 +163,8 @@ getAliasingOpOperand(Operation *op, OpResult opResult, BufferizationState &state) const { auto genericOp = cast(op); + + // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. DenseMap pairs = computeAliasingPairs(genericOp); for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) if (pairs[opOperand] == opResult) @@ -166,6 +175,8 @@ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, BufferizationState &state) const { auto genericOp = cast(op); + + // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. DenseMap pairs = computeAliasingPairs(genericOp); return pairs[&opOperand]; } @@ -207,22 +218,26 @@ } }; +/// Bufferization of linalg.tiled_loop. Replace with a new linalg.tiled_loop +/// that operates entirely on memrefs. struct TiledLoopOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, BufferizationState &state) const { - // TiledLoop alone doesn't bufferize to a memory read, one of the uses of - // its matching bbArg may. auto tiledLoopOp = cast(op); + + // linalg.tiled_loop operands alone do not bufferize to a memory read, but + // one of the uses of their matching bbArgs may. return state.isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand)); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, BufferizationState &state) const { - // TiledLoop alone doesn't bufferize to a memory write, one of the uses of - // its matching bbArg may. auto bufferizableOp = cast(op); + + // Only operands with an aliasing OpResult (i.e., output operands) bufferize + // to a memory write. return static_cast( bufferizableOp.getAliasingOpResult(opOperand, state)); } @@ -230,6 +245,8 @@ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, BufferizationState &state) const { auto tiledLoopOp = cast(op); + + // Output operands are tied to their corresponding OpResults. return tiledLoopOp.getTiedOpResult(opOperand); } @@ -240,8 +257,8 @@ } bool isWritable(Operation *op, Value value, BufferizationState &state) const { - // Interestingly, linalg::TiledLoopOp's bbArg can **always** be viewed - // inplace from the perspective of ops nested under: + // Interestingly, linalg::TiledLoopOp's bbArgs can **always** be viewed + // inplace from the perspective of nested ops: // 1. Either the matching iter operand is not bufferized inplace and an // alloc + optional copy makes the bbArg itself inplaceable. // 2. Or the matching iter operand is bufferized inplace and bbArg just @@ -267,10 +284,10 @@ int nextResultNum = 0; for (Value value : tiledLoopOp.outputs()) { if (value.getType().isa()) { - Value buffer = state.getResultBuffer( + FailureOr buffer = state.getResultBuffer( rewriter, tiledLoopOp->getResult(nextResultNum++)); - newOutputs.push_back(buffer); - newResults.push_back(buffer); + newOutputs.push_back(*buffer); + newResults.push_back(*buffer); } else { newOutputs.push_back(value); } @@ -340,6 +357,8 @@ } }; +/// Bufferization of linalg.yield. Bufferized as part of linalg.tiled_loop's +/// bufferization. struct YieldOpInterface : public BufferizableOpInterface::ExternalModel { @@ -398,13 +417,12 @@ /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def /// chain, starting from the OpOperand and always following the aliasing /// OpOperand, that eventually ends at a single InitTensorOp. -LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: - InitTensorEliminationStep::eliminateInitTensors( - Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, - std::function anchorMatchFunc, - std::function rewriteFunc, - SmallVector &newOps) { +LogicalResult +mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep:: + eliminateInitTensors(Operation *op, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, + AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, + SmallVector &newOps) { OpBuilder b(op->getContext()); WalkResult status = op->walk([&](Operation *op) { @@ -499,6 +517,7 @@ BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { return eliminateInitTensors( op, state, aliasInfo, + /*anchorMatchFunc=*/ [&](OpOperand &operand) { auto insertSliceOp = dyn_cast(operand.getOwner()); @@ -509,6 +528,7 @@ return false; return &operand == &insertSliceOp->getOpOperand(0) /*source*/; }, + /*rewriteFunc=*/ [](OpBuilder &b, Location loc, OpOperand &operand) { auto insertSliceOp = cast(operand.getOwner()); auto extractOp = b.create( diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -5,6 +5,88 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// Module bufferization is an extension of Comprehensive Bufferize that +// bufferizes function boundaries. It provides `BufferizableOpInterface` +// implementations for FuncOp, CallOp and ReturnOp, along with a few helper +// functions that control the order in which functions are bufferized. +// +// Three cases can occur during bufferization of FuncOps. +// +// i. inplaceable function arguments may be reused in place after the +// function itself has been bufferized. This is encoded by IR resembling: +// +// ``` +// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// func @foo(%A: tensor {linalg.inplaceable = true}) +// -> tensor { +// %0 = bufferization.to_memref %A : memref +// // ... uses of %0 +// %res = bufferization.to_tensor %0 : memref +// return %res : tensor +// } +// ``` +// +// this is the cue for the bufferization of the function foo (and calls +// to it) may bufferize to `func @foo(%A: memref)`. +// To fully achieve bufferization, an additional analysis is needed to +// determine whether function argument/operand pairs bufferize to a +// single inplace buffer argument (i.e. functions may return tensors in +// arbitrary order that may not match argument numbers). +// +// ii. results that don't map to an inplaceable function argument are +// generally allocated. Since memref semantics wrt ownership of the +// underlying memory region are not well-defined, comprehensive +// bufferization chooses to perform allocations in a scoped fashion: +// returning memrefs is always considered illegal. +// Such scenarios are encoded by IR resembling: +// +// ``` +// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// func @foo(%A: tensor {linalg.inplaceable = true}) +// -> tensor { +// %0 = bufferization.to_memref %A : memref +// %1 = memref.dim %0, %c0 : memref +// %2 = memref.alloc(%1) : memref +// %3 = memref.cast %2 : memref to memref +// // ... uses of %3 +// memref.dealloc %2 : memref +// %res = bufferization.to_tensor %3 : memref +// return %res : tensor +// } +// ``` +// +// this is the cue for the bufferization of the function foo (and calls +// to it) that it must bufferize to `func @foo(%A: memref, +// %B: memref)` (i.e. make a cloned +// allocation of the result tensor) +// To fully achieve bufferization, the alloc/dealloc pair must be lifted +// out of the function at each call site. +// +// iii. as an optimization over ii., it may be possible to reuse an argument +// and only want to return a slice. +// This may forego allocation by letting *all* callers decide whether to +// pass a new *aliasing* memref function argument (i.e. a subview). +// Without loss of generality, callers may agree to allocate a new buffer +// to avoid this aliasing. Such scenarios are encoded by IR resembling: +// +// ``` +// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// func @foo(%arg0: tensor {linalg.inplaceable = true}) +// -> tensor<4xf32> { +// %0 = bufferization.to_memref %arg0 : memref +// %1 = memref.subview %0[0] [4] [1] : memref to +// memref<4xf32, #map> +// // ... inplace computes into %1 +// %3 = bufferization.to_tensor %1 : memref<4xf32, #map> +// return %3 : tensor<4xf32> +// } +// ``` +// +// Note: In the future, it may be worthwhile to design special bufferization +// ops to encode the desired semantics at function boundaries for i., ii. and +// iii. #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" @@ -153,7 +235,7 @@ if (auto rankedTensorType = t.dyn_cast()) return getDynamicMemRefType(rankedTensorType); if (auto tensorType = t.dyn_cast()) - return getContiguousOrUnrankedMemRefType(tensorType); + return getUnrankedMemRefType(tensorType.getElementType()); return t; }; auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite)); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -19,6 +19,8 @@ namespace comprehensive_bufferize { namespace scf_ext { +/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not +/// fully implemented at the moment. struct ExecuteRegionOpInterface : public BufferizableOpInterface::ExternalModel { @@ -80,6 +82,7 @@ } }; +/// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. struct IfOpInterface : public BufferizableOpInterface::ExternalModel { SmallVector @@ -213,6 +216,8 @@ } }; +/// Bufferization of scf.for. Replace with a new scf.for that operates on +/// memrefs. struct ForOpInterface : public BufferizableOpInterface::ExternalModel { @@ -292,7 +297,7 @@ // Construct a new scf.for op with memref instead of tensor values. SmallVector initArgs = convert(forOp.getInitArgs(), [&](Value val, int64_t index) { - return state.getResultBuffer(rewriter, forOp->getOpResult(index)); + return *state.getResultBuffer(rewriter, forOp->getOpResult(index)); }); auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), @@ -399,6 +404,8 @@ return status; } +/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so +/// this is for analysis only. struct YieldOpInterface : public BufferizableOpInterface::ExternalModel { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -38,6 +38,7 @@ tensor::TensorDialect::getDialectNamespace()); } +/// Bufferization of tensor.cast. Replace with memref.cast. struct CastOpInterface : public BufferizableOpInterface::ExternalModel { @@ -66,29 +67,38 @@ BufferizationState &state) const { auto castOp = cast(op); - Value resultBuffer = state.getResultBuffer(rewriter, castOp->getResult(0)); - if (!resultBuffer) - return failure(); - Type sourceType = resultBuffer.getType(); - auto rankedMemRefType = sourceType.dyn_cast(); - auto unrankedMemRefType = sourceType.dyn_cast(); - assert(rankedMemRefType || unrankedMemRefType); - Attribute memorySpace = rankedMemRefType - ? rankedMemRefType.getMemorySpace() - : unrankedMemRefType.getMemorySpace(); - TensorType tensorType = castOp.getResult().getType().cast(); - MemRefLayoutAttrInterface layout = - rankedMemRefType && tensorType.isa() - ? rankedMemRefType.getLayout() - : MemRefLayoutAttrInterface(); - Type memRefType = getContiguousOrUnrankedMemRefType( - castOp.getResult().getType(), layout, memorySpace); - state.replaceOpWithNewBufferizedOp(rewriter, op, memRefType, - resultBuffer); + // The result buffer still has the old (pre-cast) type. + FailureOr resultBuffer = + state.getResultBuffer(rewriter, castOp->getResult(0)); + auto sourceMemRefType = resultBuffer->getType().cast(); + Attribute memorySpace = sourceMemRefType.getMemorySpace(); + TensorType resultTensorType = + castOp.getResult().getType().cast(); + MemRefLayoutAttrInterface layout; + + if (auto rankedMemRefType = sourceMemRefType.dyn_cast()) + if (resultTensorType.isa()) + layout = rankedMemRefType.getLayout(); + + // Compute the new memref type. + Type resultMemRefType; + if (auto rankedTensorType = resultTensorType.isa()) { + resultMemRefType = + getContiguousMemRefType(resultTensorType, layout, memorySpace); + } else { + resultMemRefType = + getUnrankedMemRefType(resultTensorType.getElementType(), memorySpace); + } + + // Replace the op with a memref.cast. + state.replaceOpWithNewBufferizedOp( + rewriter, op, resultMemRefType, *resultBuffer); + return success(); } }; +/// Bufferization of tensor.dim. Replace with memref.dim. struct DimOpInterface : public BufferizableOpInterface::ExternalModel { @@ -110,8 +120,6 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto dimOp = cast(op); - if (!dimOp.source().getType().isa()) - return dimOp.emitError("unranked tensor not supported"); Value v = state.lookupBuffer(rewriter, dimOp.source()); state.replaceOpWithNewBufferizedOp(rewriter, op, v, dimOp.index()); @@ -119,6 +127,7 @@ } }; +/// Bufferization of tensor.extract_slice. Replace with memref.subview. struct ExtractSliceOpInterface : public BufferizableOpInterface::ExternalModel { @@ -172,7 +181,7 @@ loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); - /// If not inplaceable, copy. + // If not inplaceable, copy. if (!inplace) { // Do not copy if the copied data is never read. if (state.isValueRead(extractSliceOp.result())) @@ -185,6 +194,7 @@ } }; +/// Bufferization of tensor.extract. Replace with memref.load. struct ExtractOpInterface : public BufferizableOpInterface::ExternalModel { @@ -213,6 +223,7 @@ } }; +/// Bufferization of tensor.insert. Replace with memref.store. struct InsertOpInterface : public BufferizableOpInterface::ExternalModel { @@ -242,12 +253,11 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto insertOp = cast(op); - Location loc = insertOp.getLoc(); - Value destMemref = + FailureOr destMemref = state.getResultBuffer(rewriter, insertOp->getOpResult(0)); - rewriter.create(loc, insertOp.scalar(), destMemref, - insertOp.indices()); - state.replaceOpWithBufferizedValues(rewriter, op, destMemref); + rewriter.create(insertOp.getLoc(), insertOp.scalar(), + *destMemref, insertOp.indices()); + state.replaceOpWithBufferizedValues(rewriter, op, *destMemref); return success(); } @@ -309,6 +319,8 @@ condition); } +/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under +/// certain circumstances, this op can also be a no-op. struct InsertSliceOpInterface : public BufferizableOpInterface::ExternalModel { @@ -424,16 +436,14 @@ TensorBufferizationState &tensorState = getTensorBufferizationState(state); // When bufferizing out-of-place, `getResultBuffer` allocates. - Value dstMemref = + FailureOr dstMemref = state.getResultBuffer(rewriter, insertSliceOp->getResult(0)); - if (!dstMemref) - return failure(); bool needCopy = !tensorState.insertSliceOpsWithoutCopy.contains(insertSliceOp); if (needCopy) { // Take a subview of the dst. - auto dstMemrefType = dstMemref.getType().cast(); + auto dstMemrefType = dstMemref->getType().cast(); auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( insertSliceOp.getSourceType().getRank(), dstMemrefType, @@ -441,18 +451,18 @@ insertSliceOp.getMixedStrides()) .cast(); Value subView = rewriter.create( - loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(), + loc, subviewMemRefType, *dstMemref, insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); // Copy tensor. Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source()); - state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView); + state.createMemCpy(rewriter, loc, srcMemref, subView); } else { // Make sure that `source` does not DCE away. rewriter.create( loc, insertSliceOp.source()); } - state.replaceOpWithBufferizedValues(rewriter, op, dstMemref); + state.replaceOpWithBufferizedValues(rewriter, op, *dstMemref); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -17,6 +17,8 @@ namespace comprehensive_bufferize { namespace vector_ext { +/// Bufferization of vector.transfer_read. Replaced with a new +/// vector.transfer_read that operates on a memref. struct TransferReadOpInterface : public BufferizableOpInterface::ExternalModel { @@ -55,6 +57,8 @@ } }; +/// Bufferization of vector.transfer_write. Replace with a new +/// vector.transfer_write that operates on a memref. struct TransferWriteOpInterface : public BufferizableOpInterface::ExternalModel { @@ -94,13 +98,12 @@ // Create a new transfer_write on buffer that doesn't have a return value. // Leave the previous transfer_write to dead code as it still has uses at // this point. - Value resultBuffer = state.getResultBuffer(rewriter, op->getResult(0)); - if (!resultBuffer) - return failure(); + FailureOr resultBuffer = + state.getResultBuffer(rewriter, op->getResult(0)); rewriter.create( - writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), + writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(), writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); - state.replaceOpWithBufferizedValues(rewriter, op, resultBuffer); + state.replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); return success(); }