diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -14,6 +14,8 @@ #include namespace mlir { +class DominanceInfo; + namespace bufferization { struct OneShotBufferizationOptions; @@ -63,6 +65,12 @@ AnalysisState::getOptions()); } + /// Analyze the given op and its nested ops. + LogicalResult analyzeOp(Operation *op, const DominanceInfo &domInfo); + + /// Analyze a single op (without nested ops). + LogicalResult analyzeSingleOp(Operation *op, const DominanceInfo &domInfo); + /// Apply `fun` to all the members of the equivalence class of `v`. void applyOnEquivalenceClass(Value v, function_ref fun) const; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -812,61 +812,13 @@ return success(); } -/// Analyze the `ops` to determine which OpOperands are inplaceable. Walk ops in -/// reverse and bufferize ops greedily. This is a good starter heuristic. -/// -/// Even if an op does not read or write, it may still create an alias when -/// bufferized in-place. An example of such ops is tensor.extract_slice. -/// -/// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace: -/// -/// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This -/// cannot change the flow of information for either the source or the -/// result buffers. -/// -/// When bufferized inplace, an ExtractSliceOp does not by itself create any -/// read or write from memory. Instead, it has the effect of merging the alias -/// sets of the source and the result buffers. -/// -/// An analysis is required to ensure inplace bufferization would not result in -/// RaW dependence violations. -static LogicalResult inPlaceAnalysis(SmallVector &ops, - OneShotAnalysisState &state, - const DominanceInfo &domInfo, - unsigned analysisFuzzerSeed = 0) { - if (analysisFuzzerSeed) { - // This is a fuzzer. For testing purposes only. Randomize the order in which - // operations are analyzed. The bufferization quality is likely worse, but - // we want to make sure that no assertions are triggered anywhere. - std::mt19937 g(analysisFuzzerSeed); - llvm::shuffle(ops.begin(), ops.end(), g); - } - - // Analyze a single op. - auto analyzeOp = [&](Operation *op) { - for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) - if (failed(bufferizableInPlaceAnalysisImpl(opOperand, state, domInfo))) - return failure(); - return success(); - }; - - OneShotBufferizationOptions::AnalysisHeuristic heuristic = - state.getOptions().analysisHeuristic; - if (heuristic == OneShotBufferizationOptions::AnalysisHeuristic::BottomUp) { - // Default: Walk ops in reverse for better interference analysis. - for (Operation *op : reverse(ops)) - if (failed(analyzeOp(op))) - return failure(); - } else if (heuristic == - OneShotBufferizationOptions::AnalysisHeuristic::TopDown) { - for (Operation *op : ops) - if (failed(analyzeOp(op))) +LogicalResult +OneShotAnalysisState::analyzeSingleOp(Operation *op, + const DominanceInfo &domInfo) { + for (OpOperand &opOperand : op->getOpOperands()) + if (opOperand.get().getType().isa()) + if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo))) return failure(); - } else { - llvm_unreachable("unsupported heuristic"); - } - return success(); } @@ -877,23 +829,6 @@ return hasTensorResult || hasTensorOperand; } -/// Analyze all ops that are contained in `op`. -static LogicalResult inPlaceAnalysis(Operation *op, - OneShotAnalysisState &state, - const DominanceInfo &domInfo, - unsigned analysisFuzzerSeed = 0) { - // Collect ops so we can build our own reverse traversal. - SmallVector ops; - op->walk([&](Operation *op) { - // No tensors => no buffers. - if (!hasTensorSemantics(op)) - return; - ops.push_back(op); - }); - - return inPlaceAnalysis(ops, state, domInfo, analysisFuzzerSeed); -} - /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. static void equivalenceAnalysis(SmallVector &ops, OneShotAnalysisState &state) { @@ -924,6 +859,45 @@ equivalenceAnalysis(ops, state); } +LogicalResult OneShotAnalysisState::analyzeOp(Operation *op, + const DominanceInfo &domInfo) { + // Collect ops so we can build our own reverse traversal. + SmallVector ops; + op->walk([&](Operation *op) { + // No tensors => no buffers. + if (!hasTensorSemantics(op)) + return; + ops.push_back(op); + }); + + if (getOptions().analysisFuzzerSeed) { + // This is a fuzzer. For testing purposes only. Randomize the order in which + // operations are analyzed. The bufferization quality is likely worse, but + // we want to make sure that no assertions are triggered anywhere. + std::mt19937 g(getOptions().analysisFuzzerSeed); + llvm::shuffle(ops.begin(), ops.end(), g); + } + + OneShotBufferizationOptions::AnalysisHeuristic heuristic = + getOptions().analysisHeuristic; + if (heuristic == OneShotBufferizationOptions::AnalysisHeuristic::BottomUp) { + // Default: Walk ops in reverse for better interference analysis. + for (Operation *op : reverse(ops)) + if (failed(analyzeSingleOp(op, domInfo))) + return failure(); + } else if (heuristic == + OneShotBufferizationOptions::AnalysisHeuristic::TopDown) { + for (Operation *op : ops) + if (failed(analyzeSingleOp(op, domInfo))) + return failure(); + } else { + llvm_unreachable("unsupported heuristic"); + } + + equivalenceAnalysis(op, *this); + return success(); +} + /// Assert that the current bufferization decisions are consistent. static LogicalResult checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, @@ -1060,7 +1034,7 @@ return failure(); // If the analysis fails, just return. - if (failed(inPlaceAnalysis(op, state, domInfo, options.analysisFuzzerSeed))) + if (failed(state.analyzeOp(op, domInfo))) return failure(); if (statistics) { @@ -1068,8 +1042,6 @@ statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace(); } - equivalenceAnalysis(op, state); - bool failedAnalysis = false; if (!options.allowReturnAllocs) failedAnalysis |= failed(assertNoAllocsReturned(op, state));