diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -408,8 +408,8 @@ private: /// Helper for sanity checking preconditions for create* methods below. - void checkHasAbstractOperation(const OperationName &name) { - if (LLVM_UNLIKELY(!name.getAbstractOperation())) + void checkHasRegisteredInfo(const OperationName &name) { + if (LLVM_UNLIKELY(!name.isRegistered())) llvm::report_fatal_error( "Building op `" + name.getStringRef() + "` but it isn't registered in this MLIRContext: the dialect may not " @@ -423,7 +423,7 @@ template OpTy create(Location location, Args &&...args) { OperationState state(location, OpTy::getOperationName()); - checkHasAbstractOperation(state.name); + checkHasRegisteredInfo(state.name); OpTy::build(*this, state, std::forward(args)...); auto *op = createOperation(state); auto result = dyn_cast(op); @@ -440,7 +440,7 @@ // Create the operation without using 'createOperation' as we don't want to // insert it yet. OperationState state(location, OpTy::getOperationName()); - checkHasAbstractOperation(state.name); + checkHasRegisteredInfo(state.name); OpTy::build(*this, state, std::forward(args)...); Operation *op = Operation::create(state); diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -114,7 +114,7 @@ /// Return the hook to parse an operation registered to this dialect, if any. /// By default this will lookup for registered operations and return the - /// `parse()` method registered on the AbstractOperation. Dialects can + /// `parse()` method registered on the RegisteredOperationName. Dialects can /// override this behavior and handle unregistered operations as well. virtual Optional getParseOperationHook(StringRef opName) const; @@ -194,7 +194,7 @@ /// template void addOperations() { (void)std::initializer_list{ - 0, (AbstractOperation::insert(*this), 0)...}; + 0, (RegisteredOperationName::insert(*this), 0)...}; } /// Register a set of type classes with this dialect. diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -20,7 +20,6 @@ } // end namespace llvm namespace mlir { -class AbstractOperation; class DebugActionManager; class DiagnosticEngine; class Dialect; @@ -28,6 +27,7 @@ class InFlightDiagnostic; class Location; class MLIRContextImpl; +class RegisteredOperationName; class StorageUniquer; /// MLIRContext is the top-level object for a collection of MLIR operations. It @@ -172,7 +172,7 @@ /// Return information about all registered operations. This isn't very /// efficient: typically you should ask the operations about their properties /// directly. - std::vector getRegisteredOperations(); + std::vector getRegisteredOperations(); /// Return true if this operation name is registered in this context. bool isOperationRegistered(StringRef name); diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -191,7 +191,7 @@ Operation *state; /// Allow access to internal hook implementation methods. - friend AbstractOperation; + friend RegisteredOperationName; }; // Allow comparing operators. @@ -1585,8 +1585,8 @@ /// Return true if this "op class" can match against the specified operation. static bool classof(Operation *op) { - if (auto *abstractOp = op->getAbstractOperation()) - return TypeID::get() == abstractOp->typeID; + if (auto info = op->getRegisteredInfo()) + return TypeID::get() == info->getTypeID(); #ifndef NDEBUG if (op->getName().getStringRef() == ConcreteType::getOperationName()) llvm::report_fatal_error( @@ -1628,13 +1628,13 @@ /// for the concrete operation. template static void attachInterface(MLIRContext &context) { - AbstractOperation *abstract = AbstractOperation::lookupMutable( + Optional info = RegisteredOperationName::lookup( ConcreteType::getOperationName(), &context); - if (!abstract) + if (!info) llvm::report_fatal_error( "Attempting to attach an interface to an unregistered operation " + ConcreteType::getOperationName() + "."); - abstract->interfaceMap.insert(); + info->attachInterface(); } private: @@ -1673,10 +1673,10 @@ return detail::InterfaceMap::template get...>(); } - /// Return the internal implementations of each of the AbstractOperation + /// Return the internal implementations of each of the OperationName /// hooks. - /// Implementation of `FoldHookFn` AbstractOperation hook. - static AbstractOperation::FoldHookFn getFoldHookFn() { + /// Implementation of `FoldHookFn` OperationName hook. + static OperationName::FoldHookFn getFoldHookFn() { return getFoldHookFnImpl(); } /// The internal implementation of `getFoldHookFn` above that is invoked if @@ -1685,7 +1685,7 @@ static std::enable_if_t, Traits...>::value && detect_has_single_result_fold::value, - AbstractOperation::FoldHookFn> + OperationName::FoldHookFn> getFoldHookFnImpl() { return [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { @@ -1698,7 +1698,7 @@ static std::enable_if_t, Traits...>::value && detect_has_fold::value, - AbstractOperation::FoldHookFn> + OperationName::FoldHookFn> getFoldHookFnImpl() { return [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { @@ -1710,7 +1710,7 @@ template static std::enable_if_t::value && !detect_has_fold::value, - AbstractOperation::FoldHookFn> + OperationName::FoldHookFn> getFoldHookFnImpl() { return [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { @@ -1754,29 +1754,29 @@ return result; } - /// Implementation of `GetCanonicalizationPatternsFn` AbstractOperation hook. - static AbstractOperation::GetCanonicalizationPatternsFn + /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook. + static OperationName::GetCanonicalizationPatternsFn getGetCanonicalizationPatternsFn() { return &ConcreteType::getCanonicalizationPatterns; } /// Implementation of `GetHasTraitFn` - static AbstractOperation::HasTraitFn getHasTraitFn() { + static OperationName::HasTraitFn getHasTraitFn() { return [](TypeID id) { return op_definition_impl::hasTrait(id); }; } - /// Implementation of `ParseAssemblyFn` AbstractOperation hook. - static AbstractOperation::ParseAssemblyFn getParseAssemblyFn() { + /// Implementation of `ParseAssemblyFn` OperationName hook. + static OperationName::ParseAssemblyFn getParseAssemblyFn() { return &ConcreteType::parse; } - /// Implementation of `PrintAssemblyFn` AbstractOperation hook. - static AbstractOperation::PrintAssemblyFn getPrintAssemblyFn() { + /// Implementation of `PrintAssemblyFn` OperationName hook. + static OperationName::PrintAssemblyFn getPrintAssemblyFn() { return getPrintAssemblyFnImpl(); } /// The internal implementation of `getPrintAssemblyFn` that is invoked when /// the concrete operation does not define a `print` method. template static std::enable_if_t::value, - AbstractOperation::PrintAssemblyFn> + OperationName::PrintAssemblyFn> getPrintAssemblyFnImpl() { return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) { return OpState::print(op, printer); @@ -1786,7 +1786,7 @@ /// the concrete operation defines a `print` method. template static std::enable_if_t::value, - AbstractOperation::PrintAssemblyFn> + OperationName::PrintAssemblyFn> getPrintAssemblyFnImpl() { return &printAssembly; } @@ -1795,8 +1795,8 @@ OpState::printOpName(op, p, defaultDialect); return cast(op).print(p); } - /// Implementation of `VerifyInvariantsFn` AbstractOperation hook. - static AbstractOperation::VerifyInvariantsFn getVerifyInvariantsFn() { + /// Implementation of `VerifyInvariantsFn` OperationName hook. + static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() { return &verifyInvariants; } @@ -1816,7 +1816,7 @@ } /// Allow access to internal implementation methods. - friend AbstractOperation; + friend RegisteredOperationName; }; /// This class represents the base of an operation interface. See the definition @@ -1836,22 +1836,22 @@ protected: /// Returns the impl interface instance for the given operation. static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { - // Access the raw interface from the abstract operation. - auto *abstractOp = op->getAbstractOperation(); - if (abstractOp) { - if (auto *opIface = abstractOp->getInterface()) + OperationName name = op->getName(); + + // Access the raw interface from the operation info. + if (Optional rInfo = name.getRegisteredInfo()) { + if (auto *opIface = rInfo->getInterface()) return opIface; // Fallback to the dialect to provide it with a chance to implement this // interface for this operation. - return abstractOp->dialect.getRegisteredInterfaceForOp( + return rInfo->getDialect().getRegisteredInterfaceForOp( op->getName()); } // Fallback to the dialect to provide it with a chance to implement this // interface for this operation. - Dialect *dialect = op->getName().getDialect(); - return dialect ? dialect->getRegisteredInterfaceForOp( - op->getName()) - : nullptr; + if (Dialect *dialect = name.getDialect()) + return dialect->getRegisteredInterfaceForOp(name); + return nullptr; } /// Allow access to `getInterfaceFor`. diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -57,14 +57,14 @@ OperationName getName() { return name; } /// If this operation has a registered operation description, return it. - /// Otherwise return null. - const AbstractOperation *getAbstractOperation() { - return getName().getAbstractOperation(); + /// Otherwise return None. + Optional getRegisteredInfo() { + return getName().getRegisteredInfo(); } /// Returns true if this operation has a registered operation description, /// otherwise false. - bool isRegistered() { return getAbstractOperation(); } + bool isRegistered() { return getName().isRegistered(); } /// Remove this operation from its parent block and delete it. void erase(); @@ -466,16 +466,14 @@ /// Returns true if the operation was registered with a particular trait, e.g. /// hasTrait(). template