diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -16,7 +16,6 @@ namespace bufferization { struct OneShotBufferizationOptions; -class BufferizationAliasInfo; struct BufferizationStatistics; class OneShotAnalysisState; @@ -35,108 +34,8 @@ AnalysisHeuristic analysisHeuristic = AnalysisHeuristic::BottomUp; }; -/// The BufferizationAliasInfo class maintains a list of buffer aliases and -/// equivalence classes to support bufferization. -class BufferizationAliasInfo { -public: - explicit BufferizationAliasInfo(Operation *rootOp); - - // BufferizationAliasInfo should be passed as a reference. - BufferizationAliasInfo(const BufferizationAliasInfo &) = delete; - - /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the - /// beginning the alias and equivalence sets only contain `v` itself. - void createAliasInfoEntry(Value v); - - /// Insert an info entry for `newValue` and merge its alias set with that of - /// `alias`. - void insertNewBufferAlias(Value newValue, Value alias); - - /// Insert an info entry for `newValue` and merge its alias set with that of - /// `alias`. Additionally, merge their equivalence classes. - void insertNewBufferEquivalence(Value newValue, Value alias); - - /// Set the inPlace bufferization spec to true. - /// Merge result's and operand's aliasing sets and iterate to a fixed point. - void bufferizeInPlace(OpOperand &operand, AnalysisState &state); - - /// Set the inPlace bufferization spec to false. - void bufferizeOutOfPlace(OpOperand &operand); - - /// Return true if `v1` and `v2` may bufferize to aliasing buffers. - bool areAliasingBufferizedValues(Value v1, Value v2) const { - return aliasInfo.isEquivalent(v1, v2); - } - - /// Return true if `v1` and `v2` bufferize to equivalent buffers. - bool areEquivalentBufferizedValues(Value v1, Value v2) const { - return equivalentInfo.isEquivalent(v1, v2); - } - - /// Union the alias sets of `v1` and `v2`. - void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); } - - /// Union the equivalence classes of `v1` and `v2`. - void unionEquivalenceClasses(Value v1, Value v2) { - equivalentInfo.unionSets(v1, v2); - } - - /// Apply `fun` to all the members of the equivalence class of `v`. - void applyOnEquivalenceClass(Value v, function_ref fun) const; - - /// Apply `fun` to all aliases of `v`. - void applyOnAliases(Value v, function_ref fun) const; - - /// Mark a value as in-place bufferized. - void markInPlace(OpOperand &o) { inplaceBufferized.insert(&o); } - - /// Return `true` if a value was marked as in-place bufferized. - bool isInPlace(OpOperand &opOperand) const; - - int64_t getStatNumTensorOutOfPlace() const { return statNumTensorOutOfPlace; } - int64_t getStatNumTensorInPlace() const { return statNumTensorInPlace; } - -private: - /// llvm::EquivalenceClasses wants comparable elements. This comparator uses - /// uses pointer comparison on the defining op. This is a poor man's - /// comparison but it's not like UnionFind needs ordering anyway. - struct ValueComparator { - bool operator()(const Value &lhs, const Value &rhs) const { - return lhs.getImpl() < rhs.getImpl(); - } - }; - - using EquivalenceClassRangeType = llvm::iterator_range< - llvm::EquivalenceClasses::member_iterator>; - /// Check that aliasInfo for `v` exists and return a reference to it. - EquivalenceClassRangeType getAliases(Value v) const; - - /// Set of all OpResults that were decided to bufferize in-place. - llvm::DenseSet inplaceBufferized; - - /// Auxiliary structure to store all the values a given value may alias with. - /// Alias information is "may be" conservative: In the presence of branches, a - /// value may alias with one of multiple other values. The concrete aliasing - /// value may not even be known at compile time. All such values are - /// considered to be aliases. - llvm::EquivalenceClasses aliasInfo; - - /// Auxiliary structure to store all the equivalent buffer classes. Equivalent - /// buffer information is "must be" conservative: Only if two values are - /// guaranteed to be equivalent at runtime, they said to be equivalent. It is - /// possible that, in the presence of branches, it cannot be determined - /// statically if two values are equivalent. In that case, the values are - /// considered to be not equivalent. - llvm::EquivalenceClasses equivalentInfo; - - // Bufferization statistics. - int64_t statNumTensorOutOfPlace = 0; - int64_t statNumTensorInPlace = 0; -}; - /// State for analysis-enabled bufferization. This class keeps track of alias -/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize -/// in-place. +/// to decide if tensor OpOperands should bufferize in-place. class OneShotAnalysisState : public AnalysisState { public: OneShotAnalysisState(Operation *op, @@ -156,11 +55,11 @@ AnalysisState::getOptions()); } - /// Return a reference to the BufferizationAliasInfo. - BufferizationAliasInfo &getAliasInfo() { return aliasInfo; } + /// Apply `fun` to all the members of the equivalence class of `v`. + void applyOnEquivalenceClass(Value v, function_ref fun) const; - /// Return `true` if the given OpResult has been decided to bufferize inplace. - bool isInPlace(OpOperand &opOperand) const override; + /// Apply `fun` to all aliases of `v`. + void applyOnAliases(Value v, function_ref fun) const; /// Return true if `v1` and `v2` bufferize to equivalent buffers. bool areEquivalentBufferizedValues(Value v1, Value v2) const override; @@ -168,12 +67,16 @@ /// Return true if `v1` and `v2` may bufferize to aliasing buffers. bool areAliasingBufferizedValues(Value v1, Value v2) const override; - /// Return `true` if the given tensor has undefined contents. - bool hasUndefinedContents(OpOperand *opOperand) const override; + /// Mark the given OpOperand as in-place and merge the results' and operand's + /// aliasing sets. + void bufferizeInPlace(OpOperand &operand); - /// Return true if the given tensor (or an aliasing tensor) is yielded from - /// the containing block. Also include all aliasing tensors in the same block. - bool isTensorYielded(Value tensor) const override; + /// Mark the given OpOperand as out-of-place. + void bufferizeOutOfPlace(OpOperand &operand); + + /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the + /// beginning the alias and equivalence sets only contain `v` itself. + void createAliasInfoEntry(Value v); /// Find all tensor values in the given operation that have undefined contents /// and store them in `undefinedTensorUses`. @@ -183,6 +86,19 @@ /// `yieldedTensors`. Also include all aliasing tensors in the same block. void gatherYieldedTensors(Operation *op); + int64_t getStatNumTensorOutOfPlace() const { return statNumTensorOutOfPlace; } + int64_t getStatNumTensorInPlace() const { return statNumTensorInPlace; } + + /// Return `true` if the given tensor has undefined contents. + bool hasUndefinedContents(OpOperand *opOperand) const override; + + /// Return `true` if the given OpResult has been decided to bufferize inplace. + bool isInPlace(OpOperand &opOperand) const override; + + /// Return true if the given tensor (or an aliasing tensor) is yielded from + /// the containing block. Also include all aliasing tensors in the same block. + bool isTensorYielded(Value tensor) const override; + /// Return true if the buffer of the given tensor value is written to. Must /// not be called for values inside not yet analyzed functions. bool isValueWritten(Value value) const; @@ -190,6 +106,12 @@ /// Return true if the buffer of the given tensor value is writable. bool isWritable(Value value) const; + /// Union the alias sets of `v1` and `v2`. + void unionAliasSets(Value v1, Value v2); + + /// Union the equivalence classes of `v1` and `v2`. + void unionEquivalenceClasses(Value v1, Value v2); + /// Base class for OneShotAnalysisState extensions that allow /// OneShotAnalysisState to contain user-specified information in the state /// object. Clients are expected to derive this class, add the desired fields, @@ -274,9 +196,41 @@ } private: - /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal - /// functions and `runOneShotBufferize` may access this object. - BufferizationAliasInfo aliasInfo; + /// llvm::EquivalenceClasses wants comparable elements. This comparator uses + /// uses pointer comparison on the defining op. This is a poor man's + /// comparison but it's not like UnionFind needs ordering anyway. + struct ValueComparator { + bool operator()(const Value &lhs, const Value &rhs) const { + return lhs.getImpl() < rhs.getImpl(); + } + }; + + using EquivalenceClassRangeType = llvm::iterator_range< + llvm::EquivalenceClasses::member_iterator>; + /// Check that aliasInfo for `v` exists and return a reference to it. + EquivalenceClassRangeType getAliases(Value v) const; + + /// Set of all OpResults that were decided to bufferize in-place. + llvm::DenseSet inplaceBufferized; + + /// Auxiliary structure to store all the values a given value may alias with. + /// Alias information is "may be" conservative: In the presence of branches, a + /// value may alias with one of multiple other values. The concrete aliasing + /// value may not even be known at compile time. All such values are + /// considered to be aliases. + llvm::EquivalenceClasses aliasInfo; + + /// Auxiliary structure to store all the equivalent buffer classes. Equivalent + /// buffer information is "must be" conservative: Only if two values are + /// guaranteed to be equivalent at runtime, they said to be equivalent. It is + /// possible that, in the presence of branches, it cannot be determined + /// statically if two values are equivalent. In that case, the values are + /// considered to be not equivalent. + llvm::EquivalenceClasses equivalentInfo; + + // Bufferization statistics. + int64_t statNumTensorOutOfPlace = 0; + int64_t statNumTensorInPlace = 0; /// A set of all tensors (and maybe aliasing tensors) that yielded from a /// block. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -71,7 +71,7 @@ //===----------------------------------------------------------------------===// // Bufferization-specific attribute manipulation. // These are for testing and debugging only. Bufferization information is stored -// in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR is +// in OneShotBufferizationState. When run with `testAnalysisOnly`, the IR is // annotated with the results of the analysis, so that they can be checked in // tests. //===----------------------------------------------------------------------===// @@ -98,11 +98,14 @@ } //===----------------------------------------------------------------------===// -// BufferizationAliasInfo +// OneShotAnalysisState //===----------------------------------------------------------------------===// -BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { - rootOp->walk([&](Operation *op) { +OneShotAnalysisState::OneShotAnalysisState( + Operation *op, const OneShotBufferizationOptions &options) + : AnalysisState(options, TypeID::get()) { + // Set up alias sets. + op->walk([&](Operation *op) { for (Value v : op->getResults()) if (v.getType().isa()) createAliasInfoEntry(v); @@ -112,55 +115,20 @@ if (bbArg.getType().isa()) createAliasInfoEntry(bbArg); }); -} - -/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the -/// beginning the alias and equivalence sets only contain `v` itself. -void BufferizationAliasInfo::createAliasInfoEntry(Value v) { - aliasInfo.insert(v); - equivalentInfo.insert(v); -} - -/// Insert an info entry for `newValue` and merge its alias set with that of -/// `alias`. -void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) { - createAliasInfoEntry(newValue); - aliasInfo.unionSets(newValue, alias); -} - -/// Insert an info entry for `newValue` and merge its alias set with that of -/// `alias`. Additionally, merge their equivalence classes. -void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue, - Value alias) { - insertNewBufferAlias(newValue, alias); - equivalentInfo.unionSets(newValue, alias); -} - -/// Return `true` if a value was marked as in-place bufferized. -bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const { - return inplaceBufferized.contains(&operand); -} - -/// Set the inPlace bufferization spec to true. -void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand, - AnalysisState &state) { - if (inplaceBufferized.contains(&operand)) - return; - markInPlace(operand); - for (OpResult result : state.getAliasingOpResults(operand)) - aliasInfo.unionSets(result, operand.get()); - ++statNumTensorInPlace; -} -/// Set the inPlace bufferization spec to false. -void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) { - assert(!inplaceBufferized.contains(&operand) && - "OpOperand was already decided to bufferize inplace"); - ++statNumTensorOutOfPlace; + // Mark OpOperands in-place that must bufferize in-place. + op->walk([&](BufferizableOpInterface bufferizableOp) { + if (!options.isOpAllowed(bufferizableOp)) + return WalkResult::skip(); + for (OpOperand &opOperand : bufferizableOp->getOpOperands()) + if (opOperand.get().getType().isa()) + if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) + bufferizeInPlace(opOperand); + return WalkResult::advance(); + }); } -/// Apply `fun` to all the members of the equivalence class of `v`. -void BufferizationAliasInfo::applyOnEquivalenceClass( +void OneShotAnalysisState::applyOnEquivalenceClass( Value v, function_ref fun) const { auto leaderIt = equivalentInfo.findLeader(v); for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; @@ -169,66 +137,48 @@ } } -/// Apply `fun` to all aliases of `v`. -void BufferizationAliasInfo::applyOnAliases( - Value v, function_ref fun) const { +void OneShotAnalysisState::applyOnAliases(Value v, + function_ref fun) const { auto leaderIt = aliasInfo.findLeader(v); for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) { fun(*mit); } } -BufferizationAliasInfo::EquivalenceClassRangeType -BufferizationAliasInfo::getAliases(Value v) const { - DenseSet res; - auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v)); - for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end(); - mit != meit; ++mit) { - res.insert(static_cast(*mit)); - } - return BufferizationAliasInfo::EquivalenceClassRangeType( - aliasInfo.member_begin(it), aliasInfo.member_end()); +bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1, + Value v2) const { + return equivalentInfo.isEquivalent(v1, v2); } -//===----------------------------------------------------------------------===// -// OneShotAnalysisState -//===----------------------------------------------------------------------===// - -OneShotAnalysisState::OneShotAnalysisState( - Operation *op, const OneShotBufferizationOptions &options) - : AnalysisState(options, TypeID::get()), - aliasInfo(op) { - // Set up alias sets for OpResults that must bufferize in-place. This should - // be done before making any other bufferization decisions. - op->walk([&](BufferizableOpInterface bufferizableOp) { - if (!options.isOpAllowed(bufferizableOp)) - return WalkResult::skip(); - for (OpOperand &opOperand : bufferizableOp->getOpOperands()) - if (opOperand.get().getType().isa()) - if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) - aliasInfo.bufferizeInPlace(opOperand, *this); - return WalkResult::advance(); - }); +bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1, + Value v2) const { + return aliasInfo.isEquivalent(v1, v2); } -bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const { - return aliasInfo.isInPlace(opOperand); +void OneShotAnalysisState::bufferizeInPlace(OpOperand &operand) { + if (inplaceBufferized.contains(&operand)) + return; + inplaceBufferized.insert(&operand); + for (OpResult result : getAliasingOpResults(operand)) + aliasInfo.unionSets(result, operand.get()); + ++statNumTensorInPlace; } -bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1, - Value v2) const { - return aliasInfo.areEquivalentBufferizedValues(v1, v2); +void OneShotAnalysisState::bufferizeOutOfPlace(OpOperand &operand) { + assert(!inplaceBufferized.contains(&operand) && + "OpOperand was already decided to bufferize inplace"); + ++statNumTensorOutOfPlace; } -bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1, - Value v2) const { - return aliasInfo.areAliasingBufferizedValues(v1, v2); +void OneShotAnalysisState::createAliasInfoEntry(Value v) { + aliasInfo.insert(v); + equivalentInfo.insert(v); } // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is // to ensure that such information is available during bufferization time. -// Alias information can no longer be queried through BufferizationAliasInfo -// once we have started modifying the IR. +// Alias information can no longer be queried once we have started modifying +// the IR. void OneShotAnalysisState::gatherYieldedTensors(Operation *op) { op->walk([&](Operation *returnOp) { if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp)) @@ -242,7 +192,7 @@ // Add all aliases of the returned value. But only the ones that are in // the same block. - aliasInfo.applyOnAliases(returnVal, [&](Value v) { + applyOnAliases(returnVal, [&](Value v) { if (auto bbArg = v.dyn_cast()) { if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp()) yieldedTensors.insert(bbArg); @@ -285,13 +235,17 @@ return undefinedTensorUses.contains(opOperand); } +bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const { + return inplaceBufferized.contains(&opOperand); +} + bool OneShotAnalysisState::isTensorYielded(Value tensor) const { return yieldedTensors.contains(tensor); } bool OneShotAnalysisState::isValueWritten(Value value) const { bool isWritten = false; - aliasInfo.applyOnAliases(value, [&](Value val) { + applyOnAliases(value, [&](Value val) { for (OpOperand &use : val.getUses()) if (isInPlace(use) && bufferizesToMemoryWrite(use)) isWritten = true; @@ -314,6 +268,14 @@ return false; } +void OneShotAnalysisState::unionAliasSets(Value v1, Value v2) { + aliasInfo.unionSets(v1, v2); +} + +void OneShotAnalysisState::unionEquivalenceClasses(Value v1, Value v2) { + equivalentInfo.unionSets(v1, v2); +} + OneShotAnalysisState::Extension::~Extension() = default; //===----------------------------------------------------------------------===// @@ -322,13 +284,12 @@ /// Return true if opOperand has been decided to bufferize in-place. static bool isInplaceMemoryWrite(OpOperand &opOperand, - const BufferizationAliasInfo &aliasInfo, - const AnalysisState &state) { + const OneShotAnalysisState &state) { // OpOperands that do not bufferize to a memory write do not write in-place. if (!state.bufferizesToMemoryWrite(opOperand)) return false; // Check current bufferization decisions. - return aliasInfo.isInPlace(opOperand); + return state.isInPlace(opOperand); } /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors @@ -501,10 +462,11 @@ /// A conflict is: According to SSA use-def chains, a read R is supposed to read /// the result of a definition W1. But because of bufferization decisions, R /// actually reads another definition W2. -static bool hasReadAfterWriteInterference( - const DenseSet &usesRead, - const DenseSet &usesWrite, const DominanceInfo &domInfo, - AnalysisState &state, const BufferizationAliasInfo &aliasInfo) { +static bool +hasReadAfterWriteInterference(const DenseSet &usesRead, + const DenseSet &usesWrite, + const DominanceInfo &domInfo, + OneShotAnalysisState &state) { const BufferizationOptions &options = state.getOptions(); // Check if op dominance can be used to rule out read-after-write conflicts. @@ -659,21 +621,19 @@ // Helper function to iterate on aliases of `root` and capture the writes. static void getAliasingInplaceWrites(DenseSet &res, Value root, - const BufferizationAliasInfo &aliasInfo, - const AnalysisState &state) { - aliasInfo.applyOnAliases(root, [&](Value alias) { + const OneShotAnalysisState &state) { + state.applyOnAliases(root, [&](Value alias) { for (auto &use : alias.getUses()) // Inplace write to a value that aliases root. - if (isInplaceMemoryWrite(use, aliasInfo, state)) + if (isInplaceMemoryWrite(use, state)) res.insert(&use); }); } // Helper function to iterate on aliases of `root` and capture the reads. static void getAliasingReads(DenseSet &res, Value root, - const BufferizationAliasInfo &aliasInfo, - const AnalysisState &state) { - aliasInfo.applyOnAliases(root, [&](Value alias) { + const OneShotAnalysisState &state) { + state.applyOnAliases(root, [&](Value alias) { for (auto &use : alias.getUses()) { // Read of a value that aliases root. if (state.bufferizesToMemoryRead(use)) { @@ -736,22 +696,20 @@ /// OpResult. In that case, only the consistency of bufferization decisions /// involving aliases of the given OpOperand are checked. static bool wouldCreateReadAfterWriteInterference( - OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state, - const BufferizationAliasInfo &aliasInfo, - bool checkConsistencyOnly = false) { + OpOperand &operand, const DominanceInfo &domInfo, + OneShotAnalysisState &state, bool checkConsistencyOnly = false) { // Collect reads and writes of all aliases of OpOperand and OpResult. DenseSet usesRead, usesWrite; - getAliasingReads(usesRead, operand.get(), aliasInfo, state); - getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); + getAliasingReads(usesRead, operand.get(), state); + getAliasingInplaceWrites(usesWrite, operand.get(), state); for (OpResult result : state.getAliasingOpResults(operand)) { - getAliasingReads(usesRead, result, aliasInfo, state); - getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); + getAliasingReads(usesRead, result, state); + getAliasingInplaceWrites(usesWrite, result, state); } if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); - return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state, - aliasInfo); + return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state); } /// Annotate IR with details about the detected non-writability conflict. @@ -778,7 +736,6 @@ /// materialized in `aliasInfo` yet. static bool hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand, - const BufferizationAliasInfo &aliasInfo, const OneShotAnalysisState &state) { SmallVector worklist; worklist.push_back(value); @@ -799,7 +756,7 @@ AliasingOpOperandList aliasingOpOperands = state.getAliasingOpOperands(opResult); for (OpOperand *opOperand : aliasingOpOperands) - if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand) + if (state.isInPlace(*opOperand) || currentOpOperand == opOperand) worklist.push_back(opOperand->get()); } return false; @@ -807,14 +764,15 @@ /// Return true if bufferizing `operand` inplace would create a write to a /// non-writable buffer. -static bool wouldCreateWriteToNonWritableBuffer( - OpOperand &operand, const BufferizationAliasInfo &aliasInfo, - OneShotAnalysisState &state, bool checkConsistencyOnly = false) { +static bool +wouldCreateWriteToNonWritableBuffer(OpOperand &operand, + OneShotAnalysisState &state, + bool checkConsistencyOnly = false) { // Collect writes of all aliases of OpOperand and OpResult. DenseSet usesWrite; - getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); + getAliasingInplaceWrites(usesWrite, operand.get(), state); for (OpResult result : state.getAliasingOpResults(operand)) { - getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); + getAliasingInplaceWrites(usesWrite, result, state); } if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); @@ -823,8 +781,7 @@ // alias), check if there is a non-writable tensor in the reverse SSA use-def // chain. for (OpOperand *uWrite : usesWrite) { - if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, - aliasInfo, state)) { + if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, state)) { LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n"); return true; } @@ -838,22 +795,22 @@ //===----------------------------------------------------------------------===// /// Determine if `operand` can be bufferized in-place. -static LogicalResult bufferizableInPlaceAnalysisImpl( - OpOperand &operand, BufferizationAliasInfo &aliasInfo, - OneShotAnalysisState &state, const DominanceInfo &domInfo) { +static LogicalResult +bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state, + const DominanceInfo &domInfo) { LLVM_DEBUG( llvm::dbgs() << "//===-------------------------------------------===//\n" << "Analyzing operand #" << operand.getOperandNumber() << " of " << *operand.getOwner() << "\n"); bool foundInterference = - wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) || - wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo); + wouldCreateWriteToNonWritableBuffer(operand, state) || + wouldCreateReadAfterWriteInterference(operand, domInfo, state); if (foundInterference) - aliasInfo.bufferizeOutOfPlace(operand); + state.bufferizeOutOfPlace(operand); else - aliasInfo.bufferizeInPlace(operand, state); + state.bufferizeInPlace(operand); LLVM_DEBUG(llvm::dbgs() << "//===-------------------------------------------===//\n"); @@ -879,7 +836,6 @@ /// An analysis is required to ensure inplace bufferization would not result in /// RaW dependence violations. static LogicalResult inPlaceAnalysis(SmallVector &ops, - BufferizationAliasInfo &aliasInfo, OneShotAnalysisState &state, const DominanceInfo &domInfo, unsigned analysisFuzzerSeed = 0) { @@ -895,8 +851,7 @@ auto analyzeOp = [&](Operation *op) { for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) - if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo, state, - domInfo))) + if (failed(bufferizableInPlaceAnalysisImpl(opOperand, state, domInfo))) return failure(); return success(); }; @@ -929,7 +884,6 @@ /// Analyze all ops that are contained in `op`. static LogicalResult inPlaceAnalysis(Operation *op, - BufferizationAliasInfo &aliasInfo, OneShotAnalysisState &state, const DominanceInfo &domInfo, unsigned analysisFuzzerSeed = 0) { @@ -942,13 +896,12 @@ ops.push_back(op); }); - return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed); + return inPlaceAnalysis(ops, state, domInfo, analysisFuzzerSeed); } /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. static void equivalenceAnalysis(SmallVector &ops, - BufferizationAliasInfo &aliasInfo, - AnalysisState &state) { + OneShotAnalysisState &state) { for (Operation *op : ops) if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) for (OpResult opResult : op->getOpResults()) @@ -958,14 +911,12 @@ if (state.isInPlace(*opOperand)) if (bufferizableOp.bufferRelation(opResult, state) == BufferRelation::Equivalent) - aliasInfo.unionEquivalenceClasses(opResult, opOperand->get()); + state.unionEquivalenceClasses(opResult, opOperand->get()); } /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained /// in `op`. -static void equivalenceAnalysis(Operation *op, - BufferizationAliasInfo &aliasInfo, - AnalysisState &state) { +static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) { // Traverse ops in PostOrder: Nested ops first, then enclosing ops. SmallVector ops; op->walk([&](Operation *op) { @@ -975,14 +926,13 @@ ops.push_back(op); }); - equivalenceAnalysis(ops, aliasInfo, state); + equivalenceAnalysis(ops, state); } /// Assert that the current bufferization decisions are consistent. -static LogicalResult -checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, - AnalysisState &state, - const BufferizationAliasInfo &aliasInfo) { +static LogicalResult checkAliasInfoConsistency(Operation *op, + const DominanceInfo &domInfo, + OneShotAnalysisState &state) { const BufferizationOptions &options = state.getOptions(); WalkResult walkResult = op->walk([&](BufferizableOpInterface op) { @@ -1000,7 +950,7 @@ for (OpOperand &opOperand : op->getOpOperands()) { if (opOperand.get().getType().isa()) { if (wouldCreateReadAfterWriteInterference( - opOperand, domInfo, state, aliasInfo, + opOperand, domInfo, state, /*checkConsistencyOnly=*/true)) { // This error can happen if certain "mustBufferizeInPlace" interface // methods are implemented incorrectly, such that the IR already has @@ -1020,13 +970,12 @@ /// Annotate the IR with the result of the analysis. For testing/debugging only. static void annotateOpsWithBufferizationMarkers(Operation *op, - const BufferizationAliasInfo &aliasInfo, - const BufferizationOptions &options) { + const OneShotAnalysisState &state) { // Add __inplace_operands_attr__. op->walk([&](Operation *op) { for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) - setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand)); + setInPlaceOpOperand(opOperand, state.isInPlace(opOperand)); }); } @@ -1061,12 +1010,12 @@ // TODO: Remove buffer deallocation from One-Shot Bufferize and fix the buffer // deallocation pass. static LogicalResult assertNoAllocsReturned(Operation *op, - const BufferizationOptions &options, - BufferizationAliasInfo &aliasInfo) { + const OneShotAnalysisState &state) { LogicalResult status = success(); DominanceInfo domInfo(op); op->walk([&](Operation *returnOp) { - if (!isRegionReturnLike(returnOp) || !options.isOpAllowed(returnOp)) + if (!isRegionReturnLike(returnOp) || + !state.getOptions().isOpAllowed(returnOp)) return WalkResult::advance(); for (OpOperand &returnValOperand : returnOp->getOpOperands()) { @@ -1076,7 +1025,7 @@ continue; bool foundEquivValue = false; - aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) { + state.applyOnEquivalenceClass(returnVal, [&](Value equivVal) { if (auto bbArg = equivVal.dyn_cast()) { Operation *definingOp = bbArg.getOwner()->getParentOp(); if (definingOp->isProperAncestor(returnOp)) @@ -1110,27 +1059,25 @@ OneShotAnalysisState &state, BufferizationStatistics *statistics) { DominanceInfo domInfo(op); - BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); const OneShotBufferizationOptions &options = state.getOptions(); - if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) + if (failed(checkAliasInfoConsistency(op, domInfo, state))) return failure(); // If the analysis fails, just return. - if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo, - options.analysisFuzzerSeed))) + if (failed(inPlaceAnalysis(op, state, domInfo, options.analysisFuzzerSeed))) return failure(); if (statistics) { - statistics->numTensorInPlace = aliasInfo.getStatNumTensorInPlace(); - statistics->numTensorOutOfPlace = aliasInfo.getStatNumTensorOutOfPlace(); + statistics->numTensorInPlace = state.getStatNumTensorInPlace(); + statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace(); } - equivalenceAnalysis(op, aliasInfo, state); + equivalenceAnalysis(op, state); bool failedAnalysis = false; if (!options.allowReturnAllocs) - failedAnalysis |= failed(assertNoAllocsReturned(op, options, aliasInfo)); + failedAnalysis |= failed(assertNoAllocsReturned(op, state)); // Gather some extra analysis data. state.gatherYieldedTensors(op); @@ -1147,7 +1094,7 @@ // Annotate operations if we only want to report the analysis. if (options.testAnalysisOnly) - annotateOpsWithBufferizationMarkers(op, aliasInfo, options); + annotateOpsWithBufferizationMarkers(op, state); return success(!failedAnalysis); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -250,7 +250,6 @@ /// analyzed. // TODO: This does not handle cyclic function call graphs etc. static void equivalenceAnalysis(func::FuncOp funcOp, - BufferizationAliasInfo &aliasInfo, OneShotAnalysisState &state, FuncAnalysisState &funcState) { funcOp->walk([&](func::CallOp callOp) { @@ -268,7 +267,7 @@ continue; Value returnVal = callOp.getResult(returnIdx); Value argVal = callOp->getOperand(bbargIdx); - aliasInfo.unionEquivalenceClasses(returnVal, argVal); + state.unionEquivalenceClasses(returnVal, argVal); } return WalkResult::advance(); @@ -365,7 +364,6 @@ assert(state.getOptions().bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state); - BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); // A list of functions in the order in which they are analyzed + bufferized. SmallVector orderedFuncOps; @@ -385,7 +383,7 @@ funcState.startFunctionAnalysis(funcOp); // Gather equivalence info for CallOps. - equivalenceAnalysis(funcOp, aliasInfo, state, funcState); + equivalenceAnalysis(funcOp, state, funcState); // Analyze funcOp. if (failed(analyzeOp(funcOp, state, statistics)))