diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -406,22 +406,27 @@ private: /// Helper for sanity checking preconditions for create* methods below. - void checkHasRegisteredInfo(const OperationName &name) { - if (LLVM_UNLIKELY(!name.isRegistered())) + template + RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) { + Optional opName = + RegisteredOperationName::lookup(OpT::getOperationName(), ctx); + if (LLVM_UNLIKELY(!opName)) { llvm::report_fatal_error( - "Building op `" + name.getStringRef() + + "Building op `" + OpT::getOperationName() + "` but it isn't registered in this MLIRContext: the dialect may not " "be loaded or this operation isn't registered by the dialect. See " "also https://mlir.llvm.org/getting_started/Faq/" "#registered-loaded-dependent-whats-up-with-dialects-management"); + } + return *opName; } public: /// Create an operation of specific op type at the current insertion point. template OpTy create(Location location, Args &&...args) { - OperationState state(location, OpTy::getOperationName()); - checkHasRegisteredInfo(state.name); + OperationState state(location, + getCheckRegisteredInfo(location.getContext())); OpTy::build(*this, state, std::forward(args)...); auto *op = createOperation(state); auto result = dyn_cast(op); @@ -437,8 +442,8 @@ Args &&...args) { // Create the operation without using 'createOperation' as we don't want to // insert it yet. - OperationState state(location, OpTy::getOperationName()); - checkHasRegisteredInfo(state.name); + OperationState state(location, + getCheckRegisteredInfo(location.getContext())); OpTy::build(*this, state, std::forward(args)...); Operation *op = Operation::create(state); diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -231,9 +231,7 @@ /// Lookup the registered operation information for the given operation. /// Returns None if the operation isn't registered. static Optional lookup(StringRef name, - MLIRContext *ctx) { - return OperationName(name, ctx).getRegisteredInfo(); - } + MLIRContext *ctx); /// Register a new operation in a Dialect object. /// This constructor is used by Dialect objects when they register the list of @@ -582,9 +580,12 @@ public: OperationState(Location location, StringRef name); - OperationState(Location location, OperationName name); + OperationState(Location location, OperationName name, ValueRange operands, + TypeRange types, ArrayRef attributes, + BlockRange successors = {}, + MutableArrayRef> regions = {}); OperationState(Location location, StringRef name, ValueRange operands, TypeRange types, ArrayRef attributes, BlockRange successors = {}, diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -1406,9 +1406,9 @@ // name that works both in scalar mode and vector mode. // TODO: Is it worth considering an Operation.clone operation which // changes the type so we can promote an Operation with less boilerplate? - OperationState vecOpState(op->getLoc(), op->getName().getStringRef(), - vectorOperands, vectorTypes, op->getAttrs(), - /*successors=*/{}, /*regions=*/{}); + OperationState vecOpState(op->getLoc(), op->getName(), vectorOperands, + vectorTypes, op->getAttrs(), /*successors=*/{}, + /*regions=*/{}); Operation *vecOp = state.builder.createOperation(vecOpState); state.registerOpVectorReplacement(op, vecOp); return vecOp; diff --git a/mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp --- a/mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp +++ b/mlir/lib/Dialect/Vector/VectorUnrollDistribute.cpp @@ -70,8 +70,7 @@ Operation *op, ArrayRef operands, ArrayRef resultTypes) { - OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, - op->getAttrs()); + OperationState res(loc, op->getName(), operands, resultTypes, op->getAttrs()); return builder.createOperation(res); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -182,7 +182,7 @@ llvm::StringMap operations; /// A vector of operation info specifically for registered operations. - SmallVector registeredOperations; + llvm::StringMap registeredOperations; /// A mutex used when accessing operation information. llvm::sys::SmartRWMutex operationInfoMutex; @@ -576,8 +576,9 @@ // We just have the operations in a non-deterministic hash table order. Dump // into a temporary array, then sort it by operation name to get a stable // ordering. - std::vector result( - impl->registeredOperations.begin(), impl->registeredOperations.end()); + auto unwrappedNames = llvm::make_second_range(impl->registeredOperations); + std::vector result(unwrappedNames.begin(), + unwrappedNames.end()); llvm::array_pod_sort(result.begin(), result.end(), [](const RegisteredOperationName *lhs, const RegisteredOperationName *rhs) { @@ -589,7 +590,7 @@ } bool MLIRContext::isOperationRegistered(StringRef name) { - return OperationName(name, this).isRegistered(); + return RegisteredOperationName::lookup(name, this).hasValue(); } void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { @@ -649,6 +650,15 @@ // Check for an existing name in read-only mode. bool isMultithreadingEnabled = context->isMultithreadingEnabled(); if (isMultithreadingEnabled) { + // Check the registered info map first. In the overwhelmingly common case, + // the entry will be in here and it also removes the need to acquire any + // locks. + auto registeredIt = ctxImpl.registeredOperations.find(name); + if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) { + impl = registeredIt->second.impl; + return; + } + llvm::sys::SmartScopedReader contextLock(ctxImpl.operationInfoMutex); auto it = ctxImpl.operations.find(name); if (it != ctxImpl.operations.end()) { @@ -676,6 +686,15 @@ // RegisteredOperationName //===----------------------------------------------------------------------===// +Optional +RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) { + auto &impl = ctx->getImpl(); + auto it = impl.registeredOperations.find(name); + if (it != impl.registeredOperations.end()) + return it->getValue(); + return llvm::None; +} + ParseResult RegisteredOperationName::parseAssembly(OpAsmParser &parser, OperationState &result) const { @@ -717,7 +736,8 @@ << "' is already registered.\n"; abort(); } - ctxImpl.registeredOperations.push_back(RegisteredOperationName(&impl)); + ctxImpl.registeredOperations.try_emplace(name, + RegisteredOperationName(&impl)); // Update the registered info for this operation. impl.dialect = &dialect; diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -170,12 +170,12 @@ OperationState::OperationState(Location location, OperationName name) : location(location), name(name) {} -OperationState::OperationState(Location location, StringRef name, +OperationState::OperationState(Location location, OperationName name, ValueRange operands, TypeRange types, ArrayRef attributes, BlockRange successors, MutableArrayRef> regions) - : location(location), name(name, location->getContext()), + : location(location), name(name), operands(operands.begin(), operands.end()), types(types.begin(), types.end()), attributes(attributes.begin(), attributes.end()), @@ -183,6 +183,13 @@ for (std::unique_ptr &r : regions) this->regions.push_back(std::move(r)); } +OperationState::OperationState(Location location, StringRef name, + ValueRange operands, TypeRange types, + ArrayRef attributes, + BlockRange successors, + MutableArrayRef> regions) + : OperationState(location, OperationName(name, location.getContext()), + operands, types, attributes, successors, regions) {} void OperationState::addOperands(ValueRange newOperands) { operands.append(newOperands.begin(), newOperands.end());