diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td @@ -77,7 +77,7 @@ /*defaultImplementation=*/[{ ConcreteOp op = cast(this->getOperation()); assert(memref == getMemRef()); - return {Identifier::get(op.getMapAttrName(), op.getContext()), + return {DialectIdentifier::get(op.getMapAttrName(), op.getContext()), op.getAffineMapAttr()}; }] >, @@ -156,7 +156,7 @@ /*defaultImplementation=*/[{ ConcreteOp op = cast(this->getOperation()); assert(memref == getMemRef()); - return {Identifier::get(op.getMapAttrName(), op.getContext()), + return {DialectIdentifier::get(op.getMapAttrName(), op.getContext()), op.getAffineMapAttr()}; }] >, diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -194,14 +194,14 @@ /// Returns the AffineMapAttr associated with 'memref'. NamedAttribute getAffineMapAttrForMemRef(Value memref) { if (memref == getSrcMemRef()) - return {Identifier::get(getSrcMapAttrName(), getContext()), + return {DialectIdentifier::get(getSrcMapAttrName(), getContext()), getSrcMapAttr()}; else if (memref == getDstMemRef()) - return {Identifier::get(getDstMapAttrName(), getContext()), + return {DialectIdentifier::get(getDstMapAttrName(), getContext()), getDstMapAttr()}; assert(memref == getTagMemRef() && "DmaStartOp expected source, destination or tag memref"); - return {Identifier::get(getTagMapAttrName(), getContext()), + return {DialectIdentifier::get(getTagMapAttrName(), getContext()), getTagMapAttr()}; } @@ -306,7 +306,7 @@ /// Returns the AffineMapAttr associated with 'memref'. NamedAttribute getAffineMapAttrForMemRef(Value memref) { assert(memref == getTagMemRef()); - return {Identifier::get(getTagMapAttrName(), getContext()), + return {DialectIdentifier::get(getTagMapAttrName(), getContext()), getTagMapAttr()}; } diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -755,7 +755,7 @@ /// Returns the AffineMapAttr associated with 'memref'. NamedAttribute getAffineMapAttrForMemRef(Value mref) { assert(mref == memref()); - return {Identifier::get(getMapAttrName(), getContext()), + return {DialectIdentifier::get(getMapAttrName(), getContext()), getAffineMapAttr()}; } @@ -795,7 +795,7 @@ /// Returns the AffineMapAttr associated with 'memref'. NamedAttribute getAffineMapAttrForMemRef(Value memref) { assert(memref == getMemRef()); - return {Identifier::get(getMapAttrName(), getContext()), + return {DialectIdentifier::get(getMapAttrName(), getContext()), getAffineMapAttr()}; } diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -13,7 +13,7 @@ #include "llvm/Support/PointerLikeTypeTraits.h" namespace mlir { -class Identifier; +class DialectIdentifier; /// Attributes are known-constant values of operations. /// @@ -127,7 +127,7 @@ /// NamedAttribute is combination of a name, represented by an Identifier, and a /// value, represented by an Attribute. The attribute pointer should always be /// non-null. -using NamedAttribute = std::pair; +using NamedAttribute = std::pair; bool operator<(const NamedAttribute &lhs, const NamedAttribute &rhs); bool operator<(const NamedAttribute &lhs, StringRef rhs); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -53,6 +53,7 @@ MLIRContext *getContext() const { return context; } Identifier getIdentifier(StringRef str); + DialectIdentifier getDialectIdentifier(StringRef str); // Locations. Location getUnknownLoc(); diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -17,6 +17,7 @@ namespace mlir { class AffineMap; class FunctionType; +class Identifier; class IntegerSet; class Location; class ShapedType; diff --git a/mlir/include/mlir/IR/Identifier.h b/mlir/include/mlir/IR/Identifier.h --- a/mlir/include/mlir/IR/Identifier.h +++ b/mlir/include/mlir/IR/Identifier.h @@ -11,10 +11,12 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/StringMapEntry.h" #include "llvm/Support/PointerLikeTypeTraits.h" namespace mlir { +class Dialect; class MLIRContext; /// This class represents a uniqued string owned by an MLIRContext. Strings @@ -70,13 +72,32 @@ /// Compare the underlying StringRef. int compare(Identifier rhs) const { return strref().compare(rhs.strref()); } -private: +protected: /// This contains the bytes of the string, which is guaranteed to be nul /// terminated. const EntryType *entry; explicit Identifier(const EntryType *entry) : entry(entry) {} }; +/// This class represents a uniqued string owned by an MLIRContext which may be +/// prefixed with a Dialect. It is particularly useful for dialect attribute +/// names and allows to lookup the dialect on creation and avoid constant string +/// manipulation later. +class DialectIdentifier : public Identifier { +public: + /// Return a dialect identifier for the specified string. + static DialectIdentifier get(StringRef str, MLIRContext *context); + DialectIdentifier(Identifier identifier, MLIRContext *context); + + /// Return the dialect registered/loaded in the context for this + /// identifier. + Dialect *getDialect(); + MLIRContext *getContext(); + +private: + PointerUnion dialectOrContext; +}; + inline raw_ostream &operator<<(raw_ostream &os, Identifier identifier) { identifier.print(os); return os; @@ -120,6 +141,28 @@ } }; +// Identifiers hash just like pointers, there is no need to hash the bytes. +template <> +struct DenseMapInfo { + static mlir::DialectIdentifier getEmptyKey() { + auto pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::DialectIdentifier( + mlir::Identifier::getFromOpaquePointer(pointer), nullptr); + } + static mlir::DialectIdentifier getTombstoneKey() { + auto pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::DialectIdentifier( + mlir::Identifier::getFromOpaquePointer(pointer), nullptr); + } + static unsigned getHashValue(mlir::DialectIdentifier val) { + return mlir::hash_value(val); + } + static bool isEqual(mlir::DialectIdentifier lhs, + mlir::DialectIdentifier rhs) { + return lhs == rhs; + } +}; + /// The pointer inside of an identifier comes from a StringMap, so its alignment /// is always at least 4 and probably 8 (on 64-bit machines). Allow LLVM to /// steal the low bits. diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -240,7 +240,8 @@ /// Add an attribute with the specified name. void append(Identifier name, Attribute attr) { - append(NamedAttribute(name, attr)); + assert(attr && "Can't append a null attribute to a NamedAttrList"); + append(NamedAttribute({name, attr.getContext()}, attr)); } /// Append the given named attribute. diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -69,7 +69,7 @@ attributes.reserve(numElements); for (intptr_t i = 0; i < numElements; ++i) attributes.emplace_back( - Identifier::get(unwrap(elements[i].name), unwrap(ctx)), + DialectIdentifier::get(unwrap(elements[i].name), unwrap(ctx)), unwrap(elements[i].attribute)); return wrap(DictionaryAttr::get(attributes, unwrap(ctx))); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -23,6 +23,10 @@ return Identifier::get(str, context); } +DialectIdentifier Builder::getDialectIdentifier(StringRef str) { + return DialectIdentifier::get(str, context); +} + //===----------------------------------------------------------------------===// // Locations. //===----------------------------------------------------------------------===// @@ -86,7 +90,8 @@ //===----------------------------------------------------------------------===// NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) { - return NamedAttribute(getIdentifier(name), val); + assert(val && "Can't get a NamedAttribute with a null value"); + return NamedAttribute({getIdentifier(name), val.getContext()}, val); } UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); } diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -151,7 +151,7 @@ /// from this function to dest. void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { // Add the attributes of this function to dest. - llvm::MapVector newAttrs; + llvm::MapVector newAttrs; for (auto &attr : dest.getAttrs()) newAttrs.insert(attr); for (auto &attr : getAttrs()) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -723,6 +723,32 @@ return Identifier(localEntry); } +/// Return a dialect identifier for the specified string. +DialectIdentifier DialectIdentifier::get(StringRef str, MLIRContext *context) { + return DialectIdentifier{Identifier::get(str, context), context}; +} +DialectIdentifier::DialectIdentifier(Identifier identifier, + MLIRContext *context) + : Identifier(identifier) { + auto dialectNamePair = identifier.strref().split('.'); + if (!dialectNamePair.first.empty()) { + if (Dialect *dialect = context->getLoadedDialect(dialectNamePair.first)) { + dialectOrContext = dialect; + return; + } + } + dialectOrContext = context; +} + +Dialect *DialectIdentifier::getDialect() { + return dialectOrContext.dyn_cast(); +} + +MLIRContext *DialectIdentifier::getContext() { + if (Dialect *dialect = getDialect()) + return dialect->getContext(); + return dialectOrContext.get(); +} //===----------------------------------------------------------------------===// // Type uniquing //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -136,7 +136,7 @@ // Otherwise, insert the new attribute into its sorted position. it = llvm::lower_bound(attrs, name); dictionarySorted.setPointer(nullptr); - attrs.insert(it, {name, value}); + attrs.insert(it, {DialectIdentifier{name, value.getContext()}, value}); return Attribute(); } Attribute NamedAttrList::set(StringRef name, Attribute value) { diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -271,14 +271,15 @@ // Try to parse the '=' for the attribute value. if (!consumeIf(Token::equal)) { // If there is no '=', we treat this as a unit attribute. - attributes.push_back({*nameId, builder.getUnitAttr()}); + attributes.push_back( + {{*nameId, builder.getContext()}, builder.getUnitAttr()}); return success(); } auto attr = parseAttribute(); if (!attr) return failure(); - attributes.push_back({*nameId, attr}); + attributes.push_back({{*nameId, builder.getContext()}, attr}); return success(); }; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -759,7 +759,8 @@ } auto namesAttr = parser.getBuilder().getStrArrayAttr(names); - result.attributes.push_back({Identifier::get("names", context), namesAttr}); + result.attributes.push_back( + {DialectIdentifier::get("names", context), namesAttr}); return success(); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -747,7 +747,7 @@ derivedAttrs, body, [&](const NamedAttribute &namedAttr) { auto tmpl = namedAttr.attr.getConvertFromStorageCall(); - body << " {::mlir::Identifier::get(\"" << namedAttr.name + body << " {::mlir::DialectIdentifier::get(\"" << namedAttr.name << "\", ctx),\n" << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()") .withBuilder("odsBuilder") diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1208,7 +1208,7 @@ const char *addAttrCmd = "if (auto tmpAttr = {1}) {\n" - " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), " + " tblgen_attrs.emplace_back(rewriter.getDialectIdentifier(\"{0}\"), " "tmpAttr);\n}\n"; for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { if (resultOp.getArg(argIndex).is()) { diff --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp --- a/mlir/tools/mlir-tblgen/StructsGen.cpp +++ b/mlir/tools/mlir-tblgen/StructsGen.cpp @@ -131,13 +131,13 @@ const char *getFieldInfo = R"( assert({0}); - auto {0}_id = ::mlir::Identifier::get("{0}", context); + auto {0}_id = ::mlir::DialectIdentifier::get("{0}", context); fields.emplace_back({0}_id, {0}); )"; const char *getFieldInfoOptional = R"( if ({0}) { - auto {0}_id = ::mlir::Identifier::get("{0}", context); + auto {0}_id = ::mlir::DialectIdentifier::get("{0}", context); fields.emplace_back({0}_id, {0}); } )"; diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp --- a/mlir/unittests/TableGen/StructsGenTest.cpp +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -62,7 +62,7 @@ expectedValues.end()); // Add an extra NamedAttribute. - auto wrongId = mlir::Identifier::get("wrong", &context); + auto wrongId = mlir::DialectIdentifier::get("wrong", &context); auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second); newValues.push_back(wrongAttr); @@ -84,7 +84,7 @@ expectedValues.begin() + 1, expectedValues.end()); // Add a copy of the first attribute with the wrong Identifier. - auto wrongId = mlir::Identifier::get("wrong", &context); + auto wrongId = mlir::DialectIdentifier::get("wrong", &context); auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second); newValues.push_back(wrongAttr); @@ -109,7 +109,7 @@ auto elementsType = mlir::RankedTensorType::get({3}, i64Type); auto elementsAttr = mlir::DenseIntElementsAttr::get(elementsType, ArrayRef{1, 2, 3}); - mlir::Identifier id = expectedValues.back().first; + mlir::DialectIdentifier id = expectedValues.back().first; auto wrongAttr = mlir::NamedAttribute(id, elementsAttr); newValues.push_back(wrongAttr);