diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -25,11 +25,10 @@ /// operations are organized into operation blocks represented by a 'Block' /// class. class Operation final - : public IRMultiObjectWithUseList, - public llvm::ilist_node_with_parent, - private llvm::TrailingObjects { + : public llvm::ilist_node_with_parent, + private llvm::TrailingObjects { public: /// Create a new Operation with the specific fields. static Operation *create(Location location, OperationName name, @@ -490,6 +489,74 @@ } //===--------------------------------------------------------------------===// + // Uses + //===--------------------------------------------------------------------===// + + /// Drop all uses of results of this operation. + void dropAllUses() { + for (OpResult result : getOpResults()) + result.dropAllUses(); + } + + /// This class implements a use iterator for the Operation. This iterates over + /// all uses of all results. + class UseIterator final + : public llvm::iterator_facade_base< + UseIterator, std::forward_iterator_tag, OpOperand> { + public: + /// Initialize UseIterator for op, specify end to return iterator to last + /// use. + explicit UseIterator(Operation *op, bool end = false); + + UseIterator &operator++(); + OpOperand *operator->() const { return use.getOperand(); } + OpOperand &operator*() const { return *use.getOperand(); } + + bool operator==(const UseIterator &rhs) const { return use == rhs.use; } + bool operator!=(const UseIterator &rhs) const { return !(*this == rhs); } + + private: + void skipOverResultsWithNoUsers(); + + /// The operation whose uses are being iterated over. + Operation *op; + /// The result of op who's uses are being iterated over. + Operation::result_iterator res; + /// The use of the result. + Value::use_iterator use; + }; + using use_iterator = UseIterator; + using use_range = iterator_range; + + use_iterator use_begin() { return use_iterator(this); } + use_iterator use_end() { return use_iterator(this, /*end=*/true); } + + /// Returns a range of all uses, which is useful for iterating over all uses. + use_range getUses() { return {use_begin(), use_end()}; } + + /// Returns true if this operation has exactly one use. + bool hasOneUse() { return llvm::hasSingleElement(getUses()); } + + /// Returns true if this operation has no uses. + bool use_empty() { + return llvm::all_of(getOpResults(), + [](OpResult result) { return result.use_empty(); }); + } + + //===--------------------------------------------------------------------===// + // Users + //===--------------------------------------------------------------------===// + + using user_iterator = ValueUserIterator; + using user_range = iterator_range; + + user_iterator user_begin() { return user_iterator(use_begin()); } + user_iterator user_end() { return user_iterator(use_end()); } + + /// Returns a range of all users. + user_range getUsers() { return {user_begin(), user_end()}; } + + //===--------------------------------------------------------------------===// // Other //===--------------------------------------------------------------------===// @@ -543,13 +610,14 @@ return *getTrailingObjects(); } - /// Returns a raw pointer to the storage for the given trailing result. The - /// given result number should be 0-based relative to the trailing results, - /// and not all of the results of the operation. This method should generally - /// only be used by the 'Value' classes. - detail::TrailingOpResult *getTrailingResult(unsigned trailingResultNumber) { - return getTrailingObjects() + - trailingResultNumber; + /// Returns a pointer to the use list for the given trailing result. + detail::TrailingOpResult *getTrailingResult(unsigned resultNumber) { + return getTrailingObjects() + resultNumber; + } + + /// Returns a pointer to the use list for the given inline result. + detail::InLineOpResult *getInlineResult(unsigned resultNumber) { + return getTrailingObjects() + resultNumber; } /// Provide a 'getParent' method for ilist_node_with_parent methods. @@ -595,15 +663,20 @@ // allow block to access the 'orderIndex' field. friend class Block; - // allow value to access the 'getTrailingResult' method. + // allow value to access the 'ResultStorage' methods. friend class Value; // allow ilist_node_with_parent to access the 'getParent' method. friend class llvm::ilist_node_with_parent; // This stuff is used by the TrailingObjects template. - friend llvm::TrailingObjects; + friend llvm::TrailingObjects; + size_t numTrailingObjects(OverloadToken) const { + return OpResult::getNumInline( + const_cast(this)->getNumResults()); + } size_t numTrailingObjects(OverloadToken) const { return OpResult::getNumTrailing( const_cast(this)->getNumResults()); diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -467,14 +467,31 @@ } // end namespace detail //===----------------------------------------------------------------------===// -// TrailingOpResult +// ResultStorage //===----------------------------------------------------------------------===// namespace detail { -/// This class provides the implementation for a trailing operation result. -struct TrailingOpResult { - /// The only element is the trailing result number, or the offset from the - /// beginning of the trailing array. +/// This class provides the implementation for an in-line operation result. This +/// is an operation result whose number can be stored inline inside of the bits +/// of an Operation*. +struct InLineOpResult : public IRObjectWithUseList {}; +/// This class provides the implementation for an out-of-line operation result. +/// This is an operation result whose number cannot be stored inline inside of +/// the bits of an Operation*. +struct TrailingOpResult : public IRObjectWithUseList { + TrailingOpResult(uint64_t trailingResultNumber) + : trailingResultNumber(trailingResultNumber) {} + + /// Returns the parent operation of this trailing result. + Operation *getOwner(); + + /// Return the proper result number of this op result. + unsigned getResultNumber() { + return trailingResultNumber + OpResult::getMaxInlineResults(); + } + + /// The trailing result number, or the offset from the beginning of the + /// trailing array. uint64_t trailingResultNumber; }; } // end namespace detail diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -100,86 +100,6 @@ }; //===----------------------------------------------------------------------===// -// IRMultiObjectWithUseList -//===----------------------------------------------------------------------===// - -/// This class represents multiple IR objects with a single use list. This class -/// provides wrapper functionality for manipulating the uses of a single object. -template -class IRMultiObjectWithUseList : public IRObjectWithUseList { -public: - using BaseType = IRObjectWithUseList; - using ValueType = typename OperandType::ValueType; - - /// Drop all uses of `value` from their respective owners. - void dropAllUses(ValueType value) { - assert(this == OperandType::getUseList(value) && - "value not attached to this use list"); - for (OperandType &use : llvm::make_early_inc_range(getUses(value))) - use.drop(); - } - using BaseType::dropAllUses; - - /// Replace all uses of `oldValue` with the new value, updating anything in - /// the IR that uses 'this' to use the other value instead. When this returns - /// there are zero uses of 'this'. - void replaceAllUsesWith(ValueType oldValue, ValueType newValue) { - assert(this == OperandType::getUseList(oldValue) && - "value not attached to this use list"); - assert((!newValue || this != OperandType::getUseList(newValue)) && - "cannot RAUW a value with itself"); - for (OperandType &use : llvm::make_early_inc_range(getUses(oldValue))) - use.set(newValue); - } - using BaseType::replaceAllUsesWith; - - //===--------------------------------------------------------------------===// - // Uses - //===--------------------------------------------------------------------===// - - using filtered_use_iterator = FilteredValueUseIterator; - using filtered_use_range = iterator_range; - - filtered_use_iterator use_begin(ValueType value) const { - return filtered_use_iterator(this->getFirstUse(), value); - } - filtered_use_iterator use_end(ValueType) const { return use_end(); } - filtered_use_range getUses(ValueType value) const { - return {use_begin(value), use_end(value)}; - } - bool hasOneUse(ValueType value) const { - return llvm::hasSingleElement(getUses(value)); - } - bool use_empty(ValueType value) const { - return use_begin(value) == use_end(value); - } - using BaseType::getUses; - using BaseType::hasOneUse; - using BaseType::use_begin; - using BaseType::use_empty; - using BaseType::use_end; - - //===--------------------------------------------------------------------===// - // Users - //===--------------------------------------------------------------------===// - - using filtered_user_iterator = - ValueUserIterator; - using filtered_user_range = iterator_range; - - filtered_user_iterator user_begin(ValueType value) const { - return {use_begin(value)}; - } - filtered_user_iterator user_end(ValueType value) const { return {use_end()}; } - filtered_user_range getUsers(ValueType value) const { - return {user_begin(value), user_end(value)}; - } - using BaseType::getUsers; - using BaseType::user_begin; - using BaseType::user_end; -}; - -//===----------------------------------------------------------------------===// // IROperand //===----------------------------------------------------------------------===// @@ -209,6 +129,9 @@ insertIntoCurrent(); } + /// Returns true if this operand contains the given value. + bool is(ValueType other) const { return value == other; } + /// Return the owner of this operand. Operation *getOwner() { return owner; } Operation *getOwner() const { return owner; } @@ -310,6 +233,7 @@ OpaqueValue(std::nullptr_t = nullptr) : impl(nullptr) {} OpaqueValue(const OpaqueValue &) = default; OpaqueValue &operator=(const OpaqueValue &) = default; + bool operator==(const OpaqueValue &other) const { return impl == other.impl; } operator bool() const { return impl; } /// Implicit conversion back to 'Value'. @@ -320,7 +244,8 @@ }; } // namespace detail -/// A reference to a value, suitable for use as an operand of an operation. +/// This class represents an operand of an operation. Instances of this class +/// contain a reference to a specific `Value`. class OpOperand : public IROperand { public: using IROperand::IROperand; @@ -342,34 +267,33 @@ // ValueUseIterator //===----------------------------------------------------------------------===// -namespace detail { -/// A base iterator class that allows for iterating over the uses of a value. -/// This is templated to allow for derived iterators to override specific -/// iterator methods. -template -class ValueUseIteratorImpl - : public llvm::iterator_facade_base +class ValueUseIterator + : public llvm::iterator_facade_base, + std::forward_iterator_tag, OperandType> { public: - template - ValueUseIteratorImpl(const ValueUseIteratorImpl &other) - : current(other.getOperand()) {} - ValueUseIteratorImpl(OperandType *current = nullptr) : current(current) {} + ValueUseIterator(OperandType *current = nullptr) : current(current) {} + /// Returns the user that owns this use. Operation *getUser() const { return current->getOwner(); } - OperandType *getOperand() const { return current; } + /// Returns the current operands. + OperandType *getOperand() const { return current; } OperandType &operator*() const { return *current; } - using llvm::iterator_facade_base, + std::forward_iterator_tag, OperandType>::operator++; - ValueUseIteratorImpl &operator++() { + ValueUseIterator &operator++() { assert(current && "incrementing past end()!"); current = (OperandType *)current->getNextOperandUsingThisValue(); return *this; } - bool operator==(const ValueUseIteratorImpl &rhs) const { + bool operator==(const ValueUseIterator &rhs) const { return current == rhs.current; } @@ -377,63 +301,12 @@ OperandType *current; }; -} // end namespace detail - -/// An iterator over all of the uses of an IR object. -template -class ValueUseIterator - : public detail::ValueUseIteratorImpl, - OperandType> { -public: - using detail::ValueUseIteratorImpl, - OperandType>::ValueUseIteratorImpl; -}; - -/// This class represents an iterator of the uses of a IR object that optionally -/// filters on a specific sub-value. This allows for filtering the uses of an -/// IRMultiObjectWithUseList. -template -class FilteredValueUseIterator - : public detail::ValueUseIteratorImpl, - OperandType> { -public: - using BaseT = - detail::ValueUseIteratorImpl, - OperandType>; - - FilteredValueUseIterator() = default; - FilteredValueUseIterator(const ValueUseIterator &it) - : BaseT(it), filterVal(nullptr) {} - FilteredValueUseIterator(OperandType *current, - typename OperandType::ValueType filterVal) - : BaseT(current), filterVal(filterVal) { - findNextValid(); - } - - using BaseT::operator++; - FilteredValueUseIterator &operator++() { - BaseT::operator++(); - findNextValid(); - return *this; - } - -private: - void findNextValid() { - if (!filterVal) - return; - while (this->current && ((OperandType *)this->current)->get() != filterVal) - BaseT::operator++(); - } - - /// An optional value to use to filter specific uses. - typename OperandType::ValueType filterVal; -}; - //===----------------------------------------------------------------------===// // ValueUserIterator //===----------------------------------------------------------------------===// -/// An iterator over all users of a ValueBase. +/// An iterator over the users of an IRObject. This is a wrapper iterator around +/// a specific use iterator. template class ValueUserIterator final : public llvm::mapped_iterator( it, &unwrap) {} diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -149,7 +149,7 @@ // Uses /// This class implements an iterator over the uses of a value. - using use_iterator = FilteredValueUseIterator; + using use_iterator = ValueUseIterator; using use_range = iterator_range; use_iterator use_begin() const; @@ -300,12 +300,21 @@ /// Returns the number of this result. unsigned getResultNumber() const; + /// Returns the maximum number of results that can be stored inline. + static unsigned getMaxInlineResults() { + return static_cast(Kind::TrailingOpResult); + } + private: /// Given a number of operation results, returns the number that need to be + /// stored inline. + static unsigned getNumInline(unsigned numResults); + + /// Given a number of operation results, returns the number that need to be /// stored as trailing. static unsigned getNumTrailing(unsigned numResults); - /// Allow access to `create` and `destroy`. + /// Allow access to constructor. friend Operation; }; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -103,14 +103,16 @@ bool resizableOperandList) { // We only need to allocate additional memory for a subset of results. unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size()); + unsigned numInlineResults = OpResult::getNumInline(resultTypes.size()); unsigned numSuccessors = successors.size(); unsigned numOperands = operands.size(); // Compute the byte size for the operation and the operand storage. - auto byteSize = totalSizeToAlloc( - numTrailingResults, numSuccessors, numRegions, - /*detail::OperandStorage*/ 1); + auto byteSize = + totalSizeToAlloc( + numInlineResults, numTrailingResults, numSuccessors, numRegions, + /*detail::OperandStorage*/ 1); byteSize += llvm::alignTo(detail::OperandStorage::additionalAllocSize( numOperands, resizableOperandList), alignof(Operation)); @@ -123,16 +125,11 @@ assert((numSuccessors == 0 || !op->isKnownNonTerminator()) && "unexpected successors in a non-terminator operation"); - // Initialize the trailing results. - if (LLVM_UNLIKELY(numTrailingResults > 0)) { - // We initialize the trailing results with their result number. This makes - // 'getResultNumber' checks much more efficient. The main purpose for these - // results is to give an anchor to the main operation anyways, so this is - // purely an optimization. - auto *trailingResultIt = op->getTrailingObjects(); - for (unsigned i = 0; i != numTrailingResults; ++i, ++trailingResultIt) - trailingResultIt->trailingResultNumber = i; - } + // Initialize the results. + for (unsigned i = 0; i < numInlineResults; ++i) + new (op->getInlineResult(i)) detail::InLineOpResult(); + for (unsigned i = 0; i < numTrailingResults; ++i) + new (op->getTrailingResult(i)) detail::TrailingOpResult(i); // Initialize the regions. for (unsigned i = 0; i != numRegions; ++i) @@ -1072,3 +1069,38 @@ block.push_back(buildTerminatorOp()); } + +//===----------------------------------------------------------------------===// +// UseIterator +//===----------------------------------------------------------------------===// + +Operation::UseIterator::UseIterator(Operation *op, bool end) + : op(op), res(end ? op->result_end() : op->result_begin()) { + // Only initialize current use if there are results/can be uses. + if (op->getNumResults()) + skipOverResultsWithNoUsers(); +} + +Operation::UseIterator &Operation::UseIterator::operator++() { + // We increment over uses, if we reach the last use then move to next + // result. + if (use != (*res).use_end()) + ++use; + if (use == (*res).use_end()) { + ++res; + skipOverResultsWithNoUsers(); + } + return *this; +} + +void Operation::UseIterator::skipOverResultsWithNoUsers() { + while (res != op->result_end() && (*res).use_empty()) + ++res; + + // If we are at the last result, then set use to first use of + // first result (sentinel value used for end). + if (res == op->result_end()) + use = {}; + else + use = (*res).use_begin(); +} diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -133,6 +133,22 @@ } //===----------------------------------------------------------------------===// +// ResultStorage +//===----------------------------------------------------------------------===// + +/// Returns the parent operation of this trailing result. +Operation *detail::TrailingOpResult::getOwner() { + // We need to do some arithmetic to get the operation pointer. Move the + // trailing owner to the start of the array. + TrailingOpResult *trailingIt = this - trailingResultNumber; + + // Move the owner past the inline op results to get to the operation. + auto *inlineResultIt = reinterpret_cast(trailingIt) - + OpResult::getMaxInlineResults(); + return reinterpret_cast(inlineResultIt) - 1; +} + +//===----------------------------------------------------------------------===// // Operation Value-Iterators //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -11,10 +11,12 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/StandardTypes.h" #include "llvm/ADT/SmallPtrSet.h" + using namespace mlir; +using namespace mlir::detail; /// Construct a value. -Value::Value(detail::BlockArgumentImpl *impl) +Value::Value(BlockArgumentImpl *impl) : ownerAndKind(impl, Kind::BlockArgument) {} Value::Value(Operation *op, unsigned resultNo) { assert(op->getNumResults() > resultNo && "invalid result number"); @@ -23,12 +25,9 @@ return; } - // If we can't pack the result directly, we need to represent this as a - // trailing result. - unsigned trailingResultNo = - resultNo - static_cast(Kind::TrailingOpResult); - ownerAndKind = {op->getTrailingResult(trailingResultNo), - Kind::TrailingOpResult}; + // If we can't pack the result directly, grab the use list from the parent op. + unsigned trailingNo = resultNo - OpResult::getMaxInlineResults(); + ownerAndKind = {op->getTrailingResult(trailingNo), Kind::TrailingOpResult}; } /// Return the type of this value. @@ -96,30 +95,23 @@ IRObjectWithUseList *Value::getUseList() const { if (BlockArgument arg = dyn_cast()) return arg.getImpl(); - return cast().getOwner(); + if (getKind() != Kind::TrailingOpResult) { + OpResult result = cast(); + return result.getOwner()->getInlineResult(result.getResultNumber()); + } + + // Otherwise this is a trailing operation result, which contains a use list. + return reinterpret_cast(ownerAndKind.getPointer()); } /// Drop all uses of this object from their respective owners. -void Value::dropAllUses() const { - if (BlockArgument arg = dyn_cast()) - return arg.getImpl()->dropAllUses(); - Operation *owner = cast().getOwner(); - if (owner->hasSingleResult) - return owner->dropAllUses(); - return owner->dropAllUses(*this); -} +void Value::dropAllUses() const { return getUseList()->dropAllUses(); } /// Replace all uses of 'this' value with the new value, updating anything in /// the IR that uses 'this' to use the other value instead. When this returns /// there are zero uses of 'this'. void Value::replaceAllUsesWith(Value newValue) const { - if (BlockArgument arg = dyn_cast()) - return arg.getImpl()->replaceAllUsesWith(newValue); - Operation *owner = cast().getOwner(); - IRMultiObjectWithUseList *useList = owner; - if (owner->hasSingleResult) - return useList->replaceAllUsesWith(newValue); - useList->replaceAllUsesWith(*this, newValue); + return getUseList()->replaceAllUsesWith(newValue); } /// Replace all uses of 'this' value with the new value, updating anything in @@ -137,28 +129,14 @@ // Uses auto Value::use_begin() const -> use_iterator { - if (BlockArgument arg = dyn_cast()) - return arg.getImpl()->use_begin(); - Operation *owner = cast().getOwner(); - return owner->hasSingleResult ? use_iterator(owner->use_begin()) - : owner->use_begin(*this); + return getUseList()->use_begin(); } /// Returns true if this value has exactly one use. -bool Value::hasOneUse() const { - if (BlockArgument arg = dyn_cast()) - return arg.getImpl()->hasOneUse(); - Operation *owner = cast().getOwner(); - return owner->hasSingleResult ? owner->hasOneUse() : owner->hasOneUse(*this); -} +bool Value::hasOneUse() const { return getUseList()->hasOneUse(); } /// Returns true if this value has no uses. -bool Value::use_empty() const { - if (BlockArgument arg = dyn_cast()) - return arg.getImpl()->use_empty(); - Operation *owner = cast().getOwner(); - return owner->hasSingleResult ? owner->use_empty() : owner->use_empty(*this); -} +bool Value::use_empty() const { return getUseList()->use_empty(); } //===----------------------------------------------------------------------===// // OpResult @@ -167,18 +145,12 @@ /// Returns the operation that owns this result. Operation *OpResult::getOwner() const { // If the result is in-place, the `owner` is the operation. + void *owner = ownerAndKind.getPointer(); if (LLVM_LIKELY(getKind() != Kind::TrailingOpResult)) - return reinterpret_cast(ownerAndKind.getPointer()); + return static_cast(owner); - // Otherwise, we need to do some arithmetic to get the operation pointer. - // Move the trailing owner to the start of the array. - auto *trailingIt = - static_cast(ownerAndKind.getPointer()); - trailingIt -= trailingIt->trailingResultNumber; - - // This point is the first trailing object after the operation. So all we need - // to do here is adjust for the operation size. - return reinterpret_cast(trailingIt) - 1; + // Otherwise, query the trailing result for the owner. + return static_cast(owner)->getOwner(); } /// Return the result number of this result. @@ -186,20 +158,23 @@ // If the result is in-place, we can use the kind directly. if (LLVM_LIKELY(getKind() != Kind::TrailingOpResult)) return static_cast(ownerAndKind.getInt()); - // Otherwise, we add the number of inline results to the trailing owner. - auto *trailingIt = - static_cast(ownerAndKind.getPointer()); - unsigned trailingNumber = trailingIt->trailingResultNumber; - return trailingNumber + static_cast(Kind::TrailingOpResult); + // Otherwise, query the trailing result. + auto *result = static_cast(ownerAndKind.getPointer()); + return result->getResultNumber(); +} + +/// Given a number of operation results, returns the number that need to be +/// stored inline. +unsigned OpResult::getNumInline(unsigned numResults) { + return std::min(numResults, getMaxInlineResults()); } /// Given a number of operation results, returns the number that need to be /// stored as trailing. unsigned OpResult::getNumTrailing(unsigned numResults) { // If we can pack all of the results, there is no need for additional storage. - if (numResults <= static_cast(Kind::TrailingOpResult)) - return 0; - return numResults - static_cast(Kind::TrailingOpResult); + unsigned maxInline = getMaxInlineResults(); + return numResults <= maxInline ? 0 : numResults - maxInline; } //===----------------------------------------------------------------------===// @@ -227,12 +202,12 @@ /// Return the current value being used by this operand. Value OpOperand::get() const { - return IROperand::get(); + return IROperand::get(); } /// Set the operand to the given value. void OpOperand::set(Value value) { - IROperand::set(value); + IROperand::set(value); } /// Return which operand this is in the operand list. @@ -241,14 +216,13 @@ } //===----------------------------------------------------------------------===// -// detail::OpaqueValue +// OpaqueValue //===----------------------------------------------------------------------===// /// Implicit conversion from 'Value'. -detail::OpaqueValue::OpaqueValue(Value value) - : impl(value.getAsOpaquePointer()) {} +OpaqueValue::OpaqueValue(Value value) : impl(value.getAsOpaquePointer()) {} /// Implicit conversion back to 'Value'. -detail::OpaqueValue::operator Value() const { +OpaqueValue::operator Value() const { return Value::getFromOpaquePointer(impl); }