diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -421,8 +421,10 @@ // buffers in memory. // 3. Whether an op operand, when bufferized inplace, aliases a return value. // 4. Whether an op return value, when bufferized inplace, aliases an operand. -// 5. Wheher an op bufferizes to a memory read. -// 6. Wheher an op bufferizes to a memory write. +// 5. Whether an op bufferizes to a memory read. +// 6. Whether an op bufferizes to a memory write. +// 7. The buffer relationship between an operand and it corresponding result +// (in case of in-place bufferization). // These interfaces are necessary to distinguish between various cases and allow // special inplace behavior for (ExtractSliceOp, InsertSliceOp) pairs. //===----------------------------------------------------------------------===// @@ -557,7 +559,7 @@ return None; return TypeSwitch(result.getDefiningOp()) .Case([&](tensor::CastOp op) { return &op->getOpOperand(0); }) - .Case([&](ConstantOp op) { return &op->getOpOperand(0); }) + .Case([&](ConstantOp op) { return nullptr; }) .Case([&](ExtractSliceOp op) { return &op->getOpOperand(0); }) // In the case of scf::ForOp, this currently assumes the iter_args / yield // are 1-1. This may fail and is verified at the end. @@ -682,6 +684,24 @@ getInPlace(opResult) == inPlaceSpec; } +/// Specify fine-grain relationship between buffers to enable more analysis. +enum class BufferRelation { + None, + // TODO: ResultContainsOperand, + // TODO: OperandContainsResult, + Equivalent +}; + +/// Returns the relationship between the operand and the its corresponding +/// OpResult that it may alias with. +static BufferRelation bufferRelation(OpOperand &operand) { + return TypeSwitch(operand.getOwner()) + // ExtractSliceOp returns a subview of the original tensor. + .Case([&](ExtractSliceOp op) { return BufferRelation::None; }) + // All other ops: Buffers are equivalent. + .Default([&](Operation *op) { return BufferRelation::Equivalent; }); +} + //===----------------------------------------------------------------------===// // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// @@ -700,14 +720,6 @@ /// uses BufferizationAliasInfo. class BufferizationAliasInfo { public: - /// Specify fine-grain relationship between buffers to enable more analysis. - enum class BufferRelation { - None, - // TODO: ResultContainsOperand, - // TODO: OperandContainsResult, - Equivalent - }; - explicit BufferizationAliasInfo(Operation *rootOp); /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the @@ -722,10 +734,9 @@ /// `alias`. Additionally, merge their equivalence classes. void insertNewBufferEquivalence(Value newValue, Value alias); - /// Return true if the buffer to which `operand` would bufferize aliases a - /// buffer that is known to not be writeable. This implies that the matching - /// OpResult cannot be bufferized inplace. - bool aliasesNonWriteableBuffer(OpOperand &operand) const; + /// Return true if, under current bufferization decisions, the buffer of + /// `value` is writeable. + bool aliasesNonWriteableBuffer(Value value) const; /// Return true if the buffer to which `operand` would bufferize is equivalent /// to some buffer write. @@ -733,8 +744,7 @@ /// Set the inPlace bufferization spec to true. /// Merge result's and operand's aliasing sets and iterate to a fixed point. - void bufferizeInPlace(OpResult result, OpOperand &operand, - BufferRelation bufferRelation = BufferRelation::None); + void bufferizeInPlace(OpResult result, OpOperand &operand); /// Set the inPlace bufferization spec to false. void bufferizeOutOfPlace(OpResult result); @@ -776,6 +786,10 @@ wouldCreateReadAfterWriteInterference(OpResult result, const DominanceInfo &domInfo) const; + /// Return true if bufferizing `opOperand` inplace would create a write to a + /// non-writeable buffer. + bool wouldCreateWriteToNonWriteableBuffer(OpOperand &opOperand) const; + /// Return true if `v1` and `v2` bufferize to equivalent buffers. bool areEquivalentBufferizedValues(Value v1, Value v2) const { return equivalentInfo.getLeaderValue(v1) == @@ -920,15 +934,11 @@ equivalentInfo.unionSets(newValue, alias); } -/// Return true if the buffer to which `operand` would bufferize aliases a -/// buffer that is known to not be writeable. This implies that the matching -/// OpResult cannot be bufferized inplace. -bool BufferizationAliasInfo::aliasesNonWriteableBuffer( - OpOperand &operand) const { +/// Return true if, under current bufferization decisions, the buffer of `value` +/// is writeable. +bool BufferizationAliasInfo::aliasesNonWriteableBuffer(Value value) const { LDBG("----Start aliasesNonWriteableBuffer\n"); - LDBG("-------for -> #" << operand.getOperandNumber() << ": " - << printOperationInfo(operand.getOwner()) << '\n'); - for (Value v : getAliases(operand.get())) { + for (Value v : getAliases(value)) { LDBG("-----------examine: " << printValueInfo(v) << '\n'); if (auto bbArg = v.dyn_cast()) { if (getInPlace(bbArg) == InPlaceSpec::True) { @@ -936,18 +946,18 @@ << '\n'); continue; } - LDBG("-----------notWriteable\n"); + LDBG("-----------notWriteable bbArg\n"); return true; } if (Operation *op = v.getDefiningOp()) { if (isa(op) || !hasKnownBufferizationAliasingBehavior(op)) { - LDBG("-----------notWriteable\n"); + LDBG("-----------notWriteable op\n"); return true; } } } - LDBG("---->operand is writeable\n"); + LDBG("---->value is writeable\n"); return false; } @@ -971,13 +981,12 @@ /// Set the inPlace bufferization spec to true. void BufferizationAliasInfo::bufferizeInPlace(OpResult result, - OpOperand &operand, - BufferRelation bufferRelation) { + OpOperand &operand) { setInPlaceOpResult(result, InPlaceSpec::True); aliasInfo.unionSets(result, operand.get()); // Dump the updated alias analysis. LLVM_DEBUG(dumpAliases()); - if (bufferRelation == BufferRelation::Equivalent) + if (bufferRelation(operand) == BufferRelation::Equivalent) equivalentInfo.unionSets(result, operand.get()); // Dump the updated equivalence analysis. LLVM_DEBUG(dumpEquivalences()); @@ -1146,6 +1155,32 @@ usesWrite, domInfo); } +/// Return true if bufferizing `opOperand` inplace would create a write to a +/// non-writeable buffer. +bool BufferizationAliasInfo::wouldCreateWriteToNonWriteableBuffer( + OpOperand &opOperand) const { + OpResult opResult = getAliasingOpResult(opOperand); + assert(opResult && "expected that opOperand has aliasing OpResult"); + + // Certain buffers are not writeable: + // 1. A function bbArg that is not inplaceable or + // 2. A constant op. + bool nonWriteable = aliasesNonWriteableBuffer(opResult) || + aliasesNonWriteableBuffer(opOperand.get()); + if (!nonWriteable) + return false; + + // This is only a problem if the buffer is written to via some alias. + bool hasWrite = aliasesInPlaceWrite(opResult) || + aliasesInPlaceWrite(opOperand.get()) || + bufferizesToMemoryWrite(opOperand); + if (!hasWrite) + return false; + + LDBG("->the corresponding buffer is not writeable\n"); + return true; +} + /// Return true if the source of a `insertSliceOp` bufferizes to an /// equivalent ExtractSliceOp that bufferizes inplace. bool BufferizationAliasInfo::isSourceEquivalentToAMatchingInplaceExtractSliceOp( @@ -2369,56 +2404,6 @@ // Bufferization analyses. //===----------------------------------------------------------------------===// -/// -/// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace. -/// =========================================================== -/// -/// When bufferized out of place, a ExtractSlice lowers to alloc + copy. This -/// cannot change the flow of information for either the source or the -/// result buffers. -/// -/// When bufferized inplace, a ExtractSliceOp does not by itself create any read -/// or write from memory. Instead, it has the effect of merging the alias sets -/// of the source and the result buffers. -/// -/// An analysis is required to ensure inplace bufferization would not result in -/// RaW dependence violations. -static LogicalResult -bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { - LDBG('\n'); - LDBG("Inplace analysis for extract_slice: " - << printOperationInfo(extractSliceOp) << '\n'); - - // If `extractSliceOp` were to be bufferized inplace, it cannot end up - // aliasing a write into a non-writeable buffer. - bool wouldCreateAliasingWriteToNonWriteableBuffer = - aliasInfo.aliasesInPlaceWrite(extractSliceOp.result()) && - aliasInfo.aliasesNonWriteableBuffer(extractSliceOp->getOpOperand(0)); - - if (wouldCreateAliasingWriteToNonWriteableBuffer) - LDBG("->the corresponding buffer is not writeable\n"); - else - LDBG("->bufferizes to writeable inplace buffer\n"); - - // In any of extractSliceOp.result's aliases, can we find 2 such that we hit - // an interfering write? - OpResult r = extractSliceOp->getResult(0); - OpOperand &s = extractSliceOp->getOpOperand(0); - bool foundInterference = - wouldCreateAliasingWriteToNonWriteableBuffer || - aliasInfo.wouldCreateReadAfterWriteInterference(r, domInfo); - if (foundInterference) - aliasInfo.bufferizeOutOfPlace(r); - else - aliasInfo.bufferizeInPlace(r, s); - - LDBG("Done inplace analysis for extract_slice\n"); - - return success(); -} - /// Determine if `operand` can be bufferized in-place with one of the op's /// results. If so, set InPlaceSpec::True on the result. Otherwise, set /// InPlaceSpec::False on the result. @@ -2426,15 +2411,10 @@ bufferizableInPlaceAnalysis(OpOperand &operand, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { - OpResult result = getInplaceableOpResult(operand); + OpResult result = getAliasingOpResult(operand); if (!result) return success(); - Operation *op = result.getDefiningOp(); - assert(result && !isa(op) && - "expected OpResult not coming from a ExtractSliceOp"); - (void)op; - int64_t resultNumber = result.getResultNumber(); (void)resultNumber; LDBG('\n'); @@ -2442,29 +2422,14 @@ << operand.getOperandNumber() << " in " << printValueInfo(result) << '\n'); - // `result` must bufferize to a writeable buffer to be a candidate. - // This means the operand must not alias either: - // 1. a function bbArg that is not inplaceable or - // 2. a constant op. - // to be considered for inplace bufferization - bool wouldCreateAliasingWriteToNonWriteableBuffer = - aliasInfo.aliasesNonWriteableBuffer(operand); - if (wouldCreateAliasingWriteToNonWriteableBuffer) - LDBG("->the corresponding buffer is not writeable\n"); - else - LDBG("->bufferizes to writeable inplace buffer\n"); - bool foundInterference = - wouldCreateAliasingWriteToNonWriteableBuffer || + aliasInfo.wouldCreateWriteToNonWriteableBuffer(operand) || aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo); if (foundInterference) aliasInfo.bufferizeOutOfPlace(result); else - // TODO: Atm, all inplace bufferizations yield equivalent tensors. Support - // more cases on a per-need basis. - aliasInfo.bufferizeInPlace( - result, operand, BufferizationAliasInfo::BufferRelation::Equivalent); + aliasInfo.bufferizeInPlace(result, operand); LDBG("Done inplace analysis for result #" << resultNumber << '\n'); @@ -2496,24 +2461,11 @@ }); // Walk ops in reverse for better interference analysis. - for (Operation *op : reverse(ops)) { + for (Operation *op : reverse(ops)) for (OpOperand &opOperand : op->getOpOperands()) if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) return failure(); - // Special logic to analyze ExtractSliceOp. - // Note that ExtractSliceOp analysis needs to be interleaved with other ops - // to properly capture aliases. - // Walk ExtractSliceOps in reverse for better clobbering analysis behavior: - // it is easier to detect clobbers of smaller slices before larger ones. - if (auto extractSliceOp = dyn_cast(op)) { - if (failed( - bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo))) - return failure(); - continue; - } - } - LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); return success(); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -620,10 +620,28 @@ return } +// ----- + //===----------------------------------------------------------------------===// // Transitive cases through extract_slice. //===----------------------------------------------------------------------===// +// CHECK-LABEL: func @write_into_constant_via_alias +func @write_into_constant_via_alias(%v : vector<5xi32>, + %s1 : index, %s2 : index, + %s3 : index) -> tensor { + %A = constant dense<[1, 2, 3, 4]> : tensor<4xi32> + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"]} + %b = tensor.extract_slice %A[%s1][%s2][1] : tensor<4xi32> to tensor + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + %r = vector.transfer_write %v, %b[%s3] : vector<5xi32>, tensor + return %r : tensor +} + +// ----- + builtin.func @matmul_on_tensors( %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false},