diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -93,7 +93,7 @@ /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { - return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line, loc.col); } diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -93,7 +93,7 @@ /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { - return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line, loc.col); } diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -93,7 +93,7 @@ /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { - return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line, loc.col); } diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -93,7 +93,7 @@ /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { - return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line, loc.col); } diff --git a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp --- a/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch6/mlir/MLIRGen.cpp @@ -93,7 +93,7 @@ /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { - return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line, loc.col); } diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp --- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp @@ -113,7 +113,7 @@ /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { - return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line, loc.col); } diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -56,8 +56,6 @@ // Locations. Location getUnknownLoc(); - Location getFileLineColLoc(Identifier filename, unsigned line, - unsigned column); Location getFusedLoc(ArrayRef locs, Attribute metadata = Attribute()); diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -296,8 +296,7 @@ using Base::getChecked; /// Get or create a new OpaqueAttr with the provided dialect and string data. - static OpaqueAttr get(MLIRContext *context, Identifier dialect, - StringRef attrData, Type type); + static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type); /// Get or create a new OpaqueAttr with the provided dialect and string data. /// If the given identifier is not a valid namespace for a dialect, then a diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -293,6 +293,15 @@ "Identifier":$dialectNamespace, StringRefParameter<"">:$typeData ); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "Identifier":$dialectNamespace, CArg<"StringRef", "{}">:$typeData + ), [{ + return $_get(dialectNamespace.getContext(), dialectNamespace, typeData); + }]> + ]; + let skipDefaultBuilders = 1; let genVerifyDecl = 1; } diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -129,8 +129,7 @@ using Base::Base; /// Return a uniqued FileLineCol location object. - static Location get(Identifier filename, unsigned line, unsigned column, - MLIRContext *context); + static Location get(Identifier filename, unsigned line, unsigned column); static Location get(StringRef filename, unsigned line, unsigned column, MLIRContext *context); @@ -174,7 +173,7 @@ static Location get(Identifier name, Location child); /// Return a uniqued name location object with an unknown child. - static Location get(Identifier name, MLIRContext *context); + static Location get(Identifier name); /// Return the name identifier. Identifier getName() const; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -491,7 +491,7 @@ class OpaqueType : Type, summary, "::mlir::OpaqueType">, - BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), " + BuildableType<"::mlir::OpaqueType::get(" "$_builder.getIdentifier(\"" # dialect # "\"), \"" # name # "\")">; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -314,7 +314,15 @@ OperationName(StringRef name, MLIRContext *context); /// Return the name of the dialect this operation is registered to. - StringRef getDialect() const; + StringRef getDialectNamespace() const; + + /// Return the Dialect this operation is registered to if it is loaded in the + /// context, or nullptr if the dialect isn't loaded. + Dialect *getDialect() const { + if (const auto *abstractOp = getAbstractOperation()) + return &abstractOp->dialect; + return representation.get().getDialect(); + } /// Return the operation name with dialect name stripped, if it has one. StringRef stripDialect() const; diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -163,9 +163,9 @@ MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, intptr_t dataLength, const char *data, MlirType type) { - return wrap(OpaqueAttr::get( - unwrap(ctx), Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), - StringRef(data, dataLength), unwrap(type))); + return wrap( + OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), + StringRef(data, dataLength), unwrap(type))); } MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -29,11 +29,6 @@ Location Builder::getUnknownLoc() { return UnknownLoc::get(context); } -Location Builder::getFileLineColLoc(Identifier filename, unsigned line, - unsigned column) { - return FileLineColLoc::get(filename, line, column, context); -} - Location Builder::getFusedLoc(ArrayRef locs, Attribute metadata) { return FusedLoc::get(locs, metadata, context); } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -382,9 +382,8 @@ // OpaqueAttr //===----------------------------------------------------------------------===// -OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect, - StringRef attrData, Type type) { - return Base::get(context, dialect, attrData, type); +OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type) { + return Base::get(dialect.getContext(), dialect, attrData, type); } OpaqueAttr OpaqueAttr::getChecked(function_ref emitError, diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -127,8 +127,8 @@ Type Dialect::parseType(DialectAsmParser &parser) const { // If this dialect allows unknown types, then represent this with OpaqueType. if (allowsUnknownTypes()) { - auto ns = Identifier::get(getNamespace(), getContext()); - return OpaqueType::get(getContext(), ns, parser.getFullSymbolSpec()); + Identifier ns = Identifier::get(getNamespace(), getContext()); + return OpaqueType::get(ns, parser.getFullSymbolSpec()); } parser.emitError(parser.getNameLoc()) diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -48,14 +48,14 @@ //===----------------------------------------------------------------------===// Location FileLineColLoc::get(Identifier filename, unsigned line, - unsigned column, MLIRContext *context) { - return Base::get(context, filename, line, column); + unsigned column) { + return Base::get(filename.getContext(), filename, line, column); } Location FileLineColLoc::get(StringRef filename, unsigned line, unsigned column, MLIRContext *context) { return get(Identifier::get(filename.empty() ? "-" : filename, context), line, - column, context); + column); } StringRef FileLineColLoc::getFilename() const { return getImpl()->filename; } @@ -112,8 +112,8 @@ return Base::get(child->getContext(), name, child); } -Location NameLoc::get(Identifier name, MLIRContext *context) { - return get(name, UnknownLoc::get(context)); +Location NameLoc::get(Identifier name) { + return get(name, UnknownLoc::get(name.getContext())); } /// Return the name identifier. diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -520,9 +520,11 @@ // Refresh all the identifiers dialect field, this catches cases where a // dialect may be loaded after identifier prefixed with this dialect name // were already created. + llvm::SmallString<32> dialectPrefix(dialectNamespace); + dialectPrefix.push_back('.'); for (auto &identifierEntry : impl.identifiers) - if (!identifierEntry.second && - identifierEntry.first().startswith(dialectNamespace)) + if (identifierEntry.second.is() && + identifierEntry.first().startswith(dialectPrefix)) identifierEntry.second = dialect.get(); // Actually register the interfaces with delayed registration. diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -35,8 +35,10 @@ } /// Return the name of the dialect this operation is registered to. -StringRef OperationName::getDialect() const { - return getStringRef().split('.').first; +StringRef OperationName::getDialectNamespace() const { + if (Dialect *dialect = getDialect()) + return dialect->getNamespace(); + return representation.get().strref().split('.').first; } /// Return the operation name with dialect name stripped, if it has one. @@ -213,14 +215,7 @@ /// Return the dialect this operation is associated with, or nullptr if the /// associated dialect is not registered. -Dialect *Operation::getDialect() { - if (auto *abstractOp = getAbstractOperation()) - return &abstractOp->dialect; - - // If this operation hasn't been registered or doesn't have abstract - // operation, try looking up the dialect name in the context. - return getContext()->getLoadedDialect(getName().getDialect()); -} +Dialect *Operation::getDialect() { return getName().getDialect(); } Region *Operation::getParentRegion() { return block ? block->getParent() : nullptr; diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -46,13 +46,6 @@ /// Verify the given operation. LogicalResult verify(Operation &op); - /// Returns the registered dialect for a dialect-specific attribute. - Dialect *getDialectForAttribute(const NamedAttribute &attr) { - assert(attr.first.strref().contains('.') && "expected dialect attribute"); - auto dialectNamePair = attr.first.strref().split('.'); - return ctx->getLoadedDialect(dialectNamePair.first); - } - private: /// Verify the given potentially nested region or block. LogicalResult verifyRegion(Region ®ion); @@ -81,10 +74,6 @@ /// Dominance information for this operation, when checking dominance. DominanceInfo *domInfo = nullptr; - - /// Mapping between dialect namespace and if that dialect supports - /// unregistered operations. - llvm::StringMap dialectAllowsUnknownOps; }; } // end anonymous namespace @@ -170,15 +159,14 @@ /// Verify that all of the attributes are okay. for (auto attr : op.getAttrs()) { // Check for any optional dialect specific attributes. - if (!attr.first.strref().contains('.')) - continue; - if (auto *dialect = getDialectForAttribute(attr)) + if (auto *dialect = attr.first.getDialect()) if (failed(dialect->verifyOperationAttribute(&op, attr))) return failure(); } // If we can get operation info for this, check the custom hook. - auto *opInfo = op.getAbstractOperation(); + OperationName opName = op.getName(); + auto *opInfo = opName.getAbstractOperation(); if (opInfo && failed(opInfo->verifyInvariants(&op))) return failure(); @@ -213,33 +201,21 @@ return success(); // Otherwise, verify that the parent dialect allows un-registered operations. - auto dialectPrefix = op.getName().getDialect(); - - // Check for an existing answer for the operation dialect. - auto it = dialectAllowsUnknownOps.find(dialectPrefix); - if (it == dialectAllowsUnknownOps.end()) { - // If the operation dialect is registered, query it directly. - if (auto *dialect = ctx->getLoadedDialect(dialectPrefix)) - it = dialectAllowsUnknownOps - .try_emplace(dialectPrefix, dialect->allowsUnknownOperations()) - .first; - // Otherwise, unregistered dialects (when allowed by the context) - // conservatively allow unknown operations. - else { - if (!op.getContext()->allowsUnregisteredDialects() && !op.getDialect()) - return op.emitOpError() - << "created with unregistered dialect. If this is " - "intended, please call allowUnregisteredDialects() on the " - "MLIRContext, or use -allow-unregistered-dialect with " - "mlir-opt"; - - it = dialectAllowsUnknownOps.try_emplace(dialectPrefix, true).first; + Dialect *dialect = opName.getDialect(); + if (!dialect) { + if (!ctx->allowsUnregisteredDialects()) { + return op.emitOpError() + << "created with unregistered dialect. If this is " + "intended, please call allowUnregisteredDialects() on the " + "MLIRContext, or use -allow-unregistered-dialect with " + "mlir-opt"; } + return success(); } - if (!it->second) { + if (!dialect->allowsUnknownOperations()) { return op.emitError("unregistered operation '") - << op.getName() << "' found in dialect ('" << dialectPrefix + << op.getName() << "' found in dialect ('" << dialect->getNamespace() << "') that does not allow unknown operations"; } diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -563,7 +563,7 @@ // Otherwise, form a new opaque type. return OpaqueType::getChecked( - getEncodedSourceLocation(loc), state.context, + getEncodedSourceLocation(loc), Identifier::get(dialectName, state.context), symbolData); }); } diff --git a/mlir/lib/Parser/LocationParser.cpp b/mlir/lib/Parser/LocationParser.cpp --- a/mlir/lib/Parser/LocationParser.cpp +++ b/mlir/lib/Parser/LocationParser.cpp @@ -145,7 +145,7 @@ "expected ')' after child location of NameLoc")) return failure(); } else { - loc = NameLoc::get(Identifier::get(str, ctx), ctx); + loc = NameLoc::get(Identifier::get(str, ctx)); } return success(); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1944,8 +1944,8 @@ auto fileName = debugInfoMap.lookup(debugLine->fileID).str(); if (fileName.empty()) fileName = ""; - return opBuilder.getFileLineColLoc(opBuilder.getIdentifier(fileName), - debugLine->line, debugLine->col); + return FileLineColLoc::get(opBuilder.getIdentifier(fileName), debugLine->line, + debugLine->col); } LogicalResult diff --git a/mlir/lib/Transforms/LocationSnapshot.cpp b/mlir/lib/Transforms/LocationSnapshot.cpp --- a/mlir/lib/Transforms/LocationSnapshot.cpp +++ b/mlir/lib/Transforms/LocationSnapshot.cpp @@ -44,8 +44,7 @@ if (it == opToLineCol.end()) return; const std::pair &lineCol = it->second; - auto newLoc = - builder.getFileLineColLoc(file, lineCol.first, lineCol.second); + auto newLoc = FileLineColLoc::get(file, lineCol.first, lineCol.second); // If we don't have a tag, set the location directly if (!tagIdentifier) { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2702,10 +2702,10 @@ if (it != legalOperations.end()) return it->second; // Check for info for the parent dialect. - auto dialectIt = legalDialects.find(op.getDialect()); + auto dialectIt = legalDialects.find(op.getDialectNamespace()); if (dialectIt != legalDialects.end()) { Optional callback; - auto dialectFn = dialectLegalityFns.find(op.getDialect()); + auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace()); if (dialectFn != dialectLegalityFns.end()) callback = dialectFn->second; return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false, diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -862,8 +862,7 @@ } genContext.setLoc(NameLoc::get( - Identifier::get(opConfig.metadata->cppOpName, &mlirContext), - &mlirContext)); + Identifier::get(opConfig.metadata->cppOpName, &mlirContext))); if (failed(generateOp(opConfig, genContext))) { return 1; } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -842,8 +842,7 @@ if (tree.getNumArgs() == 1) { DagLeaf leaf = tree.getArgAsLeaf(0); if (leaf.isStringAttr()) - return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), " - "rewriter.getContext())", + return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"))", leaf.getStringAttr()) .str(); return lookUpArgLoc(0); diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -151,7 +151,7 @@ TEST(DenseSplatTest, StringSplat) { MLIRContext context; Type stringType = - OpaqueType::get(&context, Identifier::get("test", &context), "string"); + OpaqueType::get(Identifier::get("test", &context), "string"); StringRef value = "test-string"; testSplat(stringType, value); } @@ -159,7 +159,7 @@ TEST(DenseSplatTest, StringAttrSplat) { MLIRContext context; Type stringType = - OpaqueType::get(&context, Identifier::get("test", &context), "string"); + OpaqueType::get(Identifier::get("test", &context), "string"); Attribute stringAttr = StringAttr::get("test-string", stringType); testSplat(stringType, stringAttr); }