diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -139,23 +139,8 @@ void dumpEquivalences() const; private: - /// llvm::EquivalenceClasses wants comparable elements because it uses - /// std::set as the underlying impl. - /// ValueWrapper wraps Value and uses pointer comparison on the defining op. - /// This is a poor man's comparison but it's not like UnionFind needs ordering - /// anyway .. - struct ValueWrapper { - ValueWrapper(Value val) : v(val) {} - operator Value() const { return v; } - bool operator<(const ValueWrapper &wrap) const { - return v.getImpl() < wrap.v.getImpl(); - } - bool operator==(const ValueWrapper &wrap) const { return v == wrap.v; } - Value v; - }; - - using EquivalenceClassRangeType = llvm::iterator_range< - llvm::EquivalenceClasses::member_iterator>; + using EquivalenceClassRangeType = + llvm::iterator_range::member_iterator>; /// Check that aliasInfo for `v` exists and return a reference to it. EquivalenceClassRangeType getAliases(Value v) const; @@ -222,10 +207,10 @@ /// Auxiliary structure to store all the values a given value aliases with. /// These are the conservative cases that can further decompose into /// "equivalent" buffer relationships. - llvm::EquivalenceClasses aliasInfo; + llvm::EquivalenceClasses aliasInfo; /// Auxiliary structure to store all the equivalent buffer classes. - llvm::EquivalenceClasses equivalentInfo; + llvm::EquivalenceClasses equivalentInfo; }; /// Analyze the `ops` to determine which OpResults are inplaceable: diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -114,6 +114,7 @@ explicit operator bool() const { return impl; } bool operator==(const Value &other) const { return impl == other.impl; } bool operator!=(const Value &other) const { return !(*this == other); } + bool operator<(const Value &other) const { return impl < other.impl; } /// Return the type of this value. Type getType() const { return impl->getType(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -976,11 +976,11 @@ for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; ++mit) { auto extractSliceOp = - dyn_cast_or_null(mit->v.getDefiningOp()); + dyn_cast_or_null(mit->getDefiningOp()); if (extractSliceOp && areEquivalentExtractSliceOps(extractSliceOp, insertSliceOp) && getInPlace(extractSliceOp.result()) == InPlaceSpec::True) { - LDBG("\tfound: " << *mit->v.getDefiningOp() << '\n'); + LDBG("\tfound: " << *mit->getDefiningOp() << '\n'); return true; } } @@ -994,7 +994,7 @@ auto leaderIt = equivalentInfo.findLeader(v); for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; ++mit) { - fun(mit->v); + fun(*mit); } } @@ -1091,10 +1091,10 @@ auto leaderIt = equivalentInfo.findLeader(valueToClobber); for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; ++mit) { - Operation *candidateOp = mit->v.getDefiningOp(); + Operation *candidateOp = mit->getDefiningOp(); if (!candidateOp) continue; - auto maybeAliasingOperand = getAliasingOpOperand(mit->v.cast()); + auto maybeAliasingOperand = getAliasingOpOperand(mit->cast()); if (!maybeAliasingOperand || !*maybeAliasingOperand || !bufferizesToMemoryWrite(**maybeAliasingOperand)) continue;