diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -71,18 +71,10 @@ /// Set the inPlace bufferization spec to false. void bufferizeOutOfPlace(OpResult result); - /// Return true if it is possible to find an inplace write W among `usesWrite` - /// and a read R among `usesRead`, such that W and R interfere. - /// Such a (W, R) pair is an interference to the inplace bufferization of - /// opResult when: - /// 1. R is not known properly dominate W (i.e. the effects of the write may - /// be visible from R). - /// 2. one cannot find an intermediate clobbering write `C` to W, such that - /// C interleaved between W and R (i.e. W -> C -> R where -> denotes - /// dominance). - bool wouldCreateReadAfterWriteInterference( - Operation *opToBufferize, DenseSet &usesRead, - DenseSet &usesWrite, const DominanceInfo &domInfo) const; + /// Return true if `value` has an ExtractSliceOp matching the given + /// InsertSliceOp in its reverse SSA use-def chain. + bool hasMatchingExtractSliceOp(Value value, + tensor::InsertSliceOp insertOp) const; /// Return true if bufferizing `opOperand` inplace with `opResult` would /// create a write to a non-writable buffer. @@ -90,25 +82,8 @@ OpResult opResult) const; /// Assume that result bufferizes in-place with one of the operation's - /// operands. Return true if it is possible to find an inplace write W (resp. - /// a read R) among the uses of `aliasInfo[result]`, and a read R (resp. an - /// inplace write W) among the uses of - /// `aliasInfo[getAliasingOpOperand(result)]`, such that W and R interfere. - /// Interference detection is needed to determine which cases may bufferize - /// inplace without interferences. Such cases comprise: - /// - /// ``` - /// %0 = op_to_bufferize(%1) - /// read(%1) - /// - /// %0 = op_to_bufferize(%1) - /// write(%0) - /// read(%1) - /// - /// %0 = op_to_bufferize(%1) - /// write(%1) - /// read(%0) - /// ``` + /// operands. Return true if it is possible to find an inplace write W that + /// creates a conflict. bool wouldCreateReadAfterWriteInterference(OpOperand &operand, OpResult result, const DominanceInfo &domInfo) const; @@ -171,54 +146,13 @@ bool areEquivalentExtractSliceOps(tensor::ExtractSliceOp st, tensor::InsertSliceOp sti) const; - /// Return true if there is a `candidateOp` that would write to memory after - /// bufferization and such that: - /// 1. The written buffer is equivalent to either `aliasingRead` or - /// `aliasingWrite` under the inPlace bufferization decisions taken - /// so far. - /// 2. `aliasingWrite` properly dominates `candidateOp`. - /// 3. `candidateOp` properly dominates `aliasingReadOp`. - // TODO: richer clobbering analysis with container-containee relationship - // instead of equivalence. - bool existsInterleavedValueClobber(OpOperand &aliasingRead, - OpOperand &aliasingWrite, + /// Given sets of uses and writes, return true if there is a RaW conflict + /// under the assumption that all given reads/writes alias the same buffer and + /// that all given writes bufferize inplace. + bool hasReadAfterWriteInterference(const DenseSet &usesRead, + const DenseSet &usesWrite, const DominanceInfo &domInfo) const; - /// Return true if there is a write that: - /// 1. Properly dominates aliasingReadOp. - /// 2. Is properly dominated by aliasingWriteOp. - /// 3. Clobbers the write that would be interfering with the read. - /// - /// Case discussion: - /// ================ - /// Case 1: opOperand is produced by opToBufferize, - /// Case 2: opResult is produced by opToBufferize, - /// Common case: - /// - aliasingReadOp is a read to an alias of opOperand. - /// - aliasingWriteOp is an inplace write to an alias of opResult. - /// - aliasingWriteOp dominates aliasingReadOp. - /// - /// ``` - /// // Either case 1: - /// %opOperand = opToBufferize(%opResult) - /// aliasingWriteOp(%aliasingWrite = alias(%opResult)) // inplace - /// aliasingReadOp( %aliasingRead = alias(%opOperand)) - /// ``` - /// - /// ``` - /// // Or case 2: - /// %opResult = opToBufferize(%opOperand) - /// aliasingWriteOp(%aliasingWrite = alias(%opResult)) // inplace - /// aliasingReadOp( %aliasingRead = alias(%opOperand)) - /// ``` - /// - /// Capture possible cases where `aliasingWriteOp(alias(%opResult))` has no - /// visible effect on `aliasingReadOp(alias(%opOperand))`. - bool isClobberedWriteBeforeRead(Operation *opToBufferize, - OpOperand &aliasingRead, - OpOperand &aliasingWrite, - const DominanceInfo &domInfo) const; - /// Set of tensors that are known to bufferize to writable memory. llvm::DenseSet bufferizeToWritableMemory; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -816,69 +816,198 @@ setInPlaceOpResult(result, InPlaceSpec::False); } -/// Return true if it is possible to find an inplace write W among `usesWrite` -/// and a read R among `usesRead`, such that W and R interfere. -bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference( - Operation *opToBufferize, DenseSet &usesRead, - DenseSet &usesWrite, const DominanceInfo &domInfo) const { +/// Starting from `value`, follow the use-def chain in reverse, always selecting +/// the corresponding aliasing OpOperand. Try to find and return a Value for +/// which `condition` evaluates to true for the aliasing OpOperand. Return an +/// empty Value if no such Value was found. If `returnLast`, return the last +/// Value (at the end of the chain), even if it does not satisfy the condition. +static Value +findValueInReverseUseDefChain(Value value, + std::function condition, + bool returnLast = false) { + while (value.isa()) { + auto opResult = value.cast(); + SmallVector opOperands = getAliasingOpOperand(opResult); + assert(opOperands.size() <= 1 && "more than 1 OpOperand not supported yet"); + if (opOperands.empty()) + // No aliasing OpOperand. This could be an unsupported op or an op without + // a tensor arg such as InitTensorOp. This is the end of the chain. + return returnLast ? value : Value(); + OpOperand *opOperand = opOperands.front(); + if (condition(*opOperand)) + return value; + value = opOperand->get(); + } + // Value is a BlockArgument. Reached the end of the chain. + return returnLast ? value : Value(); +} + +/// Find the Value (result) of the last preceding write of a given Value. +/// +/// Note: Unknown ops are handled conservatively and assumed to be writes. +/// Furthermore, BlockArguments are also assumed to be writes. There is no +/// analysis across block boundaries. +static Value findLastPrecedingWrite(Value value) { + return findValueInReverseUseDefChain(value, bufferizesToMemoryWrite, true); +} + +/// Return true if `value` is originating from an ExtractSliceOp that matches +/// the given InsertSliceOp. +bool BufferizationAliasInfo::hasMatchingExtractSliceOp( + Value value, InsertSliceOp insertOp) const { + return static_cast( + findValueInReverseUseDefChain(value, [&](OpOperand &opOperand) { + if (auto extractOp = dyn_cast(opOperand.getOwner())) + if (areEquivalentExtractSliceOps(extractOp, insertOp)) + return true; + return false; + })); +} + +/// Given sets of uses and writes, return true if there is a RaW conflict under +/// the assumption that all given reads/writes alias the same buffer and that +/// all given writes bufferize inplace. +/// +/// A conflict is: According to SSA use-def chains, a read R is supposed to read +/// the result of a write W1. But because of bufferization decisions, R actually +/// reads another write W2. +bool BufferizationAliasInfo::hasReadAfterWriteInterference( + const DenseSet &usesRead, + const DenseSet &usesWrite, + const DominanceInfo &domInfo) const { + for (OpOperand *uRead : usesRead) { - Operation *aliasingReadOp = uRead->getOwner(); - LDBG("----++++aliasRead -> #" - << uRead->getOperandNumber() - << " in: " << printOperationInfo(aliasingReadOp) << '\n'); - for (OpOperand *uWrite : usesWrite) { - // The same operand may both read and write. - // Don't consider self-use of the same operand for interference. - // Multiple different uses within the same op is fair game though. - if (uWrite == uRead) + Operation *readingOp = uRead->getOwner(); + + // Find most recent write of uRead by following the SSA use-def chain. E.g.: + // + // %0 = "writing_op"(%t) : tensor -> tensor + // %1 = "aliasing_op"(%0) : tensor -> tensor + // %2 = "reading_op"(%1) : : tensor -> not_a_tensor_type + // + // In the above example, if uRead is the OpOperand of reading_op, lastWrite + // is %0. Note that operations that create an alias but do not write (such + // as ExtractSliceOp) are skipped. + // TODO: With branches this should probably be a list of Values. + Value lastWrite = findLastPrecedingWrite(uRead->get()); + + // Look for conflicting memory writes. Potential conflicts are writes to an + // alias that have been decided to bufferize inplace. + for (OpOperand *uConflictingWrite : usesWrite) { + // Throughout this loop, check for multiple requirements that have to be + // met for uConflictingWrite to be an actual conflict. + Operation *conflictingWritingOp = uConflictingWrite->getOwner(); + + // Print some debug info. + LDBG("Found potential conflict:\n"); + LDBG("READ = #" << uRead->getOperandNumber() << " of " + << printOperationInfo(readingOp) << "\n"); + LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n"); + LDBG("CONFLICTING WRITE = #" + << uConflictingWrite->getOperandNumber() << " of " + << printOperationInfo(conflictingWritingOp) << "\n"); + + // No conflict if the readingOp dominates conflictingWritingOp, i.e., the + // write is not visible when reading. + if (domInfo.properlyDominates(readingOp, conflictingWritingOp)) continue; - Operation *aliasingWriteOp = uWrite->getOwner(); - LDBG("---- aliasWrite -> #" - << uWrite->getOperandNumber() - << " in: " << printOperationInfo(aliasingWriteOp) << '\n'); - // If the candidate write is the one that produces the read value (in the - // SSA def-use sense), this is not considered an interference. - if (getInplaceableOpResult(*uWrite) == uRead->get()) - continue; - // If aliasingReadOp properly dominates aliasingWriteOp, the read cannot - // be affected by the write: there is no interference. - if (domInfo.properlyDominates(aliasingReadOp, aliasingWriteOp)) + // No conflict if the conflicting write happens before the last write. + if (Operation *writingOp = lastWrite.getDefiningOp()) { + if (domInfo.properlyDominates(conflictingWritingOp, writingOp)) + // conflictingWritingOp happens before writingOp. No conflict. + continue; + } else { + auto bbArg = lastWrite.cast(); + Block *block = bbArg.getOwner(); + if (!block->findAncestorOpInBlock(*conflictingWritingOp)) + // conflictingWritingOp happens outside of the block. No + // conflict. + continue; + } + + // No conflict if the conflicting write and the last write are the same + // use. + if (getAliasingOpResult(*uConflictingWrite) == lastWrite) continue; - // At this point, aliasingWriteOp properly dominates aliasingReadOp or - // there is no clear dominance and we need to be conservative. - LDBG("---->found RaW interference between:\n"); - LDBG(" OpToBufferize -> " << printOperationInfo(opToBufferize) - << '\n'); - LDBG(" Interfering write -> #" - << uWrite->getOperandNumber() << ":" - << printOperationInfo(aliasingWriteOp) << '\n'); - LDBG(" Target read -> #" << uRead->getOperandNumber() << ":" - << printOperationInfo(aliasingReadOp) - << '\n'); - LDBG("---->opportunity to clobber RaW interference\n"); - if (isClobberedWriteBeforeRead(opToBufferize, *uRead, *uWrite, domInfo)) { - LDBG("---->clobbered! -> skip\n"); + + // No conflict is the same use is the read and the conflicting write. A + // use cannot conflict with itself. + if (uConflictingWrite == uRead) continue; + + // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If + // uRead is an InsertSliceOp... + if (auto insertSliceOp = dyn_cast(readingOp)) { + // As an example, consider the following IR. + // + // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace= [true] } + // %1 = linalg.fill %cst, %0 {inplace= [true] } + // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] + // {inplace= [true] } + + // TODO: Use insertSliceOp.getDestOpOperand etc. when available. + if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && + hasMatchingExtractSliceOp(uConflictingWrite->get(), insertSliceOp)) + // Case 1: The main insight is that InsertSliceOp reads only part of + // the destination tensor. The overwritten area is not read. If + // uConflictingWrite writes into exactly the memory location that is + // being read by uRead, this is not a conflict. + // + // In the above example: + // uRead = OpOperand 1 (%t) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%0) of linalg.fill + // + // The read of %t does not conflict with the write of the FillOp + // (same aliases!) because the area that the FillOp operates on is + // exactly the one that is *not* read via %t. + continue; + + if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && + uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && + hasMatchingExtractSliceOp(uRead->get(), insertSliceOp)) + // Case 2: The read of the source tensor and the write to the dest + // tensor via an InsertSliceOp is not a conflict if the read is + // reading exactly that part of an equivalent tensor that the + // InsertSliceOp is writing. + // + // In the above example: + // uRead = OpOperand 0 (%1) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice + continue; } - LDBG("---->not clobbered -> found an interference\n"); + + // All requirements are met. Conflict found! + LDBG("CONFLICT CONFIRMED!\n\n"); return true; } } - LDBG("----No interference found\n"); + + LDBG("NOT A CONFLICT!\n\n"); return false; } -/// Return true if it is possible to find an inplace write W among the uses of -/// aliasInfo[result], and a read R among the uses of aliasInfo[result], -/// such that W and R interfere. -/// Such a (W, R) pair is an interference to the inplace bufferization of -/// opResult when: -/// 1. R is not known to properly dominate W (i.e. the effects of the write -/// may be visible from R). -/// 2. one cannot find an intermediate clobbering write `C` to W, such that -/// C interleaved between W and R (i.e. W -> C -> R where -> denotes -/// dominance). +/// Return true if bufferizing result inplace would create a conflict. A read R +/// and a write W of the same alias set is a conflict if inplace bufferization +/// of W changes the value read by R to a value different from the one that +/// would be expected by tracing back R's origin through SSA use-def chains. +/// A conflict can only be introduced by a new alias and/or an inplace +/// bufferization decision. +/// +/// Example: +/// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?} +/// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor +/// %e = tensor.extract_slice %1 +/// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor +/// %3 = vector.transfer_read %e, %cst : tensor, vector<7xf32> +/// +/// In the above example, the two TransferWriteOps have already been decided to +/// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a +/// conflict because: +/// * According to SSA use-def chains, we expect to read the result of %1. +/// * However, adding an alias {%0, %t} would mean that the second +/// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp +/// would no longer be reading the result of %1. bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference( OpOperand &operand, OpResult result, const DominanceInfo &domInfo) const { #ifndef NDEBUG @@ -887,93 +1016,34 @@ "operand and result do not match"); #endif // NDEBUG - Operation *opToBufferize = result.getDefiningOp(); - Value opResult = result; - Value opOperand = operand.get(); - - LDBG("----Start wouldCreateReadAfterWriteInterference\n"); - LDBG("--------consider all aliases to root read: " - << printValueInfo(opOperand) << "\n"); - LDBG("--------consider all aliases to root write: " - << printValueInfo(opResult) << "\n"); - - /// Helper function to iterate on aliases of `root` and capture the reads. + // Helper function to iterate on aliases of `root` and capture the reads. auto getAliasingReads = [&](DenseSet &res, Value root) { - for (Value alias : getAliases(root)) { - for (auto &use : alias.getUses()) { + for (Value alias : getAliases(root)) + for (auto &use : alias.getUses()) // Read to a value that aliases root. - if (bufferizesToMemoryRead(use)) { - LDBG("------------bufferizesToMemoryRead: " - << use.getOwner()->getName().getStringRef() << "\n"); + if (bufferizesToMemoryRead(use)) res.insert(&use); - } - } - } }; - /// Helper function to iterate on aliases of `root` and capture the writes. + // Helper function to iterate on aliases of `root` and capture the writes. auto getAliasingInplaceWrites = [&](DenseSet &res, Value root) { - for (Value alias : getAliases(root)) { - for (auto &use : alias.getUses()) { + for (Value alias : getAliases(root)) + for (auto &use : alias.getUses()) // Inplace write to a value that aliases root. - if (isInplaceMemoryWrite(use)) { - LDBG("------------bufferizesToMemoryWrite: " - << use.getOwner()->getName().getStringRef() << "\n"); + if (isInplaceMemoryWrite(use)) res.insert(&use); - } - } - } }; - // Check if we can find any interference between reads to aliases[`opOperand`] - // and writes to aliases[`opResult`]. This handles the case: - // - // ``` - // %0 = op_to_bufferize_maybe_inplace(%1) - // %2 = some_alias(%0) - // inplace_write(%2) - // %3 = some_alias(%1) - // read(%3) - // ``` + // Collect reads and writes of all aliases of OpOperand and OpResult. DenseSet usesRead, usesWrite; - LDBG("--------\n"); - LDBG("--------Test reads(opOperand) vs writes(opResult)\n"); - getAliasingReads(usesRead, opOperand); - getAliasingInplaceWrites(usesWrite, opResult); - // Additionally, `result` is not yet bufferized and we need to check for - // interferences as if it were bufferized inplace: add `operand` if it is a - // write. This handles the case: - // - // ``` - // %0 = op_to_bufferize_maybe_inplace(%1) - // %2 = some_alias(%1) - // read(%2) - // ``` + getAliasingReads(usesRead, operand.get()); + getAliasingReads(usesRead, result); + getAliasingInplaceWrites(usesWrite, operand.get()); + getAliasingInplaceWrites(usesWrite, result); if (bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); - if (wouldCreateReadAfterWriteInterference(opToBufferize, usesRead, usesWrite, - domInfo)) - return true; - // Check if we can find any interference between writes to - // aliases[`opOperand`] and reads to aliases[`opResult`]. This handles the - // case: - // - // ``` - // %0 = op_to_bufferize_maybe_inplace(%1) - // %2 = some_alias(%1) - // inplace_write(%2) - // %3 = some_alias(%0) - // read(%3) - // ``` - LDBG("--------\n"); - LDBG("--------Test reads(opResult) vs writes(opOperand)\n"); - usesRead.clear(); - usesWrite.clear(); - getAliasingReads(usesRead, opResult); - getAliasingInplaceWrites(usesWrite, opOperand); - return wouldCreateReadAfterWriteInterference(opToBufferize, usesRead, - usesWrite, domInfo); + return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo); } /// Return true if bufferizing `opOperand` inplace with `opResult` would create @@ -1105,125 +1175,12 @@ return false; if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) return false; + // TODO: Is the following needed? if (!equivalentInfo.isEquivalent(st.result(), sti.source())) return false; return true; } -/// Return true if there is a `candidateOp` that would write to memory after -/// bufferization and such that: -/// 1. The written buffer is equivalent to either `aliasingRead` or -/// `aliasingWrite` under the inPlace bufferization decisions taken -/// so far. -/// 2. `aliasingWrite` properly dominates `candidateOp`. -/// 3. `candidateOp` properly dominates `aliasingReadOp`. -// TODO: richer clobbering analysis with container-containee relationship -// instead of equivalence. -bool BufferizationAliasInfo::existsInterleavedValueClobber( - OpOperand &aliasingRead, OpOperand &aliasingWrite, - const DominanceInfo &domInfo) const { - Operation *aliasingReadOp = aliasingRead.getOwner(); - Operation *aliasingWriteOp = aliasingWrite.getOwner(); - assert(!domInfo.properlyDominates(aliasingReadOp, aliasingWriteOp) && - "Unexpected aliasingReadOp properly dominates aliasingWriteOp"); - - for (Value valueToClobber : {aliasingRead.get(), aliasingWrite.get()}) { - auto leaderIt = equivalentInfo.findLeader(valueToClobber); - for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; - ++mit) { - Operation *candidateOp = mit->v.getDefiningOp(); - if (!candidateOp) - continue; - SmallVector operands = - getAliasingOpOperand(mit->v.cast()); - assert(operands.size() <= 1 && "more than 1 OpOperand not supported yet"); - // TODO: Should we check for isInplaceMemoryWrite instead? - if (operands.empty() || !bufferizesToMemoryWrite(*operands.front())) - continue; - LDBG("---->clobbering candidate: " << printOperationInfo(candidateOp) - << '\n'); - if (domInfo.properlyDominates(aliasingWriteOp, candidateOp) && - domInfo.properlyDominates(candidateOp, aliasingReadOp)) - return true; - } - } - return false; -} - -/// Return true if there is a write that: -/// 1. Properly dominates aliasingReadOp. -/// 2. Is properly dominated by aliasingWriteOp. -/// 3. Clobbers the write that would be interfering with the read. -/// -bool BufferizationAliasInfo::isClobberedWriteBeforeRead( - Operation *opToBufferize, OpOperand &aliasingRead, OpOperand &aliasingWrite, - const DominanceInfo &domInfo) const { - Operation *aliasingReadOp = aliasingRead.getOwner(); - Operation *aliasingWriteOp = aliasingWrite.getOwner(); - assert(!domInfo.properlyDominates(aliasingReadOp, aliasingWriteOp) && - "Unexpected aliasingReadOp properly dominates aliasingWriteOp"); - - // Bail if the write does not dominate the read: it may clobber but only on - // a strict subset of paths, which is not enough for safety. - if (!domInfo.dominates(aliasingWriteOp, aliasingReadOp)) { - LDBG("---->no clobbering: write does not dominate read\n"); - return false; - } - - // The case `opToBufferize` isa ExtractSliceOp is important enough that we - // look for it specifically. The key information to discover is whether the - // aliasing read or write come from a matching InsertSliceOp. - // Such a pattern is introduced by tiling and is the key inplace condition - // not to miss. - if (auto extractSliceOp = dyn_cast(opToBufferize)) { - if (auto insertSliceOp = dyn_cast(aliasingReadOp)) { - // %1 = extract_slice %0[%offset_sizes_and_strides_1] - // - // ... // 0 or more of inplace compute that reduces to: %X is an - // // aliasingWrite equivalent to %1. - // %W = inplace_write(%1) - // - // // aliasingRead %Y in insert_slice - // ... = insert_slice %W into %R[%offset_sizes_and_strides_1] - if (aliasingRead.get() == insertSliceOp.dest() && - // TODO: This is currently too restrictive and misses clobberings. - // When available, use container-containee analysis: the condition - // should be that the `aliasingWrite` is contained within - // `insertSliceOp.source()`. - equivalentInfo.isEquivalent(aliasingWrite.get(), - insertSliceOp.source()) && - areEquivalentExtractSliceOps(extractSliceOp, insertSliceOp)) { - LDBG("---->clobbering matching extract_slice/insert_slice\n"); - return true; - } - // %1 = extract_slice %0[%offset_sizes_and_strides_1] - // - // ... // bunch of inplace ops that reduce to %X, equivalent to %1. - // %X = inplace_write(%1) - // - // // aliasingRead %X in insert_slice - // // aliasingWrite %Y in insert_slice - // ... = insert_slice %X into %Y[%offset_sizes_and_strides_1] - if (aliasingReadOp == aliasingWriteOp) { - assert(aliasingRead.get() == insertSliceOp.source() && - "expected read to source of insert_slice"); - assert(aliasingWrite.get() == insertSliceOp.dest() && - "expected write to dest of insert_slice"); - if (areEquivalentExtractSliceOps(extractSliceOp, insertSliceOp)) { - LDBG("---->clobbering matching extract_slice/insert_slice\n"); - return true; - } - } - } - } - - // General case: look for a properly interleaved clobber of either exactly - // `aliasingRead` or `aliasingWrite`. - // TODO: Relax this to inclusion instead of double inclusion (a.k.a - // equivalence). We will need to compute container-containee relationship. - return existsInterleavedValueClobber(aliasingRead, aliasingWrite, domInfo); -} - //===----------------------------------------------------------------------===// // Forward declarations. //===----------------------------------------------------------------------===// @@ -2030,6 +1987,7 @@ // - The result is not inplace. This is the case where the whole tensor is // cloned and the clone needs to be updated. auto inPlace = getInPlace(insertSliceOp->getResult(0)); + // TODO: Is this necessary? if (!aliasInfo.isSourceEquivalentToAMatchingInplaceExtractSliceOp( insertSliceOp) || inPlace != InPlaceSpec::True) { diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -364,11 +364,7 @@ // CHECK-NEXT: tensor.extract_slice // CHECK-SAME: {__inplace_results_attr__ = ["true"]} // CHECK-NEXT: tensor.extract_slice - // Atm, this 2nd tensor.extract_slice fails to bufferize inplace because - // clobbering analysis conservatively test for equivalent buffers. - // TODO: This is currently too restrictive and misses clobberings. - // When available, use container-containee analysis. - // CHECK-SAME: {__inplace_results_attr__ = ["false"]} + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} // CHECK-NEXT: tensor.extract_slice // CHECK-SAME: {__inplace_results_attr__ = ["true"]} // CHECK-NEXT: fill @@ -744,9 +740,7 @@ %0 = linalg.fill(%cst, %arg2) : f32, tensor<62x90xf32> -> tensor<62x90xf32> // CHECK: tensor.extract_slice - // CHECK-SAME: {__inplace_results_attr__ = ["false"] - // TODO: in order to have this extract_slice bufferize inplace, we need to write a range - // analysis and determine that intersection([0, 32)x[0, 90), [32, 62)x[0, 90)) is empty. + // CHECK-SAME: {__inplace_results_attr__ = ["true"] %2 = tensor.extract_slice %0[0, 0] [32, 90] [1, 1] : tensor<62x90xf32> to tensor<32x90xf32> // CHECK: vector.transfer_write // CHECK-SAME: {__inplace_results_attr__ = ["true"] @@ -793,3 +787,128 @@ return %r : tensor<10x20xf32> } +// ----- + +#accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)> +] +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel"] +} + +// CHECK-LABEL: func @linalg_op_same_out_tensors +func @linalg_op_same_out_tensors( + %t1: tensor {linalg.inplaceable = true}, + %t2: tensor {linalg.inplaceable = true}) -> (tensor, tensor){ + + // CHECK: linalg.generic + // CHECK-SAME: {__inplace_results_attr__ = ["true", "false"] + %o:2 = linalg.generic #trait ins(%t1 : tensor) + outs (%t2, %t2 : tensor, tensor) { + ^bb(%0: f32, %1: f32, %2 : f32) : + linalg.yield %0, %0 : f32, f32 + } -> (tensor, tensor) + return %o#0, %o#1 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @double_insert_slice_into_alias +func @double_insert_slice_into_alias( + %v1: vector<32x90xf32>, + %v2: vector<30x90xf32>, + %arg2: tensor<62x90xf32> {linalg.inplaceable = true}, + %s1: index, %s2: index, %s3: index, %s4: index) + -> (tensor<62x90xf32>, tensor) +{ + %c0 = arith.constant 0 : index + + // Cannot bufferize inplace this extract_slice because both operand and result + // are modified and returned separately. + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %e = tensor.extract_slice %arg2[%s1, %s2][%s3, %s4][1, 1] : tensor<62x90xf32> to tensor + + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %2 = tensor.extract_slice %arg2[0, 0] [32, 90] [1, 1] : tensor<62x90xf32> to tensor<32x90xf32> + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %7 = vector.transfer_write %v1, %2[%c0, %c0] {in_bounds = [true, true]} : vector<32x90xf32>, tensor<32x90xf32> + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %8 = tensor.insert_slice %7 into %arg2[0, 0] [32, 90] [1, 1] : tensor<32x90xf32> into tensor<62x90xf32> + + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %10 = tensor.extract_slice %e[32, 0] [30, 90] [1, 1] : tensor to tensor<30x90xf32> + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %14 = vector.transfer_write %v2, %10[%c0, %c0] {in_bounds = [true, true]} : vector<30x90xf32>, tensor<30x90xf32> + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %15 = tensor.insert_slice %14 into %e[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor + + return %8, %15 : tensor<62x90xf32>, tensor +} + +// ----- + +// CHECK-LABEL: func @interleaved_extract_insert_slice_chain_1 +func @interleaved_extract_insert_slice_chain_1( + %arg2: tensor<62x90xf32> {linalg.inplaceable = true}) + -> (tensor<62x90xf32>) +{ + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %2 = tensor.extract_slice %arg2[0, 0] [32, 90] [1, 1] : tensor<62x90xf32> to tensor<32x90xf32> + + // TODO: This should bufferize inplace once we have a proper range analysis. + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %10 = tensor.extract_slice %arg2[32, 0] [30, 90] [1, 1] : tensor<62x90xf32> to tensor<30x90xf32> + + + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %8 = tensor.insert_slice %2 into %arg2[0, 0] [32, 90] [1, 1] : tensor<32x90xf32> into tensor<62x90xf32> + + + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %15 = tensor.insert_slice %10 into %8[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32> + + return %15 : tensor<62x90xf32> +} + +// ----- + +// CHECK-LABEL: func @interleaved_extract_insert_slice_chain_2 +func @interleaved_extract_insert_slice_chain_2( + %arg2: tensor<62x90xf32> {linalg.inplaceable = true}) + -> (tensor<62x90xf32>) +{ + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %2 = tensor.extract_slice %arg2[0, 0] [32, 90] [1, 1] : tensor<62x90xf32> to tensor<32x90xf32> + + // The slices are overlapping, so this can never bufferize inplace. + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"] + %10 = tensor.extract_slice %arg2[31, 0] [30, 90] [1, 1] : tensor<62x90xf32> to tensor<30x90xf32> + + + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %8 = tensor.insert_slice %2 into %arg2[0, 0] [32, 90] [1, 1] : tensor<32x90xf32> into tensor<62x90xf32> + + + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"] + %15 = tensor.insert_slice %10 into %8[31, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32> + + return %15 : tensor<62x90xf32> +}