diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -780,18 +780,25 @@ void dumpEquivalences() const { printEquivalences(llvm::errs()); } private: - /// Check that aliasInfo for `v` exists and return a reference to it. - DenseSet &getAliasInfoRef(Value v); - - const DenseSet &getAliasInfoRef(Value v) const { - return const_cast(this)->getAliasInfoRef(v); - } - - /// Union all the aliasing sets of all aliases of v1 and v2. - bool mergeAliases(Value v1, Value v2); + /// llvm::EquivalenceClasses wants comparable elements because it uses + /// std::set as the underlying impl. + /// ValueWrapper wraps Value and uses pointer comparison on the defining op. + /// This is a poor man's comparison but it's not like UnionFind needs ordering + /// anyway .. + struct ValueWrapper { + ValueWrapper(Value val) : v(val) {} + operator Value() const { return v; } + bool operator<(const ValueWrapper &wrap) const { + return v.getImpl() < wrap.v.getImpl(); + } + bool operator==(const ValueWrapper &wrap) const { return v == wrap.v; } + Value v; + }; - /// Iteratively merge alias sets until a fixed-point. - void mergeAliasesToFixedPoint(); + using EquivalenceClassRangeType = llvm::iterator_range< + llvm::EquivalenceClasses::member_iterator>; + /// Check that aliasInfo for `v` exists and return a reference to it. + EquivalenceClassRangeType getAliases(Value v) const; /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. /// equivalent operand / result and same offset/sizes/strides specification). @@ -849,24 +856,10 @@ OpOperand &aliasingWrite, const DominanceInfo &domInfo) const; - /// EquivalenceClasses wants comparable elements because it uses std::set. - /// ValueWrapper wraps Value and uses pointer comparison on the defining op. - /// This is a poor man's comparison but it's not like UnionFind needs ordering - /// anyway .. - struct ValueWrapper { - ValueWrapper(Value val) : v(val) {} - operator Value() const { return v; } - bool operator<(const ValueWrapper &wrap) const { - return v.getImpl() < wrap.v.getImpl(); - } - bool operator==(const ValueWrapper &wrap) const { return v == wrap.v; } - Value v; - }; - /// Auxiliary structure to store all the values a given value aliases with. /// These are the conservative cases that can further decompose into /// "equivalent" buffer relationships. - DenseMap> aliasInfo; + llvm::EquivalenceClasses aliasInfo; /// Auxiliary structure to store all the equivalent buffer classes. llvm::EquivalenceClasses equivalentInfo; @@ -889,19 +882,15 @@ /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the /// beginning the alias and equivalence sets only contain `v` itself. void BufferizationAliasInfo::createAliasInfoEntry(Value v) { - DenseSet selfSet; - selfSet.insert(v); - aliasInfo.try_emplace(v, selfSet); + aliasInfo.insert(v); equivalentInfo.insert(v); } /// Insert an info entry for `newValue` and merge its alias set with that of /// `alias`. void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) { - assert(aliasInfo.find(alias) != aliasInfo.end() && "Missing alias entry"); createAliasInfoEntry(newValue); - mergeAliases(newValue, alias); - mergeAliasesToFixedPoint(); + aliasInfo.unionSets(newValue, alias); } /// Insert an info entry for `newValue` and merge its alias set with that of @@ -920,7 +909,7 @@ LDBG("----Start aliasesNonWriteableBuffer\n"); LDBG("-------for -> #" << operand.getOperandNumber() << ": " << printOperationInfo(operand.getOwner()) << '\n'); - for (Value v : getAliasInfoRef(operand.get())) { + for (Value v : getAliases(operand.get())) { LDBG("-----------examine: " << printValueInfo(v) << '\n'); if (auto bbArg = v.dyn_cast()) { if (getInPlace(bbArg) == InPlaceSpec::True) { @@ -948,7 +937,7 @@ bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const { LDBG("----Start aliasesInPlaceWrite\n"); LDBG("-------for : " << printValueInfo(value) << '\n'); - for (Value v : getAliasInfoRef(value)) { + for (Value v : getAliases(value)) { for (auto &use : v.getUses()) { if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) { LDBG("-----------wants to bufferize to inPlace write: " @@ -967,8 +956,7 @@ OpOperand &operand, BufferRelation bufferRelation) { setInPlaceOpResult(result, InPlaceSpec::True); - if (mergeAliases(result, operand.get())) - mergeAliasesToFixedPoint(); + aliasInfo.unionSets(result, operand.get()); // Dump the updated alias analysis. LLVM_DEBUG(dumpAliases()); if (bufferRelation == BufferRelation::Equivalent) @@ -1009,7 +997,7 @@ // opToBufferize is not yet inplace, we want to determine if it can be inplace // so we also consider all its write uses, not just the inplace ones. DenseSet usesWrite; - for (Value vWrite : getAliasInfoRef(root)) { + for (Value vWrite : getAliases(root)) { for (auto &uWrite : vWrite.getUses()) { if (!bufferizesToMemoryWrite(uWrite)) continue; @@ -1018,7 +1006,7 @@ usesWrite.insert(&uWrite); } } - for (Value vWrite : getAliasInfoRef(result)) + for (Value vWrite : getAliases(result)) for (auto &uWrite : vWrite.getUses()) if (bufferizesToMemoryWrite(uWrite, InPlaceSpec::True)) usesWrite.insert(&uWrite); @@ -1027,12 +1015,12 @@ // opToBufferize is not yet inplace, we want to determine if it can be inplace // so we also consider all read uses of its result. DenseSet usesRead; - auto &aliasListRead = getAliasInfoRef(root); + auto aliasListRead = getAliases(root); for (Value vRead : aliasListRead) for (auto &uRead : vRead.getUses()) if (bufferizesToMemoryRead(uRead)) usesRead.insert(&uRead); - for (Value vRead : getAliasInfoRef(result)) + for (Value vRead : getAliases(result)) for (auto &uRead : vRead.getUses()) if (bufferizesToMemoryRead(uRead)) usesRead.insert(&uRead); @@ -1116,16 +1104,21 @@ } void BufferizationAliasInfo::printAliases(raw_ostream &os) const { - os << "\n/========================== AliasInfo " - "==========================\n"; - for (auto it : aliasInfo) { - os << "|\n| -- source: " << printValueInfo(it.getFirst(), /*prefix=*/false) + os << "\n/===================== AliasInfo =====================\n"; + for (auto it = aliasInfo.begin(), eit = aliasInfo.end(); it != eit; ++it) { + if (!it->isLeader()) + continue; + Value leader = it->getData(); + os << "|\n| -- leader: " << printValueInfo(leader, /*prefix=*/false) << '\n'; - for (auto v : it.getSecond()) - os << "| ---- target: " << printValueInfo(v, /*prefix=*/false) << '\n'; + for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end(); + mit != meit; ++mit) { + Value v = static_cast(*mit); + os << "| ---- equivalent member: " << printValueInfo(v, /*prefix=*/false) + << '\n'; + } } - os << "|\n\\====================== End AliasInfo " - "======================\n\n"; + os << "\n/===================== End AliasInfo =====================\n\n"; } void BufferizationAliasInfo::printEquivalences(raw_ostream &os) const { @@ -1148,37 +1141,16 @@ os << "|\n\\***************** End Equivalent Buffers *****************\n\n"; } -DenseSet &BufferizationAliasInfo::getAliasInfoRef(Value v) { - auto it = aliasInfo.find(v); - if (it == aliasInfo.end()) - llvm_unreachable("Missing alias"); - return it->getSecond(); -} - -/// Union all the aliasing sets of all aliases of v1 and v2. -bool BufferizationAliasInfo::mergeAliases(Value v1, Value v2) { - // Avoid invalidation of iterators by pre unioning the aliases for v1 and v2. - bool changed = set_union(getAliasInfoRef(v1), getAliasInfoRef(v2)) || - set_union(getAliasInfoRef(v2), getAliasInfoRef(v1)); - for (auto v : getAliasInfoRef(v1)) - if (v != v1) - changed |= set_union(getAliasInfoRef(v), getAliasInfoRef(v2)); - for (auto v : getAliasInfoRef(v2)) - if (v != v2) - changed |= set_union(getAliasInfoRef(v), getAliasInfoRef(v1)); - return changed; -} - -/// Iteratively merge alias sets until a fixed-point. -void BufferizationAliasInfo::mergeAliasesToFixedPoint() { - while (true) { - bool changed = false; - for (auto it : aliasInfo) - for (auto v : it.getSecond()) - changed |= mergeAliases(it.getFirst(), v); - if (!changed) - break; +BufferizationAliasInfo::EquivalenceClassRangeType +BufferizationAliasInfo::getAliases(Value v) const { + DenseSet res; + auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v)); + for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end(); + mit != meit; ++mit) { + res.insert(static_cast(*mit)); } + return BufferizationAliasInfo::EquivalenceClassRangeType( + aliasInfo.member_begin(it), aliasInfo.member_end()); } /// This is one particular type of relationship between ops on tensors that