diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -421,8 +421,10 @@ // buffers in memory. // 3. Whether an op operand, when bufferized inplace, aliases a return value. // 4. Whether an op return value, when bufferized inplace, aliases an operand. -// 5. Wheher an op bufferizes to a memory read. -// 6. Wheher an op bufferizes to a memory write. +// 5. Whether an op bufferizes to a memory read. +// 6. Whether an op bufferizes to a memory write. +// 7. The buffer relationship between an operand and it corresponding result +// (in case of in-place bufferization). // These interfaces are necessary to distinguish between various cases and allow // special inplace behavior for (ExtractSliceOp, InsertSliceOp) pairs. //===----------------------------------------------------------------------===// @@ -682,6 +684,24 @@ getInPlace(opResult) == inPlaceSpec; } +/// Specify fine-grain relationship between buffers to enable more analysis. +enum class BufferRelation { + None, + // TODO: ResultContainsOperand, + // TODO: OperandContainsResult, + Equivalent +}; + +/// Returns the relationship between the operand and the its corresponding +/// OpResult that it may alias with. +static BufferRelation bufferRelation(OpOperand &operand) { + return TypeSwitch(operand.getOwner()) + // ExtractSliceOp returns a subview of the original tensor. + .Case([&](ExtractSliceOp op) { return BufferRelation::None; }) + // All other ops: Buffers are equivalent. + .Default([&](Operation *op) { return BufferRelation::Equivalent; }); +} + //===----------------------------------------------------------------------===// // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// @@ -700,14 +720,6 @@ /// uses BufferizationAliasInfo. class BufferizationAliasInfo { public: - /// Specify fine-grain relationship between buffers to enable more analysis. - enum class BufferRelation { - None, - // TODO: ResultContainsOperand, - // TODO: OperandContainsResult, - Equivalent - }; - explicit BufferizationAliasInfo(Operation *rootOp); /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the @@ -733,8 +745,7 @@ /// Set the inPlace bufferization spec to true. /// Merge result's and operand's aliasing sets and iterate to a fixed point. - void bufferizeInPlace(OpResult result, OpOperand &operand, - BufferRelation bufferRelation = BufferRelation::None); + void bufferizeInPlace(OpResult result, OpOperand &operand); /// Set the inPlace bufferization spec to false. void bufferizeOutOfPlace(OpResult result); @@ -971,13 +982,12 @@ /// Set the inPlace bufferization spec to true. void BufferizationAliasInfo::bufferizeInPlace(OpResult result, - OpOperand &operand, - BufferRelation bufferRelation) { + OpOperand &operand) { setInPlaceOpResult(result, InPlaceSpec::True); aliasInfo.unionSets(result, operand.get()); // Dump the updated alias analysis. LLVM_DEBUG(dumpAliases()); - if (bufferRelation == BufferRelation::Equivalent) + if (bufferRelation(operand) == BufferRelation::Equivalent) equivalentInfo.unionSets(result, operand.get()); // Dump the updated equivalence analysis. LLVM_DEBUG(dumpEquivalences()); @@ -2372,34 +2382,11 @@ /// Determine if `operand` can be bufferized in-place with one of the op's /// results. If so, set InPlaceSpec::True on the result. Otherwise, set /// InPlaceSpec::False on the result. -/// -/// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace. -/// =========================================================== -/// -/// When bufferized out of place, a ExtractSlice lowers to alloc + copy. This -/// cannot change the flow of information for either the source or the -/// result buffers. -/// -/// When bufferized inplace, a ExtractSliceOp does not by itself create any read -/// or write from memory. Instead, it has the effect of merging the alias sets -/// of the source and the result buffers. -/// -/// An analysis is required to ensure inplace bufferization would not result in -/// RaW dependence violations. static LogicalResult bufferizableInPlaceAnalysis(OpOperand &operand, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { - Operation *op = operand.getOwner(); - bool isExtractSliceOp = false; - OpResult result; - if (isa(op) && (operand.getOperandNumber() == 0)) { - result = op->getOpResult(0); - isExtractSliceOp = true; - } else { - result = getInplaceableOpResult(operand); - } - + OpResult result = getAliasingOpResult(operand); if (!result) return success(); @@ -2410,14 +2397,20 @@ << operand.getOperandNumber() << " in " << printValueInfo(result) << '\n'); - // `result` must bufferize to a writeable buffer to be a candidate. - // This means the operand must not alias either: - // 1. a function bbArg that is not inplaceable or - // 2. a constant op. - // to be considered for inplace bufferization + // The operand may create an alias when bufferized in-place but it does not + // necessarily write to memory. E.g., tensor.extract_slice is such an op. + bool memoryWrite = bufferizesToMemoryWrite(operand); + + // The op must not write to a non-writable buffer, i.e.: + // 1. A function bbArg that is not inplaceable or + // 2. A constant op. + // + // If the op does not write, the newly introduced alias could still result in + // a memory write via another writing op for which in-place bufferization has + // already been decided. This must also be avoided. bool wouldCreateAliasingWriteToNonWriteableBuffer = aliasInfo.aliasesNonWriteableBuffer(operand) && - (!isExtractSliceOp || aliasInfo.aliasesInPlaceWrite(result)); + (memoryWrite || aliasInfo.aliasesInPlaceWrite(result)); if (wouldCreateAliasingWriteToNonWriteableBuffer) LDBG("->the corresponding buffer is not writeable\n"); @@ -2431,10 +2424,7 @@ if (foundInterference) aliasInfo.bufferizeOutOfPlace(result); else - aliasInfo.bufferizeInPlace( - result, operand, - isExtractSliceOp ? BufferizationAliasInfo::BufferRelation::None - : BufferizationAliasInfo::BufferRelation::Equivalent); + aliasInfo.bufferizeInPlace(result, operand); LDBG("Done inplace analysis for result #" << resultNumber << '\n');