diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -258,10 +258,12 @@ }; /// Registers all dialects and hooks from the global registries with the /// specified MLIRContext. +/// Note: This method is not thread-safe. void registerAllDialects(MLIRContext *context); /// Utility to register a dialect. Client can register their dialect with the /// global registry by calling registerDialect(); +/// Note: This method is not thread-safe. template void registerDialect() { Dialect::registerDialectAllocator(TypeID::get(), [](MLIRContext *ctx) { diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -270,10 +270,6 @@ // Other //===--------------------------------------------------------------------===// - /// A general purpose mutex to lock access to parts of the context that do not - /// have a more specific mutex, e.g. registry operations. - llvm::sys::SmartRWMutex contextMutex; - /// This is a list of dialects that are created referring to this context. /// The MLIRContext owns the objects. std::vector> dialects; @@ -425,8 +421,6 @@ /// Return information about all registered IR dialects. std::vector MLIRContext::getRegisteredDialects() { - // Lock access to the context registry. - ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled); std::vector result; result.reserve(impl->dialects.size()); for (auto &dialect : impl->dialects) @@ -437,9 +431,6 @@ /// Get a registered IR dialect with the given namespace. If none is found, /// then return nullptr. Dialect *MLIRContext::getRegisteredDialect(StringRef name) { - // Lock access to the context registry. - ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled); - // Dialects are sorted by name, so we can use binary search for lookup. auto it = llvm::lower_bound( impl->dialects, name, @@ -455,9 +446,6 @@ auto &impl = context->getImpl(); std::unique_ptr dialect(this); - // Lock access to the context registry. - ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled); - // Get the correct insertion position sorted by namespace. auto insertPt = llvm::lower_bound( impl.dialects, dialect, [](const auto &lhs, const auto &rhs) { @@ -524,35 +512,26 @@ /// efficient, typically you should ask the operations about their properties /// directly. std::vector MLIRContext::getRegisteredOperations() { - std::vector> opsToSort; - - { // Lock access to the context registry. - ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled); - - // We just have the operations in a non-deterministic hash table order. Dump - // into a temporary array, then sort it by operation name to get a stable - // ordering. - llvm::StringMap ®isteredOps = - impl->registeredOperations; - - opsToSort.reserve(registeredOps.size()); - for (auto &elt : registeredOps) - opsToSort.push_back({elt.first(), &elt.second}); - } - - llvm::array_pod_sort(opsToSort.begin(), opsToSort.end()); + // We just have the operations in a non-deterministic hash table order. Dump + // into a temporary array, then sort it by operation name to get a stable + // ordering. + llvm::StringMap ®isteredOps = + impl->registeredOperations; std::vector result; - result.reserve(opsToSort.size()); - for (auto &elt : opsToSort) - result.push_back(elt.second); + result.reserve(registeredOps.size()); + for (auto &elt : registeredOps) + result.push_back(&elt.second); + llvm::array_pod_sort( + result.begin(), result.end(), + [](AbstractOperation *const *lhs, AbstractOperation *const *rhs) { + return (*lhs)->name.compare((*rhs)->name); + }); + return result; } bool MLIRContext::isOperationRegistered(StringRef name) { - // Lock access to the context registry. - ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled); - return impl->registeredOperations.count(name); } @@ -561,12 +540,9 @@ "op name doesn't start with dialect namespace"); assert(&opInfo.dialect == this && "Dialect object mismatch"); auto &impl = context->getImpl(); - - // Lock access to the context registry. StringRef opName = opInfo.name; - ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled); if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) { - llvm::errs() << "error: operation named '" << opName + llvm::errs() << "error: operation named '" << opInfo.name << "' is already registered.\n"; abort(); } @@ -574,9 +550,6 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { auto &impl = context->getImpl(); - - // Lock access to the context registry. - ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled); auto *newInfo = new (impl.abstractDialectSymbolAllocator.Allocate()) AbstractType(std::move(typeInfo)); @@ -586,9 +559,6 @@ void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { auto &impl = context->getImpl(); - - // Lock access to the context registry. - ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled); auto *newInfo = new (impl.abstractDialectSymbolAllocator.Allocate()) AbstractAttribute(std::move(attrInfo)); @@ -612,9 +582,6 @@ const AbstractOperation *AbstractOperation::lookup(StringRef opName, MLIRContext *context) { auto &impl = context->getImpl(); - - // Lock access to the context registry. - ScopedReaderLock registryLock(impl.contextMutex, impl.threadingIsEnabled); auto it = impl.registeredOperations.find(opName); if (it != impl.registeredOperations.end()) return &it->second;