diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -77,8 +77,7 @@ ```c++ Dialect *dialect = ...; -if (DialectInlinerInterface *interface - = dialect->getRegisteredInterface()) { +if (DialectInlinerInterface *interface = dyn_cast(dialect)) { // The dialect has provided an implementation of this interface. ... } diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -440,11 +440,58 @@ namespace llvm { /// Provide isa functionality for Dialects. -template struct isa_impl { +template +struct isa_impl::value>> { static inline bool doit(const ::mlir::Dialect &dialect) { return mlir::TypeID::get() == dialect.getTypeID(); } }; +template +struct isa_impl< + T, ::mlir::Dialect, + std::enable_if_t::value>> { + static inline bool doit(const ::mlir::Dialect &dialect) { + return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface(); + } +}; +template +struct cast_retty_impl { + using ret_type = + std::conditional_t::value, T *, + const T *>; +}; +template +struct cast_retty_impl { + using ret_type = + std::conditional_t::value, T &, + const T &>; +}; + +template +struct cast_convert_val { + template + static std::enable_if_t::value, To &> + doitImpl(::mlir::Dialect &dialect) { + return static_cast(dialect); + } + template + static std::enable_if_t::value, + const To &> + doitImpl(::mlir::Dialect &dialect) { + return *dialect.getRegisteredInterface(); + } + + static auto &doit(::mlir::Dialect &dialect) { return doitImpl(dialect); } +}; +template +struct cast_convert_val { + static auto doit(::mlir::Dialect *dialect) { + return &cast_convert_val::doit( + *dialect); + } +}; + } // namespace llvm #endif diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp --- a/mlir/lib/Dialect/DLTI/DLTI.cpp +++ b/mlir/lib/Dialect/DLTI/DLTI.cpp @@ -231,8 +231,8 @@ // dialect is not loaded for some reason, use the default combinator // that conservatively accepts identical entries only. entriesForID[id] = - dialect ? dialect->getRegisteredInterface() - ->combine(entriesForID[id], kvp.second) + dialect ? cast(dialect)->combine( + entriesForID[id], kvp.second) : DataLayoutDialectInterface::defaultCombine(entriesForID[id], kvp.second); if (!entriesForID[id]) diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1236,8 +1236,7 @@ Dialect *dialect = getContext()->getLoadedDialect(getDialect()); if (!dialect) return true; - auto *interface = - dialect->getRegisteredInterface(); + auto *interface = llvm::dyn_cast(dialect); if (!interface) return true; return failed(interface->decode(*this, result)); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -506,7 +506,7 @@ if (!dialect) return failure(); - auto *interface = dialect->getRegisteredInterface(); + auto *interface = dyn_cast(dialect); if (!interface) return failure(); diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -438,8 +438,7 @@ if (!dialect) continue; - const auto *iface = - dialect->getRegisteredInterface(); + const auto *iface = dyn_cast(dialect); if (!iface) { return emitError(loc) << "the '" << dialect->getNamespace() diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp --- a/mlir/unittests/IR/DialectTest.cpp +++ b/mlir/unittests/IR/DialectTest.cpp @@ -68,18 +68,17 @@ MLIRContext context(registry); // Load the TestDialect and check that the interface got registered for it. - auto *testDialect = context.getOrLoadDialect(); + Dialect *testDialect = context.getOrLoadDialect(); ASSERT_TRUE(testDialect != nullptr); - auto *testDialectInterface = - testDialect->getRegisteredInterface(); + auto *testDialectInterface = dyn_cast(testDialect); EXPECT_TRUE(testDialectInterface != nullptr); // Load the SecondTestDialect and check that the interface is not registered // for it. - auto *secondTestDialect = context.getOrLoadDialect(); + Dialect *secondTestDialect = context.getOrLoadDialect(); ASSERT_TRUE(secondTestDialect != nullptr); auto *secondTestDialectInterface = - secondTestDialect->getRegisteredInterface(); + dyn_cast(secondTestDialect); EXPECT_TRUE(secondTestDialectInterface == nullptr); // Use the same mechanism as for delayed registration but for an already @@ -90,7 +89,7 @@ .addDialectInterface(); context.appendDialectRegistry(secondRegistry); secondTestDialectInterface = - secondTestDialect->getRegisteredInterface(); + dyn_cast(secondTestDialect); EXPECT_TRUE(secondTestDialectInterface != nullptr); } @@ -102,10 +101,9 @@ MLIRContext context(registry); // Load the TestDialect and check that the interface got registered for it. - auto *testDialect = context.getOrLoadDialect(); + Dialect *testDialect = context.getOrLoadDialect(); ASSERT_TRUE(testDialect != nullptr); - auto *testDialectInterface = - testDialect->getRegisteredInterface(); + auto *testDialectInterface = dyn_cast(testDialect); EXPECT_TRUE(testDialectInterface != nullptr); // Try adding the same dialect interface again and check that we don't crash @@ -114,8 +112,7 @@ secondRegistry.insert(); secondRegistry.addDialectInterface(); context.appendDialectRegistry(secondRegistry); - testDialectInterface = - testDialect->getRegisteredInterface(); + testDialectInterface = dyn_cast(testDialect); EXPECT_TRUE(testDialectInterface != nullptr); }