diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -9,7 +9,6 @@ #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" @@ -23,21 +22,7 @@ /// Asserts that the operations provided as template arguments implement the /// TransformOpInterface and MemoryEffectsOpInterface. This must be a dynamic /// assertion since interface implementations may be registered at runtime. -template -static inline void checkImplementsTransformInterface(MLIRContext *context) { - // Since the operation is being inserted into the Transform dialect and the - // dialect does not implement the interface fallback, only check for the op - // itself having the interface implementation. - RegisteredOperationName opName = - *RegisteredOperationName::lookup(OpTy::getOperationName(), context); - assert((opName.hasInterface() || - opName.hasTrait()) && - "non-terminator ops injected into the transform dialect must " - "implement TransformOpInterface"); - assert(opName.hasInterface() && - "ops injected into the transform dialect must implement " - "MemoryEffectsOpInterface"); -} +void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context); /// Asserts that the type provided as template argument implements the /// TransformTypeInterface. This must be a dynamic assertion since interface @@ -200,6 +185,25 @@ bool buildOnly; }; +template +void TransformDialect::addOperationIfNotRegistered() { + StringRef name = OpTy::getOperationName(); + Optional opName = + RegisteredOperationName::lookup(name, getContext()); + if (!opName) { + addOperations(); +#ifndef NDEBUG + detail::checkImplementsTransformOpInterface(name, getContext()); +#endif // NDEBUG + return; + } + + if (opName->getTypeID() == TypeID::get()) + return; + + reportDuplicateOpRegistration(name); +} + template void TransformDialect::addTypeIfNotRegistered() { // Use the address of the parse method as a proxy for identifying whether we @@ -210,6 +214,8 @@ const ExtensionTypeParsingHook &parsingHook = it->getValue(); if (*parsingHook.target() != &Type::parse) reportDuplicateTypeRegistration(mnemonic); + else + return; } typePrintingHooks.try_emplace( TypeID::get(), +[](mlir::Type type, AsmPrinter &printer) { @@ -217,6 +223,11 @@ cast(type).print(printer); }); addTypes(); + +#ifndef NDEBUG + detail::checkImplementsTransformTypeInterface(TypeID::get(), + getContext()); +#endif // NDEBUG } /// A wrapper for transform dialect extensions that forces them to be diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -322,33 +322,17 @@ std::function; private: - template - void addOperationIfNotRegistered() { - Optional opName = - RegisteredOperationName::lookup(OpTy::getOperationName(), - getContext()); - if (!opName) - return addOperations(); - - if (opName->getTypeID() == TypeID::get()) - return; - - llvm::errs() << "error: extensible dialect operation '" - << OpTy::getOperationName() - << "' is already registered with a mismatching TypeID"; - abort(); - } - /// Registers operations specified as template parameters with this /// dialect. Checks that they implement the required interfaces. template void addOperationsChecked() { (addOperationIfNotRegistered(),...); - - #ifndef NDEBUG - (detail::checkImplementsTransformInterface(getContext()),...); - #endif // NDEBUG } + template + void addOperationIfNotRegistered(); + + /// Reports a repeated registration error of an op with the given name. + [[noreturn]] void reportDuplicateOpRegistration(StringRef opName); /// Registers the types specified as template parameters with the /// Transform dialect. Checks that they meet the requirements for @@ -356,15 +340,7 @@ template void addTypesChecked() { (addTypeIfNotRegistered(), ...); - - #ifndef NDEBUG - (detail::checkImplementsTransformTypeInterface( - TypeID::get(), getContext()), ...); - #endif // NDEBUG } - - /// Implementation of the type registration for a single type, should - /// not be called directly, use addTypesChecked instead. template void addTypeIfNotRegistered(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -18,6 +19,22 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc" #ifndef NDEBUG +void transform::detail::checkImplementsTransformOpInterface( + StringRef name, MLIRContext *context) { + // Since the operation is being inserted into the Transform dialect and the + // dialect does not implement the interface fallback, only check for the op + // itself having the interface implementation. + RegisteredOperationName opName = + *RegisteredOperationName::lookup(name, context); + assert((opName.hasInterface() || + opName.hasTrait()) && + "non-terminator ops injected into the transform dialect must " + "implement TransformOpInterface"); + assert(opName.hasInterface() && + "ops injected into the transform dialect must implement " + "MemoryEffectsOpInterface"); +} + void transform::detail::checkImplementsTransformTypeInterface( TypeID typeID, MLIRContext *context) { const auto &abstractType = AbstractType::lookup(typeID, context); @@ -76,10 +93,20 @@ StringRef mnemonic) { std::string buffer; llvm::raw_string_ostream msg(buffer); - msg << "error: extensible dialect type '" << mnemonic + msg << "extensible dialect type '" << mnemonic << "' is already registered with a different implementation"; msg.flush(); llvm::report_fatal_error(StringRef(buffer)); } +void transform::TransformDialect::reportDuplicateOpRegistration( + StringRef opName) { + std::string buffer; + llvm::raw_string_ostream msg(buffer); + msg << "extensible dialect operation '" << opName + << "' is already registered with a mismatching TypeID"; + msg.flush(); + llvm::report_fatal_error(StringRef(buffer)); +} + #include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc"