diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -302,6 +302,19 @@ Type type) { if (!Dialect::isValidNamespace(dialect.strref())) return emitError() << "invalid dialect namespace '" << dialect << "'"; + + // Check that the dialect is actually registered. + MLIRContext *context = dialect.getContext(); + if (!context->allowsUnregisteredDialects() && + !context->getLoadedDialect(dialect.strref())) { + return emitError() + << "#" << dialect << "<\"" << attrData << "\"> : " << type + << " attribute created with unregistered dialect. If this is " + "intended, please call allowUnregisteredDialects() on the " + "MLIRContext, or use -allow-unregistered-dialect with " + "mlir-opt"; + } + return success(); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -201,6 +201,19 @@ Identifier dialect, StringRef typeData) { if (!Dialect::isValidNamespace(dialect.strref())) return emitError() << "invalid dialect namespace '" << dialect << "'"; + + // Check that the dialect is actually registered. + MLIRContext *context = dialect.getContext(); + if (!context->allowsUnregisteredDialects() && + !context->getLoadedDialect(dialect.strref())) { + return emitError() + << "`!" << dialect << "<\"" << typeData << "\">" + << "` type created with unregistered dialect. If this is " + "intended, please call allowUnregisteredDialects() on the " + "MLIRContext, or use -allow-unregistered-dialect with " + "mlir-opt"; + } + return success(); } diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s //===----------------------------------------------------------------------===// // Test integer attributes diff --git a/mlir/test/IR/invalid-unregistered.mlir b/mlir/test/IR/invalid-unregistered.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/invalid-unregistered.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error @below {{op created with unregistered dialect}} +"unregistered_dialect.op"() : () -> () + +// ----- + +// expected-error @below {{attribute created with unregistered dialect}} +#attr = #unregistered_dialect.attribute + +// ----- + +// expected-error @below {{type created with unregistered dialect}} +!type = type !unregistered_dialect.type diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -150,6 +150,7 @@ TEST(DenseSplatTest, StringSplat) { MLIRContext context; + context.allowUnregisteredDialects(); Type stringType = OpaqueType::get(Identifier::get("test", &context), "string"); StringRef value = "test-string"; @@ -158,6 +159,7 @@ TEST(DenseSplatTest, StringAttrSplat) { MLIRContext context; + context.allowUnregisteredDialects(); Type stringType = OpaqueType::get(Identifier::get("test", &context), "string"); Attribute stringAttr = StringAttr::get("test-string", stringType);