diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -327,7 +327,9 @@ /// If this operation has a registered operation description, return it. /// Otherwise return null. - const AbstractOperation *getAbstractOperation() const; + const AbstractOperation *getAbstractOperation() const { + return representation.dyn_cast(); + } void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h --- a/mlir/include/mlir/Support/InterfaceSupport.h +++ b/mlir/include/mlir/Support/InterfaceSupport.h @@ -152,10 +152,8 @@ public: InterfaceMap(InterfaceMap &&) = default; ~InterfaceMap() { - if (interfaces) { - for (auto &it : *interfaces) - free(it.second); - } + for (auto &it : interfaces) + free(it.second); } /// Construct an InterfaceMap with the given set of template types. For @@ -182,15 +180,22 @@ /// Returns an instance of the concept object for the given interface if it /// was registered to this map, null otherwise. template typename T::Concept *lookup() const { - void *inst = interfaces ? interfaces->lookup(T::getInterfaceID()) : nullptr; - return reinterpret_cast(inst); + return reinterpret_cast(lookup(T::getInterfaceID())); } private: + /// Compare two TypeID instances by comparing the underlying pointer. + static bool compare(TypeID lhs, TypeID rhs) { + return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer(); + } + InterfaceMap() = default; InterfaceMap(MutableArrayRef> elements) - : interfaces(std::make_unique>( - elements.begin(), elements.end())) {} + : interfaces(elements.begin(), elements.end()) { + llvm::sort(interfaces, [](const auto &lhs, const auto &rhs) { + return compare(lhs.first, rhs.first); + }); + } template static InterfaceMap getImpl(std::tuple *) { @@ -200,9 +205,17 @@ return InterfaceMap(elements); } - /// The internal map of interfaces. This is constructed statically for each - /// set of interfaces. - std::unique_ptr> interfaces; + /// Returns an instance of the concept object for the given interface id if it + /// was registered to this map, null otherwise. + void *lookup(TypeID id) const { + auto it = llvm::lower_bound(interfaces, id, [](const auto &it, TypeID id) { + return compare(it.first, id); + }); + return (it != interfaces.end() && it->first == id) ? it->second : nullptr; + } + + /// A list of interface instances, sorted by TypeID. + SmallVector> interfaces; }; } // end namespace detail diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -57,10 +57,6 @@ return representation.get(); } -const AbstractOperation *OperationName::getAbstractOperation() const { - return representation.dyn_cast(); -} - OperationName OperationName::getFromOpaquePointer(const void *pointer) { return OperationName( RepresentationUnion::getFromOpaqueValue(const_cast(pointer)));