diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -56,10 +56,9 @@ /// `alias`. Additionally, merge their equivalence classes. void insertNewBufferEquivalence(Value newValue, Value alias); - /// Return true if the buffer to which `operand` would bufferize aliases a - /// buffer that is known to not be writable. This implies that the matching - /// OpResult cannot be bufferized inplace. - bool aliasesNonWritableBuffer(OpOperand &operand) const; + /// Return true if, under current bufferization decisions, the buffer of + /// `value` is writable. + bool aliasesNonWritableBuffer(Value value) const; /// Return true if the buffer to which `operand` would bufferize is equivalent /// to some buffer write. @@ -67,8 +66,7 @@ /// Set the inPlace bufferization spec to true. /// Merge result's and operand's aliasing sets and iterate to a fixed point. - void bufferizeInPlace(OpResult result, OpOperand &operand, - BufferRelation bufferRelation = BufferRelation::None); + void bufferizeInPlace(OpResult result, OpOperand &operand); /// Set the inPlace bufferization spec to false. void bufferizeOutOfPlace(OpResult result); @@ -107,9 +105,14 @@ /// read(%0) /// ``` bool - wouldCreateReadAfterWriteInterference(OpResult result, + wouldCreateReadAfterWriteInterference(OpOperand &operand, OpResult result, const DominanceInfo &domInfo) const; + /// Return true if bufferizing `opOperand` inplace with `opResult` would + /// create a write to a non-writable buffer. + bool wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, + OpResult opResult) const; + /// Return true if `v1` and `v2` bufferize to equivalent buffers. bool areEquivalentBufferizedValues(Value v1, Value v2) const { return equivalentInfo.getLeaderValue(v1) == @@ -228,7 +231,7 @@ llvm::EquivalenceClasses equivalentInfo; }; -/// Analyze the `ops` to determine which OpResults are inplaceable: +/// Analyze the `ops` to determine which OpResults are inplaceable. LogicalResult inPlaceAnalysis(SmallVector &ops, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -136,6 +136,8 @@ using namespace linalg; using namespace tensor; +using BufferRelation = BufferizationAliasInfo::BufferRelation; + #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X) @@ -421,8 +423,10 @@ // buffers in memory. // 3. Whether an op operand, when bufferized inplace, aliases a return value. // 4. Whether an op return value, when bufferized inplace, aliases an operand. -// 5. Wheher an op bufferizes to a memory read. -// 6. Wheher an op bufferizes to a memory write. +// 5. Whether an op bufferizes to a memory read. +// 6. Whether an op bufferizes to a memory write. +// 7. The buffer relationship between an operand and it corresponding result +// (in case of in-place bufferization). // These interfaces are necessary to distinguish between various cases and allow // special inplace behavior for (ExtractSliceOp, InsertSliceOp) pairs. //===----------------------------------------------------------------------===// @@ -557,7 +561,7 @@ return None; return TypeSwitch(result.getDefiningOp()) .Case([&](tensor::CastOp op) { return &op->getOpOperand(0); }) - .Case([&](ConstantOp op) { return &op->getOpOperand(0); }) + .Case([&](ConstantOp op) { return nullptr; }) .Case([&](ExtractSliceOp op) { return &op->getOpOperand(0); }) // In the case of scf::ForOp, this currently assumes the iter_args / yield // are 1-1. This may fail and is verified at the end. @@ -682,6 +686,16 @@ getInPlace(opResult) == inPlaceSpec; } +/// Returns the relationship between the operand and the its corresponding +/// OpResult that it may alias with. +static BufferRelation bufferRelation(OpOperand &operand) { + return TypeSwitch(operand.getOwner()) + // ExtractSliceOp returns a subview of the original tensor. + .Case([&](ExtractSliceOp op) { return BufferRelation::None; }) + // All other ops: Buffers are equivalent. + .Default([&](Operation *op) { return BufferRelation::Equivalent; }); +} + //===----------------------------------------------------------------------===// // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// @@ -721,40 +735,36 @@ equivalentInfo.unionSets(newValue, alias); } -/// Return true if the buffer to which `operand` would bufferize aliases a -/// buffer that is known to not be writable. This implies that the matching -/// OpResult cannot be bufferized inplace. -bool BufferizationAliasInfo::aliasesNonWritableBuffer( - OpOperand &operand) const { +/// Return true if, under current bufferization decisions, the buffer of `value` +/// is writable. +bool BufferizationAliasInfo::aliasesNonWritableBuffer(Value value) const { LDBG("----Start aliasesNonWritableBuffer\n"); - LDBG("-------for -> #" << operand.getOperandNumber() << ": " - << printOperationInfo(operand.getOwner()) << '\n'); - for (Value v : getAliases(operand.get())) { + for (Value v : getAliases(value)) { LDBG("-----------examine: " << printValueInfo(v) << '\n'); if (bufferizesToWritableMemory(v)) { - LDBG("-----------Value is known to be writeable -> skip: " + LDBG("-----------Value is known to be writable -> skip: " << printValueInfo(v) << '\n'); continue; } if (auto bbArg = v.dyn_cast()) { if (getInPlace(bbArg) == InPlaceSpec::True) { - LDBG("-----------bbArg is writeable -> skip: " << printValueInfo(bbArg) - << '\n'); + LDBG("-----------bbArg is writable -> skip: " << printValueInfo(bbArg) + << '\n'); continue; } - LDBG("-----------notWriteable\n"); + LDBG("-----------notWritable bbArg\n"); return true; } if (Operation *op = v.getDefiningOp()) { if (isa(op) || !hasKnownBufferizationAliasingBehavior(op)) { - LDBG("-----------notWritable\n"); + LDBG("-----------notWritable op\n"); return true; } } } - LDBG("---->operand is writable\n"); + LDBG("---->value is writable\n"); return false; } @@ -787,13 +797,12 @@ /// Set the inPlace bufferization spec to true. void BufferizationAliasInfo::bufferizeInPlace(OpResult result, - OpOperand &operand, - BufferRelation bufferRelation) { + OpOperand &operand) { setInPlaceOpResult(result, InPlaceSpec::True); aliasInfo.unionSets(result, operand.get()); // Dump the updated alias analysis. LLVM_DEBUG(dumpAliases()); - if (bufferRelation == BufferRelation::Equivalent) + if (bufferRelation(operand) == BufferRelation::Equivalent) equivalentInfo.unionSets(result, operand.get()); // Dump the updated equivalence analysis. LLVM_DEBUG(dumpEquivalences()); @@ -868,14 +877,10 @@ /// C interleaved between W and R (i.e. W -> C -> R where -> denotes /// dominance). bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference( - OpResult result, const DominanceInfo &domInfo) const { - Optional maybeAliasingOperand = getAliasingOpOperand(result); - if (!maybeAliasingOperand) - return false; - + OpOperand &operand, OpResult result, const DominanceInfo &domInfo) const { Operation *opToBufferize = result.getDefiningOp(); Value opResult = result; - Value opOperand = (*maybeAliasingOperand)->get(); + Value opOperand = operand.get(); LDBG("----Start wouldCreateReadAfterWriteInterference\n"); LDBG("--------consider all aliases to root read: " @@ -927,16 +932,16 @@ getAliasingReads(usesRead, opOperand); getAliasingInplaceWrites(usesWrite, opResult); // Additionally, `result` is not yet bufferized and we need to check for - // interferences as if it were bufferized inplace: add `maybeAliasingOperand` - // if it is a write. This handles the case: + // interferences as if it were bufferized inplace: add `operand` if it is a + // write. This handles the case: // // ``` // %0 = op_to_bufferize_maybe_inplace(%1) // %2 = some_alias(%1) // read(%2) // ``` - if (bufferizesToMemoryWrite(**maybeAliasingOperand)) - usesWrite.insert(*maybeAliasingOperand); + if (bufferizesToMemoryWrite(operand)) + usesWrite.insert(&operand); if (wouldCreateReadAfterWriteInterference(opToBufferize, usesRead, usesWrite, domInfo)) return true; @@ -962,6 +967,29 @@ usesWrite, domInfo); } +/// Return true if bufferizing `opOperand` inplace with `opResult` would create +/// a write to a non-writable buffer. +bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer( + OpOperand &opOperand, OpResult opResult) const { + // Certain buffers are not writeable: + // 1. A function bbArg that is not inplaceable or + // 2. A constant op. + bool nonWriteable = aliasesNonWritableBuffer(opResult) || + aliasesNonWritableBuffer(opOperand.get()); + if (!nonWriteable) + return false; + + // This is a problem only if the buffer is written to via some alias. + bool hasWrite = aliasesInPlaceWrite(opResult) || + aliasesInPlaceWrite(opOperand.get()) || + bufferizesToMemoryWrite(opOperand); + if (!hasWrite) + return false; + + LDBG("->the corresponding buffer is not writeable\n"); + return true; +} + /// Return true if the source of a `insertSliceOp` bufferizes to an /// equivalent ExtractSliceOp that bufferizes inplace. bool BufferizationAliasInfo::isSourceEquivalentToAMatchingInplaceExtractSliceOp( @@ -2199,52 +2227,30 @@ // Bufferization analyses. //===----------------------------------------------------------------------===// -/// -/// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace. -/// =========================================================== -/// -/// When bufferized out of place, a ExtractSlice lowers to alloc + copy. This -/// cannot change the flow of information for either the source or the -/// result buffers. -/// -/// When bufferized inplace, a ExtractSliceOp does not by itself create any read -/// or write from memory. Instead, it has the effect of merging the alias sets -/// of the source and the result buffers. -/// -/// An analysis is required to ensure inplace bufferization would not result in -/// RaW dependence violations. +/// Determine if `operand` can be bufferized in-place with `result`. If so, set +/// InPlaceSpec::True on the result. Otherwise, set InPlaceSpec::False on the +/// result. static LogicalResult -bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp, +bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { + int64_t resultNumber = result.getResultNumber(); + (void)resultNumber; LDBG('\n'); - LDBG("Inplace analysis for extract_slice: " - << printOperationInfo(extractSliceOp) << '\n'); - - // If `extractSliceOp` were to be bufferized inplace, it cannot end up - // aliasing a write into a non-writable buffer. - bool wouldCreateAliasingWriteToNonWritableBuffer = - aliasInfo.aliasesInPlaceWrite(extractSliceOp.result()) && - aliasInfo.aliasesNonWritableBuffer(extractSliceOp->getOpOperand(0)); - - if (wouldCreateAliasingWriteToNonWritableBuffer) - LDBG("->the corresponding buffer is not writable\n"); - else - LDBG("->bufferizes to writable inplace buffer\n"); + LDBG("Inplace analysis for <- #" << resultNumber << " -> #" + << operand.getOperandNumber() << " in " + << printValueInfo(result) << '\n'); - // In any of extractSliceOp.result's aliases, can we find 2 such that we hit - // an interfering write? - OpResult r = extractSliceOp->getResult(0); - OpOperand &s = extractSliceOp->getOpOperand(0); bool foundInterference = - wouldCreateAliasingWriteToNonWritableBuffer || - aliasInfo.wouldCreateReadAfterWriteInterference(r, domInfo); + aliasInfo.wouldCreateWriteToNonWritableBuffer(operand, result) || + aliasInfo.wouldCreateReadAfterWriteInterference(operand, result, domInfo); + if (foundInterference) - aliasInfo.bufferizeOutOfPlace(r); + aliasInfo.bufferizeOutOfPlace(result); else - aliasInfo.bufferizeInPlace(r, s); + aliasInfo.bufferizeInPlace(result, operand); - LDBG("Done inplace analysis for extract_slice\n"); + LDBG("Done inplace analysis for result #" << resultNumber << '\n'); return success(); } @@ -2259,54 +2265,12 @@ OpResult result = getInplaceableOpResult(operand); if (!result) return success(); - - Operation *op = result.getDefiningOp(); - assert(result && !isa(op) && - "expected OpResult not coming from a ExtractSliceOp"); - (void)op; - - int64_t resultNumber = result.getResultNumber(); - (void)resultNumber; - LDBG('\n'); - LDBG("Inplace analysis for <- #" << resultNumber << " -> #" - << operand.getOperandNumber() << " in " - << printValueInfo(result) << '\n'); - - // `result` must bufferize to a writable buffer to be a candidate. - // This means the operand must not alias either: - // 1. a function bbArg that is not inplaceable or - // 2. a constant op. - // to be considered for inplace bufferization - bool wouldCreateAliasingWriteToNonWritableBuffer = - aliasInfo.aliasesNonWritableBuffer(operand); - if (wouldCreateAliasingWriteToNonWritableBuffer) - LDBG("->the corresponding buffer is not writable\n"); - else - LDBG("->bufferizes to writable inplace buffer\n"); - - bool foundInterference = - wouldCreateAliasingWriteToNonWritableBuffer || - aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo); - - if (foundInterference) - aliasInfo.bufferizeOutOfPlace(result); - else - // TODO: Atm, all inplace bufferizations yield equivalent tensors. Support - // more cases on a per-need basis. - aliasInfo.bufferizeInPlace( - result, operand, BufferizationAliasInfo::BufferRelation::Equivalent); - - LDBG("Done inplace analysis for result #" << resultNumber << '\n'); - - return success(); + return bufferizableInPlaceAnalysis(operand, result, aliasInfo, domInfo); } -/// Analyze the `ops` to determine which OpResults are inplaceable: -/// 1. First, analyze InsertSliceOp greedily: we almost never want to -/// bufferize the tensor "inserted into" to become out-of-place. -/// 2. Walk the other ops in reverse. This is a good starter heuristic. -/// ExtractSliceOps are interleaved with other ops in traversal order. -/// +/// Analyze the `ops` to determine which OpResults are inplaceable. Walk ops in +/// reverse and bufferize ops greedily. This is a good starter heuristic. +/// ExtractSliceOps are interleaved with other ops in traversal order. LogicalResult mlir::linalg::inPlaceAnalysis(SmallVector &ops, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { @@ -2314,20 +2278,24 @@ for (Operation *op : reverse(ops)) { for (OpOperand &opOperand : op->getOpOperands()) if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) - return failure(); - - // Special logic to analyze ExtractSliceOp. - // Note that ExtractSliceOp analysis needs to be interleaved with other ops - // to properly capture aliases. - // Walk ExtractSliceOps in reverse for better clobbering analysis behavior: - // it is easier to detect clobbers of smaller slices before larger ones. - if (auto extractSliceOp = dyn_cast(op)) { - if (failed( - bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo))) - return failure(); - continue; - } + op->emitWarning() << "Inplace analysis treated conservatively"; + + // Special logic to analyze ExtractSliceOp. When bufferized out of place, an + // ExtractSliceOp lowers to alloc + copy. This cannot change the flow of + // information for either the source or the result buffers. + // + // When bufferized inplace, a ExtractSliceOp does not by itself create any + // read or write from memory. Instead, it has the effect of merging the + // alias sets of the source and the result buffers. + // + // An analysis is required to ensure inplace bufferization would not result + // in RaW dependence violations. + if (auto extractSliceOp = dyn_cast(op)) + if (failed(bufferizableInPlaceAnalysis( + op->getOpOperand(0), op->getOpResult(0), aliasInfo, domInfo))) + op->emitWarning() << "Inplace analysis treated conservatively"; } + return success(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -620,10 +620,28 @@ return } +// ----- + //===----------------------------------------------------------------------===// // Transitive cases through extract_slice. //===----------------------------------------------------------------------===// +// CHECK-LABEL: func @write_into_constant_via_alias +func @write_into_constant_via_alias(%v : vector<5xi32>, + %s1 : index, %s2 : index, + %s3 : index) -> tensor { + %A = constant dense<[1, 2, 3, 4]> : tensor<4xi32> + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"]} + %b = tensor.extract_slice %A[%s1][%s2][1] : tensor<4xi32> to tensor + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + %r = vector.transfer_write %v, %b[%s3] : vector<5xi32>, tensor + return %r : tensor +} + +// ----- + builtin.func @matmul_on_tensors( %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false},