diff --git a/mlir/include/mlir/Parser.h b/mlir/include/mlir/Parser.h --- a/mlir/include/mlir/Parser.h +++ b/mlir/include/mlir/Parser.h @@ -24,6 +24,8 @@ } // end namespace llvm namespace mlir { +class AsmParserState; + namespace detail { /// Given a block containing operations that have just been parsed, if the block @@ -77,10 +79,14 @@ /// returned. Otherwise, an error message is emitted through the error handler /// registered in the context, and failure is returned. If `sourceFileLoc` is /// non-null, it is populated with a file location representing the start of the -/// source file that is being parsed. +/// source file that is being parsed. If `asmState` is non-null, it is populated +/// with detailed information about the parsed IR (including exact locations for +/// SSA uses and definitions). `asmState` should only be provided if this +/// detailed information is desired. LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, MLIRContext *context, - LocationAttr *sourceFileLoc = nullptr); + LocationAttr *sourceFileLoc = nullptr, + AsmParserState *asmState = nullptr); /// This parses the file specified by the indicated filename and appends parsed /// operations to the given block. If the block is non-empty, the operations are @@ -99,11 +105,15 @@ /// parsing is successful, success is returned. Otherwise, an error message is /// emitted through the error handler registered in the context, and failure is /// returned. If `sourceFileLoc` is non-null, it is populated with a file -/// location representing the start of the source file that is being parsed. +/// location representing the start of the source file that is being parsed. If +/// `asmState` is non-null, it is populated with detailed information about the +/// parsed IR (including exact locations for SSA uses and definitions). +/// `asmState` should only be provided if this detailed information is desired. LogicalResult parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr, Block *block, MLIRContext *context, - LocationAttr *sourceFileLoc = nullptr); + LocationAttr *sourceFileLoc = nullptr, + AsmParserState *asmState = nullptr); /// This parses the IR string and appends parsed operations to the given block. /// If the block is non-empty, the operations are placed before the current diff --git a/mlir/include/mlir/Parser/AsmParserState.h b/mlir/include/mlir/Parser/AsmParserState.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Parser/AsmParserState.h @@ -0,0 +1,131 @@ +//===- AsmParserState.h -----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_PARSER_ASMPARSERSTATE_H +#define MLIR_PARSER_ASMPARSERSTATE_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/SMLoc.h" +#include + +namespace mlir { +class Block; +class BlockArgument; +class FileLineColLoc; +class Operation; +class Value; + +/// This class represents state from a parsed MLIR textual format string. It is +/// useful for building additional analysis and language utilities on top of +/// textual MLIR. This should generally not be used for traditional compilation. +class AsmParserState { +public: + /// This class represents a definition within the source manager, containing + /// it's defining location and locations of any uses. SMDefinitions are only + /// provided for entities that have uses within an input file, e.g. SSA + /// values, Blocks, and Symbols. + struct SMDefinition { + SMDefinition() = default; + SMDefinition(llvm::SMRange loc) : loc(loc) {} + + /// The source location of the definition. + llvm::SMRange loc; + /// The source location of all uses of the definition. + SmallVector uses; + }; + + /// This class represents the information for an operation definition within + /// an input file. + struct OperationDefinition { + struct ResultGroupDefinition { + /// The result number that starts this group. + unsigned startIndex; + /// The source definition of the result group. + SMDefinition definition; + }; + + OperationDefinition(Operation *op, llvm::SMRange loc) : op(op), loc(loc) {} + + /// The operation representing this definition. + Operation *op; + + /// The source location for the operation, i.e. the location of its name. + llvm::SMRange loc; + + /// Source definitions for any result groups of this operation. + SmallVector> resultGroups; + }; + + /// This class represents the information for a block definition within the + /// input file. + struct BlockDefinition { + BlockDefinition(Block *block, llvm::SMRange loc = {}) + : block(block), definition(loc) {} + + /// The block representing this definition. + Block *block; + + /// The source location for the block, i.e. the location of its name, and + /// any uses it has. + SMDefinition definition; + + /// Source definitions for any arguments of this block. + SmallVector arguments; + }; + + AsmParserState(); + ~AsmParserState(); + + //===--------------------------------------------------------------------===// + // Access State + //===--------------------------------------------------------------------===// + + using BlockDefIterator = llvm::pointee_iterator< + ArrayRef>::iterator>; + using OperationDefIterator = llvm::pointee_iterator< + ArrayRef>::iterator>; + + /// Return a range of the BlockDefinitions held by the current parser state. + iterator_range getBlockDefs() const; + + /// Return a range of the OperationDefinitions held by the current parser + /// state. + iterator_range getOpDefs() const; + + //===--------------------------------------------------------------------===// + // Populate State + //===--------------------------------------------------------------------===// + + /// Add a definition of the given operation. + void addDefinition( + Operation *op, llvm::SMRange location, + ArrayRef> resultGroups = llvm::None); + void addDefinition(Block *block, llvm::SMLoc location); + void addDefinition(BlockArgument blockArg, llvm::SMLoc location); + + /// Add a source uses of the given value. + void addUses(Value value, ArrayRef locations); + void addUses(Block *block, ArrayRef locations); + + /// Refine the `oldValue` to the `newValue`. This is used to indicate that + /// `oldValue` was a placeholder, and the uses of it should really refer to + /// `newValue`. + void refineDefinition(Value oldValue, Value newValue); + +private: + struct Impl; + + /// A pointer to the internal implementation of this class. + std::unique_ptr impl; +}; + +} // end namespace mlir + +#endif // MLIR_PARSER_ASMPARSERSTATE_H diff --git a/mlir/lib/Parser/AsmParserState.cpp b/mlir/lib/Parser/AsmParserState.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Parser/AsmParserState.cpp @@ -0,0 +1,168 @@ +//===- AsmParserState.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Parser/AsmParserState.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; + +/// Given a SMLoc corresponding to an identifier location, return a location +/// representing the full range of the identifier. +static llvm::SMRange convertIdLocToRange(llvm::SMLoc loc) { + if (!loc.isValid()) + return llvm::SMRange(); + + // Return if the given character is a valid identifier character. + auto isIdentifierChar = [](char c) { + return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-'; + }; + + const char *curPtr = loc.getPointer(); + while (isIdentifierChar(*(++curPtr))) + continue; + return llvm::SMRange(loc, llvm::SMLoc::getFromPointer(curPtr)); +} + +//===----------------------------------------------------------------------===// +// AsmParserState::Impl +//===----------------------------------------------------------------------===// + +struct AsmParserState::Impl { + /// A mapping from operations in the input source file to their parser state. + SmallVector> operations; + DenseMap operationToIdx; + + /// A mapping from blocks in the input source file to their parser state. + SmallVector> blocks; + DenseMap blocksToIdx; + + /// A set of value definitions that are placeholders for forward references. + /// This map should be empty if the parser finishes successfully. + DenseMap> placeholderValueUses; +}; + +//===----------------------------------------------------------------------===// +// AsmParserState +//===----------------------------------------------------------------------===// + +AsmParserState::AsmParserState() : impl(std::make_unique()) {} +AsmParserState::~AsmParserState() {} + +//===----------------------------------------------------------------------===// +// Access State + +auto AsmParserState::getBlockDefs() const -> iterator_range { + return llvm::make_pointee_range(llvm::makeArrayRef(impl->blocks)); +} + +auto AsmParserState::getOpDefs() const -> iterator_range { + return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations)); +} + +//===----------------------------------------------------------------------===// +// Populate State + +void AsmParserState::addDefinition( + Operation *op, llvm::SMRange location, + ArrayRef> resultGroups) { + std::unique_ptr def = + std::make_unique(op, location); + for (auto &resultGroup : resultGroups) + def->resultGroups.emplace_back(resultGroup.first, + convertIdLocToRange(resultGroup.second)); + + impl->operationToIdx.try_emplace(op, impl->operations.size()); + impl->operations.emplace_back(std::move(def)); +} + +void AsmParserState::addDefinition(Block *block, llvm::SMLoc location) { + auto it = impl->blocksToIdx.find(block); + if (it == impl->blocksToIdx.end()) { + impl->blocksToIdx.try_emplace(block, impl->blocks.size()); + impl->blocks.emplace_back(std::make_unique( + block, convertIdLocToRange(location))); + return; + } + + // If an entry already exists, this was a forward declaration that now has a + // proper definition. + impl->blocks[it->second]->definition.loc = convertIdLocToRange(location); +} + +void AsmParserState::addDefinition(BlockArgument blockArg, + llvm::SMLoc location) { + auto it = impl->blocksToIdx.find(blockArg.getOwner()); + assert(it != impl->blocksToIdx.end() && + "expected owner block to have an entry"); + BlockDefinition &def = *impl->blocks[it->second]; + unsigned argIdx = blockArg.getArgNumber(); + + if (def.arguments.size() <= argIdx) + def.arguments.resize(argIdx + 1); + def.arguments[argIdx] = SMDefinition(convertIdLocToRange(location)); +} + +void AsmParserState::addUses(Value value, ArrayRef locations) { + // Handle the case where the value is an operation result. + if (OpResult result = value.dyn_cast()) { + // Check to see if a definition for the parent operation has been recorded. + // If one hasn't, we treat the provided value as a placeholder value that + // will be refined further later. + Operation *parentOp = result.getOwner(); + auto existingIt = impl->operationToIdx.find(parentOp); + if (existingIt == impl->operationToIdx.end()) { + impl->placeholderValueUses[value].append(locations.begin(), + locations.end()); + return; + } + + // If a definition does exist, locate the value's result group and add the + // use. The result groups are ordered by increasing start index, so we just + // need to find the last group that has a smaller/equal start index. + unsigned resultNo = result.getResultNumber(); + OperationDefinition &def = *impl->operations[existingIt->second]; + for (auto &resultGroup : llvm::reverse(def.resultGroups)) { + if (resultNo >= resultGroup.first) { + for (llvm::SMLoc loc : locations) + resultGroup.second.uses.push_back(convertIdLocToRange(loc)); + return; + } + } + llvm_unreachable("expected valid result group for value use"); + } + + // Otherwise, this is a block argument. + BlockArgument arg = value.cast(); + auto existingIt = impl->blocksToIdx.find(arg.getOwner()); + assert(existingIt != impl->blocksToIdx.end() && + "expected valid block definition for block argument"); + BlockDefinition &blockDef = *impl->blocks[existingIt->second]; + SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()]; + for (llvm::SMLoc loc : locations) + argDef.uses.emplace_back(convertIdLocToRange(loc)); +} + +void AsmParserState::addUses(Block *block, ArrayRef locations) { + auto it = impl->blocksToIdx.find(block); + if (it == impl->blocksToIdx.end()) { + it = impl->blocksToIdx.try_emplace(block, impl->blocks.size()).first; + impl->blocks.emplace_back(std::make_unique(block)); + } + + BlockDefinition &def = *impl->blocks[it->second]; + for (llvm::SMLoc loc : locations) + def.definition.uses.push_back(convertIdLocToRange(loc)); +} + +void AsmParserState::refineDefinition(Value oldValue, Value newValue) { + auto it = impl->placeholderValueUses.find(oldValue); + assert(it != impl->placeholderValueUses.end() && + "expected `oldValue` to be a placeholder"); + addUses(newValue, it->second); + impl->placeholderValueUses.erase(oldValue); +} diff --git a/mlir/lib/Parser/CMakeLists.txt b/mlir/lib/Parser/CMakeLists.txt --- a/mlir/lib/Parser/CMakeLists.txt +++ b/mlir/lib/Parser/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(MLIRParser AffineParser.cpp + AsmParserState.cpp AttributeParser.cpp DialectSymbolParser.cpp Lexer.cpp diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -490,7 +490,7 @@ inputStr, /*BufferName=*/"", /*RequiresNullTerminator=*/false); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); - ParserState state(sourceMgr, context, symbolState); + ParserState state(sourceMgr, context, symbolState, /*asmState=*/nullptr); Parser parser(state); Token startTok = parser.getToken(); diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -268,7 +268,7 @@ function_ref parseElement, OpAsmParser::Delimiter delimiter); -private: +protected: /// The Parser is subclassed and reinstantiated. Do not add additional /// non-trivial state here, add it to the ParserState class. ParserState &state; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser.h" +#include "mlir/Parser/AsmParserState.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/bit.h" @@ -213,8 +214,8 @@ auto &values = isolatedNameScopes.back().values; if (!values.count(name) || number >= values[name].size()) return {}; - if (values[name][number].first) - return values[name][number].second; + if (values[name][number].value) + return values[name][number].loc; return {}; } @@ -278,8 +279,7 @@ ParseResult parseBlockBody(Block *block); /// Parse a (possibly empty) list of block arguments. - ParseResult parseOptionalBlockArgList(SmallVectorImpl &results, - Block *owner); + ParseResult parseOptionalBlockArgList(Block *owner); /// Get the block with the specified name, creating it if it doesn't /// already exist. The location specified is the point of use, which allows @@ -291,8 +291,23 @@ Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); private: + /// This class represents a definition of a Block. + struct BlockDefinition { + /// A pointer to the defined Block. + Block *block; + /// The location that the Block was defined at. + SMLoc loc; + }; + /// This class represents a definition of a Value. + struct ValueDefinition { + /// A pointer to the defined Value. + Value value; + /// The location that the Value was defined at. + SMLoc loc; + }; + /// Returns the info for a block at the current scope for the given name. - std::pair &getBlockInfoByName(StringRef name) { + BlockDefinition &getBlockInfoByName(StringRef name) { return blocksByName.back()[name]; } @@ -308,7 +323,7 @@ void recordDefinition(StringRef def); /// Get the value entry for the given SSA name. - SmallVectorImpl> &getSSAValueEntry(StringRef name); + SmallVectorImpl &getSSAValueEntry(StringRef name); /// Create a forward reference placeholder value with the given location and /// result type. @@ -340,7 +355,7 @@ /// This keeps track of all of the SSA values we are tracking for each name /// scope, indexed by their name. This has one entry per result number. - llvm::StringMap, 1>> values; + llvm::StringMap> values; /// This keeps track of all of the values defined by a specific name scope. SmallVector, 2> definitionsPerScope; @@ -352,7 +367,7 @@ /// This keeps track of the block names as well as the location of the first /// reference for each nested name scope. This is used to diagnose invalid /// block references and memorize them. - SmallVector>, 2> blocksByName; + SmallVector, 2> blocksByName; SmallVector, 2> forwardRef; /// These are all of the placeholders we've made along with the location of @@ -408,7 +423,7 @@ } // Resolve the locations of any deferred operations. - auto &attributeAliases = getState().symbols.attributeAliasDefinitions; + auto &attributeAliases = state.symbols.attributeAliasDefinitions; for (std::pair &it : opsWithDeferredLocs) { llvm::SMLoc tokLoc = it.second.getLoc(); StringRef identifier = it.second.getSpelling().drop_front(); @@ -432,7 +447,7 @@ //===----------------------------------------------------------------------===// void OperationParser::pushSSANameScope(bool isIsolated) { - blocksByName.push_back(DenseMap>()); + blocksByName.push_back(DenseMap()); forwardRef.push_back(DenseMap()); // Push back a new name definition scope. @@ -484,11 +499,11 @@ // If we already have an entry for this, check to see if it was a definition // or a forward reference. - if (auto existing = entries[useInfo.number].first) { + if (auto existing = entries[useInfo.number].value) { if (!isForwardRefPlaceholder(existing)) { return emitError(useInfo.loc) .append("redefinition of SSA value '", useInfo.name, "'") - .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) + .attachNote(getEncodedSourceLocation(entries[useInfo.number].loc)) .append("previously defined here"); } @@ -496,7 +511,7 @@ return emitError(useInfo.loc) .append("definition of SSA value '", useInfo.name, "#", useInfo.number, "' has type ", value.getType()) - .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) + .attachNote(getEncodedSourceLocation(entries[useInfo.number].loc)) .append("previously used here with type ", existing.getType()); } @@ -506,6 +521,11 @@ existing.replaceAllUsesWith(value); existing.getDefiningOp()->destroy(); forwardRefPlaceholders.erase(existing); + + // If a definition of the value already exists, replace it in the assembly + // state. + if (state.asmState) + state.asmState->refineDefinition(existing, value); } /// Record this definition for the current scope. @@ -560,18 +580,26 @@ Value OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) { auto &entries = getSSAValueEntry(useInfo.name); + // Functor used to record the use of the given value if the assembly state + // field is populated. + auto maybeRecordUse = [&](Value value) { + if (state.asmState) + state.asmState->addUses(value, useInfo.loc); + return value; + }; + // If we have already seen a value of this name, return it. - if (useInfo.number < entries.size() && entries[useInfo.number].first) { - auto result = entries[useInfo.number].first; + if (useInfo.number < entries.size() && entries[useInfo.number].value) { + Value result = entries[useInfo.number].value; // Check that the type matches the other uses. if (result.getType() == type) - return result; + return maybeRecordUse(result); emitError(useInfo.loc, "use of value '") .append(useInfo.name, "' expects different type than prior uses: ", type, " vs ", result.getType()) - .attachNote(getEncodedSourceLocation(entries[useInfo.number].second)) + .attachNote(getEncodedSourceLocation(entries[useInfo.number].loc)) .append("prior use here"); return nullptr; } @@ -582,16 +610,15 @@ // If the value has already been defined and this is an overly large result // number, diagnose that. - if (entries[0].first && !isForwardRefPlaceholder(entries[0].first)) + if (entries[0].value && !isForwardRefPlaceholder(entries[0].value)) return (emitError(useInfo.loc, "reference to invalid result number"), nullptr); // Otherwise, this is a forward reference. Create a placeholder and remember // that we did so. auto result = createForwardRefPlaceholder(useInfo.loc, type); - entries[useInfo.number].first = result; - entries[useInfo.number].second = useInfo.loc; - return result; + entries[useInfo.number] = {result, useInfo.loc}; + return maybeRecordUse(result); } /// Parse an SSA use with an associated type. @@ -653,8 +680,8 @@ } /// Get the value entry for the given SSA name. -SmallVectorImpl> & -OperationParser::getSSAValueEntry(StringRef name) { +auto OperationParser::getSSAValueEntry(StringRef name) + -> SmallVectorImpl & { return isolatedNameScopes.back().values[name]; } @@ -732,9 +759,10 @@ } Operation *op; - if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) + Token nameTok = getToken(); + if (nameTok.is(Token::bare_identifier) || nameTok.isKeyword()) op = parseCustomOperation(resultIDs); - else if (getToken().is(Token::string)) + else if (nameTok.is(Token::string)) op = parseGenericOperation(); else return emitError("expected operation name in quotes"); @@ -752,6 +780,18 @@ << op->getNumResults() << " results but was provided " << numExpectedResults << " to bind"; + // Add this operation to the assembly state if it was provided to populate. + if (state.asmState) { + unsigned resultIt = 0; + SmallVector> asmResultGroups; + asmResultGroups.reserve(resultIDs.size()); + for (ResultRecord &record : resultIDs) { + asmResultGroups.emplace_back(resultIt, std::get<2>(record)); + resultIt += std::get<1>(record); + } + state.asmState->addDefinition(op, nameTok.getLocRange(), asmResultGroups); + } + // Add definitions for each of the result groups. unsigned opResI = 0; for (ResultRecord &resIt : resultIDs) { @@ -761,6 +801,10 @@ return failure(); } } + + // Add this operation to the assembly state if it was provided to populate. + } else if (state.asmState) { + state.asmState->addDefinition(op, nameTok.getLocRange()); } return success(); @@ -1772,8 +1816,7 @@ } // If this alias can be resolved, do it now. - Attribute attr = - getState().symbols.attributeAliasDefinitions.lookup(identifier); + Attribute attr = state.symbols.attributeAliasDefinitions.lookup(identifier); if (attr) { if (!(directLoc = attr.dyn_cast())) return emitError(tok.getLoc()) @@ -1809,6 +1852,7 @@ ArrayRef> entryArguments, bool isIsolatedNameScope) { // Parse the '{'. + Token lBraceTok = getToken(); if (parseToken(Token::l_brace, "expected '{' to begin a region")) return failure(); @@ -1824,10 +1868,17 @@ auto owning_block = std::make_unique(); Block *block = owning_block.get(); + // If this block is not defined in the source file, add a definition for it + // now in the assembly state. Blocks with a name will be defined when the name + // is parsed. + if (state.asmState && getToken().isNot(Token::caret_identifier)) + state.asmState->addDefinition(block, lBraceTok.getLoc()); + // Add arguments to the entry block. if (!entryArguments.empty()) { for (auto &placeholderArgPair : entryArguments) { auto &argInfo = placeholderArgPair.first; + // Ensure that the argument was not already defined. if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) { return emitError(argInfo.loc, "region entry argument '" + argInfo.name + @@ -1835,10 +1886,15 @@ .attachNote(getEncodedSourceLocation(*defLoc)) << "previously referenced here"; } - if (addDefinition(placeholderArgPair.first, - block->addArgument(placeholderArgPair.second))) { + BlockArgument arg = block->addArgument(placeholderArgPair.second); + + // Add a definition of this arg to the assembly state if provided. + if (state.asmState) + state.asmState->addDefinition(arg, argInfo.loc); + + // Record the definition for this argument. + if (addDefinition(argInfo, arg)) return failure(); - } } // If we had named arguments, then don't allow a block name. @@ -1846,9 +1902,8 @@ return emitError("invalid block name in region with named arguments"); } - if (parseBlock(block)) { + if (parseBlock(block)) return failure(); - } // Verify that no other arguments were parsed. if (!entryArguments.empty() && @@ -1915,8 +1970,7 @@ // If an argument list is present, parse it. if (consumeIf(Token::l_paren)) { - SmallVector bbArgs; - if (parseOptionalBlockArgList(bbArgs, block) || + if (parseOptionalBlockArgList(block) || parseToken(Token::r_paren, "expected ')' to end argument list")) return failure(); } @@ -1943,13 +1997,17 @@ /// exist. The location specified is the point of use, which allows /// us to diagnose references to blocks that are not defined precisely. Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) { - auto &blockAndLoc = getBlockInfoByName(name); - if (!blockAndLoc.first) { - blockAndLoc = {new Block(), loc}; - insertForwardRef(blockAndLoc.first, loc); + BlockDefinition &blockDef = getBlockInfoByName(name); + if (!blockDef.block) { + blockDef = {new Block(), loc}; + insertForwardRef(blockDef.block, blockDef.loc); } - return blockAndLoc.first; + // Populate the high level assembly state if necessary. + if (state.asmState) + state.asmState->addUses(blockDef.block, loc); + + return blockDef.block; } /// Define the block with the specified name. Returns the Block* or nullptr in @@ -1957,29 +2015,32 @@ Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc, Block *existing) { auto &blockAndLoc = getBlockInfoByName(name); - if (!blockAndLoc.first) { - // If the caller provided a block, use it. Otherwise create a new one. - if (!existing) - existing = new Block(); - blockAndLoc.first = existing; - blockAndLoc.second = loc; - return blockAndLoc.first; - } - - // Forward declarations are removed once defined, so if we are defining a - // existing block and it is not a forward declaration, then it is a - // redeclaration. - if (!eraseForwardRef(blockAndLoc.first)) + blockAndLoc.loc = loc; + + // If a block has yet to be set, this is a new definition. If the caller + // provided a block, use it. Otherwise create a new one. + if (!blockAndLoc.block) { + blockAndLoc.block = existing ? existing : new Block(); + + // Otherwise, the block has a forward declaration. Forward declarations are + // removed once defined, so if we are defining a existing block and it is + // not a forward declaration, then it is a redeclaration. + } else if (!eraseForwardRef(blockAndLoc.block)) { return nullptr; - return blockAndLoc.first; + } + + // Populate the high level assembly state if necessary. + if (state.asmState) + state.asmState->addDefinition(blockAndLoc.block, loc); + + return blockAndLoc.block; } /// Parse a (possibly empty) list of SSA operands with types as block arguments. /// /// ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)* /// -ParseResult OperationParser::parseOptionalBlockArgList( - SmallVectorImpl &results, Block *owner) { +ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) { if (getToken().is(Token::r_brace)) return success(); @@ -1991,18 +2052,28 @@ return parseCommaSeparatedList([&]() -> ParseResult { return parseSSADefOrUseAndType( [&](SSAUseInfo useInfo, Type type) -> ParseResult { - // If this block did not have existing arguments, define a new one. - if (!definingExistingArgs) - return addDefinition(useInfo, owner->addArgument(type)); - - // Otherwise, ensure that this argument has already been created. - if (nextArgument >= owner->getNumArguments()) - return emitError("too many arguments specified in argument list"); - - // Finally, make sure the existing argument has the correct type. - auto arg = owner->getArgument(nextArgument++); - if (arg.getType() != type) - return emitError("argument and block argument type mismatch"); + BlockArgument arg; + + // If we are defining existing arguments, ensure that the argument + // has already been created with the right type. + if (definingExistingArgs) { + // Otherwise, ensure that this argument has already been created. + if (nextArgument >= owner->getNumArguments()) + return emitError("too many arguments specified in argument list"); + + // Finally, make sure the existing argument has the correct type. + arg = owner->getArgument(nextArgument++); + if (arg.getType() != type) + return emitError("argument and block argument type mismatch"); + } else { + arg = owner->addArgument(type); + } + + // Mark this block argument definition in the parser state if it was + // provided. + if (state.asmState) + state.asmState->addDefinition(arg, useInfo.loc); + return addDefinition(useInfo, arg); }); }); @@ -2040,7 +2111,7 @@ StringRef aliasName = getTokenSpelling().drop_front(); // Check for redefinitions. - if (getState().symbols.attributeAliasDefinitions.count(aliasName) > 0) + if (state.symbols.attributeAliasDefinitions.count(aliasName) > 0) return emitError("redefinition of attribute alias id '" + aliasName + "'"); // Make sure this isn't invading the dialect attribute namespace. @@ -2059,7 +2130,7 @@ if (!attr) return failure(); - getState().symbols.attributeAliasDefinitions[aliasName] = attr; + state.symbols.attributeAliasDefinitions[aliasName] = attr; return success(); } @@ -2072,7 +2143,7 @@ StringRef aliasName = getTokenSpelling().drop_front(); // Check for redefinitions. - if (getState().symbols.typeAliasDefinitions.count(aliasName) > 0) + if (state.symbols.typeAliasDefinitions.count(aliasName) > 0) return emitError("redefinition of type alias id '" + aliasName + "'"); // Make sure this isn't invading the dialect type namespace. @@ -2093,7 +2164,7 @@ return failure(); // Register this alias with the parser state. - getState().symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); + state.symbols.typeAliasDefinitions.try_emplace(aliasName, aliasedType); return success(); } @@ -2101,7 +2172,7 @@ Location parserLoc) { // Create a top-level operation to contain the parsed state. OwningOpRef topLevelOp(ModuleOp::create(parserLoc)); - OperationParser opParser(getState(), topLevelOp.get()); + OperationParser opParser(state, topLevelOp.get()); while (true) { switch (getToken().getKind()) { default: @@ -2153,7 +2224,8 @@ LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, MLIRContext *context, - LocationAttr *sourceFileLoc) { + LocationAttr *sourceFileLoc, + AsmParserState *asmState) { const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); Location parserLoc = FileLineColLoc::get( @@ -2162,7 +2234,7 @@ *sourceFileLoc = parserLoc; SymbolState aliasState; - ParserState state(sourceMgr, context, aliasState); + ParserState state(sourceMgr, context, aliasState, asmState); return TopLevelOperationParser(state).parse(block, parserLoc); } @@ -2176,7 +2248,8 @@ LogicalResult mlir::parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr, Block *block, MLIRContext *context, - LocationAttr *sourceFileLoc) { + LocationAttr *sourceFileLoc, + AsmParserState *asmState) { if (sourceMgr.getNumBuffers() != 0) { // TODO: Extend to support multiple buffers. return emitError(mlir::UnknownLoc::get(context), @@ -2189,7 +2262,7 @@ // Load the MLIR source file. sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); - return parseSourceFile(sourceMgr, block, context, sourceFileLoc); + return parseSourceFile(sourceMgr, block, context, sourceFileLoc, asmState); } LogicalResult mlir::parseSourceString(llvm::StringRef sourceStr, Block *block, diff --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h --- a/mlir/lib/Parser/ParserState.h +++ b/mlir/lib/Parser/ParserState.h @@ -48,9 +48,10 @@ /// such as the current lexer position etc. struct ParserState { ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, - SymbolState &symbols) + SymbolState &symbols, AsmParserState *asmState) : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), - symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) { + symbols(symbols), parserDepth(symbols.nestedParserLocs.size()), + asmState(asmState) { // Set the top level lexer for the symbol state if one doesn't exist. if (!symbols.topLevelLexer) symbols.topLevelLexer = &lex; @@ -77,6 +78,10 @@ /// The depth of this parser in the nested parsing stack. size_t parserDepth; + + /// An optional pointer to a struct containing high level parser state to be + /// populated during parsing. + AsmParserState *asmState; }; } // end namespace detail