diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -421,7 +421,7 @@ /// /// Note: OpResults of unknown ops are handled conservatively and assumed to /// be definitions. - SetVector findDefinitions(Value value) const; + const SetVector &findDefinitions(Value value); /// Return `true` if the given OpResult has been decided to bufferize inplace. virtual bool isInPlace(OpOperand &opOperand) const; @@ -465,6 +465,9 @@ /// The type of analysis. TypeID type; + + /// Cache definitions of tensor values. + DenseMap> cachedDefinitons; }; /// Create an AllocTensorOp for the given shaped value (memref or tensor). diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -510,10 +510,13 @@ } // Find the values that define the contents of the given value. -llvm::SetVector AnalysisState::findDefinitions(Value value) const { - return findValueInReverseUseDefChain( - value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, - /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false); +const llvm::SetVector &AnalysisState::findDefinitions(Value value) { + if (!cachedDefinitons.count(value)) { + cachedDefinitons[value] = findValueInReverseUseDefChain( + value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, + /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false); + } + return cachedDefinitons[value]; } AnalysisState::AnalysisState(const BufferizationOptions &options) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -511,7 +511,7 @@ // In the above example, if uRead is the OpOperand of reading_op, the // definition is %0. Note that operations that create an alias but do not // bufferize to a memory write (such as ExtractSliceOp) are skipped. - SetVector definitions = state.findDefinitions(uRead->get()); + const SetVector &definitions = state.findDefinitions(uRead->get()); if (definitions.empty()) { // Fast path: No conflict if there are no definitions. LLVM_DEBUG(llvm::dbgs()