diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -467,38 +467,12 @@ /// BufferizationState provides helper functions for performing bufferization /// rewrites and handling memref buffers. struct BufferizationState { - enum ForceInPlacability { FORCE_INPLACE, FORCE_OUT_OF_PLACE }; + BufferizationState(const BufferizationOptions &options) : options(options) {} - BufferizationState(const AnalysisState &analysisState) - : analysisState(analysisState) {} - - /// Creates a memref allocation for the given shaped value. `dealloc` - /// indicates whether the buffer should be deallocated or not. When `dealloc` - /// is `false`, this would create a memory leak, unless the buffer is - /// deallocated through some other mechanism. - /// - /// `dealloc` is optional. By default, this function will figure out by itself - /// if it is safe to deallocate the buffer. In essence, when returning the - /// buffer from a block, it is not safe to deallocate the buffer. This - /// information is queried via `AnalysisState::isTensorYielded`. - /// - /// Note: `shapedValue` is typically a tensor value. However, if it is a - /// memref value, `dealloc` is no longer optional and must be specified. - FailureOr createAlloc(OpBuilder &b, Location loc, Value shapedValue, - Optional dealloc = None); - - /// Return the buffer (memref) for a given OpOperand (tensor). Allocate - /// a new buffer and copy over data from the existing buffer if out-of-place - /// bufferization was decided. - /// - /// Whether a buffer is in-place or out-of-place is queried from the analysis - /// state. Some analyses may always conservatively opt for out-of-place - /// bufferization. Inplacability decisions can be overridden with the optional - /// `overrideInPlace` parameter. - FailureOr - getBuffer(RewriterBase &rewriter, OpOperand &opOperand, - Optional overrideInPlace = None, - Optional customCopyInsertionPoint = None); + /// Lookup the buffer for the given value. If the value was not bufferized + /// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp, + /// from which the memref operand is returned. + Value getBuffer(RewriterBase &rewriter, Value value); /// Return the buffer type for a given Value (tensor) after bufferization. /// @@ -507,36 +481,28 @@ BaseMemRefType getBufferType(Value value) const; /// Return a reference to the BufferizationOptions. - const BufferizationOptions &getOptions() const { - return analysisState.getOptions(); - } - - const AnalysisState &getAnalysisState() const { return analysisState; } + const BufferizationOptions &getOptions() const { return options; } protected: // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; private: - const AnalysisState &analysisState; + const BufferizationOptions &options; }; +/// Create an AllocTensorOp for the given shaped value (memref or tensor). +/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with +/// undefined contents is allocated. +Value allocateTensorForShapedValue(OpBuilder &b, Location loc, + Value shapedValue, bool escape, + bool copy = true); + /// Replace an op with replacement values. The op is deleted. Tensor OpResults /// must be replaced with memref values. void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values); -/// Lookup the buffer for the given value. If the value was not bufferized yet, -/// wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp, from -/// which the memref operand is returned. -/// -/// Note: Use `BufferizationState::getBuffer` during bufferization. -/// `lookupBuffer` is just for compatibility and gradual migration of -/// bufferization patterns to BufferizableOpInterface-based bufferization. It -/// does not insert any buffer copies. -Value lookupBuffer(RewriterBase &rewriter, Value tensor, - const BufferizationOptions &options); - /// Replace an op with a new op. The new op must have the same number of /// results as the replaced op. The new op may not return any tensor values. template diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h @@ -33,6 +33,11 @@ namespace mlir { namespace bufferization { +/// Populate `dynamicDims` with tensor::DimOp / memref::DimOp results for all +/// dynamic dimensions of the given shaped value. +void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, + SmallVector &dynamicDims); + /// Try to cast the given ranked MemRef-typed value to the given ranked MemRef /// type. Insert a reallocation + copy if it cannot be statically guaranteed /// that a direct cast would be valid. diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -54,40 +54,22 @@ BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns); /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. -/// Whether buffer copies are needed or not is queried from `state`. +/// If `copyBeforeWrite`, buffers are duplicated and copied before any tensor +/// use that bufferizes to a memory write. /// -/// Note: If `allowUnknownOps` is set to false, bufferization fails when an -/// unknown op (that does not implement `BufferizableOpInterface`) is found. No -/// to_tensor/to_memref ops are inserted in that case. -/// -/// Note: The layout map chosen to bufferize is the most dynamic canonical -/// strided layout of the proper rank. This ensures compatibility with expected -/// layouts after transformations. Combinations of memref.cast + -/// canonicalization are responsible for clean ups. -// TODO: Extract `options` from `state` and pass as separate argument. -LogicalResult bufferizeOp(Operation *op, const AnalysisState &analysisState); - -/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. -/// Buffers are duplicated and copied before any tensor use that bufferizes to -/// a memory write. +/// Note: In the general case, it unsafe to run with `copyBeforeWrite = false` +/// because read-after-write conflicts may materialize during bufferization. +/// `copyBeforeWrite = false` is safe only if the input IR is guaranteed to +/// *not* require any out-of-place bufferization. /// /// Note: This function bufferizes ops without utilizing analysis results. It /// can be used to implement partial bufferization passes. -LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options); +LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, + bool copyBeforeWrite = true, + const OpFilter *opFilter = nullptr); BufferizationOptions getPartialBufferizationOptions(); -//===----------------------------------------------------------------------===// -// Helper functions for extending Bufferization -//===----------------------------------------------------------------------===// - -/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. -/// Reuse an existing `BufferizationState`. -/// -/// Note: This function overload is useful for extending the bufferization. -LogicalResult bufferizeOp(Operation *op, BufferizationState &bufferizationState, - const OpFilter *opFilter = nullptr); - } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp @@ -84,7 +84,7 @@ auto castOp = cast(op); auto resultTensorType = castOp.getType().cast(); - Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/); + Value source = state.getBuffer(rewriter, castOp.getIn()); auto sourceType = source.getType().cast(); // Result type should have same layout and address space as the source type. @@ -136,15 +136,12 @@ auto selectOp = cast(op); Location loc = selectOp.getLoc(); - // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place. // TODO: It would be more efficient to copy the result of the `select` op // instead of its OpOperands. In the worst case, 2 copies are inserted at // the moment (one for each tensor). When copying the op result, only one // copy would be needed. - Value trueBuffer = - *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/); - Value falseBuffer = - *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/); + Value trueBuffer = state.getBuffer(rewriter, selectOp.getTrueValue()); + Value falseBuffer = state.getBuffer(rewriter, selectOp.getFalseValue()); // The "true" and the "false" operands must have the same type. If the // buffers have different types, they differ only in their layout map. Cast diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -43,30 +43,49 @@ constexpr const ::llvm::StringLiteral bufferization::BufferizableOpInterface::kInplaceableAttrName; -/// Create an AllocTensorOp for the given shaped value. Only ranked tensors are -/// supported at the moment. If `copy` is set, the shaped value is copied. -/// Otherwise, a tensor with undefined contents is allocated. -static Value allocateTensorForShapedValue(OpBuilder &b, Location loc, - Value shapedValue, bool escape, - bool copy = true) { - auto tensorType = shapedValue.getType().dyn_cast(); - assert(tensorType && "only RankedTensorType supported at the moment"); - Value alloc; - if (!copy) { - // No copy needed: Just allocate. - SmallVector dynamicSizes; - for (int64_t i = 0; i < tensorType.getRank(); ++i) - if (tensorType.isDynamicDim(i)) - dynamicSizes.push_back(b.create(loc, shapedValue, i)); - alloc = b.create(loc, tensorType, dynamicSizes, - /*copy=*/Value(), escape); +/// Create an AllocTensorOp for the given shaped value. If `copy` is set, the +/// shaped value is copied. Otherwise, a tensor with undefined contents is +/// allocated. +Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc, + Value shapedValue, + bool escape, bool copy) { + Value tensor; + if (shapedValue.getType().isa()) { + tensor = shapedValue; + } else if (shapedValue.getType().isa()) { + tensor = b.create(loc, shapedValue); } else { - // Allocate and copy. - alloc = b.create(loc, tensorType, - /*dynamicSizes=*/ValueRange(), shapedValue, - escape); + llvm_unreachable("expected RankedTensorType or MemRefType"); + } + RankedTensorType tensorType = tensor.getType().cast(); + SmallVector dynamicSizes; + if (!copy) { + // Compute the dynamic part of the shape. + // First try to query the shape via ReifyRankedShapedTypeOpInterface. + bool reifiedShapes = false; + if (shapedValue.getType().isa() && + shapedValue.isa()) { + if (auto rankedOp = dyn_cast_or_null( + shapedValue.getDefiningOp())) { + ReifiedRankedShapedTypeDims resultDims; + if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) { + reifiedShapes = true; + auto &shape = + resultDims[shapedValue.cast().getResultNumber()]; + for (const auto &dim : enumerate(tensorType.getShape())) + if (ShapedType::isDynamic(dim.value())) + dynamicSizes.push_back(shape[dim.index()]); + } + } + } + + // If the shape could not be reified, create DimOps. + if (!reifiedShapes) + populateDynamicDimSizes(b, loc, tensor, dynamicSizes); } - return alloc; + + return b.create(loc, tensorType, dynamicSizes, + copy ? tensor : Value(), escape); } LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( @@ -379,6 +398,10 @@ } bool AnalysisState::isInPlace(OpOperand &opOperand) const { + // ToMemrefOps are always in-place. + if (isa(opOperand.getOwner())) + return true; + // In the absence of analysis information, OpOperands that bufferize to a // memory write are out-of-place, i.e., an alloc and copy is inserted. return !bufferizesToMemoryWrite(opOperand); @@ -454,85 +477,21 @@ #endif } -Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor, - const BufferizationOptions &options) { - auto tensorType = tensor.getType().dyn_cast(); +Value BufferizationState::getBuffer(RewriterBase &rewriter, Value value) { + auto tensorType = value.getType().dyn_cast(); assert(tensorType && "unexpected non-tensor type"); // Replace "%t = to_tensor %m" with %m. - if (auto toTensorOp = tensor.getDefiningOp()) + if (auto toTensorOp = value.getDefiningOp()) return toTensorOp.memref(); // Insert to_memref op. OpBuilder::InsertionGuard g(rewriter); - setInsertionPointAfter(rewriter, tensor); - Type memrefType = getMemRefType(tensorType, options); - ensureToMemrefOpIsValid(tensor, memrefType); - return rewriter.create(tensor.getLoc(), memrefType, - tensor); -} - -/// Return the buffer (memref) for a given OpOperand (tensor). Allocate -/// a new buffer and copy over data from the existing buffer if out-of-place -/// bufferization was decided. -FailureOr -BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand, - Optional overrideInPlace, - Optional customCopyInsertionPoint) { - const BufferizationOptions &options = analysisState.getOptions(); - OpBuilder::InsertionGuard guard(rewriter); - Operation *op = opOperand.getOwner(); - Location loc = op->getLoc(); - SmallVector aliasingOpResults = - analysisState.getAliasingOpResult(opOperand); - Value operand = opOperand.get(); - Value operandBuffer = lookupBuffer(rewriter, operand, options); - - // Can `operandBuffer` be used directly or do we need a copy? - bool inplace = - overrideInPlace != FORCE_OUT_OF_PLACE && - (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand)); - if (inplace) - return operandBuffer; - - // Bufferizing out-of-place: Allocate a new buffer. - // Move insertion point right after `operandBuffer`. That is where the - // allocation should be inserted (in the absence of allocation hoisting). - setInsertionPointAfter(rewriter, operandBuffer); - // Allocate the result buffer. The buffer should be deallocated if the tensor - // is not yielded and deallocs are enabled in general. - bool dealloc = llvm::none_of(aliasingOpResults, [&](Value v) { - return getAnalysisState().isTensorYielded(v); - }); - FailureOr resultBuffer = createAlloc( - rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs); - if (failed(resultBuffer)) - return failure(); - // Do not copy the buffer if its contents are undefined. - if (analysisState.hasUndefinedContents(&opOperand)) - return resultBuffer; - // Do not copy if the copied data is never read. - if (!aliasingOpResults.empty() && - !analysisState.bufferizesToMemoryRead(opOperand) && - llvm::none_of(aliasingOpResults, [&](OpResult opResult) { - return analysisState.isValueRead(opResult); - })) - return resultBuffer; - // Do not copy if this op does not read the data, but writes it. - if (analysisState.bufferizesToMemoryWrite(opOperand) && - !analysisState.bufferizesToMemoryRead(opOperand)) - return resultBuffer; - - if (customCopyInsertionPoint) { - rewriter.setInsertionPoint(*customCopyInsertionPoint); - } else { - // The copy happens right before the op that is bufferized. - rewriter.setInsertionPoint(op); - } - if (failed(options.createMemCpy(rewriter, loc, operandBuffer, *resultBuffer))) - return failure(); - - return resultBuffer; + setInsertionPointAfter(rewriter, value); + Type memrefType = getMemRefType(tensorType, getOptions()); + ensureToMemrefOpIsValid(value, memrefType); + return rewriter.create(value.getLoc(), memrefType, + value); } /// Return the buffer type for a given Value (tensor) after bufferization. @@ -588,9 +547,12 @@ return (*allocationFn)(b, loc, type, dynShape, bufferAlignment); // Default bufferallocation via AllocOp. - Value allocated = b.create( - loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment)); - return allocated; + if (bufferAlignment != 0) + return b + .create(loc, type, dynShape, + b.getI64IntegerAttr(bufferAlignment)) + .getResult(); + return b.create(loc, type, dynShape).getResult(); } /// Creates a memref deallocation. The given memref buffer must have been @@ -650,48 +612,6 @@ return allocMemRefType; } -/// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the -/// block in case of a bbArg). -FailureOr BufferizationState::createAlloc(OpBuilder &b, Location loc, - Value shapedValue, - Optional dealloc) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - - // Compute allocation memref type. - assert(shapedValue.getType().isa()); - SmallVector dynShape; - MemRefType allocMemRefType = - getAllocationTypeAndShape(b, loc, shapedValue, dynShape); - - // Create the buffer allocation. - FailureOr buffer = - getOptions().createAlloc(b, loc, allocMemRefType, dynShape); - if (failed(buffer)) - return failure(); - - // Should be the buffer be deallocated again or should we let it leak? - if (dealloc) { - if (!dealloc.getValue()) - return *buffer; - } else { - assert(shapedValue.getType().isa() && - "must specify `dealloc` if non-tensor value is passed"); - // Buffer should be not be deallocated if deallocs are generally deactivated - // or if the tensor is yielded from a block. - if (!getOptions().createDeallocs || - getAnalysisState().isTensorYielded(shapedValue)) - return *buffer; - } - - // Create buffer deallocation. - b.setInsertionPoint(b.getInsertionBlock()->getTerminator()); - if (failed(getOptions().createDealloc(b, loc, *buffer))) - return failure(); - - return *buffer; -} - /// Create a memory copy between two memref buffers. LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const { diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -129,33 +129,85 @@ return success(); } +void mlir::bufferization::populateDynamicDimSizes( + OpBuilder &b, Location loc, Value shapedValue, + SmallVector &dynamicDims) { + auto shapedType = shapedValue.getType().cast(); + for (int64_t i = 0; i < shapedType.getRank(); ++i) { + if (shapedType.isDynamicDim(i)) { + if (shapedType.isa()) { + dynamicDims.push_back(b.create(loc, shapedValue, i)); + } else { + assert(shapedType.isa() && "expected tensor"); + dynamicDims.push_back(b.create(loc, shapedValue, i)); + } + } + } +} + //===----------------------------------------------------------------------===// // AllocTensorOp //===----------------------------------------------------------------------===// LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, BufferizationState &state) { + OpBuilder::InsertionGuard g(rewriter); + Location loc = getLoc(); + // Nothing to do for dead AllocTensorOps. - if (getOperation()->getUses().empty()) + if (getOperation()->getUses().empty()) { + rewriter.eraseOp(getOperation()); return success(); + } - Optional dealloc = llvm::None; - if (escape().hasValue()) - dealloc = !*escape(); + // Create buffer allocation. + Value copyBuffer; + if (copy()) + copyBuffer = state.getBuffer(rewriter, copy()); + auto allocType = + MemRefType::get(getType().getShape(), getType().getElementType()); + SmallVector dynamicDims = dynamicSizes(); + if (copy()) { + assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`"); + populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims); + } FailureOr alloc = - state.createAlloc(rewriter, getLoc(), getResult(), dealloc); + state.getOptions().createAlloc(rewriter, loc, allocType, dynamicDims); if (failed(alloc)) return failure(); + + // Create memory copy (if any). if (copy()) { - FailureOr copyValueBuffer = state.getBuffer( - rewriter, getOperation()->getOpOperand(getNumOperands() - 1)); - if (failed(copyValueBuffer)) - return failure(); - if (failed(state.getOptions().createMemCpy(rewriter, getLoc(), - *copyValueBuffer, *alloc))) + if (failed( + state.getOptions().createMemCpy(rewriter, loc, copyBuffer, *alloc))) return failure(); } + + // Should the buffer be deallocated? + AnalysisState analysisState(state.getOptions()); + bool dealloc; + if (escape().hasValue()) { + dealloc = !*escape(); + } else { + // No "escape" annotation found. + if (state.getOptions().createDeallocs) { + // Perform an ad-hoc analysis. + dealloc = !analysisState.isTensorYielded(getResult()); + } else { + dealloc = false; + } + } + + // Replace op. replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc); + + // Create buffer deallocation (if requested). + if (!dealloc) + return success(); + + rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator()); + if (failed(state.getOptions().createDealloc(rewriter, loc, *alloc))) + return failure(); return success(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Operation.h" @@ -288,30 +289,17 @@ return hasTensorResult || hasTensorOperand; } -LogicalResult bufferization::bufferizeOp(Operation *op, - const AnalysisState &analysisState) { - // Catch incorrect API usage. - assert((analysisState.hasDialectState( - func::FuncDialect::getDialectNamespace()) || - !analysisState.getOptions().bufferizeFunctionBoundaries) && - "must use ModuleBufferize to bufferize function boundaries"); - - BufferizationState bufferizationState(analysisState); - if (failed(bufferizeOp(op, bufferizationState))) - return failure(); - return success(); -} - namespace { /// A rewriter that keeps track of extra information during bufferization. class BufferizationRewriter : public IRRewriter { public: BufferizationRewriter(MLIRContext *ctx, DenseSet &erasedOps, DenseSet &toMemrefOps, + SmallVector &worklist, const BufferizationOptions &options, const OpFilter *opFilter) : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), - options(options), opFilter(opFilter) {} + worklist(worklist), analysisState(options), opFilter(opFilter) {} protected: void notifyOperationRemoved(Operation *op) override { @@ -323,6 +311,7 @@ void notifyOperationInserted(Operation *op) override { IRRewriter::notifyOperationInserted(op); + erasedOps.erase(op); // Keep track of to_memref ops. if (isa(op)) { @@ -338,14 +327,24 @@ if (!hasTensorSemantics(op)) return; - // Skip ops that are not allowed. + // Skip ops that are not allowed to be bufferized. + auto const &options = analysisState.getOptions(); if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op))) return; - // Adding new bufferizable ops is not allowed during bufferization. Such ops - // would not be analyzed and can lead to surprising behavior. - llvm_unreachable( - "creating new tensor ops is not allowed during bufferization"); +#ifndef NDEBUG + // Read-only tensor ops may be created during bufferization. Ops that are + // writing should not be created because such ops were never analyzed. + // Bufferizing such ops could introduce a RaW conflict. + for (OpOperand &operand : op->getOpOperands()) + if (operand.get().getType().isa()) + assert(!analysisState.bufferizesToMemoryWrite(operand) && + "creating tensor ops that bufferize to a memory write is not " + "allowed during bufferization"); +#endif // NDEBUG + + // Add op to worklist. + worklist.push_back(op); } private: @@ -355,23 +354,32 @@ /// A set of all to_memref ops. DenseSet &toMemrefOps; - /// The bufferization options. - /// Used for debug modes. - LLVM_ATTRIBUTE_UNUSED - const BufferizationOptions &options; + /// The worklist of ops to be bufferized. + SmallVector &worklist; + + /// The analysis state. Used for debug assertions and access to the + /// bufferization options. + const AnalysisState analysisState; + /// An extra op filter for bufferization. const OpFilter *opFilter; }; } // namespace LogicalResult bufferization::bufferizeOp(Operation *op, - BufferizationState &bufferizationState, + const BufferizationOptions &options, + bool copyBeforeWrite, const OpFilter *opFilter) { - const auto &options = bufferizationState.getOptions(); assert(options.unknownTypeConversion != BufferizationOptions::LayoutMapOption::InferLayoutMap && "invalid layout map option"); + if (copyBeforeWrite) { + AnalysisState state(options); + if (failed(insertTensorCopies(op, state))) + return failure(); + } + // Keep track of to_memref ops. DenseSet toMemrefOps; op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); }); @@ -393,8 +401,9 @@ DenseSet erasedOps; // Bufferize all ops. + BufferizationState bufferizationState(options); BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, - bufferizationState.getOptions(), opFilter); + worklist, options, opFilter); for (unsigned i = 0; i < worklist.size(); ++i) { Operation *op = worklist[i]; // Skip ops that were erased. @@ -443,23 +452,22 @@ // Ops without any uses and no side effects will fold away. if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op)) continue; + // ToTensorOps/ToMemrefOps without any uses fold away. + if (op->getUses().empty() && isa(op)) + continue; return op->emitError("op was not bufferized"); } return success(); } -LogicalResult bufferization::bufferizeOp(Operation *op, - const BufferizationOptions &options) { - AnalysisState state(options); - return bufferizeOp(op, state); -} - BufferizationOptions bufferization::getPartialBufferizationOptions() { BufferizationOptions options; options.allowUnknownOps = true; options.createDeallocs = false; + options.enforceAliasingInvariants = false; options.unknownTypeConversion = BufferizationOptions::LayoutMapOption::IdentityLayoutMap; + options.opFilter.allowDialect(); return options; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -306,12 +306,8 @@ // Retrieve buffers for tensor operands. Value buffer = newOperands[idx]; - if (!buffer) { - FailureOr bufferOrFailure = state.getBuffer(rewriter, opOperand); - if (failed(bufferOrFailure)) - return failure(); - buffer = *bufferOrFailure; - } + if (!buffer) + buffer = state.getBuffer(rewriter, opOperand.get()); // Caller / callee type mismatch is handled with a CastOp. auto memRefType = funcType.getInput(idx); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -46,6 +46,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AsmState.h" @@ -989,9 +990,9 @@ bufferization::runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options) { OneShotAnalysisState state(op, options); - if (failed(analyzeOp(op, state))) + if (failed(insertTensorCopies(op, options))) return failure(); if (options.testAnalysisOnly) return success(); - return bufferizeOp(op, state); + return bufferizeOp(op, options, /*copyBeforeWrite=*/false); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -64,6 +64,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Operation.h" @@ -428,7 +429,7 @@ assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); IRRewriter rewriter(moduleOp.getContext()); - BufferizationState bufferizationState(analysisState); + BufferizationState bufferizationState(options); // A list of functions in the order in which they are analyzed + bufferized. SmallVector orderedFuncOps; @@ -443,7 +444,7 @@ for (func::FuncOp funcOp : orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. - if (failed(bufferizeOp(funcOp, bufferizationState))) + if (failed(bufferizeOp(funcOp, options, /*copyBeforeWrite=*/false))) return failure(); // Change buffer return types to more precise layout maps. if (options.functionBoundaryTypeConversion == @@ -465,7 +466,7 @@ assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); OneShotAnalysisState analysisState(moduleOp, options); - if (failed(analyzeModuleOp(moduleOp, analysisState))) + if (failed(insertTensorCopies(moduleOp, options))) return failure(); if (options.testAnalysisOnly) return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -46,23 +46,15 @@ newInputBuffers.push_back(opOperand->get()); continue; } - // Input operands are never written to. - newInputBuffers.push_back(*state.getBuffer( - rewriter, *opOperand, - BufferizationState::ForceInPlacability::FORCE_INPLACE)); + newInputBuffers.push_back(state.getBuffer(rewriter, opOperand->get())); } // New output operands for the cloned op. SmallVector newOutputBuffers; for (OpResult opResult : op->getOpResults()) { - SmallVector aliasingOpOperands = - state.getAnalysisState().getAliasingOpOperand(opResult); - assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand"); - FailureOr resultBuffer = - state.getBuffer(rewriter, *aliasingOpOperands.front()); - if (failed(resultBuffer)) - return failure(); - newOutputBuffers.push_back(*resultBuffer); + OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber()); + Value resultBuffer = state.getBuffer(rewriter, opOperand->get()); + newOutputBuffers.push_back(resultBuffer); } // Merge input/output operands. diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -313,10 +313,8 @@ SmallVector result; for (OpOperand &opOperand : operands) { if (opOperand.get().getType().isa()) { - FailureOr resultBuffer = state.getBuffer(rewriter, opOperand); - if (failed(resultBuffer)) - return {}; - result.push_back(*resultBuffer); + Value resultBuffer = state.getBuffer(rewriter, opOperand.get()); + result.push_back(resultBuffer); } else { result.push_back(opOperand.get()); } @@ -325,55 +323,13 @@ } /// Helper function for loop bufferization. Compute the buffer that should be -/// yielded from a loop block (loop body or loop condition). If the given tensor -/// is equivalent to the corresponding block argument (as indicated by -/// `isEquivalent`), the buffer can be yielded directly. Otherwise, a new buffer -/// copy must be yielded. -/// -/// According to the `BufferizableOpInterface` implementation of scf loops, a -/// a bufferized OpResult may alias only with the corresponding bufferized -/// init_arg and with no other buffers. I.e., the i-th OpResult may alias with -/// the i-th init_arg; but not with any other OpOperand. If a corresponding -/// OpResult/init_arg pair bufferized to equivalent buffers (as indicated by -/// `isEquivalent`), this aliasing requirement is satisfied. Otherwise, we -/// cannot be sure and must yield a new buffer copy. (New buffer copies do not -/// alias with any buffer.) +/// yielded from a loop block (loop body or loop condition). static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor, - BaseMemRefType type, bool isEquivalent, - BufferizationState &state) { + BaseMemRefType type, BufferizationState &state) { assert(tensor.getType().isa() && "expected tensor"); ensureToMemrefOpIsValid(tensor, type); - Value yieldedVal = - bufferization::lookupBuffer(rewriter, tensor, state.getOptions()); - - if (isEquivalent) - // Yielded value is equivalent to the corresponding iter_arg bbArg. - // Yield the value directly. Most IR should be like that. Everything - // else must be resolved with copies and is potentially inefficient. - // By default, such problematic IR would already have been rejected - // during `verifyAnalysis`, unless `allow-return-allocs`. - return castBuffer(rewriter, yieldedVal, type); - - // It is not certain that the yielded value and the iter_arg bbArg - // have the same buffer. Allocate a new buffer and copy. The yielded - // buffer will get deallocated by `deallocateBuffers`. - - // TODO: There are cases in which it is not neccessary to return a new - // buffer allocation. E.g., when equivalent values are yielded in a - // different order. This could be resolved with copies. - Optional yieldedAlloc = state.createAlloc( - rewriter, tensor.getLoc(), yieldedVal, /*deallocMemref=*/false); - // TODO: We should rollback, but for now just assume that this always - // succeeds. - assert(yieldedAlloc.hasValue() && "could not create alloc"); - LogicalResult copyStatus = state.getOptions().createMemCpy( - rewriter, tensor.getLoc(), yieldedVal, *yieldedAlloc); - (void)copyStatus; - assert(succeeded(copyStatus) && "could not create memcpy"); - - // The iter_arg memref type may have a layout map. Cast the new buffer - // to the same type if needed. - return castBuffer(rewriter, *yieldedAlloc, type); + Value yieldedVal = state.getBuffer(rewriter, tensor); + return castBuffer(rewriter, yieldedVal, type); } /// Helper function for loop bufferization. Given a range of values, apply @@ -396,13 +352,12 @@ SmallVector getYieldedValues(RewriterBase &rewriter, ValueRange values, TypeRange bufferizedTypes, const DenseSet &tensorIndices, - const DenseSet &equivalentTensors, BufferizationState &state) { return convertTensorValues( values, tensorIndices, [&](Value val, int64_t index) { return getYieldedBuffer(rewriter, val, bufferizedTypes[index].cast(), - equivalentTensors.contains(index), state); + state); }); } @@ -519,18 +474,11 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto forOp = cast(op); - auto oldYieldOp = - cast(forOp.getLoopBody().front().getTerminator()); Block *oldLoopBody = &forOp.getLoopBody().front(); // Indices of all iter_args that have tensor type. These are the ones that // are bufferized. DenseSet indices = getTensorIndices(forOp.getInitArgs()); - // For every yielded value, is the value equivalent to its corresponding - // bbArg? - DenseSet equivalentYields = - getEquivalentBuffers(forOp.getRegionIterArgs(), oldYieldOp.getResults(), - state.getAnalysisState()); // The new memref init_args of the loop. SmallVector initArgs = @@ -562,9 +510,8 @@ // Update scf.yield of new loop. auto yieldOp = cast(loopBody->getTerminator()); rewriter.setInsertionPoint(yieldOp); - SmallVector yieldValues = - getYieldedValues(rewriter, yieldOp.getResults(), initArgsTypes, indices, - equivalentYields, state); + SmallVector yieldValues = getYieldedValues( + rewriter, yieldOp.getResults(), initArgsTypes, indices, state); yieldOp.getResultsMutable().assign(yieldValues); // Replace loop results. @@ -773,15 +720,6 @@ DenseSet indicesAfter = getTensorIndices(whileOp.getAfterArguments()); - // For every yielded value, is the value equivalent to its corresponding - // bbArg? - DenseSet equivalentYieldsBefore = getEquivalentBuffers( - whileOp.getBeforeArguments(), whileOp.getConditionOp().getArgs(), - state.getAnalysisState()); - DenseSet equivalentYieldsAfter = getEquivalentBuffers( - whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), - state.getAnalysisState()); - // The new memref init_args of the loop. SmallVector initArgs = getBuffers(rewriter, whileOp->getOpOperands(), state); @@ -823,7 +761,7 @@ // TODO: This could be relaxed for better bufferization results. SmallVector newConditionArgs = getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter, - indicesAfter, equivalentYieldsBefore, state); + indicesAfter, state); newConditionOp.getArgsMutable().assign(newConditionArgs); // Set up new iter_args and move the loop body block to the new op. @@ -842,7 +780,7 @@ // TODO: This could be relaxed for better bufferization results. SmallVector newYieldValues = getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore, - indicesBefore, equivalentYieldsAfter, state); + indicesBefore, state); newYieldOp.getResultsMutable().assign(newYieldValues); // Replace loop results. @@ -1023,16 +961,12 @@ // Gather new results of the ForeachThreadOp. SmallVector newResults; for (OpResult opResult : foreachThreadOp->getOpResults()) { - SmallVector insertDestOperands = - state.getAnalysisState().getAliasingOpOperand(opResult); - assert(insertDestOperands.size() == 1 && - "expected exactly one aliasing OpOperand"); + OpOperand *insertDest = + getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]; // Insert copies right before the PerformConcurrentlyOp terminator. They // should not be inside terminator (which would be the default insertion // point). - Value buffer = *state.getBuffer(b, *insertDestOperands.front(), - /*forceInPlace=*/llvm::None, - /*customCopyInsertionPoint=*/op); + Value buffer = state.getBuffer(b, insertDest->get()); newResults.push_back(buffer); } @@ -1089,7 +1023,7 @@ PerformConcurrentlyOpInterface, PerformConcurrentlyOp> { LogicalResult bufferize(Operation *op, RewriterBase &b, BufferizationState &state) const { - assert(false && "op does not have any tensor OpOperands / OpResults"); + llvm_unreachable("op does not have any tensor OpOperands / OpResults"); return failure(); } }; diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -51,11 +52,8 @@ auto castOp = cast(op); // The result buffer still has the old (pre-cast) type. - FailureOr resultBuffer = - state.getBuffer(rewriter, castOp->getOpOperand(0) /*source*/); - if (failed(resultBuffer)) - return failure(); - auto sourceMemRefType = resultBuffer->getType().cast(); + Value resultBuffer = state.getBuffer(rewriter, castOp.source()); + auto sourceMemRefType = resultBuffer.getType().cast(); Attribute memorySpace = sourceMemRefType.getMemorySpace(); TensorType resultTensorType = castOp.getResult().getType().cast(); @@ -70,11 +68,11 @@ layout, memorySpace); // Replace the op with a memref.cast. - assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), + assert(memref::CastOp::areCastCompatible(resultBuffer.getType(), resultMemRefType) && "CallOp::bufferize: cast incompatible"); replaceOpWithNewBufferizedOp(rewriter, op, resultMemRefType, - *resultBuffer); + resultBuffer); return success(); } @@ -110,14 +108,11 @@ BufferizationState &state) const { auto collapseShapeOp = cast(op); RankedTensorType tensorResultType = collapseShapeOp.getResultType(); - OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/; - auto bufferType = state.getBufferType(srcOperand.get()).cast(); + Value buffer = state.getBuffer(rewriter, collapseShapeOp.src()); + auto bufferType = buffer.getType().cast(); if (tensorResultType.getRank() == 0) { // 0-d collapses must go through a different op builder. - auto buffer = state.getBuffer(rewriter, srcOperand); - if (failed(buffer)) - return failure(); MemRefType resultType; if (bufferType.getLayout().isIdentity()) { @@ -140,7 +135,7 @@ } replaceOpWithNewBufferizedOp( - rewriter, op, resultType, *buffer, collapseShapeOp.reassociation()); + rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); return success(); } @@ -149,18 +144,23 @@ // newly allocated buffer will have no layout map and thus be collapsible. bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible( bufferType, collapseShapeOp.getReassociationIndices()); - Optional overrideInPlace = - canBeCollapsed - ? None - : Optional( - BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE); - auto buffer = state.getBuffer(rewriter, srcOperand, overrideInPlace); - if (failed(buffer)) - return failure(); + if (!canBeCollapsed) { + // TODO: Create alloc_tensor ops during TensorCopyInsertion. + AnalysisState analysisState(state.getOptions()); + Value tensorAlloc = allocateTensorForShapedValue( + rewriter, op->getLoc(), collapseShapeOp.src(), + analysisState.isTensorYielded(collapseShapeOp.result())); + auto memrefType = + MemRefType::get(collapseShapeOp.getSrcType().getShape(), + collapseShapeOp.getSrcType().getElementType(), + AffineMap(), bufferType.getMemorySpaceAsInt()); + buffer = rewriter.create( + op->getLoc(), memrefType, tensorAlloc); + } // Result type is inferred by the builder. replaceOpWithNewBufferizedOp( - rewriter, op, *buffer, collapseShapeOp.getReassociationIndices()); + rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); return success(); } }; @@ -187,11 +187,8 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto dimOp = cast(op); - auto v = state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); - if (failed(v)) - return failure(); - replaceOpWithNewBufferizedOp(rewriter, op, *v, - dimOp.index()); + auto v = state.getBuffer(rewriter, dimOp.source()); + replaceOpWithNewBufferizedOp(rewriter, op, v, dimOp.index()); return success(); } }; @@ -226,15 +223,12 @@ BufferizationState &state) const { auto expandShapeOp = cast(op); auto tensorResultType = expandShapeOp.getResultType(); - auto buffer = - state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); - if (failed(buffer)) - return failure(); + auto buffer = state.getBuffer(rewriter, expandShapeOp.src()); // Memref result type is inferred by the builder based on reassociation // indices and result shape. replaceOpWithNewBufferizedOp( - rewriter, op, tensorResultType.getShape(), *buffer, + rewriter, op, tensorResultType.getShape(), buffer, expandShapeOp.getReassociationIndices()); return success(); } @@ -273,34 +267,18 @@ // Even if this op was decided to bufferize out-of-place, do not insert the // buffer copy yet. This is done later in this function. - auto srcMemref = - state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, - BufferizationState::ForceInPlacability::FORCE_INPLACE); - if (failed(srcMemref)) - return failure(); - auto srcMemrefType = srcMemref->getType().cast(); + auto srcMemref = state.getBuffer(rewriter, extractSliceOp.source()); + auto srcMemrefType = srcMemref.getType().cast(); auto dstTensorType = extractSliceOp.result().getType().cast(); - // If not inplaceable, alloc. - bool inplace = - state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0)); - Value alloc; - if (!inplace) { - FailureOr allocOrFailure = - state.createAlloc(rewriter, loc, extractSliceOp.result()); - if (failed(allocOrFailure)) - return failure(); - alloc = *allocOrFailure; - } - // Expand offsets, sizes and strides to the full rank to handle the // rank-reducing case. SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); OffsetSizeAndStrideOpInterface::expandToRank( - *srcMemref, mixedOffsets, mixedSizes, mixedStrides, + srcMemref, mixedOffsets, mixedSizes, mixedStrides, [&](Value target, int64_t dim) -> OpFoldResult { auto shapedType = target.getType().cast(); if (shapedType.isDynamicDim(dim)) @@ -313,19 +291,9 @@ mixedOffsets, mixedSizes, mixedStrides) .cast(); Value subView = rewriter.create( - loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes, + loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, mixedStrides); - // If not inplaceable, copy. - if (!inplace) { - // Do not copy if the copied data is never read. - if (state.getAnalysisState().isValueRead(extractSliceOp.result())) - if (failed(state.getOptions().createMemCpy( - rewriter, extractSliceOp.getLoc(), subView, alloc))) - return failure(); - subView = alloc; - } - replaceOpWithBufferizedValues(rewriter, op, subView); return success(); } @@ -353,11 +321,8 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto extractOp = cast(op); - auto srcMemref = - state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); - if (failed(srcMemref)) - return failure(); - replaceOpWithNewBufferizedOp(rewriter, op, *srcMemref, + Value srcMemref = state.getBuffer(rewriter, extractOp.tensor()); + replaceOpWithNewBufferizedOp(rewriter, op, srcMemref, extractOp.indices()); return success(); } @@ -397,11 +362,16 @@ Location loc = op->getLoc(); auto tensorType = fromElementsOp.getType().cast(); auto shape = tensorType.getShape(); - FailureOr maybeBuffer = - state.createAlloc(rewriter, loc, fromElementsOp.result()); - if (failed(maybeBuffer)) - return failure(); - Value buffer = *maybeBuffer; + // TODO: Create alloc_tensor ops during TensorCopyInsertion. + AnalysisState analysisState(state.getOptions()); + Value tensorAlloc = allocateTensorForShapedValue( + rewriter, loc, fromElementsOp.result(), + analysisState.isTensorYielded(fromElementsOp.result()), + /*copy=*/false); + auto memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + Value buffer = rewriter.create( + op->getLoc(), memrefType, tensorAlloc); // Case: tensor<0xelem_type>. if (fromElementsOp.elements().empty()) { @@ -442,15 +412,19 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto generateOp = cast(op); - + auto tensorType = generateOp.getType().cast(); // Allocate memory. Location loc = op->getLoc(); - FailureOr maybeResult = - state.createAlloc(rewriter, loc, generateOp.result()); - if (failed(maybeResult)) - return failure(); - Value result = *maybeResult; - MemRefType memrefType = result.getType().cast(); + // TODO: Create alloc_tensor ops during TensorCopyInsertion. + AnalysisState analysisState(state.getOptions()); + Value tensorAlloc = allocateTensorForShapedValue( + rewriter, loc, generateOp.result(), + analysisState.isTensorYielded(generateOp.result()), + /*copy=*/false); + auto memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + Value buffer = rewriter.create( + op->getLoc(), memrefType, tensorAlloc); // Collect loop bounds. int64_t rank = memrefType.getRank(); @@ -483,10 +457,10 @@ Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); rewriter.setInsertionPointAfter(elementYield); rewriter.replaceOpWithNewOp( - elementYield, elementYield->getOperands()[0], result, + elementYield, elementYield->getOperands()[0], buffer, parallelBody->getArguments()); - replaceOpWithBufferizedValues(rewriter, op, result); + replaceOpWithBufferizedValues(rewriter, op, buffer); return success(); } }; @@ -521,13 +495,10 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto insertOp = cast(op); - FailureOr destMemref = - state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/); - if (failed(destMemref)) - return failure(); + Value destMemref = state.getBuffer(rewriter, insertOp.dest()); rewriter.create(insertOp.getLoc(), insertOp.scalar(), - *destMemref, insertOp.indices()); - replaceOpWithBufferizedValues(rewriter, op, *destMemref); + destMemref, insertOp.indices()); + replaceOpWithBufferizedValues(rewriter, op, destMemref); return success(); } @@ -682,12 +653,7 @@ // TODO: be very loud about it or even consider failing the pass. auto insertSliceOp = cast(op); Location loc = insertSliceOp.getLoc(); - - // When bufferizing out-of-place, `getResultBuffer` allocates. - FailureOr dstMemref = - state.getBuffer(rewriter, insertSliceOp->getOpOperand(1) /*dest*/); - if (failed(dstMemref)) - return failure(); + Value dstMemref = state.getBuffer(rewriter, insertSliceOp.dest()); // Expand offsets, sizes and strides to the full rank to handle the // rank-reducing case. @@ -695,7 +661,7 @@ SmallVector mixedSizes = insertSliceOp.getMixedSizes(); SmallVector mixedStrides = insertSliceOp.getMixedStrides(); OffsetSizeAndStrideOpInterface::expandToRank( - *dstMemref, mixedOffsets, mixedSizes, mixedStrides, + dstMemref, mixedOffsets, mixedSizes, mixedStrides, [&](Value target, int64_t dim) -> OpFoldResult { auto shapedType = target.getType().cast(); if (shapedType.isDynamicDim(dim)) @@ -703,25 +669,24 @@ return rewriter.getIndexAttr(shapedType.getDimSize(dim)); }); // Take a subview of the dst. - auto dstMemrefType = dstMemref->getType().cast(); + auto dstMemrefType = dstMemref.getType().cast(); auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( insertSliceOp.getSourceType().getRank(), dstMemrefType, mixedOffsets, mixedSizes, mixedStrides) .cast(); Value subView = rewriter.create( - loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, + loc, subviewMemRefType, dstMemref, mixedOffsets, mixedSizes, mixedStrides); // Copy tensor. If this tensor.insert_slice has a matching // tensor.extract_slice, the copy operation will eventually fold away. - auto srcMemref = - state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); - if (failed(srcMemref) || failed(state.getOptions().createMemCpy( - rewriter, loc, *srcMemref, subView))) + auto srcMemref = state.getBuffer(rewriter, insertSliceOp.source()); + if (failed( + state.getOptions().createMemCpy(rewriter, loc, srcMemref, subView))) return failure(); - replaceOpWithBufferizedValues(rewriter, op, *dstMemref); + replaceOpWithBufferizedValues(rewriter, op, dstMemref); return success(); } }; @@ -748,11 +713,9 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto rankOp = cast(op); - auto v = state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); - if (failed(v)) - return failure(); + auto v = state.getBuffer(rewriter, rankOp.tensor()); replaceOpWithNewBufferizedOp(rewriter, op, rankOp.getType(), - *v); + v); return success(); } }; @@ -786,21 +749,12 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto reshapeOp = cast(op); - auto &srcOperand = reshapeOp->getOpOperand(0); - auto srcBuffer = state.getBuffer(rewriter, srcOperand); - if (failed(srcBuffer)) - return failure(); - - auto &shapeOperand = reshapeOp->getOpOperand(1); - auto shapeBuffer = state.getBuffer(rewriter, shapeOperand); - if (failed(shapeBuffer)) - return failure(); - + Value srcBuffer = state.getBuffer(rewriter, reshapeOp.source()); + Value shapeBuffer = state.getBuffer(rewriter, reshapeOp.shape()); auto resultTensorType = reshapeOp.getResult().getType().cast(); auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions()); - replaceOpWithNewBufferizedOp( - rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); + rewriter, op, resultMemRefType, srcBuffer, shapeBuffer); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -50,10 +50,7 @@ auto readOp = cast(op); assert(readOp.getShapedType().isa() && "only tensor types expected"); - - // TransferReadOp always reads from the bufferized op.source(). - Value buffer = - *state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/); + Value buffer = state.getBuffer(rewriter, readOp.getSource()); replaceOpWithNewBufferizedOp( rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(), @@ -100,17 +97,12 @@ "only tensor types expected"); // Create a new transfer_write on buffer that doesn't have a return value. - // Leave the previous transfer_write to dead code as it still has uses at - // this point. - FailureOr resultBuffer = - state.getBuffer(rewriter, op->getOpOperand(1) /*source*/); - if (failed(resultBuffer)) - return failure(); + Value resultBuffer = state.getBuffer(rewriter, writeOp.getSource()); rewriter.create( - writeOp.getLoc(), writeOp.getVector(), *resultBuffer, + writeOp.getLoc(), writeOp.getVector(), resultBuffer, writeOp.getIndices(), writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); - replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); + replaceOpWithBufferizedValues(rewriter, op, resultBuffer); return success(); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir @@ -9,18 +9,17 @@ -> (tensor, tensor) { %f0 = arith.constant 0.0: f32 + + // CHECK: %[[EXTRACT_SLICE_ALLOC:.*]] = memref.alloc(%[[sz]]) + // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[EXTRACT_SLICE_ALLOC]] : memref) // Alloc is needed for the **first** insert_slice (due to backward traversal during analysis). // CHECK: %[[DIM:.*]] = memref.dim %[[FUNC_ARG]] // This allocs the whole dim to allow for a full clone of t. // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) - // alloc_tensor itself does not alloc but forwards to the **second** // insert_slice. AllocTensorOp replaces the alloc_tensor with an out-of-place // extract_slice. - // CHECK: %[[EXTRACT_SLICE_ALLOC:.*]] = memref.alloc(%[[sz]]) %a = bufferization.alloc_tensor(%sz) : tensor - - // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[EXTRACT_SLICE_ALLOC]] : memref) %f = linalg.fill ins(%f0 : f32) outs(%a : tensor) -> tensor // CHECK: memref.copy %[[FUNC_ARG]], %[[ALLOC]] : memref to memref diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir @@ -8,8 +8,8 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null -// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="dialect-filter=tensor allow-unknown-ops allow-return-allocs" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-TENSOR -// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="dialect-filter=scf allow-unknown-ops allow-return-allocs" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-SCF +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="dialect-filter=tensor,bufferization allow-unknown-ops allow-return-allocs" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-TENSOR +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="dialect-filter=scf,bufferization allow-unknown-ops allow-return-allocs" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-SCF // CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> @@ -141,14 +141,13 @@ // One alloc for the alloc_tensor, another one because the transfer_write // bufferizes out-of-place. // CHECK: %[[m1:.*]] = memref.alloc() {{.*}} : memref<10xf32> - // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10xf32> - %t1 = bufferization.alloc_tensor() : tensor<10xf32> - // CHECK: linalg.fill ins(%{{.*}}{{.*}}outs(%[[m1]] // CHECK: %[[filled_tensor:.*]] = bufferization.to_tensor %[[m1]] + %t1 = bufferization.alloc_tensor() : tensor<10xf32> %filled = linalg.fill ins(%cst : f32) outs(%t1 : tensor<10xf32>) -> tensor<10xf32> // The transfer_write is out-of-place because "dummy_op" may read. + // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10xf32> // CHECK: memref.copy %[[m1]], %[[alloc]] // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] // CHECK: %[[alloc_tensor:.*]] = bufferization.to_tensor %[[alloc]] @@ -193,10 +192,10 @@ // CHECK-TENSOR: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] %c0 = arith.constant 0 : index // CHECK-TENSOR: %[[alloc:.*]] = memref.alloc - // CHECK-TENSOR: %[[casted_alloc:.*]] = bufferization.to_tensor %[[alloc]] // CHECK-TENSOR: memref.copy %[[t1_memref]], %[[alloc]] // CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]] %0 = tensor.insert %f into %t1[%c0] : tensor + // CHECK-TENSOR: %[[casted_alloc:.*]] = bufferization.to_tensor %[[alloc]] // CHECK-TENSOR: return %[[casted_alloc]] return %0 : tensor } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir @@ -48,12 +48,11 @@ // CHECK-LABEL: func @main( // CHECK-SAME: %[[t:.*]]: memref, %sz: index, %idx: index) -> (f32, f32) { %cst = arith.constant 1.0 : f32 diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -528,12 +528,12 @@ // conflict. However, inside `entry`, the writes do cause a conflict because // %A, %B and %C are not inplaceable. This test case shows that this kind of // conflict detection has a "transitive" nature. -// CHECK-DAG: %[[ALLOC_C:.*]] = memref.alloc -// CHECK-DAG: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]] -// CHECK-DAG: %[[ALLOC_B:.*]] = memref.alloc -// CHECK-DAG: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]] // CHECK-DAG: %[[ALLOC_A:.*]] = memref.alloc // CHECK-DAG: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]] +// CHECK-DAG: %[[ALLOC_B:.*]] = memref.alloc +// CHECK-DAG: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]] +// CHECK-DAG: %[[ALLOC_C:.*]] = memref.alloc +// CHECK-DAG: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]] // CHECK-DAG: memref.copy %[[A]], %[[ALLOC_A]] // CHECK-DAG: memref.copy %[[B]], %[[ALLOC_B]] // CHECK-DAG: memref.copy %[[C]], %[[ALLOC_C]] diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -71,8 +71,8 @@ #map0 = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @multiple_results -// CHECK: %[[RESULT1:.*]] = memref.alloc() {{.*}} : memref<4xf32> // CHECK: %[[RESULT0:.*]] = memref.alloc() {{.*}} : memref<4xf32> +// CHECK: %[[RESULT1:.*]] = memref.alloc() {{.*}} : memref<4xf32> // CHECK: linalg.generic // CHECK-SAME: ins(%{{.*}} : memref<4xf32>) // CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<4xf32>, memref<4xf32>) @@ -101,11 +101,11 @@ // CHECK-SAME: %[[ARG:.*]]: tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[DIM0:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor -// CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor -// CHECK: %[[RESULT1:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {{.*}} : memref -// CHECK: %[[RESULT0:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {{.*}} : memref -// CHECK: %[[MEMREF_ARG:.*]] = bufferization.to_memref %[[ARG]] : memref +// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor +// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor +// CHECK-DAG: %[[RESULT0:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {{.*}} : memref +// CHECK-DAG: %[[RESULT1:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {{.*}} : memref +// CHECK-DAG: %[[MEMREF_ARG:.*]] = bufferization.to_memref %[[ARG]] : memref // CHECK: linalg.generic // CHECK-SAME: ins(%[[MEMREF_ARG]] : memref) // CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref, memref) diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -481,23 +481,20 @@ // CHECK-LABEL: func @scf_while_iter_arg_result_mismatch( // CHECK-SAME: %[[arg0:.*]]: memref<5xi1, #{{.*}}>, %[[arg1:.*]]: memref<5xi1, #{{.*}}> -// CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5xi1> // CHECK: %[[clone:.*]] = bufferization.clone %[[arg1]] // CHECK: scf.while (%[[arg3:.*]] = %[[clone]]) : (memref<5xi1, #{{.*}}) -> () { // CHECK-DAG: memref.dealloc %[[arg3]] // CHECK-DAG: %[[load:.*]] = memref.load %[[arg0]] // CHECK: scf.condition(%[[load]]) // CHECK: } do { +// CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5xi1> // CHECK: memref.copy %[[arg0]], %[[alloc2]] // CHECK: memref.store %{{.*}}, %[[alloc2]] -// CHECK: %[[alloc1:.*]] = memref.alloc() {{.*}} : memref<5xi1> -// CHECK: memref.copy %[[alloc2]], %[[alloc1]] -// CHECK: %[[casted:.*]] = memref.cast %[[alloc1]] : memref<5xi1> to memref<5xi1, #{{.*}}> +// CHECK: %[[casted:.*]] = memref.cast %[[alloc2]] : memref<5xi1> to memref<5xi1, #{{.*}}> // CHECK: %[[cloned:.*]] = bufferization.clone %[[casted]] -// CHECK: memref.dealloc %[[alloc1]] +// CHECK: memref.dealloc %[[alloc2]] // CHECK: scf.yield %[[cloned]] // CHECK: } -// CHECK-DAG: memref.dealloc %[[alloc2]] func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>, %arg1: tensor<5xi1>, %arg2: index) { diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -22,24 +22,22 @@ %t1 : tensor<4xf32> {bufferization.writable = true}) -> (tensor, tensor, tensor, tensor) { - // Hoisted allocs. - // CHECK: %[[REALLOC1:.*]] = memref.alloc - // CHECK: %[[REALLOC2:.*]] = memref.alloc - // CHECK: %[[REALLOC3:.*]] = memref.alloc - // Alloc and copy the whole result tensor. Copy the tensor.extract_slice. + // CHECK: %[[REALLOC3:.*]] = memref.alloc // CHECK: memref.copy %[[A0]], %[[REALLOC3]] // CHECK: %[[SV_A0:.*]] = memref.subview %[[REALLOC3]] // CHECK: memref.copy %[[t0]], %[[SV_A0]] %r0 = tensor.insert_slice %t0 into %A0[0][4][1] : tensor<4xf32> into tensor // Alloc and copy the whole result tensor. Copy the tensor.extract_slice. + // CHECK: %[[REALLOC2:.*]] = memref.alloc // CHECK: memref.copy %[[A0]] // CHECK: %[[SV_A0_2:.*]] = memref.subview %[[REALLOC2]] // CHECK: memref.copy %[[t1]], %[[SV_A0_2]] %r1 = tensor.insert_slice %t1 into %A0[0][4][1] : tensor<4xf32> into tensor // Still alloc the large tensor because %A1 is read after. Copy the tensor.extract_slice. + // CHECK: %[[REALLOC1:.*]] = memref.alloc // CHECK: memref.copy %[[A1]] // CHECK: %[[SV_A1:.*]] = memref.subview %[[REALLOC1]] // CHECK: memref.copy %[[t0]], %[[SV_A1]]