diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -448,16 +448,21 @@ /// This class represents a specific instance of a symbol Alias. class SymbolAlias { public: - SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable) - : name(name), suffixIndex(suffixIndex), isDeferrable(isDeferrable) {} + SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType, + bool isDeferrable) + : name(name), suffixIndex(suffixIndex), isType(isType), + isDeferrable(isDeferrable) {} /// Print this alias to the given stream. void print(raw_ostream &os) const { - os << name; + os << (isType ? "!" : "#") << name; if (suffixIndex) os << suffixIndex; } + /// Returns true if this is a type alias. + bool isTypeAlias() const { return isType; } + /// Returns true if this alias supports deferred resolution when parsing. bool canBeDeferred() const { return isDeferrable; } @@ -465,7 +470,9 @@ /// The main name of the alias. StringRef name; /// The suffix index of the alias. - uint32_t suffixIndex : 31; + uint32_t suffixIndex : 30; + /// A flag indicating whether this alias is for a type. + bool isType : 1; /// A flag indicating whether this alias may be deferred or not. bool isDeferrable : 1; }; @@ -482,31 +489,34 @@ aliasOS(aliasBuffer) {} void initialize(Operation *op, const OpPrintingFlags &printerFlags, - llvm::MapVector &attrToAlias, - llvm::MapVector &typeToAlias); + llvm::MapVector &attrTypeToAlias); /// Visit the given attribute to see if it has an alias. `canBeDeferred` is /// set to true if the originator of this attribute can resolve the alias /// after parsing has completed (e.g. in the case of operation locations). /// Returns the maximum alias depth of the attribute. size_t visit(Attribute attr, bool canBeDeferred = false) { - return visitImpl(attr, attrAliases, canBeDeferred); + return visitImpl(attr, aliases, canBeDeferred); } /// Visit the given type to see if it has an alias. Returns the maximum alias /// depth of the type. - size_t visit(Type type) { return visitImpl(type, typeAliases); } + size_t visit(Type type) { return visitImpl(type, aliases); } private: struct InProgressAliasInfo { - InProgressAliasInfo() : aliasDepth(0), canBeDeferred(false) {} - InProgressAliasInfo(StringRef alias, bool canBeDeferred) - : alias(alias), aliasDepth(0), canBeDeferred(canBeDeferred) {} + InProgressAliasInfo() + : aliasDepth(0), isType(false), canBeDeferred(false) {} + InProgressAliasInfo(StringRef alias, bool isType, bool canBeDeferred) + : alias(alias), aliasDepth(1), isType(isType), + canBeDeferred(canBeDeferred) {} bool operator<(const InProgressAliasInfo &rhs) const { - // Order first by depth, and then by name. + // Order first by depth, then by attr/type kind, and then by name. if (aliasDepth != rhs.aliasDepth) return aliasDepth < rhs.aliasDepth; + if (isType != rhs.isType) + return isType; return alias < rhs.alias; } @@ -514,7 +524,9 @@ Optional alias; /// The alias depth of this attribute or type, i.e. an indication of the /// relative ordering of when to print this alias. - unsigned aliasDepth : 31; + unsigned aliasDepth : 30; + /// If this alias represents a type or an attribute. + bool isType : 1; /// If this alias can be deferred or not. bool canBeDeferred : 1; }; @@ -524,22 +536,20 @@ /// the alias after parsing has completed (e.g. in the case of operation /// locations). Returns the maximum alias depth of the value. template - size_t visitImpl(T value, llvm::MapVector &aliases, + size_t visitImpl(T value, + llvm::MapVector &aliases, bool canBeDeferred = false); /// Try to generate an alias for the provided symbol. If an alias is /// generated, the provided alias mapping and reverse mapping are updated. - /// Returns success if an alias was generated, failure otherwise. template - LogicalResult generateAlias(T symbol, InProgressAliasInfo &alias, - bool canBeDeferred); + void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred); /// Given a collection of aliases and symbols, initialize a mapping from a /// symbol to a given alias. - template - static void - initializeAliases(llvm::MapVector &visitedSymbols, - llvm::MapVector &symbolToAlias); + static void initializeAliases( + llvm::MapVector &visitedSymbols, + llvm::MapVector &symbolToAlias); /// The set of asm interfaces within the context. DialectInterfaceCollection &interfaces; @@ -548,8 +558,7 @@ llvm::BumpPtrAllocator &aliasAllocator; /// The set of built aliases. - llvm::MapVector attrAliases; - llvm::MapVector typeAliases; + llvm::MapVector aliases; /// Storage and stream used when generating an alias. SmallString<32> aliasBuffer; @@ -792,11 +801,10 @@ /// Given a collection of aliases and symbols, initialize a mapping from a /// symbol to a given alias. -template void AliasInitializer::initializeAliases( - llvm::MapVector &visitedSymbols, - llvm::MapVector &symbolToAlias) { - std::vector> unprocessedAliases = + llvm::MapVector &visitedSymbols, + llvm::MapVector &symbolToAlias) { + std::vector> unprocessedAliases = visitedSymbols.takeVector(); llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) { return lhs.second < rhs.second; @@ -809,31 +817,30 @@ StringRef alias = *aliasInfo.alias; unsigned nameIndex = nameCounts[alias]++; symbolToAlias.insert( - {symbol, SymbolAlias(alias, nameIndex, aliasInfo.canBeDeferred)}); + {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType, + aliasInfo.canBeDeferred)}); } } void AliasInitializer::initialize( Operation *op, const OpPrintingFlags &printerFlags, - llvm::MapVector &attrToAlias, - llvm::MapVector &typeToAlias) { + llvm::MapVector &attrTypeToAlias) { // Use a dummy printer when walking the IR so that we can collect the // attributes/types that will actually be used during printing when // considering aliases. DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); aliasPrinter.printCustomOrGenericOp(op); - // Initialize the aliases sorted by name. - initializeAliases(attrAliases, attrToAlias); - initializeAliases(typeAliases, typeToAlias); + // Initialize the aliases. + initializeAliases(aliases, attrTypeToAlias); } template -size_t -AliasInitializer::visitImpl(T value, - llvm::MapVector &aliases, - bool canBeDeferred) { - auto [it, inserted] = aliases.insert({value, InProgressAliasInfo()}); +size_t AliasInitializer::visitImpl( + T value, llvm::MapVector &aliases, + bool canBeDeferred) { + auto [it, inserted] = + aliases.insert({value.getAsOpaquePointer(), InProgressAliasInfo()}); if (!inserted) { // Make sure that the alias isn't deferred if we don't permit it. if (!canBeDeferred) @@ -842,7 +849,7 @@ } // Try to generate an alias for this attribute. - bool hasAlias = succeeded(generateAlias(value, it->second, canBeDeferred)); + generateAlias(value, it->second, canBeDeferred); size_t aliasIndex = std::distance(aliases.begin(), it); // Check for any sub elements. @@ -852,17 +859,19 @@ if (auto subElementInterface = dyn_cast(value)) { size_t maxAliasDepth = 0; auto visitSubElement = [&](auto element) { - if (Optional depth = visit(element)) - maxAliasDepth = std::max(maxAliasDepth, *depth + 1); + if (!element) + return; + if (size_t depth = visit(element)) + maxAliasDepth = std::max(maxAliasDepth, depth + 1); }; - subElementInterface.walkSubElements(visitSubElement, visitSubElement); + subElementInterface.walkImmediateSubElements(visitSubElement, + visitSubElement); // Make sure to recompute `it` in case the map was reallocated. it = std::next(aliases.begin(), aliasIndex); - // If we had sub elements and an alias, update our main alias to account for - // the depth. - if (maxAliasDepth && hasAlias) + // If we had sub elements, update to account for the depth. + if (maxAliasDepth) it->second.aliasDepth = maxAliasDepth; } @@ -871,9 +880,8 @@ } template -LogicalResult AliasInitializer::generateAlias(T symbol, - InProgressAliasInfo &alias, - bool canBeDeferred) { +void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias, + bool canBeDeferred) { SmallString<32> nameBuffer; for (const auto &interface : interfaces) { OpAsmDialectInterface::AliasResult result = @@ -887,15 +895,15 @@ } if (nameBuffer.empty()) - return failure(); + return; SmallString<16> tempBuffer; StringRef name = sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-", /*allowTrailingDigit=*/false); name = name.copy(aliasAllocator); - alias = InProgressAliasInfo(name, canBeDeferred); - return success(); + alias = InProgressAliasInfo(name, /*isType=*/std::is_base_of_v, + canBeDeferred); } //===----------------------------------------------------------------------===// @@ -936,10 +944,8 @@ void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, bool isDeferred); - /// Mapping between attribute and alias. - llvm::MapVector attrToAlias; - /// Mapping between type and alias. - llvm::MapVector typeToAlias; + /// Mapping between attribute/type and alias. + llvm::MapVector attrTypeToAlias; /// An allocator used for alias names. llvm::BumpPtrAllocator aliasAllocator; @@ -950,23 +956,23 @@ Operation *op, const OpPrintingFlags &printerFlags, DialectInterfaceCollection &interfaces) { AliasInitializer initializer(interfaces, aliasAllocator); - initializer.initialize(op, printerFlags, attrToAlias, typeToAlias); + initializer.initialize(op, printerFlags, attrTypeToAlias); } LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { - auto it = attrToAlias.find(attr); - if (it == attrToAlias.end()) + auto it = attrTypeToAlias.find(attr.getAsOpaquePointer()); + if (it == attrTypeToAlias.end()) return failure(); - it->second.print(os << '#'); + it->second.print(os); return success(); } LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { - auto it = typeToAlias.find(ty); - if (it == typeToAlias.end()) + auto it = attrTypeToAlias.find(ty.getAsOpaquePointer()); + if (it == attrTypeToAlias.end()) return failure(); - it->second.print(os << '!'); + it->second.print(os); return success(); } @@ -975,27 +981,26 @@ auto filterFn = [=](const auto &aliasIt) { return aliasIt.second.canBeDeferred() == isDeferred; }; - for (auto &[attr, alias] : llvm::make_filter_range(attrToAlias, filterFn)) { - alias.print(p.getStream() << '#'); - p.getStream() << " = "; - - // TODO: Support nested aliases in mutable attributes. - if (attr.hasTrait()) - p.getStream() << attr; - else - p.printAttributeImpl(attr); - - p.getStream() << newLine; - } - for (auto &[type, alias] : llvm::make_filter_range(typeToAlias, filterFn)) { - alias.print(p.getStream() << '!'); + for (auto &[opaqueSymbol, alias] : + llvm::make_filter_range(attrTypeToAlias, filterFn)) { + alias.print(p.getStream()); p.getStream() << " = "; - // TODO: Support nested aliases in mutable types. - if (type.hasTrait()) - p.getStream() << type; - else - p.printTypeImpl(type); + if (alias.isTypeAlias()) { + // TODO: Support nested aliases in mutable types. + Type type = Type::getFromOpaquePointer(opaqueSymbol); + if (type.hasTrait()) + p.getStream() << type; + else + p.printTypeImpl(type); + } else { + // TODO: Support nested aliases in mutable attributes. + Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol); + if (attr.hasTrait()) + p.getStream() << attr; + else + p.printAttributeImpl(attr); + } p.getStream() << newLine; } diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir --- a/mlir/test/IR/print-attr-type-aliases.mlir +++ b/mlir/test/IR/print-attr-type-aliases.mlir @@ -1,6 +1,6 @@ -// RUN: mlir-opt %s | FileCheck %s +// RUN: mlir-opt %s -split-input-file | FileCheck %s // Verify printer of type & attr aliases. -// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -split-input-file | mlir-opt -split-input-file | FileCheck %s // CHECK-DAG: #test2Ealias = "alias_test:dot_in_name" "test.op"() {alias_test = "alias_test:dot_in_name"} : () -> () @@ -32,3 +32,10 @@ // CHECK-DAG: #loc2 = loc("nested") // CHECK-DAG: #loc3 = loc(fused<#loc2>["test.mlir":10:8]) "test.op"() {alias_test = loc(fused["test.mlir":10:8])} : () -> () + +// ----- + +// Check proper ordering of intermixed attribute/type aliases. +// CHECK: !tuple = tuple< +// CHECK: #loc1 = loc(fused>["test.mlir":10:8])} : () -> ()