diff --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md --- a/mlir/docs/Bufferization.md +++ b/mlir/docs/Bufferization.md @@ -30,40 +30,38 @@ One-Shot Bufferize is: -* **Monolithic**: A single MLIR pass does the entire -work, whereas the previous bufferization in MLIR was split across multiple -passes residing in different dialects. In One-Shot Bufferize, -`BufferizableOpInterface` implementations are spread across different dialects. - -* A **whole-function at a time analysis**. In-place bufferization decisions are -made by analyzing SSA use-def chains on tensors. Op interface implementations -not only provide the rewrite logic from tensor ops to memref ops, but also -helper methods for One-Shot Bufferize's analysis to query information about an -op's bufferization/memory semantics. - -* **Extensible** via an op interface: All -ops that implement `BufferizableOpInterface` can be bufferized. - -* **2-Pass**: -Bufferization is internally broken down into 2 steps: First, analyze the entire -IR and make bufferization decisions. Then, bufferize (rewrite) the IR. The -analysis has access to exact SSA use-def information. It incrementally builds -alias and equivalence sets and does not rely on a posteriori-alias analysis from -preallocated memory. - -* **Greedy**: Operations are analyzed one-by-one and it is -decided on the spot whether a tensor OpOperand must be copied or not. Heuristics -determine the order of analysis. - -* **Modular**: The current One-Shot Analysis -can be replaced with a different analysis. The result of the analysis are -queried by the bufferization via `BufferizationState`, in particular -`BufferizationState::isInPlace`. Any derived class of `BufferizationState` that -implements a small number virtual functions can serve as a custom analysis. It -is even possible to run One-Shot Bufferize without any analysis -(`AlwaysCopyBufferizationState`), in which case One-Shot Bufferize behaves -exactly like the old dialect conversion-based bufferization (i.e., copy every -buffer before writing to it). +* **Monolithic**: A single MLIR pass does the entire work, whereas the + previous bufferization in MLIR was split across multiple passes residing in + different dialects. In One-Shot Bufferize, `BufferizableOpInterface` + implementations are spread across different dialects. + +* A **whole-function at a time analysis**. In-place bufferization decisions + are made by analyzing SSA use-def chains on tensors. Op interface + implementations not only provide the rewrite logic from tensor ops to memref + ops, but also helper methods for One-Shot Bufferize's analysis to query + information about an op's bufferization/memory semantics. + +* **Extensible** via an op interface: All ops that implement + `BufferizableOpInterface` can be bufferized. + +* **2-Pass**: Bufferization is internally broken down into 2 steps: First, + analyze the entire IR and make bufferization decisions. Then, bufferize + (rewrite) the IR. The analysis has access to exact SSA use-def information. + It incrementally builds alias and equivalence sets and does not rely on a + posteriori-alias analysis from preallocated memory. + +* **Greedy**: Operations are analyzed one-by-one and it is decided on the spot + whether a tensor OpOperand must be copied or not. Heuristics determine the + order of analysis. + +* **Modular**: The current One-Shot Analysis can be replaced with a different + analysis. The result of the analysis are queried by the bufferization via + `AnalysisState`, in particular `AnalysisState::isInPlace`. Any derived class + of `AnalysisState` that implements a small number virtual functions can + serve as a custom analysis. It is even possible to run One-Shot Bufferize + without any analysis (`AlwaysCopyAnalysisState`), in which case One-Shot + Bufferize behaves exactly like the old dialect conversion-based + bufferization (i.e., copy every buffer before writing to it). To reduce complexity, One-Shot Bufferize should be [run after other transformations](https://llvm.discourse.group/t/rfc-linalg-on-tensors-update-and-comprehensive-bufferization-rfc/3373), diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -236,7 +236,7 @@ /// /// Note: Deactivating this flag can lead to incorrect bufferization results /// when used incorrectly. This flag is useful with - /// `AlwaysCopyBufferizationState` which bufferizes all writing tensor + /// `AlwaysCopyAnalysisState` which bufferizes all writing tensor /// OpOperands out-of-place. bool enforceAliasingInvariants = true; @@ -464,33 +464,6 @@ const BufferizationOptions &options; }; -/// BufferizationState provides helper functions for performing bufferization -/// rewrites and handling memref buffers. -struct BufferizationState { - BufferizationState(const BufferizationOptions &options) : options(options) {} - - /// Lookup the buffer for the given value. If the value was not bufferized - /// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp, - /// from which the memref operand is returned. - Value getBuffer(RewriterBase &rewriter, Value value); - - /// Return the buffer type for a given Value (tensor) after bufferization. - /// - /// Note: Op implementations should preferrably call `getBuffer()->getType()`. - /// This function should only be used if `getBuffer` cannot be used. - BaseMemRefType getBufferType(Value value) const; - - /// Return a reference to the BufferizationOptions. - const BufferizationOptions &getOptions() const { return options; } - -protected: - // BufferizationState should be passed as a reference. - BufferizationState(const BufferizationState &) = delete; - -private: - const BufferizationOptions &options; -}; - /// Create an AllocTensorOp for the given shaped value (memref or tensor). /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with /// undefined contents is allocated. @@ -498,6 +471,18 @@ Value shapedValue, bool escape, bool copy = true); +/// Lookup the buffer for the given value. If the value was not bufferized +/// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp, +/// from which the memref operand is returned. +Value getBuffer(RewriterBase &rewriter, Value value, + const BufferizationOptions &options); + +/// Return the buffer type for a given Value (tensor) after bufferization. +/// +/// Note: Op implementations should preferrably call `getBuffer()->getType()`. +/// This function should only be used if `getBuffer` cannot be used. +BaseMemRefType getBufferType(Value value, const BufferizationOptions &options); + /// Replace an op with replacement values. The op is deleted. Tensor OpResults /// must be replaced with memref values. void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -221,7 +221,7 @@ InterfaceMethod< /*desc=*/[{ Bufferize this op, i.e., rewrite it into a memref-based equivalent. - Buffers of tensor SSA values can be retrieved via `state.getBuffer`. + Buffers of tensor SSA values can be retrieved via `getBuffer`. Uses of tensor results of the existing tensor op can be replaced with `replaceOpWithBufferizedValues` or `replaceOpWithNewBufferizedOp`. These two functions automatically handle the tensor-to-memref type @@ -233,12 +233,6 @@ a) A buffer that aliases one of buffers in getAliasingOpOperand(r). b) Or: A newly allocated buffer. - Regions of an op should be inlined into the new op instead of cloning - them. This is not only more efficient, but also necessary so that no - analysis results are lost. (Bufferization decisions are tracked via - OpOperand pointers and cloned ops have new OpOperands.) If regions are - cloned instead of inlined, additional buffer copies may be inserted. - This method will never be called on ops that do not have at least one tensor operand/result. @@ -252,7 +246,7 @@ /*retType=*/"LogicalResult", /*methodName=*/"bufferize", /*args=*/(ins "RewriterBase &":$rewriter, - "BufferizationState &":$state), + "const BufferizationOptions &":$options), /*methodBody=*/"", /*defaultImplementation=*/[{ llvm_unreachable("bufferize not implemented"); diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -71,7 +71,8 @@ let results = (outs AnyTensor:$result); let extraClassDeclaration = [{ - LogicalResult bufferize(RewriterBase &rewriter, BufferizationState &state); + LogicalResult bufferize(RewriterBase &rewriter, + const BufferizationOptions &options); bool isMemoryWrite(OpResult opResult, const AnalysisState &state); @@ -242,7 +243,7 @@ // results as not writable enforces a buffer copy and has the same effect. LogicalResult bufferize(RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { // to_tensor cannot be bufferized. However, other ops that are using // to_tensor's result will eventually be bufferized. At that point, they // will start using to_tensor's memref operand. Once all users of @@ -334,7 +335,7 @@ } LogicalResult bufferize(RewriterBase &rewriter, - BufferizationState &state); + const BufferizationOptions &options); }]; let assemblyFormat = "$tensor attr-dict `:` type($memref)"; diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -25,7 +25,6 @@ namespace bufferization { class AnalysisState; -struct BufferizationState; struct BufferizationOptions; class OpFilter; diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -15,7 +15,6 @@ class ModuleOp; namespace bufferization { -struct BufferizationState; class OneShotAnalysisState; struct OneShotBufferizationOptions; diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp @@ -23,7 +23,7 @@ : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto constantOp = cast(op); // Only ranked tensors are supported. @@ -38,7 +38,7 @@ // Create global memory segment and replace tensor with memref pointing to // that memory segment. FailureOr globalOp = - getGlobalFor(constantOp, state.getOptions().bufferAlignment); + getGlobalFor(constantOp, options.bufferAlignment); if (failed(globalOp)) return failure(); memref::GlobalOp globalMemref = globalOp.getValue(); @@ -80,11 +80,11 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto castOp = cast(op); auto resultTensorType = castOp.getType().cast(); - Value source = state.getBuffer(rewriter, castOp.getIn()); + Value source = getBuffer(rewriter, castOp.getIn(), options); auto sourceType = source.getType().cast(); // Result type should have same layout and address space as the source type. @@ -132,7 +132,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto selectOp = cast(op); Location loc = selectOp.getLoc(); @@ -140,8 +140,8 @@ // instead of its OpOperands. In the worst case, 2 copies are inserted at // the moment (one for each tensor). When copying the op result, only one // copy would be needed. - Value trueBuffer = state.getBuffer(rewriter, selectOp.getTrueValue()); - Value falseBuffer = state.getBuffer(rewriter, selectOp.getFalseValue()); + Value trueBuffer = getBuffer(rewriter, selectOp.getTrueValue(), options); + Value falseBuffer = getBuffer(rewriter, selectOp.getFalseValue(), options); // The "true" and the "false" operands must have the same type. If the // buffers have different types, they differ only in their layout map. Cast diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -477,7 +477,8 @@ #endif } -Value BufferizationState::getBuffer(RewriterBase &rewriter, Value value) { +Value bufferization::getBuffer(RewriterBase &rewriter, Value value, + const BufferizationOptions &options) { auto tensorType = value.getType().dyn_cast(); assert(tensorType && "unexpected non-tensor type"); @@ -488,21 +489,22 @@ // Insert to_memref op. OpBuilder::InsertionGuard g(rewriter); setInsertionPointAfter(rewriter, value); - Type memrefType = getMemRefType(tensorType, getOptions()); + Type memrefType = getMemRefType(tensorType, options); ensureToMemrefOpIsValid(value, memrefType); return rewriter.create(value.getLoc(), memrefType, value); } /// Return the buffer type for a given Value (tensor) after bufferization. -BaseMemRefType BufferizationState::getBufferType(Value value) const { +BaseMemRefType +bufferization::getBufferType(Value value, const BufferizationOptions &options) { auto tensorType = value.getType().dyn_cast(); assert(tensorType && "unexpected non-tensor type"); if (auto toTensorOp = value.getDefiningOp()) return toTensorOp.memref().getType().cast(); - return getMemRefType(tensorType, getOptions()); + return getMemRefType(tensorType, options); } void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -150,7 +150,7 @@ //===----------------------------------------------------------------------===// LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, - BufferizationState &state) { + const BufferizationOptions &options) { OpBuilder::InsertionGuard g(rewriter); Location loc = getLoc(); @@ -163,7 +163,7 @@ // Create buffer allocation. Value copyBuffer; if (copy()) - copyBuffer = state.getBuffer(rewriter, copy()); + copyBuffer = getBuffer(rewriter, copy(), options); auto allocType = MemRefType::get(getType().getShape(), getType().getElementType()); SmallVector dynamicDims = dynamicSizes(); @@ -172,25 +172,24 @@ populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims); } FailureOr alloc = - state.getOptions().createAlloc(rewriter, loc, allocType, dynamicDims); + options.createAlloc(rewriter, loc, allocType, dynamicDims); if (failed(alloc)) return failure(); // Create memory copy (if any). if (copy()) { - if (failed( - state.getOptions().createMemCpy(rewriter, loc, copyBuffer, *alloc))) + if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc))) return failure(); } // Should the buffer be deallocated? - AnalysisState analysisState(state.getOptions()); + AnalysisState analysisState(options); bool dealloc; if (escape().hasValue()) { dealloc = !*escape(); } else { // No "escape" annotation found. - if (state.getOptions().createDeallocs) { + if (options.createDeallocs) { // Perform an ad-hoc analysis. dealloc = !analysisState.isTensorYielded(getResult()); } else { @@ -206,7 +205,7 @@ return success(); rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator()); - if (failed(state.getOptions().createDealloc(rewriter, loc, *alloc))) + if (failed(options.createDealloc(rewriter, loc, *alloc))) return failure(); return success(); } @@ -627,7 +626,7 @@ } LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, - BufferizationState &state) { + const BufferizationOptions &options) { // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. (void)foldToMemrefToTensorPair(rewriter, *this); // Note: The return value of `bufferize` indicates whether there was an error diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -401,7 +401,6 @@ DenseSet erasedOps; // Bufferize all ops. - BufferizationState bufferizationState(options); BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, worklist, options, opFilter); for (unsigned i = 0; i < worklist.size(); ++i) { @@ -420,7 +419,7 @@ continue; // Bufferize the op. rewriter.setInsertionPoint(op); - if (failed(bufferizableOp.bufferize(rewriter, bufferizationState))) + if (failed(bufferizableOp.bufferize(rewriter, options))) return op->emitError("failed to bufferize op"); } @@ -433,7 +432,7 @@ /// Check the result of bufferization. Return an error if an op was not /// bufferized, unless partial bufferization is allowed. - if (bufferizationState.getOptions().allowUnknownOps) + if (options.allowUnknownOps) return success(); for (Operation *op : worklist) { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -258,7 +258,7 @@ /// All function arguments are writable. It is the responsibility of the /// CallOp to insert buffer copies where necessary. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { func::CallOp callOp = cast(op); unsigned numResults = callOp.getNumResults(); unsigned numOperands = callOp->getNumOperands(); @@ -307,7 +307,7 @@ // Retrieve buffers for tensor operands. Value buffer = newOperands[idx]; if (!buffer) - buffer = state.getBuffer(rewriter, opOperand.get()); + buffer = getBuffer(rewriter, opOperand.get(), options); // Caller / callee type mismatch is handled with a CastOp. auto memRefType = funcType.getInput(idx); @@ -364,7 +364,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { #ifndef NDEBUG auto returnOp = cast(op); assert(isa(returnOp->getParentOp()) && @@ -386,11 +386,9 @@ /// All function bbArgs are writable unless they are explicitly marked as /// read-only. Callers must insert copies when needed. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto funcOp = cast(op); FunctionType funcType = funcOp.getFunctionType(); - const OneShotBufferizationOptions &options = - static_cast(state.getOptions()); // Construct the bufferized function type. SmallVector argTypes; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -429,7 +429,6 @@ assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); IRRewriter rewriter(moduleOp.getContext()); - BufferizationState bufferizationState(options); // A list of functions in the order in which they are analyzed + bufferized. SmallVector orderedFuncOps; diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -20,11 +20,9 @@ namespace { -// TODO: Ops in the linalg dialect can directly implement this interface. - /// Generic conversion for any LinalgOp on tensors. static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, - BufferizationState &state) { + const BufferizationOptions &options) { // Take a guard before anything else. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); @@ -46,14 +44,14 @@ newInputBuffers.push_back(opOperand->get()); continue; } - newInputBuffers.push_back(state.getBuffer(rewriter, opOperand->get())); + newInputBuffers.push_back(getBuffer(rewriter, opOperand->get(), options)); } // New output operands for the cloned op. SmallVector newOutputBuffers; for (OpResult opResult : op->getOpResults()) { OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber()); - Value resultBuffer = state.getBuffer(rewriter, opOperand->get()); + Value resultBuffer = getBuffer(rewriter, opOperand->get(), options); newOutputBuffers.push_back(resultBuffer); } @@ -123,8 +121,8 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { - return bufferizeLinalgOp(rewriter, cast(op), state); + const BufferizationOptions &options) const { + return bufferizeLinalgOp(rewriter, cast(op), options); } }; diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -73,7 +73,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto executeRegionOp = cast(op); // Compute new result types. @@ -81,7 +81,7 @@ for (Type type : executeRegionOp->getResultTypes()) { if (auto tensorType = type.dyn_cast()) { // TODO: Infer the result type instead of computing it. - newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); + newResultTypes.push_back(getMemRefType(tensorType, options)); } else { newResultTypes.push_back(type); } @@ -183,7 +183,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto ifOp = cast(op); // Compute new types of the bufferized scf.if op. @@ -191,7 +191,7 @@ for (Type returnType : ifOp->getResultTypes()) { if (auto tensorType = returnType.dyn_cast()) { // TODO: Infer the result type instead of computing it. - newTypes.push_back(getMemRefType(tensorType, state.getOptions())); + newTypes.push_back(getMemRefType(tensorType, options)); } else { newTypes.push_back(returnType); } @@ -309,11 +309,11 @@ /// given OpOperands. If an operand is not a tensor, return the original value. static SmallVector getBuffers(RewriterBase &rewriter, MutableArrayRef operands, - BufferizationState &state) { + const BufferizationOptions &options) { SmallVector result; for (OpOperand &opOperand : operands) { if (opOperand.get().getType().isa()) { - Value resultBuffer = state.getBuffer(rewriter, opOperand.get()); + Value resultBuffer = getBuffer(rewriter, opOperand.get(), options); result.push_back(resultBuffer); } else { result.push_back(opOperand.get()); @@ -325,10 +325,11 @@ /// Helper function for loop bufferization. Compute the buffer that should be /// yielded from a loop block (loop body or loop condition). static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor, - BaseMemRefType type, BufferizationState &state) { + BaseMemRefType type, + const BufferizationOptions &options) { assert(tensor.getType().isa() && "expected tensor"); ensureToMemrefOpIsValid(tensor, type); - Value yieldedVal = state.getBuffer(rewriter, tensor); + Value yieldedVal = getBuffer(rewriter, tensor, options); return castBuffer(rewriter, yieldedVal, type); } @@ -352,12 +353,12 @@ SmallVector getYieldedValues(RewriterBase &rewriter, ValueRange values, TypeRange bufferizedTypes, const DenseSet &tensorIndices, - BufferizationState &state) { + const BufferizationOptions &options) { return convertTensorValues( values, tensorIndices, [&](Value val, int64_t index) { return getYieldedBuffer(rewriter, val, bufferizedTypes[index].cast(), - state); + options); }); } @@ -472,7 +473,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto forOp = cast(op); Block *oldLoopBody = &forOp.getLoopBody().front(); @@ -482,7 +483,7 @@ // The new memref init_args of the loop. SmallVector initArgs = - getBuffers(rewriter, forOp.getIterOpOperands(), state); + getBuffers(rewriter, forOp.getIterOpOperands(), options); // Construct a new scf.for op with memref instead of tensor values. auto newForOp = rewriter.create( @@ -511,7 +512,7 @@ auto yieldOp = cast(loopBody->getTerminator()); rewriter.setInsertionPoint(yieldOp); SmallVector yieldValues = getYieldedValues( - rewriter, yieldOp.getResults(), initArgsTypes, indices, state); + rewriter, yieldOp.getResults(), initArgsTypes, indices, options); yieldOp.getResultsMutable().assign(yieldValues); // Replace loop results. @@ -704,7 +705,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto whileOp = cast(op); assert(whileOp.getBefore().getBlocks().size() == 1 && @@ -722,12 +723,12 @@ // The new memref init_args of the loop. SmallVector initArgs = - getBuffers(rewriter, whileOp->getOpOperands(), state); + getBuffers(rewriter, whileOp->getOpOperands(), options); // The result types of a WhileOp are the same as the "after" bbArg types. SmallVector argsTypesAfter = llvm::to_vector( llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { - return state.getBufferType(bbArg).cast(); + return getBufferType(bbArg, options).cast(); })); // Construct a new scf.while op with memref instead of tensor values. @@ -761,7 +762,7 @@ // TODO: This could be relaxed for better bufferization results. SmallVector newConditionArgs = getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter, - indicesAfter, state); + indicesAfter, options); newConditionOp.getArgsMutable().assign(newConditionArgs); // Set up new iter_args and move the loop body block to the new op. @@ -780,7 +781,7 @@ // TODO: This could be relaxed for better bufferization results. SmallVector newYieldValues = getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore, - indicesBefore, state); + indicesBefore, options); newYieldOp.getResultsMutable().assign(newYieldValues); // Replace loop results. @@ -866,7 +867,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto yieldOp = cast(op); if (!isa( yieldOp->getParentOp())) @@ -954,7 +955,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &b, - BufferizationState &state) const { + const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(b); auto foreachThreadOp = cast(op); @@ -966,7 +967,7 @@ // Insert copies right before the PerformConcurrentlyOp terminator. They // should not be inside terminator (which would be the default insertion // point). - Value buffer = state.getBuffer(b, insertDest->get()); + Value buffer = getBuffer(b, insertDest->get(), options); newResults.push_back(buffer); } @@ -991,8 +992,7 @@ performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) { Location loc = insertOp.getLoc(); Type srcType = getMemRefType( - insertOp.getSource().getType().cast(), - state.getOptions()); + insertOp.getSource().getType().cast(), options); // ParallelInsertSliceOp bufferizes to a copy. auto srcMemref = b.create( loc, srcType, insertOp.getSource()); @@ -1001,8 +1001,8 @@ loc, destMemref, insertOp.getMixedOffsets(), insertOp.getMixedSizes(), insertOp.getMixedStrides()); // This memcpy will fold away if everything bufferizes in-place. - if (failed(state.getOptions().createMemCpy(b, insertOp.getLoc(), - srcMemref, subview))) + if (failed(options.createMemCpy(b, insertOp.getLoc(), srcMemref, + subview))) return WalkResult::interrupt(); b.eraseOp(insertOp); return WalkResult::advance(); @@ -1022,7 +1022,7 @@ : public BufferizableOpInterface::ExternalModel< PerformConcurrentlyOpInterface, PerformConcurrentlyOp> { LogicalResult bufferize(Operation *op, RewriterBase &b, - BufferizationState &state) const { + const BufferizationOptions &options) const { llvm_unreachable("op does not have any tensor OpOperands / OpResults"); return failure(); } @@ -1110,7 +1110,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &b, - BufferizationState &state) const { + const BufferizationOptions &options) const { // Will be bufferized as part of ForeachThreadOp. return failure(); } diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -59,7 +59,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto assumingOp = cast(op); // Compute new result types. @@ -67,7 +67,7 @@ for (Type type : assumingOp->getResultTypes()) { if (auto tensorType = type.dyn_cast()) { // TODO: Infer the result type instead of computing it. - newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); + newResultTypes.push_back(getMemRefType(tensorType, options)); } else { newResultTypes.push_back(type); } @@ -152,7 +152,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { // Op is bufferized as part of AssumingOp. return failure(); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -48,11 +48,11 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto castOp = cast(op); // The result buffer still has the old (pre-cast) type. - Value resultBuffer = state.getBuffer(rewriter, castOp.source()); + Value resultBuffer = getBuffer(rewriter, castOp.source(), options); auto sourceMemRefType = resultBuffer.getType().cast(); Attribute memorySpace = sourceMemRefType.getMemorySpace(); TensorType resultTensorType = @@ -64,8 +64,8 @@ layout = rankedMemRefType.getLayout(); // Compute the new memref type. - Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), - layout, memorySpace); + Type resultMemRefType = + getMemRefType(resultTensorType, options, layout, memorySpace); // Replace the op with a memref.cast. assert(memref::CastOp::areCastCompatible(resultBuffer.getType(), @@ -105,10 +105,10 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto collapseShapeOp = cast(op); RankedTensorType tensorResultType = collapseShapeOp.getResultType(); - Value buffer = state.getBuffer(rewriter, collapseShapeOp.src()); + Value buffer = getBuffer(rewriter, collapseShapeOp.src(), options); auto bufferType = buffer.getType().cast(); if (tensorResultType.getRank() == 0) { @@ -146,7 +146,7 @@ bufferType, collapseShapeOp.getReassociationIndices()); if (!canBeCollapsed) { // TODO: Create alloc_tensor ops during TensorCopyInsertion. - AnalysisState analysisState(state.getOptions()); + AnalysisState analysisState(options); Value tensorAlloc = allocateTensorForShapedValue( rewriter, op->getLoc(), collapseShapeOp.src(), analysisState.isTensorYielded(collapseShapeOp.result())); @@ -185,9 +185,9 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto dimOp = cast(op); - auto v = state.getBuffer(rewriter, dimOp.source()); + auto v = getBuffer(rewriter, dimOp.source(), options); replaceOpWithNewBufferizedOp(rewriter, op, v, dimOp.index()); return success(); } @@ -220,10 +220,10 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto expandShapeOp = cast(op); auto tensorResultType = expandShapeOp.getResultType(); - auto buffer = state.getBuffer(rewriter, expandShapeOp.src()); + auto buffer = getBuffer(rewriter, expandShapeOp.src(), options); // Memref result type is inferred by the builder based on reassociation // indices and result shape. @@ -261,13 +261,13 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto extractSliceOp = cast(op); Location loc = extractSliceOp.getLoc(); // Even if this op was decided to bufferize out-of-place, do not insert the // buffer copy yet. This is done later in this function. - auto srcMemref = state.getBuffer(rewriter, extractSliceOp.source()); + auto srcMemref = getBuffer(rewriter, extractSliceOp.source(), options); auto srcMemrefType = srcMemref.getType().cast(); auto dstTensorType = extractSliceOp.result().getType().cast(); @@ -319,9 +319,9 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto extractOp = cast(op); - Value srcMemref = state.getBuffer(rewriter, extractOp.tensor()); + Value srcMemref = getBuffer(rewriter, extractOp.tensor(), options); replaceOpWithNewBufferizedOp(rewriter, op, srcMemref, extractOp.indices()); return success(); @@ -355,7 +355,7 @@ : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto fromElementsOp = cast(op); // Allocate a buffer for the result. @@ -363,7 +363,7 @@ auto tensorType = fromElementsOp.getType().cast(); auto shape = tensorType.getShape(); // TODO: Create alloc_tensor ops during TensorCopyInsertion. - AnalysisState analysisState(state.getOptions()); + AnalysisState analysisState(options); Value tensorAlloc = allocateTensorForShapedValue( rewriter, loc, fromElementsOp.result(), analysisState.isTensorYielded(fromElementsOp.result()), @@ -410,13 +410,13 @@ : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto generateOp = cast(op); auto tensorType = generateOp.getType().cast(); // Allocate memory. Location loc = op->getLoc(); // TODO: Create alloc_tensor ops during TensorCopyInsertion. - AnalysisState analysisState(state.getOptions()); + AnalysisState analysisState(options); Value tensorAlloc = allocateTensorForShapedValue( rewriter, loc, generateOp.result(), analysisState.isTensorYielded(generateOp.result()), @@ -493,9 +493,9 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto insertOp = cast(op); - Value destMemref = state.getBuffer(rewriter, insertOp.dest()); + Value destMemref = getBuffer(rewriter, insertOp.dest(), options); rewriter.create(insertOp.getLoc(), insertOp.scalar(), destMemref, insertOp.indices()); replaceOpWithBufferizedValues(rewriter, op, destMemref); @@ -645,7 +645,7 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { // insert_slice ops arise from tiling and bufferizing them out-of-place is // generally a deal breaker. When used with loops, this ends up cloning the // whole tensor on every single iteration and is a symptom of a @@ -653,7 +653,7 @@ // TODO: be very loud about it or even consider failing the pass. auto insertSliceOp = cast(op); Location loc = insertSliceOp.getLoc(); - Value dstMemref = state.getBuffer(rewriter, insertSliceOp.dest()); + Value dstMemref = getBuffer(rewriter, insertSliceOp.dest(), options); // Expand offsets, sizes and strides to the full rank to handle the // rank-reducing case. @@ -681,9 +681,8 @@ // Copy tensor. If this tensor.insert_slice has a matching // tensor.extract_slice, the copy operation will eventually fold away. - auto srcMemref = state.getBuffer(rewriter, insertSliceOp.source()); - if (failed( - state.getOptions().createMemCpy(rewriter, loc, srcMemref, subView))) + auto srcMemref = getBuffer(rewriter, insertSliceOp.source(), options); + if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView))) return failure(); replaceOpWithBufferizedValues(rewriter, op, dstMemref); @@ -711,9 +710,9 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto rankOp = cast(op); - auto v = state.getBuffer(rewriter, rankOp.tensor()); + auto v = getBuffer(rewriter, rankOp.tensor(), options); replaceOpWithNewBufferizedOp(rewriter, op, rankOp.getType(), v); return success(); @@ -747,12 +746,12 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto reshapeOp = cast(op); - Value srcBuffer = state.getBuffer(rewriter, reshapeOp.source()); - Value shapeBuffer = state.getBuffer(rewriter, reshapeOp.shape()); + Value srcBuffer = getBuffer(rewriter, reshapeOp.source(), options); + Value shapeBuffer = getBuffer(rewriter, reshapeOp.shape(), options); auto resultTensorType = reshapeOp.getResult().getType().cast(); - auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions()); + auto resultMemRefType = getMemRefType(resultTensorType, options); replaceOpWithNewBufferizedOp( rewriter, op, resultMemRefType, srcBuffer, shapeBuffer); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -46,11 +46,11 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto readOp = cast(op); assert(readOp.getShapedType().isa() && "only tensor types expected"); - Value buffer = state.getBuffer(rewriter, readOp.getSource()); + Value buffer = getBuffer(rewriter, readOp.getSource(), options); replaceOpWithNewBufferizedOp( rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(), @@ -91,13 +91,13 @@ } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { + const BufferizationOptions &options) const { auto writeOp = cast(op); assert(writeOp.getShapedType().isa() && "only tensor types expected"); // Create a new transfer_write on buffer that doesn't have a return value. - Value resultBuffer = state.getBuffer(rewriter, writeOp.getSource()); + Value resultBuffer = getBuffer(rewriter, writeOp.getSource(), options); rewriter.create( writeOp.getLoc(), writeOp.getVector(), resultBuffer, writeOp.getIndices(), writeOp.getPermutationMapAttr(),