diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -56,10 +56,9 @@ /// `alias`. Additionally, merge their equivalence classes. void insertNewBufferEquivalence(Value newValue, Value alias); - /// Return true if the buffer to which `operand` would bufferize aliases a - /// buffer that is known to not be writable. This implies that the matching - /// OpResult cannot be bufferized inplace. - bool aliasesNonWritableBuffer(OpOperand &operand) const; + /// Return true if, under current bufferization decisions, the buffer of + /// `value` is writable. + bool aliasesNonWritableBuffer(Value value) const; /// Return true if the buffer to which `operand` would bufferize is equivalent /// to some buffer write. @@ -67,8 +66,7 @@ /// Set the inPlace bufferization spec to true. /// Merge result's and operand's aliasing sets and iterate to a fixed point. - void bufferizeInPlace(OpResult result, OpOperand &operand, - BufferRelation bufferRelation = BufferRelation::None); + void bufferizeInPlace(OpResult result, OpOperand &operand); /// Set the inPlace bufferization spec to false. void bufferizeOutOfPlace(OpResult result); @@ -110,6 +108,10 @@ wouldCreateReadAfterWriteInterference(OpResult result, const DominanceInfo &domInfo) const; + /// Return true if bufferizing `opOperand` inplace would create a write to a + /// non-writable buffer. + bool wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand) const; + /// Return true if `v1` and `v2` bufferize to equivalent buffers. bool areEquivalentBufferizedValues(Value v1, Value v2) const { return equivalentInfo.getLeaderValue(v1) == diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -136,6 +136,8 @@ using namespace linalg; using namespace tensor; +using BufferRelation = BufferizationAliasInfo::BufferRelation; + #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X) @@ -421,8 +423,10 @@ // buffers in memory. // 3. Whether an op operand, when bufferized inplace, aliases a return value. // 4. Whether an op return value, when bufferized inplace, aliases an operand. -// 5. Wheher an op bufferizes to a memory read. -// 6. Wheher an op bufferizes to a memory write. +// 5. Whether an op bufferizes to a memory read. +// 6. Whether an op bufferizes to a memory write. +// 7. The buffer relationship between an operand and it corresponding result +// (in case of in-place bufferization). // These interfaces are necessary to distinguish between various cases and allow // special inplace behavior for (ExtractSliceOp, InsertSliceOp) pairs. //===----------------------------------------------------------------------===// @@ -557,7 +561,7 @@ return None; return TypeSwitch(result.getDefiningOp()) .Case([&](tensor::CastOp op) { return &op->getOpOperand(0); }) - .Case([&](ConstantOp op) { return &op->getOpOperand(0); }) + .Case([&](ConstantOp op) { return nullptr; }) .Case([&](ExtractSliceOp op) { return &op->getOpOperand(0); }) // In the case of scf::ForOp, this currently assumes the iter_args / yield // are 1-1. This may fail and is verified at the end. @@ -682,6 +686,16 @@ getInPlace(opResult) == inPlaceSpec; } +/// Returns the relationship between the operand and the its corresponding +/// OpResult that it may alias with. +static BufferRelation bufferRelation(OpOperand &operand) { + return TypeSwitch(operand.getOwner()) + // ExtractSliceOp returns a subview of the original tensor. + .Case([&](ExtractSliceOp op) { return BufferRelation::None; }) + // All other ops: Buffers are equivalent. + .Default([&](Operation *op) { return BufferRelation::Equivalent; }); +} + //===----------------------------------------------------------------------===// // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// @@ -721,40 +735,36 @@ equivalentInfo.unionSets(newValue, alias); } -/// Return true if the buffer to which `operand` would bufferize aliases a -/// buffer that is known to not be writable. This implies that the matching -/// OpResult cannot be bufferized inplace. -bool BufferizationAliasInfo::aliasesNonWritableBuffer( - OpOperand &operand) const { +/// Return true if, under current bufferization decisions, the buffer of `value` +/// is writable. +bool BufferizationAliasInfo::aliasesNonWritableBuffer(Value value) const { LDBG("----Start aliasesNonWritableBuffer\n"); - LDBG("-------for -> #" << operand.getOperandNumber() << ": " - << printOperationInfo(operand.getOwner()) << '\n'); - for (Value v : getAliases(operand.get())) { + for (Value v : getAliases(value)) { LDBG("-----------examine: " << printValueInfo(v) << '\n'); if (bufferizesToWritableMemory(v)) { - LDBG("-----------Value is known to be writeable -> skip: " + LDBG("-----------Value is known to be writable -> skip: " << printValueInfo(v) << '\n'); continue; } if (auto bbArg = v.dyn_cast()) { if (getInPlace(bbArg) == InPlaceSpec::True) { - LDBG("-----------bbArg is writeable -> skip: " << printValueInfo(bbArg) - << '\n'); + LDBG("-----------bbArg is writable -> skip: " << printValueInfo(bbArg) + << '\n'); continue; } - LDBG("-----------notWriteable\n"); + LDBG("-----------notWritable bbArg\n"); return true; } if (Operation *op = v.getDefiningOp()) { if (isa(op) || !hasKnownBufferizationAliasingBehavior(op)) { - LDBG("-----------notWritable\n"); + LDBG("-----------notWritable op\n"); return true; } } } - LDBG("---->operand is writable\n"); + LDBG("---->value is writable\n"); return false; } @@ -787,13 +797,12 @@ /// Set the inPlace bufferization spec to true. void BufferizationAliasInfo::bufferizeInPlace(OpResult result, - OpOperand &operand, - BufferRelation bufferRelation) { + OpOperand &operand) { setInPlaceOpResult(result, InPlaceSpec::True); aliasInfo.unionSets(result, operand.get()); // Dump the updated alias analysis. LLVM_DEBUG(dumpAliases()); - if (bufferRelation == BufferRelation::Equivalent) + if (bufferRelation(operand) == BufferRelation::Equivalent) equivalentInfo.unionSets(result, operand.get()); // Dump the updated equivalence analysis. LLVM_DEBUG(dumpEquivalences()); @@ -962,6 +971,32 @@ usesWrite, domInfo); } +/// Return true if bufferizing `opOperand` inplace would create a write to a +/// non-writable buffer. +bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer( + OpOperand &opOperand) const { + OpResult opResult = getAliasingOpResult(opOperand); + assert(opResult && "expected that opOperand has aliasing OpResult"); + + // Certain buffers are not writeable: + // 1. A function bbArg that is not inplaceable or + // 2. A constant op. + bool nonWriteable = aliasesNonWritableBuffer(opResult) || + aliasesNonWritableBuffer(opOperand.get()); + if (!nonWriteable) + return false; + + // This is only a problem if the buffer is written to via some alias. + bool hasWrite = aliasesInPlaceWrite(opResult) || + aliasesInPlaceWrite(opOperand.get()) || + bufferizesToMemoryWrite(opOperand); + if (!hasWrite) + return false; + + LDBG("->the corresponding buffer is not writeable\n"); + return true; +} + /// Return true if the source of a `insertSliceOp` bufferizes to an /// equivalent ExtractSliceOp that bufferizes inplace. bool BufferizationAliasInfo::isSourceEquivalentToAMatchingInplaceExtractSliceOp( @@ -2199,56 +2234,6 @@ // Bufferization analyses. //===----------------------------------------------------------------------===// -/// -/// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace. -/// =========================================================== -/// -/// When bufferized out of place, a ExtractSlice lowers to alloc + copy. This -/// cannot change the flow of information for either the source or the -/// result buffers. -/// -/// When bufferized inplace, a ExtractSliceOp does not by itself create any read -/// or write from memory. Instead, it has the effect of merging the alias sets -/// of the source and the result buffers. -/// -/// An analysis is required to ensure inplace bufferization would not result in -/// RaW dependence violations. -static LogicalResult -bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { - LDBG('\n'); - LDBG("Inplace analysis for extract_slice: " - << printOperationInfo(extractSliceOp) << '\n'); - - // If `extractSliceOp` were to be bufferized inplace, it cannot end up - // aliasing a write into a non-writable buffer. - bool wouldCreateAliasingWriteToNonWritableBuffer = - aliasInfo.aliasesInPlaceWrite(extractSliceOp.result()) && - aliasInfo.aliasesNonWritableBuffer(extractSliceOp->getOpOperand(0)); - - if (wouldCreateAliasingWriteToNonWritableBuffer) - LDBG("->the corresponding buffer is not writable\n"); - else - LDBG("->bufferizes to writable inplace buffer\n"); - - // In any of extractSliceOp.result's aliases, can we find 2 such that we hit - // an interfering write? - OpResult r = extractSliceOp->getResult(0); - OpOperand &s = extractSliceOp->getOpOperand(0); - bool foundInterference = - wouldCreateAliasingWriteToNonWritableBuffer || - aliasInfo.wouldCreateReadAfterWriteInterference(r, domInfo); - if (foundInterference) - aliasInfo.bufferizeOutOfPlace(r); - else - aliasInfo.bufferizeInPlace(r, s); - - LDBG("Done inplace analysis for extract_slice\n"); - - return success(); -} - /// Determine if `operand` can be bufferized in-place with one of the op's /// results. If so, set InPlaceSpec::True on the result. Otherwise, set /// InPlaceSpec::False on the result. @@ -2256,15 +2241,10 @@ bufferizableInPlaceAnalysis(OpOperand &operand, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { - OpResult result = getInplaceableOpResult(operand); + OpResult result = getAliasingOpResult(operand); if (!result) return success(); - Operation *op = result.getDefiningOp(); - assert(result && !isa(op) && - "expected OpResult not coming from a ExtractSliceOp"); - (void)op; - int64_t resultNumber = result.getResultNumber(); (void)resultNumber; LDBG('\n'); @@ -2272,29 +2252,14 @@ << operand.getOperandNumber() << " in " << printValueInfo(result) << '\n'); - // `result` must bufferize to a writable buffer to be a candidate. - // This means the operand must not alias either: - // 1. a function bbArg that is not inplaceable or - // 2. a constant op. - // to be considered for inplace bufferization - bool wouldCreateAliasingWriteToNonWritableBuffer = - aliasInfo.aliasesNonWritableBuffer(operand); - if (wouldCreateAliasingWriteToNonWritableBuffer) - LDBG("->the corresponding buffer is not writable\n"); - else - LDBG("->bufferizes to writable inplace buffer\n"); - bool foundInterference = - wouldCreateAliasingWriteToNonWritableBuffer || + aliasInfo.wouldCreateWriteToNonWritableBuffer(operand) || aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo); if (foundInterference) aliasInfo.bufferizeOutOfPlace(result); else - // TODO: Atm, all inplace bufferizations yield equivalent tensors. Support - // more cases on a per-need basis. - aliasInfo.bufferizeInPlace( - result, operand, BufferizationAliasInfo::BufferRelation::Equivalent); + aliasInfo.bufferizeInPlace(result, operand); LDBG("Done inplace analysis for result #" << resultNumber << '\n'); @@ -2311,23 +2276,11 @@ BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { // Walk ops in reverse for better interference analysis. - for (Operation *op : reverse(ops)) { + for (Operation *op : reverse(ops)) for (OpOperand &opOperand : op->getOpOperands()) if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) - return failure(); - - // Special logic to analyze ExtractSliceOp. - // Note that ExtractSliceOp analysis needs to be interleaved with other ops - // to properly capture aliases. - // Walk ExtractSliceOps in reverse for better clobbering analysis behavior: - // it is easier to detect clobbers of smaller slices before larger ones. - if (auto extractSliceOp = dyn_cast(op)) { - if (failed( - bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo))) - return failure(); - continue; - } - } + op->emitWarning() << "Inplace analysis treated conservatively"; + return success(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -620,10 +620,28 @@ return } +// ----- + //===----------------------------------------------------------------------===// // Transitive cases through extract_slice. //===----------------------------------------------------------------------===// +// CHECK-LABEL: func @write_into_constant_via_alias +func @write_into_constant_via_alias(%v : vector<5xi32>, + %s1 : index, %s2 : index, + %s3 : index) -> tensor { + %A = constant dense<[1, 2, 3, 4]> : tensor<4xi32> + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"]} + %b = tensor.extract_slice %A[%s1][%s2][1] : tensor<4xi32> to tensor + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + %r = vector.transfer_write %v, %b[%s3] : vector<5xi32>, tensor + return %r : tensor +} + +// ----- + builtin.func @matmul_on_tensors( %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false},