diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -371,6 +371,23 @@ /// Return `true` if the given value is a BlockArgument of a func::FuncOp. bool isFunctionArgument(Value value); +/// Traversal parameters for `findValueInReverseUseDefChain`. +struct TraversalConfig { + /// Specifies if leaves (that do not further OpOperands to follow) should + /// be returned even if they do not match the specified filter. + bool alwaysIncludeLeaves = true; + + /// Specifies whether out-of-place/undecided OpOperands should be followed. + bool followInPlaceOnly = false; + + /// Specifies whether non-equivalent OpOperands should be followed. + bool followEquivalentOnly = false; + + /// Specifies whether unknown/non-bufferizable/ops not included in the + /// OpFilter of BufferizationOptions should be follwed. + bool followUnknownOps = false; +}; + /// AnalysisState provides a variety of helper functions for dealing with /// tensor values. class AnalysisState { @@ -416,9 +433,8 @@ /// `condition` evaluates to true. OpOperands of such matching Values are not /// traversed any further. /// - /// When reaching the end of a chain (BlockArgument or Value without aliasing - /// OpOperands), also return the last Value of that chain if - /// `alwaysIncludeLeaves` is set. + /// When reaching the end of a chain, also return the last Value of that + /// chain if `config.alwaysIncludeLeaves` is set. /// /// Example: /// @@ -436,10 +452,11 @@ /// starting the traversal from Value 1, the resulting SetVector is: /// { 2, 7, 8, 5 } /// - /// If `followEquivalentOnly` is set, only equivalent OpOperands are selected. + /// Additional stopping conditions for the traversal can be specified in + /// `config`. SetVector findValueInReverseUseDefChain( Value value, llvm::function_ref condition, - bool followEquivalentOnly = false, bool alwaysIncludeLeaves = true) const; + TraversalConfig config = TraversalConfig()) const; /// Find the values that may define the contents of the given value at /// runtime. A block argument is always a definition. An OpResult is a diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -358,10 +358,7 @@ BufferizableOpInterface BufferizationOptions::dynCastBufferizableOp(Value value) const { - if (auto bufferizableOp = value.getDefiningOp()) - if (isOpAllowed(bufferizableOp.getOperation())) - return bufferizableOp; - return nullptr; + return dynCastBufferizableOp(getOwnerOfValue(value)); } //===----------------------------------------------------------------------===// @@ -476,13 +473,14 @@ // further. llvm::SetVector AnalysisState::findValueInReverseUseDefChain( Value value, llvm::function_ref condition, - bool followEquivalentOnly, bool alwaysIncludeLeaves) const { + TraversalConfig config) const { llvm::SetVector result, workingSet; workingSet.insert(value); while (!workingSet.empty()) { Value value = workingSet.pop_back_val(); if (condition(value) || value.isa()) { + // Stop iterating if the value is a match or a BlockArgument was reached. result.insert(value); continue; } @@ -490,25 +488,42 @@ OpResult opResult = value.cast(); BufferizableOpInterface bufferizableOp = options.dynCastBufferizableOp(opResult.getDefiningOp()); - AliasingOpOperandList aliases = getAliasingOpOperands(opResult); + if (!config.followUnknownOps && !bufferizableOp) { + // Stop iterating if `followUnknownOps` is unset and the op is either + // not bufferizable or excluded in the OpFilter. + if (config.alwaysIncludeLeaves) + result.insert(value); + continue; + } - // Stop iterating in either one of these cases: - // * The current op is not bufferizable or excluded in the filter. - // * There are no OpOperands to follow. - if (!bufferizableOp || aliases.getNumAliases() == 0) { - if (alwaysIncludeLeaves) + AliasingOpOperandList aliases = getAliasingOpOperands(opResult); + if (aliases.getNumAliases() == 0) { + // The traversal ends naturally if there are no more OpOperands that + // could be followed. + if (config.alwaysIncludeLeaves) result.insert(value); continue; } for (AliasingOpOperand a : aliases) { - if (followEquivalentOnly && a.relation != BufferRelation::Equivalent) { + if (config.followEquivalentOnly && + a.relation != BufferRelation::Equivalent) { // Stop iterating if `followEquivalentOnly` is set but the alias is not // equivalent. - result.insert(value); - } else { - workingSet.insert(a.opOperand->get()); + if (config.alwaysIncludeLeaves) + result.insert(value); + continue; + } + + if (config.followInPlaceOnly && !isInPlace(*a.opOperand)) { + // Stop iterating if `followInPlaceOnly` is set but the alias is + // out-of-place. + if (config.alwaysIncludeLeaves) + result.insert(value); + continue; } + + workingSet.insert(a.opOperand->get()); } } @@ -517,9 +532,10 @@ // Find the values that define the contents of the given value. llvm::SetVector AnalysisState::findDefinitions(Value value) const { + TraversalConfig config; + config.alwaysIncludeLeaves = false; return findValueInReverseUseDefChain( - value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, - /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false); + value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config); } AnalysisState::AnalysisState(const BufferizationOptions &options) @@ -895,8 +911,7 @@ for (AliasingOpOperand alias : opOperands) { if (!state .findValueInReverseUseDefChain(alias.opOperand->get(), - isMemoryWriteInsideOp, - /*followEquivalentOnly=*/false) + isMemoryWriteInsideOp) .empty()) return true; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -132,10 +132,13 @@ // Find tensor.empty ops on the reverse SSA use-def chain. Only follow // equivalent tensors. I.e., stop when there are ops such as extract_slice // on the path. + TraversalConfig config; + config.followEquivalentOnly = true; + config.alwaysIncludeLeaves = false; SetVector emptyTensors = state.findValueInReverseUseDefChain( operand.get(), /*condition=*/ [&](Value val) { return val.getDefiningOp(); }, - /*followEquivalentOnly=*/true, /*alwaysIncludeLeaves=*/false); + config); for (Value v : emptyTensors) { Operation *emptyTensorOp = v.getDefiningOp(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -775,11 +775,8 @@ // Find the values that define the contents of the given value. const llvm::SetVector & OneShotAnalysisState::findDefinitionsCached(Value value) { - if (!cachedDefinitions.count(value)) { - cachedDefinitions[value] = findValueInReverseUseDefChain( - value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, - /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false); - } + if (!cachedDefinitions.count(value)) + cachedDefinitions[value] = findDefinitions(value); return cachedDefinitions[value]; }