diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -724,62 +724,42 @@ } } -/// Check the reverse SSA use-def chain (following aliasing OpOperands) for -/// non-writable tensor values. Stop searching when an out-of-place bufferized -/// OpOperand was found (or when the OpOperand was not bufferized yet). -/// `currentOpOperand` is assumed to be in-place, even if that decision was not -/// materialized in `aliasInfo` yet. -static bool -hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand, - const OneShotAnalysisState &state) { - SmallVector worklist; - worklist.push_back(value); - while (!worklist.empty()) { - Value nextVal = worklist.pop_back_val(); - if (!state.isWritable(nextVal)) { - if (state.getOptions().printConflicts) - annotateNonWritableTensor(nextVal); - return true; - } - - // If `nextVal` is not a BlockArgument: End of use-def chain reached. - auto opResult = nextVal.dyn_cast(); - if (!opResult) - continue; - - // Follow reverse SSA use-def chain. - AliasingOpOperandList aliasingOpOperands = - state.getAliasingOpOperands(opResult); - for (OpOperand *opOperand : aliasingOpOperands) - if (state.isInPlace(*opOperand) || currentOpOperand == opOperand) - worklist.push_back(opOperand->get()); - } - return false; -} - /// Return true if bufferizing `operand` inplace would create a write to a /// non-writable buffer. static bool wouldCreateWriteToNonWritableBuffer(OpOperand &operand, OneShotAnalysisState &state, bool checkConsistencyOnly = false) { - // Collect writes of all aliases of OpOperand and OpResult. - DenseSet usesWrite; - getAliasingInplaceWrites(usesWrite, operand.get(), state); - for (OpResult result : state.getAliasingOpResults(operand)) { - getAliasingInplaceWrites(usesWrite, result, state); + bool foundWrite = + !checkConsistencyOnly && state.bufferizesToMemoryWrite(operand); + + if (!foundWrite) { + // Collect writes of all aliases of OpOperand and OpResult. + DenseSet usesWrite; + getAliasingInplaceWrites(usesWrite, operand.get(), state); + for (OpResult result : state.getAliasingOpResults(operand)) + getAliasingInplaceWrites(usesWrite, result, state); + foundWrite = !usesWrite.empty(); } - if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) - usesWrite.insert(&operand); - // Assuming that `operand` bufferizes in-place: For each write (to each - // alias), check if there is a non-writable tensor in the reverse SSA use-def - // chain. - for (OpOperand *uWrite : usesWrite) { - if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, state)) { - LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n"); - return true; + if (!foundWrite) + return false; + + // Look for a read-only tensor among all aliases. + bool foundReadOnly = false; + auto checkReadOnly = [&](Value v) { + if (!state.isWritable(v)) { + foundReadOnly = true; + if (state.getOptions().printConflicts) + annotateNonWritableTensor(v); } + }; + state.applyOnAliases(operand.get(), checkReadOnly); + for (OpResult result : state.getAliasingOpResults(operand)) + state.applyOnAliases(result, checkReadOnly); + if (foundReadOnly) { + LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n"); + return true; } return false;