diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -57,6 +57,10 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" +// Run mlir-opt with `-debug-only="one-shot-analysis"` for detailed debug +// output. +#define DEBUG_TYPE "one-shot-analysis" + using namespace mlir; using namespace mlir::bufferization; @@ -553,6 +557,7 @@ // Check if op dominance can be used to rule out read-after-write conflicts. bool useDominance = canUseOpDominance(usesRead, usesWrite, state); + LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n"); for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); @@ -572,6 +577,14 @@ // Look for conflicting memory writes. Potential conflicts are writes to an // alias that have been decided to bufferize inplace. for (OpOperand *uConflictingWrite : usesWrite) { + LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n"); + LLVM_DEBUG(llvm::dbgs() + << " uRead = operand " << uRead->getOperandNumber() << " of " + << *uRead->getOwner() << "\n"); + LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand " + << uConflictingWrite->getOperandNumber() << " of " + << *uConflictingWrite->getOwner() << "\n"); + // Throughout this loop, check for multiple requirements that have to be // met for uConflictingWrite to be an actual conflict. Operation *conflictingWritingOp = uConflictingWrite->getOwner(); @@ -585,8 +598,11 @@ // Note: If ops are executed multiple times (e.g., because they are // inside a loop), there may be no meaningful `happensBefore` // relationship. - if (happensBefore(readingOp, conflictingWritingOp, domInfo)) + if (happensBefore(readingOp, conflictingWritingOp, domInfo)) { + LLVM_DEBUG(llvm::dbgs() + << " no conflict: read happens before write\n"); continue; + } // No conflict if the reading use equals the use of the conflicting // write. A use cannot conflict with itself. @@ -595,61 +611,93 @@ // use. // Note: If the op is executed multiple times (e.g., because it is // inside a loop), it may be conflicting with itself. - if (uConflictingWrite == uRead) + if (uConflictingWrite == uRead) { + LLVM_DEBUG(llvm::dbgs() + << " no conflict: read and write are same use\n"); continue; + } // Ops are not conflicting if they are in mutually exclusive regions. // // Note: If ops are executed multiple times (e.g., because they are // inside a loop), mutually exclusive regions may be executed // multiple times. - if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) + if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) { + LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in " + "mutually exclusive regions\n"); continue; + } } // No conflict if the op interface says so. - if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) - if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) + if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) { + if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) { + LLVM_DEBUG(llvm::dbgs() + << " no conflict: op interace of reading op says 'no'\n"); continue; + } + } - if (conflictingWritingOp != readingOp) + if (conflictingWritingOp != readingOp) { if (auto bufferizableOp = - options.dynCastBufferizableOp(conflictingWritingOp)) - if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) + options.dynCastBufferizableOp(conflictingWritingOp)) { + if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, + state)) { + LLVM_DEBUG( + llvm::dbgs() + << " no conflict: op interace of writing op says 'no'\n"); continue; + } + } + } // Check all possible last writes. for (Value lastWrite : lastWrites) { + LLVM_DEBUG(llvm::dbgs() << " * lastWrite = " << lastWrite << "\n"); + // No conflict if the conflicting write happens before the last // write. if (Operation *writingOp = lastWrite.getDefiningOp()) { - if (happensBefore(conflictingWritingOp, writingOp, domInfo)) + if (happensBefore(conflictingWritingOp, writingOp, domInfo)) { // conflictingWritingOp happens before writingOp. No conflict. + LLVM_DEBUG(llvm::dbgs() + << " no conflict: write happens before last write\n"); continue; + } // No conflict if conflictingWritingOp is contained in writingOp. - if (writingOp->isProperAncestor(conflictingWritingOp)) + if (writingOp->isProperAncestor(conflictingWritingOp)) { + LLVM_DEBUG( + llvm::dbgs() + << " no conflict: write is contained in last write\n"); continue; + } } else { auto bbArg = lastWrite.cast(); Block *block = bbArg.getOwner(); - if (!block->findAncestorOpInBlock(*conflictingWritingOp)) + if (!block->findAncestorOpInBlock(*conflictingWritingOp)) { + LLVM_DEBUG(llvm::dbgs() << " no conflict: last write is bbArg " + "and write happens outside of block\n"); // conflictingWritingOp happens outside of the block. No // conflict. continue; + } } // No conflict if the conflicting write and the last write are the same // use. SmallVector aliasingOpResult = state.getAliasingOpResult(*uConflictingWrite); - if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite) + if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite) { + LLVM_DEBUG(llvm::dbgs() + << " no conflict: last write and write are same\n"); continue; + } // All requirements are met. Conflict found! if (options.printConflicts) annotateConflict(uRead, uConflictingWrite, lastWrite); - + LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n"); return true; } } @@ -803,10 +851,13 @@ // Assuming that `operand` bufferizes in-place: For each write (to each // alias), check if there is a non-writable tensor in the reverse SSA use-def // chain. - for (OpOperand *uWrite : usesWrite) + for (OpOperand *uWrite : usesWrite) { if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, - aliasInfo, state)) + aliasInfo, state)) { + LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n"); return true; + } + } return false; } @@ -819,6 +870,11 @@ static LogicalResult bufferizableInPlaceAnalysisImpl( OpOperand &operand, BufferizationAliasInfo &aliasInfo, OneShotAnalysisState &state, const DominanceInfo &domInfo) { + LLVM_DEBUG( + llvm::dbgs() << "//===-------------------------------------------===//\n" + << "Analyzing operand #" << operand.getOperandNumber() + << " of " << *operand.getOwner() << "\n"); + bool foundInterference = wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) || wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo); @@ -828,6 +884,8 @@ else aliasInfo.bufferizeInPlace(operand, state); + LLVM_DEBUG(llvm::dbgs() + << "//===-------------------------------------------===//\n"); return success(); }