diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -176,10 +176,9 @@ /// emitting diagnostics. void printStackTraceOnDiagnostic(bool enable); - /// Return information about all registered operations. This isn't very - /// efficient: typically you should ask the operations about their properties - /// directly. - std::vector getRegisteredOperations(); + /// Return a sorted array containing the information about all registered + /// operations. + ArrayRef getRegisteredOperations(); /// Return true if this operation name is registered in this context. bool isOperationRegistered(StringRef name); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -184,6 +184,10 @@ /// A vector of operation info specifically for registered operations. llvm::StringMap registeredOperations; + /// This is a sorted container of registered operations for a deterministic + /// and efficient `getRegisteredOperations` implementation. + SmallVector sortedRegisteredOperations; + /// A mutex used when accessing operation information. llvm::sys::SmartRWMutex operationInfoMutex; @@ -569,24 +573,9 @@ impl->printStackTraceOnDiagnostic = enable; } -/// Return information about all registered operations. This isn't very -/// efficient, typically you should ask the operations about their properties -/// directly. -std::vector MLIRContext::getRegisteredOperations() { - // We just have the operations in a non-deterministic hash table order. Dump - // into a temporary array, then sort it by operation name to get a stable - // ordering. - auto unwrappedNames = llvm::make_second_range(impl->registeredOperations); - std::vector result(unwrappedNames.begin(), - unwrappedNames.end()); - llvm::array_pod_sort(result.begin(), result.end(), - [](const RegisteredOperationName *lhs, - const RegisteredOperationName *rhs) { - return lhs->getIdentifier().compare( - rhs->getIdentifier()); - }); - - return result; +/// Return information about all registered operations. +ArrayRef MLIRContext::getRegisteredOperations() { + return impl->sortedRegisteredOperations; } bool MLIRContext::isOperationRegistered(StringRef name) { @@ -736,8 +725,19 @@ << "' is already registered.\n"; abort(); } - ctxImpl.registeredOperations.try_emplace(name, - RegisteredOperationName(&impl)); + auto emplaced = ctxImpl.registeredOperations.try_emplace( + name, RegisteredOperationName(&impl)); + assert(emplaced.second && "operation name registration must be successful"); + + // Add emplaced operation name to the sorted operations container. + RegisteredOperationName &value = emplaced.first->getValue(); + ctxImpl.sortedRegisteredOperations.insert( + llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value, + [](auto &lhs, auto &rhs) { + return lhs.getIdentifier().compare( + rhs.getIdentifier()); + }), + value); // Update the registered info for this operation. impl.dialect = &dialect;