diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -225,6 +225,38 @@ //===----------------------------------------------------------------------===// namespace { +/// This class represents a specific instance of a symbol Alias. +class SymbolAlias { +public: + SymbolAlias(StringRef name, bool isDeferrable) + : name(name), suffixIndex(0), hasSuffixIndex(false), + isDeferrable(isDeferrable) {} + SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable) + : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true), + isDeferrable(isDeferrable) {} + + /// Print this alias to the given stream. + void print(raw_ostream &os) const { + os << name; + if (hasSuffixIndex) + os << suffixIndex; + } + + /// Returns true if this alias supports deferred resolution when parsing. + bool canBeDeferred() const { return isDeferrable; } + +private: + /// The main name of the alias. + StringRef name; + /// The optional suffix index of the alias, if multiple aliases had the same + /// name. + uint32_t suffixIndex : 30; + /// A flag indicating whether this alias has a suffix or not. + bool hasSuffixIndex : 1; + /// A flag indicating whether this alias may be deferred or not. + bool isDeferrable : 1; +}; + /// This class represents a utility that initializes the set of attribute and /// type aliases, without the need to store the extra information within the /// main AliasState class or pass it around via function arguments. @@ -236,14 +268,14 @@ : interfaces(interfaces), aliasAllocator(aliasAllocator), aliasOS(aliasBuffer) {} - void initialize( - Operation *op, const OpPrintingFlags &printerFlags, - llvm::MapVector>> - &attrToAlias, - llvm::MapVector>> &typeToAlias); + void initialize(Operation *op, const OpPrintingFlags &printerFlags, + llvm::MapVector &attrToAlias, + llvm::MapVector &typeToAlias); - /// Visit the given attribute to see if it has an alias. - void visit(Attribute attr); + /// Visit the given attribute to see if it has an alias. `canBeDeferred` is + /// set to true if the originator of this attribute can resolve the alias + /// after parsing has completed (e.g. in the case of operation locations). + void visit(Attribute attr, bool canBeDeferred = false); /// Visit the given type to see if it has an alias. void visit(Type type); @@ -251,9 +283,11 @@ private: /// Try to generate an alias for the provided symbol. If an alias is /// generated, the provided alias mapping and reverse mapping are updated. + /// Returns success if an alias was generated, failure otherwise. template - void generateAlias(T symbol, - llvm::MapVector> &aliasToSymbol); + LogicalResult + generateAlias(T symbol, + llvm::MapVector> &aliasToSymbol); /// The set of asm interfaces within the context. DialectInterfaceCollection &interfaces; @@ -268,6 +302,9 @@ /// The set of visited attributes. DenseSet visitedAttributes; + /// The set of attributes that have aliases *and* can be deferred. + DenseSet deferrableAttributes; + /// The set of visited types. DenseSet visitedTypes; @@ -291,7 +328,7 @@ void print(Operation *op) { // Visit the operation location. if (printerFlags.shouldPrintDebugInfo()) - initializer.visit(op->getLoc()); + initializer.visit(op->getLoc(), /*canBeDeferred=*/true); // If requested, always print the generic form. if (!printerFlags.shouldPrintGenericOpForm()) { @@ -464,9 +501,10 @@ /// Given a collection of aliases and symbols, initialize a mapping from a /// symbol to a given alias. template -static void initializeAliases( - llvm::MapVector> &aliasToSymbol, - llvm::MapVector>> &symbolToAlias) { +static void +initializeAliases(llvm::MapVector> &aliasToSymbol, + llvm::MapVector &symbolToAlias, + DenseSet *deferrableAliases = nullptr) { std::vector>> aliases = aliasToSymbol.takeVector(); llvm::array_pod_sort(aliases.begin(), aliases.end(), @@ -477,20 +515,24 @@ for (auto &it : aliases) { // If there is only one instance for this alias, use the name directly. if (it.second.size() == 1) { - symbolToAlias.insert({it.second.front(), {it.first, llvm::None}}); + T symbol = it.second.front(); + bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); + symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)}); continue; } // Otherwise, add the index to the name. - for (int i = 0, e = it.second.size(); i < e; ++i) - symbolToAlias.insert({it.second[i], {it.first, i}}); + for (int i = 0, e = it.second.size(); i < e; ++i) { + T symbol = it.second[i]; + bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol); + symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)}); + } } } void AliasInitializer::initialize( Operation *op, const OpPrintingFlags &printerFlags, - llvm::MapVector>> - &attrToAlias, - llvm::MapVector>> &typeToAlias) { + llvm::MapVector &attrToAlias, + llvm::MapVector &typeToAlias) { // Use a dummy printer when walking the IR so that we can collect the // attributes/types that will actually be used during printing when // considering aliases. @@ -498,13 +540,25 @@ aliasPrinter.print(op); // Initialize the aliases sorted by name. - initializeAliases(aliasToAttr, attrToAlias); + initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes); initializeAliases(aliasToType, typeToAlias); } -void AliasInitializer::visit(Attribute attr) { - if (!visitedAttributes.insert(attr).second) +void AliasInitializer::visit(Attribute attr, bool canBeDeferred) { + if (!visitedAttributes.insert(attr).second) { + // If this attribute already has an alias and this instance can't be + // deferred, make sure that the alias isn't deferred. + if (!canBeDeferred) + deferrableAttributes.erase(attr); + return; + } + + // Try to generate an alias for this attribute. + if (succeeded(generateAlias(attr, aliasToAttr))) { + if (canBeDeferred) + deferrableAttributes.insert(attr); return; + } if (auto arrayAttr = attr.dyn_cast()) { for (Attribute element : arrayAttr.getValue()) @@ -515,15 +569,16 @@ } else if (auto typeAttr = attr.dyn_cast()) { visit(typeAttr.getValue()); } - - // Try to generate an alias for this attribute. - generateAlias(attr, aliasToAttr); } void AliasInitializer::visit(Type type) { if (!visitedTypes.insert(type).second) return; + // Try to generate an alias for this type. + if (succeeded(generateAlias(type, aliasToType))) + return; + // Visit several subtypes that contain types or atttributes. if (auto funcType = type.dyn_cast()) { // Visit input and result types for functions. @@ -539,13 +594,10 @@ for (auto map : memref.getAffineMaps()) visit(AffineMapAttr::get(map)); } - - // Try to generate an alias for this type. - generateAlias(type, aliasToType); } template -void AliasInitializer::generateAlias( +LogicalResult AliasInitializer::generateAlias( T symbol, llvm::MapVector> &aliasToSymbol) { SmallString<16> tempBuffer; for (const auto &interface : interfaces) { @@ -559,8 +611,9 @@ aliasToSymbol[name].push_back(symbol); aliasBuffer.clear(); - break; + return success(); } + return failure(); } //===----------------------------------------------------------------------===// @@ -580,21 +633,31 @@ /// Returns success if an alias was printed, failure otherwise. LogicalResult getAlias(Attribute attr, raw_ostream &os) const; - /// Print all of the referenced attribute aliases. - void printAttributeAliases(raw_ostream &os, NewLineCounter &newLine) const; - /// Get an alias for the given type if it has one and print it in `os`. /// Returns success if an alias was printed, failure otherwise. LogicalResult getAlias(Type ty, raw_ostream &os) const; - /// Print all of the referenced type aliases. - void printTypeAliases(raw_ostream &os, NewLineCounter &newLine) const; + /// Print all of the referenced aliases that can not be resolved in a deferred + /// manner. + void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { + printAliases(os, newLine, /*isDeferred=*/false); + } + + /// Print all of the referenced aliases that support deferred resolution. + void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const { + printAliases(os, newLine, /*isDeferred=*/true); + } private: - /// Mapping between attribute and a pair comprised of a base alias name and a - /// count suffix. If the suffix is set to None, it is not displayed. - llvm::MapVector>> attrToAlias; - llvm::MapVector>> typeToAlias; + /// Print all of the referenced aliases that support the provided resolution + /// behavior. + void printAliases(raw_ostream &os, NewLineCounter &newLine, + bool isDeferred) const; + + /// Mapping between attribute and alias. + llvm::MapVector attrToAlias; + /// Mapping between type and alias. + llvm::MapVector typeToAlias; /// An allocator used for alias names. llvm::BumpPtrAllocator aliasAllocator; @@ -608,44 +671,34 @@ initializer.initialize(op, printerFlags, attrToAlias, typeToAlias); } -static void printAlias(raw_ostream &os, - const std::pair> &alias, - char prefix) { - os << prefix << alias.first; - if (alias.second) - os << *alias.second; -} - LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { auto it = attrToAlias.find(attr); if (it == attrToAlias.end()) return failure(); - - printAlias(os, it->second, '#'); + it->second.print(os << '#'); return success(); } -void AliasState::printAttributeAliases(raw_ostream &os, - NewLineCounter &newLine) const { - for (const auto &it : attrToAlias) { - printAlias(os, it.second, '#'); - os << " = " << it.first << newLine; - } -} - LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { auto it = typeToAlias.find(ty); if (it == typeToAlias.end()) return failure(); - printAlias(os, it->second, '!'); + it->second.print(os << '!'); return success(); } -void AliasState::printTypeAliases(raw_ostream &os, - NewLineCounter &newLine) const { - for (const auto &it : typeToAlias) { - printAlias(os, it.second, '!'); +void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine, + bool isDeferred) const { + auto filterFn = [=](const auto &aliasIt) { + return aliasIt.second.canBeDeferred() == isDeferred; + }; + for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) { + it.second.print(os << '#'); + os << " = " << it.first << newLine; + } + for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) { + it.second.print(os << '!'); os << " = " << it.first << newLine; } } @@ -2237,12 +2290,15 @@ } // end anonymous namespace void OperationPrinter::print(ModuleOp op) { - // Output the aliases at the top level. - state->getAliasState().printAttributeAliases(os, newLine); - state->getAliasState().printTypeAliases(os, newLine); + // Output the aliases at the top level that can't be deferred. + state->getAliasState().printNonDeferredAliases(os, newLine); // Print the module. print(op.getOperation()); + os << newLine; + + // Output the aliases at the top level that can be deferred. + state->getAliasState().printDeferredAliases(os, newLine); } void OperationPrinter::print(Operation *op) { diff --git a/mlir/lib/Parser/LocationParser.cpp b/mlir/lib/Parser/LocationParser.cpp --- a/mlir/lib/Parser/LocationParser.cpp +++ b/mlir/lib/Parser/LocationParser.cpp @@ -177,35 +177,3 @@ return emitError("expected location instance"); } - -ParseResult Parser::parseOptionalTrailingLocation(Location &loc) { - // If there is a 'loc' we parse a trailing location. - if (!consumeIf(Token::kw_loc)) - return success(); - if (parseToken(Token::l_paren, "expected '(' in location")) - return failure(); - Token tok = getToken(); - - // Check to see if we are parsing a location alias. - LocationAttr directLoc; - if (tok.is(Token::hash_identifier)) { - // TODO: This should be reworked a bit to allow for resolving operation - // locations to aliases after the operation has already been parsed(i.e. - // allow post parse location fixups). - Attribute attr = parseExtendedAttr(Type()); - if (!attr) - return failure(); - if (!(directLoc = attr.dyn_cast())) - return emitError(tok.getLoc()) << "expected location, but found " << attr; - - // Otherwise, we parse the location directly. - } else if (parseLocationInstance(directLoc)) { - return failure(); - } - - if (parseToken(Token::r_paren, "expected ')' in location")) - return failure(); - - loc = directLoc; - return success(); -} diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -243,12 +243,6 @@ /// Parse a name or FileLineCol location instance. ParseResult parseNameOrFileLineColLocation(LocationAttr &loc); - /// Parse an optional trailing location. - /// - /// trailing-location ::= (`loc` (`(` location `)` | attribute-alias))? - /// - ParseResult parseOptionalTrailingLocation(Location &loc); - //===--------------------------------------------------------------------===// // Affine Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -183,9 +183,15 @@ Operation *parseGenericOperation(Block *insertBlock, Block::iterator insertPt); + /// Parse an optional trailing location for the given operation. + /// + /// trailing-location ::= (`loc` (`(` location `)` | attribute-alias))? + /// + ParseResult parseTrailingOperationLocation(Operation *op); + /// This is the structure of a result specifier in the assembly syntax, /// including the name, number of results, and location. - typedef std::tuple ResultRecord; + using ResultRecord = std::tuple; /// Parse an operation instance that is in the op-defined custom form. /// resultInfo specifies information about the "%name =" specifiers. @@ -297,6 +303,10 @@ /// their first reference, to allow checking for use of undefined values. DenseMap forwardRefPlaceholders; + /// A set of operations whose locations reference aliases that have yet to + /// be resolved. + SmallVector, 8> opsWithDeferredLocs; + /// The builder used when creating parsed operation instances. OpBuilder opBuilder; @@ -333,6 +343,22 @@ return failure(); } + // Resolve the locations of any deferred operations. + auto &attributeAliases = getState().symbols.attributeAliasDefinitions; + for (std::pair &it : opsWithDeferredLocs) { + llvm::SMLoc tokLoc = it.second.getLoc(); + StringRef identifier = it.second.getSpelling().drop_front(); + Attribute attr = attributeAliases.lookup(identifier); + if (!attr) + return emitError(tokLoc) << "operation location alias was never defined"; + + LocationAttr locAttr = attr.dyn_cast(); + if (!locAttr) + return emitError(tokLoc) + << "expected location, but found '" << attr << "'"; + it.first->setLoc(locAttr); + } + return success(); } @@ -817,11 +843,11 @@ return nullptr; } - // Parse a location if one is present. - if (parseOptionalTrailingLocation(result.location)) + // Create the operation and try to parse a location for it. + Operation *op = opBuilder.createOperation(result); + if (parseTrailingOperationLocation(op)) return nullptr; - - return opBuilder.createOperation(result); + return op; } Operation *OperationParser::parseGenericOperation(Block *insertBlock, @@ -1570,12 +1596,56 @@ if (opAsmParser.didEmitError()) return nullptr; - // Parse a location if one is present. - if (parseOptionalTrailingLocation(opState.location)) + // Otherwise, create the operation and try to parse a location for it. + Operation *op = opBuilder.createOperation(opState); + if (parseTrailingOperationLocation(op)) return nullptr; + return op; +} + +ParseResult OperationParser::parseTrailingOperationLocation(Operation *op) { + // If there is a 'loc' we parse a trailing location. + if (!consumeIf(Token::kw_loc)) + return success(); + if (parseToken(Token::l_paren, "expected '(' in location")) + return failure(); + Token tok = getToken(); + + // Check to see if we are parsing a location alias. + LocationAttr directLoc; + if (tok.is(Token::hash_identifier)) { + consumeToken(); + + StringRef identifier = tok.getSpelling().drop_front(); + if (identifier.contains('.')) { + return emitError(tok.getLoc()) + << "expected location, but found dialect attribute: '#" + << identifier << "'"; + } + + // If this alias can be resolved, do it now. + Attribute attr = + getState().symbols.attributeAliasDefinitions.lookup(identifier); + if (attr) { + if (!(directLoc = attr.dyn_cast())) + return emitError(tok.getLoc()) + << "expected location, but found '" << attr << "'"; + } else { + // Otherwise, remember this operation and resolve its location later. + opsWithDeferredLocs.emplace_back(op, tok); + } - // Otherwise, we succeeded. Use the state it parsed as our op information. - return opBuilder.createOperation(opState); + // Otherwise, we parse the location directly. + } else if (parseLocationInstance(directLoc)) { + return failure(); + } + + if (parseToken(Token::r_paren, "expected ')' in location")) + return failure(); + + if (directLoc) + op->setLoc(directLoc); + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/invalid-locations.mlir b/mlir/test/IR/invalid-locations.mlir --- a/mlir/test/IR/invalid-locations.mlir +++ b/mlir/test/IR/invalid-locations.mlir @@ -102,7 +102,23 @@ // ----- func @location_invalid_alias() { - // expected-error@+1 {{expected location, but found #foo.loc}} + // expected-error@+1 {{expected location, but found dialect attribute: '#foo.loc'}} return loc(#foo.loc) } +// ----- + +func @location_invalid_alias() { + // expected-error@+1 {{operation location alias was never defined}} + return loc(#invalid_alias) +} + +// ----- + +func @location_invalid_alias() { + // expected-error@+1 {{expected location, but found 'true'}} + return loc(#non_loc) +} + +#non_loc = true + diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -27,8 +27,8 @@ // CHECK-LABEL: func @loc_attr(i1 {foo.loc_attr = loc(callsite("foo" at "mysource.cc":10:8))}) func @loc_attr(i1 {foo.loc_attr = loc(callsite("foo" at "mysource.cc":10:8))}) -// CHECK-ALIAS: #[[LOC:.*]] = loc("out_of_line_location") -#loc = loc("out_of_line_location") - -// CHECK-ALIAS: "foo.op"() : () -> () loc(#[[LOC]]) +// CHECK-ALIAS: "foo.op"() : () -> () loc(#[[LOC:.*]]) "foo.op"() : () -> () loc(#loc) + +// CHECK-ALIAS: #[[LOC]] = loc("out_of_line_location") +#loc = loc("out_of_line_location")