diff --git a/llvm/include/llvm/ADT/EquivalenceClasses.h b/llvm/include/llvm/ADT/EquivalenceClasses.h --- a/llvm/include/llvm/ADT/EquivalenceClasses.h +++ b/llvm/include/llvm/ADT/EquivalenceClasses.h @@ -30,7 +30,8 @@ /// /// This implementation is an efficient implementation that only stores one copy /// of the element being indexed per entry in the set, and allows any arbitrary -/// type to be indexed (as long as it can be ordered with operator<). +/// type to be indexed (as long as it can be ordered with operator< or a +/// comparator is provided). /// /// Here is a simple example using integers: /// @@ -54,7 +55,7 @@ /// 4 /// 5 1 2 /// -template +template > class EquivalenceClasses { /// ECValue - The EquivalenceClasses data structure is just a set of these. /// Each of these represents a relation for a value. First it stores the @@ -101,22 +102,40 @@ assert(RHS.isLeader() && RHS.getNext() == nullptr && "Not a singleton!"); } - bool operator<(const ECValue &UFN) const { return Data < UFN.Data; } - bool isLeader() const { return (intptr_t)Next & 1; } const ElemTy &getData() const { return Data; } const ECValue *getNext() const { return (ECValue*)((intptr_t)Next & ~(intptr_t)1); } + }; + + /// A wrapper of the comparator, to be passed to the set. + struct ECValueComparator { + using is_transparent = void; + + ECValueComparator() : compare(Compare()) {} + + bool operator()(const ECValue &lhs, const ECValue &rhs) const { + return compare(lhs.Data, rhs.Data); + } + + template + bool operator()(const T &lhs, const ECValue &rhs) const { + return compare(lhs, rhs.Data); + } + + template + bool operator()(const ECValue &lhs, const T &rhs) const { + return compare(lhs.Data, rhs); + } - template - bool operator<(const T &Val) const { return Data < Val; } + const Compare compare; }; /// TheMapping - This implicitly provides a mapping from ElemTy values to the /// ECValues, it just keeps the key as part of the value. - std::set TheMapping; + std::set TheMapping; public: EquivalenceClasses() = default; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -122,23 +122,17 @@ void dumpEquivalences() const; private: - /// llvm::EquivalenceClasses wants comparable elements because it uses - /// std::set as the underlying impl. - /// ValueWrapper wraps Value and uses pointer comparison on the defining op. - /// This is a poor man's comparison but it's not like UnionFind needs ordering - /// anyway .. - struct ValueWrapper { - ValueWrapper(Value val) : v(val) {} - operator Value() const { return v; } - bool operator<(const ValueWrapper &wrap) const { - return v.getImpl() < wrap.v.getImpl(); + /// llvm::EquivalenceClasses wants comparable elements. This comparator uses + /// uses pointer comparison on the defining op. This is a poor man's + /// comparison but it's not like UnionFind needs ordering anyway. + struct ValueComparator { + bool operator()(const Value &lhs, const Value &rhs) const { + return lhs.getImpl() < rhs.getImpl(); } - bool operator==(const ValueWrapper &wrap) const { return v == wrap.v; } - Value v; }; using EquivalenceClassRangeType = llvm::iterator_range< - llvm::EquivalenceClasses::member_iterator>; + llvm::EquivalenceClasses::member_iterator>; /// Check that aliasInfo for `v` exists and return a reference to it. EquivalenceClassRangeType getAliases(Value v) const; @@ -164,10 +158,10 @@ /// Auxiliary structure to store all the values a given value aliases with. /// These are the conservative cases that can further decompose into /// "equivalent" buffer relationships. - llvm::EquivalenceClasses aliasInfo; + llvm::EquivalenceClasses aliasInfo; /// Auxiliary structure to store all the equivalent buffer classes. - llvm::EquivalenceClasses equivalentInfo; + llvm::EquivalenceClasses equivalentInfo; }; /// Analyze the `ops` to determine which OpResults are inplaceable. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1213,11 +1213,11 @@ for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; ++mit) { auto extractSliceOp = - dyn_cast_or_null(mit->v.getDefiningOp()); + dyn_cast_or_null(mit->getDefiningOp()); if (extractSliceOp && areEquivalentExtractSliceOps(extractSliceOp, insertSliceOp) && getInPlace(extractSliceOp.result()) == InPlaceSpec::True) { - LDBG("\tfound: " << *mit->v.getDefiningOp() << '\n'); + LDBG("\tfound: " << *mit->getDefiningOp() << '\n'); return true; } } @@ -1231,7 +1231,7 @@ auto leaderIt = equivalentInfo.findLeader(v); for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; ++mit) { - fun(mit->v); + fun(*mit); } }