diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -355,7 +355,8 @@ /// traversed any further. /// /// When reaching the end of a chain (BlockArgument or Value without aliasing - /// OpOperands), also return the last Value of that chain. + /// OpOperands), also return the last Value of that chain if + /// `alwaysIncludeLeaves` is set. /// /// Example: /// @@ -374,20 +375,41 @@ /// { 2, 7, 8, 5 } /// /// If `followEquivalentOnly` is set, only equivalent OpOperands are selected. - SetVector - findValueInReverseUseDefChain(Value value, - llvm::function_ref condition, - bool followEquivalentOnly = false) const; - - /// Find the Values of the last preceding write of a given Value. + SetVector findValueInReverseUseDefChain( + Value value, llvm::function_ref condition, + bool followEquivalentOnly = false, bool alwaysIncludeLeaves = true) const; + + /// Find the values that may define the contents of the given value at + /// runtime. A block argument is always a definition. An OpResult is a + /// definition if it bufferizes to memory write. If it does not bufferize to + /// a memory write but has aliasing operands, we continue the lookup on these + /// values. + /// + /// Example: %r = tensor.insert %f into %t[%c0] : tensor + /// findDefinitions(%r) = {%r} because %r bufferizes to memory write. + /// + /// Example: %r = tensor.empty() : tensor<10xf32> + /// findDefinitions(%r) = {} because tensor.empty does not the define the + /// contents of its result (i.e., it does not bufferize to a memory write) + /// and it has no aliasing OpOperands. + /// + /// Example: + /// %a = arith.constant ... : tensor<10xf32> + /// %b1 = tensor.insert %f into %t : tensor<50xf32> + /// %b2 = tensor.extract_slice %b1[0][10][1] : tensor<50xf32> tensor<10xf32> + /// %r = arith.select %cond, %a, %b : tensor<10xf32> + /// findDefinitions(%r) = {%a, %b1}. %r and %b2 are skipped (lookup continues + /// in the operands) because their defining ops do not define the contents of + /// the tensor. /// - /// Note: Unknown ops are handled conservatively and assumed to be writes. - /// Furthermore, BlockArguments are also assumed to be writes. There is no - /// analysis across block boundaries. + /// Note: OpResults of unknown ops are handled conservatively and assumed to + /// be definitions. /// /// Note: When reaching an end of the reverse SSA use-def chain, that value - /// is returned regardless of whether it is a memory write or not. - SetVector findLastPrecedingWrite(Value value) const; + /// is included regardless of whether it is a definition or not unless + /// `alwaysIncludeLeaves` is unset. + SetVector findDefinitions(Value value, + bool alwaysIncludeLeaves = true) const; /// Return `true` if the given OpResult has been decided to bufferize inplace. virtual bool isInPlace(OpOperand &opOperand) const; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -444,7 +444,7 @@ // further. llvm::SetVector AnalysisState::findValueInReverseUseDefChain( Value value, llvm::function_ref condition, - bool followEquivalentOnly) const { + bool followEquivalentOnly, bool alwaysIncludeLeaves) const { llvm::SetVector result, workingSet; workingSet.insert(value); @@ -469,7 +469,8 @@ (followEquivalentOnly && bufferizableOp.bufferRelation(opResult, *this) != BufferRelation::Equivalent)) { - result.insert(value); + if (alwaysIncludeLeaves) + result.insert(value); continue; } @@ -480,11 +481,12 @@ return result; } -// Find the Values of the last preceding write of a given Value. +// Find the values that define the contents of the given value. llvm::SetVector -AnalysisState::findLastPrecedingWrite(Value value) const { +AnalysisState::findDefinitions(Value value, bool alwaysIncludeLeaves) const { return findValueInReverseUseDefChain( - value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }); + value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, + /*followEquivalentOnly=*/false, alwaysIncludeLeaves); } AnalysisState::AnalysisState(const BufferizationOptions &options) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -270,16 +270,9 @@ if (!opResult.getType().isa()) continue; - // If there is no preceding memory write, the tensor contents are + // If there is no preceding definition, the tensor contents are // undefined. - // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA - // use-def chain, it returns that value, regardless of whether it is a - // memory write or not. - SetVector lastWrites = findLastPrecedingWrite(opResult); - bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) { - return this->bufferizesToMemoryWrite(lastWrite); - }); - if (isUndefined) + if (findDefinitions(opResult, /*alwaysIncludeLeaves=*/false).empty()) for (OpOperand &use : opResult.getUses()) undefinedTensorUses.insert(&use); } @@ -471,7 +464,7 @@ /// Annotate IR with details about the detected RaW conflict. static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, - Value lastWrite) { + Value definition) { static uint64_t counter = 0; Operation *readingOp = uRead->getOwner(); Operation *conflictingWritingOp = uConflictingWrite->getOwner(); @@ -489,16 +482,15 @@ id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]"; readingOp->setAttr(readAttr, b.getUnitAttr()); - if (auto opResult = lastWrite.dyn_cast()) { - std::string lastWriteAttr = id + "[LAST-WRITE: result " + - std::to_string(opResult.getResultNumber()) + - "]"; - opResult.getDefiningOp()->setAttr(lastWriteAttr, b.getUnitAttr()); + if (auto opResult = definition.dyn_cast()) { + std::string defAttr = + id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]"; + opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr()); } else { - auto bbArg = lastWrite.cast(); - std::string lastWriteAttr = - id + "[LAST-WRITE: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; - bbArg.getOwner()->getParentOp()->setAttr(lastWriteAttr, b.getUnitAttr()); + auto bbArg = definition.cast(); + std::string defAttr = + id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]"; + bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr()); } } @@ -507,8 +499,8 @@ /// all given writes bufferize inplace. /// /// A conflict is: According to SSA use-def chains, a read R is supposed to read -/// the result of a write W1. But because of bufferization decisions, R actually -/// reads another write W2. +/// the result of a definition W1. But because of bufferization decisions, R +/// actually reads another definition W2. static bool hasReadAfterWriteInterference( const DenseSet &usesRead, const DenseSet &usesWrite, const DominanceInfo &domInfo, @@ -529,10 +521,10 @@ // %1 = "aliasing_op"(%0) : tensor -> tensor // %2 = "reading_op"(%1) : : tensor -> not_a_tensor_type // - // In the above example, if uRead is the OpOperand of reading_op, lastWrite - // is %0. Note that operations that create an alias but do not write (such - // as ExtractSliceOp) are skipped. - SetVector lastWrites = state.findLastPrecedingWrite(uRead->get()); + // In the above example, if uRead is the OpOperand of reading_op, the + // definition is %0. Note that operations that create an alias but do not + // bufferize to a memory write (such as ExtractSliceOp) are skipped. + SetVector definitions = state.findDefinitions(uRead->get()); // Look for conflicting memory writes. Potential conflicts are writes to an // alias that have been decided to bufferize inplace. @@ -611,31 +603,30 @@ } } - // Check all possible last writes. - for (Value lastWrite : lastWrites) { - LLVM_DEBUG(llvm::dbgs() << " * lastWrite = " << lastWrite << "\n"); + // Check all possible definitions. + for (Value definition : definitions) { + LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n"); - // No conflict if the conflicting write happens before the last - // write. - if (Operation *writingOp = lastWrite.getDefiningOp()) { + // No conflict if the conflicting write happens before the definition. + if (Operation *writingOp = definition.getDefiningOp()) { if (happensBefore(conflictingWritingOp, writingOp, domInfo)) { // conflictingWritingOp happens before writingOp. No conflict. LLVM_DEBUG(llvm::dbgs() - << " no conflict: write happens before last write\n"); + << " no conflict: write happens before definition\n"); continue; } // No conflict if conflictingWritingOp is contained in writingOp. if (writingOp->isProperAncestor(conflictingWritingOp)) { LLVM_DEBUG( llvm::dbgs() - << " no conflict: write is contained in last write\n"); + << " no conflict: write is contained in definition\n"); continue; } } else { - auto bbArg = lastWrite.cast(); + auto bbArg = definition.cast(); Block *block = bbArg.getOwner(); if (!block->findAncestorOpInBlock(*conflictingWritingOp)) { - LLVM_DEBUG(llvm::dbgs() << " no conflict: last write is bbArg " + LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg " "and write happens outside of block\n"); // conflictingWritingOp happens outside of the block. No // conflict. @@ -643,20 +634,20 @@ } } - // No conflict if the conflicting write and the last write are the same + // No conflict if the conflicting write and the definition are the same // use. SmallVector aliasingOpResult = state.getAliasingOpResult(*uConflictingWrite); - if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite) { + if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == definition) { LLVM_DEBUG(llvm::dbgs() - << " no conflict: last write and write are same\n"); + << " no conflict: definition and write are same\n"); continue; } // All requirements are met. Conflict found! if (options.printConflicts) - annotateConflict(uRead, uConflictingWrite, lastWrite); + annotateConflict(uRead, uConflictingWrite, definition); LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n"); return true; } @@ -734,8 +725,8 @@ /// conflict because: /// * According to SSA use-def chains, we expect to read the result of %1. /// * However, adding an alias {%0, %t} would mean that the second -/// TransferWriteOp overwrites the first one. Therefore, the TransferReadOp -/// would no longer be reading the result of %1. +/// TransferWriteOp overwrites the result of the first one. Therefore, the +/// TransferReadOp would no longer be reading the result of %1. /// /// If `checkConsistencyOnly` is true, this function checks if there is a /// read-after-write conflict without bufferizing `operand` inplace. This would diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -712,7 +712,7 @@ // In the above example: // uRead = OpOperand 0 (%1) of vector.transfer_read // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice - // lastWrite = %1 + // definition = %1 // // This is not a conflict because the InsertSliceOp overwrites the // memory segment of %1 with the exact same data. (Effectively, there