diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -122,6 +122,13 @@ /// Return true if the buffer of the given tensor value is writable. bool isWritable(Value value) const; + /// Find the definitions of the given tensor value or retrieve them from the + /// cache. + const SetVector &findDefinitionsCached(Value value); + + /// Reset cached data structures. + void resetCache(); + /// Union the alias sets of `v1` and `v2`. void unionAliasSets(Value v1, Value v2); @@ -226,6 +233,9 @@ /// Check that aliasInfo for `v` exists and return a reference to it. EquivalenceClassRangeType getAliases(Value v) const; + /// Cache definitions of tensor values. + DenseMap> cachedDefinitions; + /// Set of all OpResults that were decided to bufferize in-place. llvm::DenseSet inplaceBufferized; diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h @@ -16,6 +16,7 @@ namespace bufferization { class AnalysisState; struct BufferizationStatistics; +class OneShotAnalysisState; struct OneShotBufferizationOptions; /// A function that matches anchor OpOperands for tensor::EmptyOp elimination. @@ -36,7 +37,7 @@ /// following the aliasing OpOperand, that eventually ends at a single /// tensor::EmptyOp. LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op, - bufferization::AnalysisState &state, + OneShotAnalysisState &state, AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc); @@ -44,7 +45,7 @@ /// InsertSliceOp, i.e., if it is eventually inserted into another tensor /// (and some other conditions are met). LogicalResult insertSliceAnchoredEmptyTensorEliminationStep( - RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state); + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state); /// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops. /// After applying this transform, the IR can be bufferized without inserting diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -105,7 +105,7 @@ /// chain, starting from the OpOperand and always following the aliasing /// OpOperand, that eventually ends at a single tensor::EmptyOp. LogicalResult mlir::bufferization::eliminateEmptyTensors( - RewriterBase &rewriter, Operation *op, AnalysisState &state, + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state, AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) { OpBuilder::InsertionGuard g(rewriter); @@ -153,6 +153,7 @@ // Replace the tensor::EmptyOp. rewriter.replaceOp(emptyTensor.getDefiningOp(), replacement); + state.resetCache(); } // Advance to the next operation. @@ -189,7 +190,7 @@ /// tensor::EmptyOp. template static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep( - RewriterBase &rewriter, Operation *op, AnalysisState &state) { + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { return eliminateEmptyTensors( rewriter, op, state, /*anchorMatchFunc=*/ @@ -224,7 +225,7 @@ LogicalResult mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep( - RewriterBase &rewriter, Operation *op, AnalysisState &state) { + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep< tensor::InsertSliceOp>(rewriter, op, state))) return failure(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -222,7 +222,7 @@ // If there is no preceding definition, the tensor contents are // undefined. - if (findDefinitions(opResult).empty()) + if (findDefinitionsCached(opResult).empty()) for (OpOperand &use : opResult.getUses()) undefinedTensorUses.insert(&use); } @@ -473,7 +473,8 @@ // In the above example, if uRead is the OpOperand of reading_op, the // definition is %0. Note that operations that create an alias but do not // bufferize to a memory write (such as ExtractSliceOp) are skipped. - SetVector definitions = state.findDefinitions(uRead->get()); + const SetVector &definitions = + state.findDefinitionsCached(uRead->get()); if (definitions.empty()) { // Fast path: No conflict if there are no definitions. LLVM_DEBUG(llvm::dbgs() @@ -769,6 +770,19 @@ // Bufferization analyses. //===----------------------------------------------------------------------===// +// Find the values that define the contents of the given value. +const llvm::SetVector & +OneShotAnalysisState::findDefinitionsCached(Value value) { + if (!cachedDefinitions.count(value)) { + cachedDefinitions[value] = findValueInReverseUseDefChain( + value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, + /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false); + } + return cachedDefinitions[value]; +} + +void OneShotAnalysisState::resetCache() { cachedDefinitions.clear(); } + /// Determine if `operand` can be bufferized in-place. static LogicalResult bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,