diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -532,52 +532,20 @@ result.dropAllUses(); } - /// This class implements a use iterator for the Operation. This iterates over - /// all uses of all results. - class UseIterator final - : public llvm::iterator_facade_base< - UseIterator, std::forward_iterator_tag, OpOperand> { - public: - /// Initialize UseIterator for op, specify end to return iterator to last - /// use. - explicit UseIterator(Operation *op, bool end = false); - - using llvm::iterator_facade_base::operator++; - UseIterator &operator++(); - OpOperand *operator->() const { return use.getOperand(); } - OpOperand &operator*() const { return *use.getOperand(); } - - bool operator==(const UseIterator &rhs) const { return use == rhs.use; } - bool operator!=(const UseIterator &rhs) const { return !(*this == rhs); } - - private: - void skipOverResultsWithNoUsers(); - - /// The operation whose uses are being iterated over. - Operation *op; - /// The result of op who's uses are being iterated over. - Operation::result_iterator res; - /// The use of the result. - Value::use_iterator use; - }; - using use_iterator = UseIterator; - using use_range = iterator_range; + using use_iterator = result_range::use_iterator; + using use_range = result_range::use_range; - use_iterator use_begin() { return use_iterator(this); } - use_iterator use_end() { return use_iterator(this, /*end=*/true); } + use_iterator use_begin() { return getResults().use_begin(); } + use_iterator use_end() { return getResults().use_end(); } /// Returns a range of all uses, which is useful for iterating over all uses. - use_range getUses() { return {use_begin(), use_end()}; } + use_range getUses() { return getResults().getUses(); } /// Returns true if this operation has exactly one use. bool hasOneUse() { return llvm::hasSingleElement(getUses()); } /// Returns true if this operation has no uses. - bool use_empty() { - return llvm::all_of(getOpResults(), - [](OpResult result) { return result.use_empty(); }); - } + bool use_empty() { return getResults().use_empty(); } /// Returns true if the results of this operation are used outside of the /// given block. diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -903,12 +903,48 @@ public: using RangeBaseT::RangeBaseT; + //===--------------------------------------------------------------------===// + // Types + //===--------------------------------------------------------------------===// + /// Returns the types of the values within this range. using type_iterator = ValueTypeIterator; using type_range = ValueTypeRange; type_range getTypes() const { return {begin(), end()}; } auto getType() const { return getTypes(); } + //===--------------------------------------------------------------------===// + // Uses + //===--------------------------------------------------------------------===// + + class UseIterator; + using use_iterator = UseIterator; + using use_range = iterator_range; + + /// Returns a range of all uses of results within this range, which is useful + /// for iterating over all uses. + use_range getUses() const; + use_iterator use_begin() const; + use_iterator use_end() const; + + /// Returns true if no results in this range have uses. + bool use_empty() const { + return llvm::all_of(*this, + [](OpResult result) { return result.use_empty(); }); + } + + //===--------------------------------------------------------------------===// + // Users + //===--------------------------------------------------------------------===// + + using user_iterator = ValueUserIterator; + using user_range = iterator_range; + + /// Returns a range of all users. + user_range getUsers(); + user_iterator user_begin(); + user_iterator user_end(); + private: /// See `llvm::detail::indexed_accessor_range_base` for details. static detail::OpResultImpl *offset_base(detail::OpResultImpl *object, @@ -925,6 +961,34 @@ friend RangeBaseT; }; +/// This class implements a use iterator for a range of operation results. +/// This iterates over all uses of all results within the given result range. +class ResultRange::UseIterator final + : public llvm::iterator_facade_base { +public: + /// Initialize the UseIterator. Specify `end` to return iterator to last + /// use, otherwise this is an iterator to the first use. + explicit UseIterator(ResultRange results, bool end = false); + + using llvm::iterator_facade_base::operator++; + UseIterator &operator++(); + OpOperand *operator->() const { return use.getOperand(); } + OpOperand &operator*() const { return *use.getOperand(); } + + bool operator==(const UseIterator &rhs) const { return use == rhs.use; } + bool operator!=(const UseIterator &rhs) const { return !(*this == rhs); } + +private: + void skipOverResultsWithNoUsers(); + + /// The range of results being iterated over. + ResultRange::iterator it, endIt; + /// The use of the result. + Value::use_iterator use; +}; + //===----------------------------------------------------------------------===// // ValueRange diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1314,38 +1314,3 @@ OpBuilder opBuilder(builder.getContext()); ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); } - -//===----------------------------------------------------------------------===// -// UseIterator -//===----------------------------------------------------------------------===// - -Operation::UseIterator::UseIterator(Operation *op, bool end) - : op(op), res(end ? op->result_end() : op->result_begin()) { - // Only initialize current use if there are results/can be uses. - if (op->getNumResults()) - skipOverResultsWithNoUsers(); -} - -Operation::UseIterator &Operation::UseIterator::operator++() { - // We increment over uses, if we reach the last use then move to next - // result. - if (use != (*res).use_end()) - ++use; - if (use == (*res).use_end()) { - ++res; - skipOverResultsWithNoUsers(); - } - return *this; -} - -void Operation::UseIterator::skipOverResultsWithNoUsers() { - while (res != op->result_end() && (*res).use_empty()) - ++res; - - // If we are at the last result, then set use to first use of - // first result (sentinel value used for end). - if (res == op->result_end()) - use = {}; - else - use = (*res).use_begin(); -} diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -551,6 +551,59 @@ MutableOperandRange::OperandSegment(index, object.second)); } +//===----------------------------------------------------------------------===// +// ResultRange + +ResultRange::use_range ResultRange::getUses() const { + return {use_begin(), use_end()}; +} +ResultRange::use_iterator ResultRange::use_begin() const { + return use_iterator(*this); +} +ResultRange::use_iterator ResultRange::use_end() const { + return use_iterator(*this, /*end=*/true); +} +ResultRange::user_range ResultRange::getUsers() { + return {user_begin(), user_end()}; +} +ResultRange::user_iterator ResultRange::user_begin() { + return user_iterator(use_begin()); +} +ResultRange::user_iterator ResultRange::user_end() { + return user_iterator(use_end()); +} + +ResultRange::UseIterator::UseIterator(ResultRange results, bool end) + : it(end ? results.end() : results.begin()), endIt(results.end()) { + // Only initialize current use if there are results/can be uses. + if (it != endIt) + skipOverResultsWithNoUsers(); +} + +ResultRange::UseIterator &ResultRange::UseIterator::operator++() { + // We increment over uses, if we reach the last use then move to next + // result. + if (use != (*it).use_end()) + ++use; + if (use == (*it).use_end()) { + ++it; + skipOverResultsWithNoUsers(); + } + return *this; +} + +void ResultRange::UseIterator::skipOverResultsWithNoUsers() { + while (it != endIt && (*it).use_empty()) + ++it; + + // If we are at the last result, then set use to first use of + // first result (sentinel value used for end). + if (it == endIt) + use = {}; + else + use = (*it).use_begin(); +} + //===----------------------------------------------------------------------===// // ValueRange