diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -21,6 +21,7 @@ namespace mlir { class AsmResourcePrinter; +class AsmDialectResourceHandle; class Operation; namespace detail { @@ -455,6 +456,9 @@ AsmState(Operation *op, const OpPrintingFlags &printerFlags = OpPrintingFlags(), LocationMap *locationMap = nullptr); + AsmState(MLIRContext *ctx, + const OpPrintingFlags &printerFlags = OpPrintingFlags(), + LocationMap *locationMap = nullptr); ~AsmState(); /// Get the printer flags. @@ -480,6 +484,11 @@ name, std::forward(printFn))); } + /// Returns a map of dialect resources that were referenced when using this + /// state to print IR. + DenseMap> & + getDialectResources() const; + private: AsmState() = delete; diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -13,6 +13,7 @@ #include "llvm/Support/PointerLikeTypeTraits.h" namespace mlir { +class AsmState; class StringAttr; /// Attributes are known-constant values of operations. @@ -76,6 +77,7 @@ /// Print the attribute. void print(raw_ostream &os) const; + void print(raw_ostream &os, AsmState &state) const; void dump() const; /// Get an opaque pointer to the attribute. diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -15,6 +15,8 @@ #include "llvm/Support/PointerLikeTypeTraits.h" namespace mlir { +class AsmState; + /// Instances of the Type class are uniqued, have an immutable identifier and an /// optional mutable component. They wrap a pointer to the storage object owned /// by MLIRContext. Therefore, instances of Type are passed around by value. @@ -162,6 +164,7 @@ /// Print the current type. void print(raw_ostream &os) const; + void print(raw_ostream &os, AsmState &state) const; void dump() const; friend ::llvm::hash_code hash_value(Type arg); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -853,6 +853,7 @@ enum : unsigned { NameSentinel = ~0U }; SSANameState(Operation *op, const OpPrintingFlags &printerFlags); + SSANameState() = default; /// Print the SSA identifier for the given value to 'stream'. If /// 'printResultNo' is true, it also presents the result number ('#' number) @@ -1282,6 +1283,9 @@ AsmState::LocationMap *locationMap) : interfaces(op->getContext()), nameState(op, printerFlags), printerFlags(printerFlags), locationMap(locationMap) {} + explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags, + AsmState::LocationMap *locationMap) + : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {} /// Initialize the alias state to enable the printing of aliases. void initializeAliases(Operation *op) { @@ -1315,6 +1319,12 @@ (*locationMap)[op] = std::make_pair(line, col); } + /// Return the referenced dialect resources within the printer. + DenseMap> & + getDialectResources() { + return dialectResources; + } + private: /// Collection of OpAsm interfaces implemented in the context. DialectInterfaceCollection interfaces; @@ -1322,6 +1332,9 @@ /// A collection of non-dialect resource printers. SmallVector> externalResourcePrinters; + /// A set of dialect resources that were referenced during printing. + DenseMap> dialectResources; + /// The state used for attribute and type aliases. AliasState aliasState; @@ -1379,6 +1392,9 @@ LocationMap *locationMap) : impl(std::make_unique( op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {} +AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags, + LocationMap *locationMap) + : impl(std::make_unique(ctx, printerFlags, locationMap)) {} AsmState::~AsmState() = default; const OpPrintingFlags &AsmState::getPrinterFlags() const { @@ -1390,6 +1406,11 @@ impl->externalResourcePrinters.emplace_back(std::move(printer)); } +DenseMap> & +AsmState::getDialectResources() const { + return impl->getDialectResources(); +} + //===----------------------------------------------------------------------===// // AsmPrinter::Impl //===----------------------------------------------------------------------===// @@ -1397,11 +1418,9 @@ namespace mlir { class AsmPrinter::Impl { public: - Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None, - AsmStateImpl *state = nullptr) - : os(os), printerFlags(flags), state(state) {} - explicit Impl(Impl &other) - : Impl(other.os, other.printerFlags, other.state) {} + Impl(raw_ostream &os, AsmStateImpl &state) + : os(os), state(state), printerFlags(state.getPrinterFlags()) {} + explicit Impl(Impl &other) : Impl(other.os, other.state) {} /// Returns the output stream of the printer. raw_ostream &getStream() { return os; } @@ -1446,7 +1465,7 @@ void printResourceHandle(const AsmDialectResourceHandle &resource) { auto *interface = cast(resource.getDialect()); os << interface->getResourceKey(resource); - dialectResources[resource.getDialect()].insert(resource); + state.getDialectResources()[resource.getDialect()].insert(resource); } void printAffineMap(AffineMap map); @@ -1503,17 +1522,14 @@ /// The output stream for the printer. raw_ostream &os; + /// An underlying assembly printer state. + AsmStateImpl &state; + /// A set of flags to control the printer's behavior. OpPrintingFlags printerFlags; - /// An optional printer state for the module. - AsmStateImpl *state; - /// A tracker for the number of new lines emitted during printing. NewLineCounter newLine; - - /// A set of dialect resources that were referenced during printing. - DenseMap> dialectResources; }; } // namespace mlir @@ -1647,7 +1663,7 @@ return printLocationInternal(loc, /*pretty=*/true); os << "loc("; - if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os))) + if (!allowAlias || failed(printAlias(loc))) printLocationInternal(loc); os << ')'; } @@ -1734,11 +1750,11 @@ } LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { - return success(state && succeeded(state->getAliasState().getAlias(attr, os))); + return state.getAliasState().getAlias(attr, os); } LogicalResult AsmPrinter::Impl::printAlias(Type type) { - return success(state && succeeded(state->getAliasState().getAlias(type, os))); + return state.getAliasState().getAlias(type, os); } void AsmPrinter::Impl::printAttribute(Attribute attr, @@ -2068,7 +2084,7 @@ } // Try to print an alias for this type. - if (state && succeeded(state->getAliasState().getAlias(type, os))) + if (succeeded(printAlias(type))) return; TypeSwitch(type) @@ -2242,14 +2258,9 @@ std::string attrName; { llvm::raw_string_ostream attrNameStr(attrName); - Impl subPrinter(attrNameStr, printerFlags, state); + Impl subPrinter(attrNameStr, state); DialectAsmPrinter printer(subPrinter); dialect.printAttribute(attr, printer); - - // FIXME: Delete this when we no longer require a nested printer. - for (auto &it : subPrinter.dialectResources) - for (const auto &resource : it.second) - dialectResources[it.first].insert(resource); } printDialectSymbol(os, "#", dialect.getNamespace(), attrName); } @@ -2261,14 +2272,9 @@ std::string typeName; { llvm::raw_string_ostream typeNameStr(typeName); - Impl subPrinter(typeNameStr, printerFlags, state); + Impl subPrinter(typeNameStr, state); DialectAsmPrinter printer(subPrinter); dialect.printType(type, printer); - - // FIXME: Delete this when we no longer require a nested printer. - for (auto &it : subPrinter.dialectResources) - for (const auto &resource : it.second) - dialectResources[it.first].insert(resource); } printDialectSymbol(os, "!", dialect.getNamespace(), typeName); } @@ -2561,8 +2567,7 @@ using Impl::printType; explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) - : Impl(os, state.getPrinterFlags(), &state), - OpAsmPrinter(static_cast(*this)) {} + : Impl(os, state), OpAsmPrinter(static_cast(*this)) {} /// Print the given top-level operation. void printTopLevelOperation(Operation *op); @@ -2646,7 +2651,7 @@ /// operations. If any entry in namesToUse is null, the corresponding /// argument name is left alone. void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { - state->getSSANameState().shadowRegionArgs(region, namesToUse); + state.getSSANameState().shadowRegionArgs(region, namesToUse); } /// Print the given affine map with the symbol and dimension operands printed @@ -2736,14 +2741,14 @@ void OperationPrinter::printTopLevelOperation(Operation *op) { // Output the aliases at the top level that can't be deferred. - state->getAliasState().printNonDeferredAliases(os, newLine); + state.getAliasState().printNonDeferredAliases(os, newLine); // Print the module. print(op); os << newLine; // Output the aliases at the top level that can be deferred. - state->getAliasState().printDeferredAliases(os, newLine); + state.getAliasState().printDeferredAliases(os, newLine); // Output any file level metadata. printFileMetadataDictionary(op); @@ -2795,7 +2800,8 @@ // Print the `dialect_resources` section if we have any dialects with // resources. - for (const OpAsmDialectInterface &interface : state->getDialectInterfaces()) { + for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) { + auto &dialectResources = state.getDialectResources(); StringRef name = interface.getDialect()->getNamespace(); auto it = dialectResources.find(interface.getDialect()); if (it != dialectResources.end()) @@ -2810,7 +2816,7 @@ // Print the `external_resources` section if we have any external clients with // resources. hadResource = false; - for (const auto &printer : state->getResourcePrinters()) + for (const auto &printer : state.getResourcePrinters()) processProvider("external", printer.getName(), printer); if (hadResource) os << newLine << " }"; @@ -2836,7 +2842,7 @@ void OperationPrinter::print(Operation *op) { // Track the location of this operation. - state->registerOperationLocation(op, newLine.curLine, currentIndent); + state.registerOperationLocation(op, newLine.curLine, currentIndent); os.indent(currentIndent); printOperation(op); @@ -2854,7 +2860,7 @@ }; // Check to see if this operation has multiple result groups. - ArrayRef resultGroups = state->getSSANameState().getOpResultGroups(op); + ArrayRef resultGroups = state.getSSANameState().getOpResultGroups(op); if (!resultGroups.empty()) { // Interleave the groups excluding the last one, this one will be handled // separately. @@ -3010,7 +3016,7 @@ } void OperationPrinter::printBlockName(Block *block) { - os << state->getSSANameState().getBlockInfo(block).name; + os << state.getSSANameState().getBlockInfo(block).name; } void OperationPrinter::print(Block *block, bool printBlockArgs, @@ -3048,7 +3054,7 @@ // whatever order the use-list is in, so gather and sort them. SmallVector predIDs; for (auto *pred : block->getPredecessors()) - predIDs.push_back(state->getSSANameState().getBlockInfo(pred)); + predIDs.push_back(state.getSSANameState().getBlockInfo(pred)); llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) { return lhs.ordering < rhs.ordering; }); @@ -3084,14 +3090,14 @@ void OperationPrinter::printValueID(Value value, bool printResultNo, raw_ostream *streamOverride) const { - state->getSSANameState().printValueID(value, printResultNo, - streamOverride ? *streamOverride : os); + state.getSSANameState().printValueID(value, printResultNo, + streamOverride ? *streamOverride : os); } void OperationPrinter::printOperationID(Operation *op, raw_ostream *streamOverride) const { - state->getSSANameState().printOperationID(op, streamOverride ? *streamOverride - : os); + state.getSSANameState().printOperationID(op, streamOverride ? *streamOverride + : os); } void OperationPrinter::printSuccessor(Block *successor) { @@ -3176,7 +3182,16 @@ //===----------------------------------------------------------------------===// void Attribute::print(raw_ostream &os) const { - AsmPrinter::Impl(os).printAttribute(*this); + if (!*this) { + os << "<>"; + return; + } + + AsmState state(getContext()); + print(os, state); +} +void Attribute::print(raw_ostream &os, AsmState &state) const { + AsmPrinter::Impl(os, state.getImpl()).printAttribute(*this); } void Attribute::dump() const { @@ -3185,7 +3200,16 @@ } void Type::print(raw_ostream &os) const { - AsmPrinter::Impl(os).printType(*this); + if (!*this) { + os << "<>"; + return; + } + + AsmState state(getContext()); + print(os, state); +} +void Type::print(raw_ostream &os, AsmState &state) const { + AsmPrinter::Impl(os, state.getImpl()).printType(*this); } void Type::dump() const { print(llvm::errs()); } @@ -3205,7 +3229,8 @@ os << "<>"; return; } - AsmPrinter::Impl(os).printAffineExpr(*this); + AsmState state(getContext()); + AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(*this); } void AffineExpr::dump() const { @@ -3218,11 +3243,13 @@ os << "<>"; return; } - AsmPrinter::Impl(os).printAffineMap(*this); + AsmState state(getContext()); + AsmPrinter::Impl(os, state.getImpl()).printAffineMap(*this); } void IntegerSet::print(raw_ostream &os) const { - AsmPrinter::Impl(os).printIntegerSet(*this); + AsmState state(getContext()); + AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(*this); } void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); }