diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -48,14 +48,6 @@ /// `alias`. Additionally, merge their equivalence classes. void insertNewBufferEquivalence(Value newValue, Value alias); - /// Return true if, under current bufferization decisions, the buffer of - /// `value` is not writable. - bool aliasesNonWritableBuffer(Value value) const; - - /// Return true if the buffer to which `operand` would bufferize is equivalent - /// to some buffer write. - bool aliasesInPlaceWrite(Value v) const; - /// Set the inPlace bufferization spec to true. /// Merge result's and operand's aliasing sets and iterate to a fixed point. void bufferizeInPlace(OpResult result, OpOperand &operand); @@ -63,23 +55,6 @@ /// Set the inPlace bufferization spec to false. void bufferizeOutOfPlace(OpResult result); - /// Return true if `value` has an ExtractSliceOp matching the given - /// InsertSliceOp in its reverse SSA use-def chain. - bool hasMatchingExtractSliceOp(Value value, - tensor::InsertSliceOp insertOp) const; - - /// Return true if bufferizing `opOperand` inplace with `opResult` would - /// create a write to a non-writable buffer. - bool wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, - OpResult opResult) const; - - /// Assume that result bufferizes in-place with one of the operation's - /// operands. Return true if it is possible to find an inplace write W that - /// creates a conflict. - bool - wouldCreateReadAfterWriteInterference(OpOperand &operand, OpResult result, - const DominanceInfo &domInfo) const; - /// Return true if `v1` and `v2` bufferize to equivalent buffers. bool areEquivalentBufferizedValues(Value v1, Value v2) const { // Return `false` if we have no information about `v1` or `v2`. @@ -91,14 +66,13 @@ equivalentInfo.getLeaderValue(v2); } - /// Return true if the source of an `insertSliceOp` bufferizes to an - /// equivalent ExtractSliceOp. - bool isSourceEquivalentToAMatchingInplaceExtractSliceOp( - tensor::InsertSliceOp insertSliceOp) const; - /// Apply `fun` to all the members of the equivalence class of `v`. void applyOnEquivalenceClass(Value v, function_ref fun) const; + /// Apply `fun` to all aliases of `v`. + void applyOnAliases(Value v, function_ref fun) const; + + // TODO: Move these out of BufferizationAliasInfo. /// Return true if the value is known to bufferize to writable memory. bool bufferizesToWritableMemory(Value v) const; @@ -128,22 +102,6 @@ /// Check that aliasInfo for `v` exists and return a reference to it. EquivalenceClassRangeType getAliases(Value v) const; - /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. - /// equivalent operand / result and same offset/sizes/strides specification). - /// - /// This is one particular type of relationship between ops on tensors that - /// reduce to an equivalence on buffers. This should be generalized and - /// exposed as interfaces on the proper types. - bool areEquivalentExtractSliceOps(tensor::ExtractSliceOp st, - tensor::InsertSliceOp sti) const; - - /// Given sets of uses and writes, return true if there is a RaW conflict - /// under the assumption that all given reads/writes alias the same buffer and - /// that all given writes bufferize inplace. - bool hasReadAfterWriteInterference(const DenseSet &usesRead, - const DenseSet &usesWrite, - const DominanceInfo &domInfo) const; - /// Set of tensors that are known to bufferize to writable memory. llvm::DenseSet bufferizeToWritableMemory; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -497,6 +497,24 @@ // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// +/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. +/// equivalent operand / result and same offset/sizes/strides specification). +/// +/// This is one particular type of relationship between ops on tensors that +/// reduce to an equivalence on buffers. This should be generalized and +/// exposed as interfaces on the proper types. +static bool +areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo, + ExtractSliceOp st, InsertSliceOp sti) { + if (!st || !sti) + return false; + if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest())) + return false; + if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) + return false; + return true; +} + /// Return true if opOperand has been decided to bufferize in-place. static bool isInplaceMemoryWrite(OpOperand &opOperand) { // Ops that do not bufferize to a memory write, cannot be write in-place. @@ -560,24 +578,27 @@ /// Return true if, under current bufferization decisions, the buffer of `value` /// is not writable. -bool BufferizationAliasInfo::aliasesNonWritableBuffer(Value value) const { +static bool aliasesNonWritableBuffer(Value value, + const BufferizationAliasInfo &aliasInfo) { LDBG("----Start aliasesNonWritableBuffer\n"); - for (Value v : getAliases(value)) { + bool foundNonWritableBuffer = false; + aliasInfo.applyOnAliases(value, [&](Value v) { LDBG("-----------examine: " << printValueInfo(v) << '\n'); - if (bufferizesToWritableMemory(v)) { + if (aliasInfo.bufferizesToWritableMemory(v)) { LDBG("-----------Value is known to be writable -> skip: " << printValueInfo(v) << '\n'); - continue; + return; } if (auto bbArg = v.dyn_cast()) { if (getInPlace(bbArg) == InPlaceSpec::True) { LDBG("-----------bbArg is writable -> skip: " << printValueInfo(bbArg) << '\n'); - continue; + return; } LDBG("-----------notWritable bbArg\n"); - return true; + foundNonWritableBuffer = true; + return; } auto bufferizableOp = dyn_cast(v.getDefiningOp()); @@ -585,11 +606,15 @@ // Unknown ops are treated conservatively: Assume that it is illegal to // write to their OpResults in-place. LDBG("-----------notWritable op\n"); - return true; + foundNonWritableBuffer = true; + return; } - } - LDBG("---->value is writable\n"); - return false; + }); + + if (!foundNonWritableBuffer) + LDBG("---->value is writable\n"); + + return foundNonWritableBuffer; } bool BufferizationAliasInfo::bufferizesToWritableMemory(Value v) const { @@ -603,20 +628,26 @@ /// Return true if the buffer to which `operand` would bufferize is equivalent /// to some buffer write. -bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const { +static bool aliasesInPlaceWrite(Value value, + const BufferizationAliasInfo &aliasInfo) { LDBG("----Start aliasesInPlaceWrite\n"); LDBG("-------for : " << printValueInfo(value) << '\n'); - for (Value v : getAliases(value)) { + bool foundInplaceWrite = false; + aliasInfo.applyOnAliases(value, [&](Value v) { for (auto &use : v.getUses()) { if (isInplaceMemoryWrite(use)) { LDBG("-----------wants to bufferize to inPlace write: " << printOperationInfo(use.getOwner()) << '\n'); - return true; + foundInplaceWrite = true; + return; } } - } - LDBG("----------->does not alias an inplace write\n"); - return false; + }); + + if (!foundInplaceWrite) + LDBG("----------->does not alias an inplace write\n"); + + return foundInplaceWrite; } /// Set the inPlace bufferization spec to true. @@ -723,11 +754,11 @@ /// Return true if `value` is originating from an ExtractSliceOp that matches /// the given InsertSliceOp. -bool BufferizationAliasInfo::hasMatchingExtractSliceOp( - Value value, InsertSliceOp insertOp) const { +static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo, + Value value, InsertSliceOp insertOp) { auto condition = [&](Value val) { if (auto extractOp = val.getDefiningOp()) - if (areEquivalentExtractSliceOps(extractOp, insertOp)) + if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp)) return true; return false; }; @@ -758,10 +789,11 @@ /// A conflict is: According to SSA use-def chains, a read R is supposed to read /// the result of a write W1. But because of bufferization decisions, R actually /// reads another write W2. -bool BufferizationAliasInfo::hasReadAfterWriteInterference( - const DenseSet &usesRead, - const DenseSet &usesWrite, - const DominanceInfo &domInfo) const { +static bool +hasReadAfterWriteInterference(const DenseSet &usesRead, + const DenseSet &usesWrite, + const DominanceInfo &domInfo, + const BufferizationAliasInfo &aliasInfo) { for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); @@ -842,7 +874,8 @@ // TODO: Use insertSliceOp.getDestOpOperand etc. when available. if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(uConflictingWrite->get(), insertSliceOp)) + hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(), + insertSliceOp)) // Case 1: The main insight is that InsertSliceOp reads only part of // the destination tensor. The overwritten area is not read. If // uConflictingWrite writes into exactly the memory location that is @@ -859,7 +892,7 @@ if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(uRead->get(), insertSliceOp)) + hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp)) // Case 2: The read of the source tensor and the write to the dest // tensor via an InsertSliceOp is not a conflict if the read is // reading exactly that part of an equivalent tensor that the @@ -902,8 +935,9 @@ /// * However, adding an alias {%0, %t} would mean that the second /// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp /// would no longer be reading the result of %1. -bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference( - OpOperand &operand, OpResult result, const DominanceInfo &domInfo) const { +bool wouldCreateReadAfterWriteInterference( + OpOperand &operand, OpResult result, const DominanceInfo &domInfo, + const BufferizationAliasInfo &aliasInfo) { #ifndef NDEBUG SmallVector opOperands = getAliasingOpOperand(result); assert(llvm::find(opOperands, &operand) != opOperands.end() && @@ -912,20 +946,22 @@ // Helper function to iterate on aliases of `root` and capture the reads. auto getAliasingReads = [&](DenseSet &res, Value root) { - for (Value alias : getAliases(root)) + aliasInfo.applyOnAliases(root, [&](Value alias) { for (auto &use : alias.getUses()) // Read to a value that aliases root. if (bufferizesToMemoryRead(use)) res.insert(&use); + }); }; // Helper function to iterate on aliases of `root` and capture the writes. auto getAliasingInplaceWrites = [&](DenseSet &res, Value root) { - for (Value alias : getAliases(root)) + aliasInfo.applyOnAliases(root, [&](Value alias) { for (auto &use : alias.getUses()) // Inplace write to a value that aliases root. if (isInplaceMemoryWrite(use)) res.insert(&use); + }); }; // Collect reads and writes of all aliases of OpOperand and OpResult. @@ -937,13 +973,14 @@ if (bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); - return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo); + return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, aliasInfo); } /// Return true if bufferizing `opOperand` inplace with `opResult` would create /// a write to a non-writable buffer. -bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer( - OpOperand &opOperand, OpResult opResult) const { +static bool +wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult, + const BufferizationAliasInfo &aliasInfo) { #ifndef NDEBUG SmallVector opOperands = getAliasingOpOperand(opResult); assert(llvm::find(opOperands, &opOperand) != opOperands.end() && @@ -953,15 +990,15 @@ // Certain buffers are not writeable: // 1. A function bbArg that is not inplaceable or // 2. A constant op. - assert(!aliasesNonWritableBuffer(opResult) && + assert(!aliasesNonWritableBuffer(opResult, aliasInfo) && "expected that opResult does not alias non-writable buffer"); - bool nonWritable = aliasesNonWritableBuffer(opOperand.get()); + bool nonWritable = aliasesNonWritableBuffer(opOperand.get(), aliasInfo); if (!nonWritable) return false; // This is a problem only if the buffer is written to via some alias. - bool hasWrite = aliasesInPlaceWrite(opResult) || - aliasesInPlaceWrite(opOperand.get()) || + bool hasWrite = aliasesInPlaceWrite(opResult, aliasInfo) || + aliasesInPlaceWrite(opOperand.get(), aliasInfo) || bufferizesToMemoryWrite(opOperand); if (!hasWrite) return false; @@ -970,28 +1007,6 @@ return true; } -/// Return true if the source of a `insertSliceOp` bufferizes to an -/// equivalent ExtractSliceOp that bufferizes inplace. -bool BufferizationAliasInfo::isSourceEquivalentToAMatchingInplaceExtractSliceOp( - InsertSliceOp insertSliceOp) const { - LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp - << '\n'); - auto leaderIt = equivalentInfo.findLeader(insertSliceOp.source()); - for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; - ++mit) { - auto extractSliceOp = - dyn_cast_or_null(mit->getDefiningOp()); - if (extractSliceOp && - areEquivalentExtractSliceOps(extractSliceOp, insertSliceOp) && - getInPlace(extractSliceOp.result()) == InPlaceSpec::True) { - LDBG("\tfound: " << *mit->getDefiningOp() << '\n'); - return true; - } - } - LDBG("\tnot equivalent\n"); - return false; -} - /// Apply `fun` to all the members of the equivalence class of `v`. void BufferizationAliasInfo::applyOnEquivalenceClass( Value v, function_ref fun) const { @@ -1002,6 +1017,15 @@ } } +/// Apply `fun` to all aliases of `v`. +void BufferizationAliasInfo::applyOnAliases( + Value v, function_ref fun) const { + auto leaderIt = aliasInfo.findLeader(v); + for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) { + fun(*mit); + } +} + void BufferizationAliasInfo::printAliases(raw_ostream &os) const { os << "\n/===================== AliasInfo =====================\n"; for (auto it = aliasInfo.begin(), eit = aliasInfo.end(); it != eit; ++it) { @@ -1058,20 +1082,6 @@ printEquivalences(llvm::errs()); } -/// This is one particular type of relationship between ops on tensors that -/// reduce to an equivalence on buffers. This should be generalized and exposed -/// as interfaces on the proper types. -bool BufferizationAliasInfo::areEquivalentExtractSliceOps( - ExtractSliceOp st, InsertSliceOp sti) const { - if (!st || !sti) - return false; - if (!equivalentInfo.isEquivalent(st.source(), sti.dest())) - return false; - if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) - return false; - return true; -} - //===----------------------------------------------------------------------===// // Forward declarations. //===----------------------------------------------------------------------===// @@ -1467,8 +1477,9 @@ << printValueInfo(result) << '\n'); bool foundInterference = - aliasInfo.wouldCreateWriteToNonWritableBuffer(operand, result) || - aliasInfo.wouldCreateReadAfterWriteInterference(operand, result, domInfo); + wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo) || + wouldCreateReadAfterWriteInterference(operand, result, domInfo, + aliasInfo); if (foundInterference) aliasInfo.bufferizeOutOfPlace(result); @@ -3239,6 +3250,30 @@ } }; +/// Return true if the source of a `insertSliceOp` bufferizes to an +/// equivalent ExtractSliceOp that bufferizes inplace. +static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp( + const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) { + LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp + << '\n'); + bool foundOp = false; + aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) { + auto extractSliceOp = value.getDefiningOp(); + if (extractSliceOp && + areEquivalentExtractSliceOps(aliasInfo, extractSliceOp, + insertSliceOp) && + getInPlace(extractSliceOp.result()) == InPlaceSpec::True) { + LDBG("\tfound: " << extractSliceOp.getOperation() << '\n'); + foundOp = true; + } + }); + + if (!foundOp) + LDBG("\tnot equivalent\n"); + + return foundOp; +} + struct InsertSliceOpInterface : public BufferizableOpInterface::ExternalModel { @@ -3308,8 +3343,8 @@ // cloned and the clone needs to be updated. auto inPlace = getInPlace(insertSliceOp->getResult(0)); // TODO: Is this necessary? - if (!aliasInfo.isSourceEquivalentToAMatchingInplaceExtractSliceOp( - insertSliceOp) || + if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo, + insertSliceOp) || inPlace != InPlaceSpec::True) { LDBG("insert_slice needs extra source copy: " << insertSliceOp.source() << " -> copy\n");