diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -155,98 +155,46 @@ bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } //===----------------------------------------------------------------------===// -// ModuleState +// AliasState //===----------------------------------------------------------------------===// namespace { -/// A special index constant used for non-kind attribute aliases. -static constexpr int kNonAttrKindAlias = -1; - -class ModuleState { +/// This class manages the state for type and attribute aliases. +class AliasState { public: - explicit ModuleState(MLIRContext *context) : interfaces(context) {} - void initialize(Operation *op); - - Twine getAttributeAlias(Attribute attr) const { - auto alias = attrToAlias.find(attr); - if (alias == attrToAlias.end()) - return Twine(); - - // Return the alias for this attribute, along with the index if this was - // generated by a kind alias. - int kindIndex = alias->second.second; - return alias->second.first + - (kindIndex == kNonAttrKindAlias ? Twine() : Twine(kindIndex)); - } - - void printAttributeAliases(raw_ostream &os) const { - auto printAlias = [&](StringRef alias, Attribute attr, int index) { - os << '#' << alias; - if (index != kNonAttrKindAlias) - os << index; - os << " = " << attr << '\n'; - }; - - // Print all of the attribute kind aliases. - for (auto &kindAlias : attrKindToAlias) { - for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i) - printAlias(kindAlias.second.first, kindAlias.second.second[i], i); - os << "\n"; - } + // Initialize the internal aliases. + void + initialize(Operation *op, + DialectInterfaceCollection &interfaces); - // In a second pass print all of the remaining attribute aliases that aren't - // kind aliases. - for (Attribute attr : usedAttributes) { - auto alias = attrToAlias.find(attr); - if (alias != attrToAlias.end() && - alias->second.second == kNonAttrKindAlias) - printAlias(alias->second.first, attr, alias->second.second); - } - } + /// Return a name used for an attribute alias, or empty if there is no alias. + Twine getAttributeAlias(Attribute attr) const; - StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); } + /// Print all of the referenced attribute aliases. + void printAttributeAliases(raw_ostream &os) const; - void printTypeAliases(raw_ostream &os) const { - for (Type type : usedTypes) { - auto alias = typeToAlias.find(type); - if (alias != typeToAlias.end()) - os << '!' << alias->second << " = type " << type << '\n'; - } - } + /// Return a string to use as an alias for the given type, or empty if there + /// is no alias recorded. + StringRef getTypeAlias(Type ty) const; - /// Get an instance of the OpAsmDialectInterface for the given dialect, or - /// null if one wasn't registered. - const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) { - return interfaces.getInterfaceFor(dialect); - } + /// Print all of the referenced type aliases. + void printTypeAliases(raw_ostream &os) const; private: - void recordAttributeReference(Attribute attr) { - // Don't recheck attributes that have already been seen or those that - // already have an alias. - if (!usedAttributes.insert(attr) || attrToAlias.count(attr)) - return; + /// A special index constant used for non-kind attribute aliases. + enum { NonAttrKindAlias = -1 }; - // If this attribute kind has an alias, then record one for this attribute. - auto alias = attrKindToAlias.find(static_cast(attr.getKind())); - if (alias == attrKindToAlias.end()) - return; - std::pair attrAlias(alias->second.first, - alias->second.second.size()); - attrToAlias.insert({attr, attrAlias}); - alias->second.second.push_back(attr); - } + /// Record a reference to the given attribute. + void recordAttributeReference(Attribute attr); - void recordTypeReference(Type ty) { usedTypes.insert(ty); } + /// Record a reference to the given type. + void recordTypeReference(Type ty); // Visit functions. void visitOperation(Operation *op); void visitType(Type type); void visitAttribute(Attribute attr); - // Initialize symbol aliases. - void initializeSymbolAliases(); - /// Set of attributes known to be used within the module. llvm::SetVector usedAttributes; @@ -265,59 +213,9 @@ /// A mapping between a type and a given alias. DenseMap typeToAlias; - - /// Collection of OpAsm interfaces implemented in the context. - DialectInterfaceCollection interfaces; }; } // end anonymous namespace -// TODO Support visiting other types/operations when implemented. -void ModuleState::visitType(Type type) { - recordTypeReference(type); - if (auto funcType = type.dyn_cast()) { - // Visit input and result types for functions. - for (auto input : funcType.getInputs()) - visitType(input); - for (auto result : funcType.getResults()) - visitType(result); - return; - } - if (auto memref = type.dyn_cast()) { - // Visit affine maps in memref type. - for (auto map : memref.getAffineMaps()) - recordAttributeReference(AffineMapAttr::get(map)); - } - if (auto shapedType = type.dyn_cast()) { - visitType(shapedType.getElementType()); - } -} - -void ModuleState::visitAttribute(Attribute attr) { - recordAttributeReference(attr); - if (auto arrayAttr = attr.dyn_cast()) { - for (auto elt : arrayAttr.getValue()) - visitAttribute(elt); - } else if (auto typeAttr = attr.dyn_cast()) { - visitType(typeAttr.getValue()); - } -} - -void ModuleState::visitOperation(Operation *op) { - // Visit all the types used in the operation. - for (auto type : op->getOperandTypes()) - visitType(type); - for (auto type : op->getResultTypes()) - visitType(type); - for (auto ®ion : op->getRegions()) - for (auto &block : region) - for (auto arg : block.getArguments()) - visitType(arg->getType()); - - // Visit each of the attributes. - for (auto elt : op->getAttrs()) - visitAttribute(elt.second); -} - // Utility to generate a function to register a symbol alias. static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) { assert(!name.empty() && "expected alias name to be non-empty"); @@ -329,7 +227,9 @@ return !name.contains('.') && usedAliases.insert(name).second; } -void ModuleState::initializeSymbolAliases() { +void AliasState::initialize( + Operation *op, + DialectInterfaceCollection &interfaces) { // Track the identifiers in use for each symbol so that the same identifier // isn't used twice. llvm::StringSet<> usedAliases; @@ -374,7 +274,7 @@ for (auto &attrAliasPair : attributeAliases) { std::tie(attr, alias) = attrAliasPair; if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases)) - attrToAlias.insert({attr, {alias, kNonAttrKindAlias}}); + attrToAlias.insert({attr, {alias, NonAttrKindAlias}}); } // Clear the set of used identifiers as types can have the same identifiers as @@ -385,14 +285,164 @@ for (auto &typeAliasPair : typeAliases) if (canRegisterAlias(typeAliasPair.second, usedAliases)) typeToAlias.insert(typeAliasPair); + + // Traverse the given IR to generate the set of used attributes/types. + op->walk([&](Operation *op) { visitOperation(op); }); } -void ModuleState::initialize(Operation *op) { - // Initialize the symbol aliases. - initializeSymbolAliases(); +/// Return a name used for an attribute alias, or empty if there is no alias. +Twine AliasState::getAttributeAlias(Attribute attr) const { + auto alias = attrToAlias.find(attr); + if (alias == attrToAlias.end()) + return Twine(); - // Visit each of the nested operations. - op->walk([&](Operation *op) { visitOperation(op); }); + // Return the alias for this attribute, along with the index if this was + // generated by a kind alias. + int kindIndex = alias->second.second; + return alias->second.first + + (kindIndex == NonAttrKindAlias ? Twine() : Twine(kindIndex)); +} + +/// Print all of the referenced attribute aliases. +void AliasState::printAttributeAliases(raw_ostream &os) const { + auto printAlias = [&](StringRef alias, Attribute attr, int index) { + os << '#' << alias; + if (index != NonAttrKindAlias) + os << index; + os << " = " << attr << '\n'; + }; + + // Print all of the attribute kind aliases. + for (auto &kindAlias : attrKindToAlias) { + for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i) + printAlias(kindAlias.second.first, kindAlias.second.second[i], i); + os << "\n"; + } + + // In a second pass print all of the remaining attribute aliases that aren't + // kind aliases. + for (Attribute attr : usedAttributes) { + auto alias = attrToAlias.find(attr); + if (alias != attrToAlias.end() && alias->second.second == NonAttrKindAlias) + printAlias(alias->second.first, attr, alias->second.second); + } +} + +/// Return a string to use as an alias for the given type, or empty if there +/// is no alias recorded. +StringRef AliasState::getTypeAlias(Type ty) const { + return typeToAlias.lookup(ty); +} + +/// Print all of the referenced type aliases. +void AliasState::printTypeAliases(raw_ostream &os) const { + for (Type type : usedTypes) { + auto alias = typeToAlias.find(type); + if (alias != typeToAlias.end()) + os << '!' << alias->second << " = type " << type << '\n'; + } +} + +/// Record a reference to the given attribute. +void AliasState::recordAttributeReference(Attribute attr) { + // Don't recheck attributes that have already been seen or those that + // already have an alias. + if (!usedAttributes.insert(attr) || attrToAlias.count(attr)) + return; + + // If this attribute kind has an alias, then record one for this attribute. + auto alias = attrKindToAlias.find(static_cast(attr.getKind())); + if (alias == attrKindToAlias.end()) + return; + std::pair attrAlias(alias->second.first, + alias->second.second.size()); + attrToAlias.insert({attr, attrAlias}); + alias->second.second.push_back(attr); +} + +/// Record a reference to the given type. +void AliasState::recordTypeReference(Type ty) { usedTypes.insert(ty); } + +// TODO Support visiting other types/operations when implemented. +void AliasState::visitType(Type type) { + recordTypeReference(type); + + if (auto funcType = type.dyn_cast()) { + // Visit input and result types for functions. + for (auto input : funcType.getInputs()) + visitType(input); + for (auto result : funcType.getResults()) + visitType(result); + } else if (auto shapedType = type.dyn_cast()) { + visitType(shapedType.getElementType()); + + // Visit affine maps in memref type. + if (auto memref = type.dyn_cast()) + for (auto map : memref.getAffineMaps()) + recordAttributeReference(AffineMapAttr::get(map)); + } +} + +void AliasState::visitAttribute(Attribute attr) { + recordAttributeReference(attr); + + if (auto arrayAttr = attr.dyn_cast()) { + for (auto elt : arrayAttr.getValue()) + visitAttribute(elt); + } else if (auto typeAttr = attr.dyn_cast()) { + visitType(typeAttr.getValue()); + } +} + +void AliasState::visitOperation(Operation *op) { + // Visit all the types used in the operation. + for (auto type : op->getOperandTypes()) + visitType(type); + for (auto type : op->getResultTypes()) + visitType(type); + for (auto ®ion : op->getRegions()) + for (auto &block : region) + for (auto arg : block.getArguments()) + visitType(arg->getType()); + + // Visit each of the attributes. + for (auto elt : op->getAttrs()) + visitAttribute(elt.second); +} + +//===----------------------------------------------------------------------===// +// ModuleState +//===----------------------------------------------------------------------===// + +namespace { +class ModuleState { +public: + explicit ModuleState(MLIRContext *context) : interfaces(context) {} + + /// Initialize the alias state to enable the printing of aliases. + void initializeAliases(Operation *op); + + /// Get an instance of the OpAsmDialectInterface for the given dialect, or + /// null if one wasn't registered. + const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) { + return interfaces.getInterfaceFor(dialect); + } + + /// Get the state used for aliases. + AliasState &getAliasState() { return aliasState; } + +private: + /// Collection of OpAsm interfaces implemented in the context. + DialectInterfaceCollection interfaces; + + /// The state used for attribute and type aliases. + AliasState aliasState; +}; +} // end anonymous namespace + +/// Initialize the alias state to enable the printing of aliases. +void ModuleState::initializeAliases(Operation *op) { + aliasState.initialize(op, interfaces); } //===----------------------------------------------------------------------===// @@ -745,7 +795,7 @@ // Check for an alias for this attribute. if (state) { - Twine alias = state->getAttributeAlias(attr); + Twine alias = state->getAliasState().getAttributeAlias(attr); if (!alias.isTriviallyEmpty()) { os << '#' << alias; return; @@ -975,7 +1025,7 @@ void ModulePrinter::printType(Type type) { // Check for an alias for this type. if (state) { - StringRef alias = state->getTypeAlias(type); + StringRef alias = state->getAliasState().getTypeAlias(type); if (!alias.empty()) { os << '!' << alias; return; @@ -1997,8 +2047,8 @@ void ModulePrinter::print(ModuleOp module) { // Output the aliases at the top level. if (state) { - state->printAttributeAliases(os); - state->printTypeAliases(os); + state->getAliasState().printAttributeAliases(os); + state->getAliasState().printTypeAliases(os); } // Print the module. @@ -2136,9 +2186,9 @@ void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) { ModuleState state(getContext()); - // Skip initializing in local scope to avoid populating aliases. + // Don't populate aliases when printing at local scope. if (!flags.shouldUseLocalScope()) - state.initialize(*this); + state.initializeAliases(*this); ModulePrinter(os, flags, &state).print(*this); }