diff --git a/mlir/include/mlir/AsmParser/AsmParserState.h b/mlir/include/mlir/AsmParser/AsmParserState.h --- a/mlir/include/mlir/AsmParser/AsmParserState.h +++ b/mlir/include/mlir/AsmParser/AsmParserState.h @@ -12,6 +12,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/SMLoc.h" #include @@ -94,6 +95,32 @@ SmallVector arguments; }; + /// This class represents the information for an attribute alias definition + /// within the input file. + struct AttributeAliasDefinition { + AttributeAliasDefinition(StringRef name, SMRange loc = {}) + : name(name), definition(loc) {} + + /// The name of the attribute alias. + StringRef name; + + /// The source location for the alias. + SMDefinition definition; + }; + + /// This class represents the information for type definition within the input + /// file. + struct TypeAliasDefinition { + TypeAliasDefinition(StringRef name, SMRange loc) + : name(name), definition(loc) {} + + /// The name of the attribute alias. + StringRef name; + + /// The source location for the alias. + SMDefinition definition; + }; + AsmParserState(); ~AsmParserState(); AsmParserState &operator=(AsmParserState &&other); @@ -154,10 +181,14 @@ /// Add a definition of the given entity. void addDefinition(Block *block, SMLoc location); void addDefinition(BlockArgument blockArg, SMLoc location); + void addAttrAliasDefinition(StringRef name, SMRange location); + void addTypeAliasDefinition(StringRef name, SMRange location); /// Add a source uses of the given value. void addUses(Value value, ArrayRef locations); void addUses(Block *block, ArrayRef locations); + void addAttrAliasUses(StringRef name, SMRange locations); + void addTypeAliasUses(StringRef name, SMRange locations); /// Add source uses for all the references nested under `refAttr`. The /// provided `locations` should match 1-1 with the number of references in diff --git a/mlir/lib/AsmParser/AsmParserState.cpp b/mlir/lib/AsmParser/AsmParserState.cpp --- a/mlir/lib/AsmParser/AsmParserState.cpp +++ b/mlir/lib/AsmParser/AsmParserState.cpp @@ -47,6 +47,12 @@ SmallVector> blocks; DenseMap blocksToIdx; + /// A mapping from aliases in the input source file to their parser state. + SmallVector> attrAliases; + SmallVector> typeAliases; + llvm::StringMap attrAliasToIdx; + llvm::StringMap typeAliasToIdx; + /// A set of value definitions that are placeholders for forward references. /// This map should be empty if the parser finishes successfully. DenseMap> placeholderValueUses; @@ -271,6 +277,26 @@ def.arguments[argIdx] = SMDefinition(convertIdLocToRange(location)); } +void AsmParserState::addAttrAliasDefinition(StringRef name, SMRange location) { + auto [it, inserted] = + impl->attrAliasToIdx.try_emplace(name, impl->attrAliases.size()); + // Location aliases may be referenced before they are defined. + if (inserted) { + impl->attrAliases.push_back( + std::make_unique(name, location)); + } else { + impl->attrAliases[it->second]->definition.loc = location; + } +} + +void AsmParserState::addTypeAliasDefinition(StringRef name, SMRange location) { + auto [it, inserted] = + impl->typeAliasToIdx.try_emplace(name, impl->typeAliases.size()); + assert(inserted && "unexpected attribute alias redefinition"); + impl->typeAliases.push_back( + std::make_unique(name, location)); +} + void AsmParserState::addUses(Value value, ArrayRef locations) { // Handle the case where the value is an operation result. if (OpResult result = dyn_cast(value)) { @@ -335,6 +361,27 @@ locations.end()); } +void AsmParserState::addAttrAliasUses(StringRef name, SMRange location) { + auto it = impl->attrAliasToIdx.find(name); + // Location aliases may be referenced before they are defined. + if (it == impl->attrAliasToIdx.end()) { + it = impl->attrAliasToIdx.try_emplace(name, impl->attrAliases.size()).first; + impl->attrAliases.push_back( + std::make_unique(name)); + } + AttributeAliasDefinition &def = *impl->attrAliases[it->second]; + def.definition.uses.push_back(location); +} + +void AsmParserState::addTypeAliasUses(StringRef name, SMRange location) { + auto it = impl->typeAliasToIdx.find(name); + // Location aliases may be referenced before they are defined. + assert(it != impl->typeAliasToIdx.end() && + "expected valid type alias definition"); + TypeAliasDefinition &def = *impl->typeAliases[it->second]; + def.definition.uses.push_back(location); +} + void AsmParserState::refineDefinition(Value oldValue, Value newValue) { auto it = impl->placeholderValueUses.find(oldValue); assert(it != impl->placeholderValueUses.end() && diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -157,7 +157,8 @@ /// Parse an extended dialect symbol. template -static Symbol parseExtendedSymbol(Parser &p, SymbolAliasMap &aliases, +static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, + SymbolAliasMap &aliases, CreateFn &&createSymbol) { Token tok = p.getToken(); @@ -167,6 +168,7 @@ return p.codeCompleteDialectSymbol(aliases); // Parse the dialect namespace. + SMRange range = p.getToken().getLocRange(); SMLoc loc = p.getToken().getLoc(); p.consumeToken(); @@ -189,6 +191,12 @@ return (p.emitWrongTokenError("undefined symbol alias id '" + identifier + "'"), nullptr); + if (asmState) { + if constexpr (std::is_same_v) + asmState->addTypeAliasUses(identifier, range); + else + asmState->addAttrAliasUses(identifier, range); + } return aliasIt->second; } @@ -232,7 +240,7 @@ Attribute Parser::parseExtendedAttr(Type type) { MLIRContext *ctx = getContext(); Attribute attr = parseExtendedSymbol( - *this, state.symbols.attributeAliasDefinitions, + *this, state.asmState, state.symbols.attributeAliasDefinitions, [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute { // Parse an optional trailing colon type. Type attrType = type; @@ -279,7 +287,7 @@ Type Parser::parseExtendedType() { MLIRContext *ctx = getContext(); return parseExtendedSymbol( - *this, state.symbols.typeAliasDefinitions, + *this, state.asmState, state.symbols.typeAliasDefinitions, [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type { // If we found a registered dialect, then ask it to parse the type. if (auto *dialect = ctx->getOrLoadDialect(dialectName)) { diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -2020,6 +2020,8 @@ << "expected location, but found dialect attribute: '#" << identifier << "'"; } + if (state.asmState) + state.asmState->addAttrAliasUses(identifier, tok.getLocRange()); // If this alias can be resolved, do it now. Attribute attr = state.symbols.attributeAliasDefinitions.lookup(identifier); @@ -2527,6 +2529,7 @@ return emitError("attribute names with a '.' are reserved for " "dialect-defined names"); + SMRange location = getToken().getLocRange(); consumeToken(Token::hash_identifier); // Parse the '='. @@ -2538,6 +2541,9 @@ if (!attr) return failure(); + // Register this alias with the parser state. + if (state.asmState) + state.asmState->addAttrAliasDefinition(aliasName, location); state.symbols.attributeAliasDefinitions[aliasName] = attr; return success(); } @@ -2554,6 +2560,8 @@ if (aliasName.contains('.')) return emitError("type names with a '.' are reserved for " "dialect-defined names"); + + SMRange location = getToken().getLocRange(); consumeToken(Token::exclamation_identifier); // Parse the '='. @@ -2566,6 +2574,8 @@ return failure(); // Register this alias with the parser state. + if (state.asmState) + state.asmState->addTypeAliasDefinition(aliasName, location); state.symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); return success(); }