diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -868,8 +868,7 @@ }; } // namespace -SSANameState::SSANameState( - Operation *op, const OpPrintingFlags &printerFlags) +SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags) : printerFlags(printerFlags) { llvm::SaveAndRestore valueIDSaver(nextValueID); llvm::SaveAndRestore argumentIDSaver(nextArgumentID); @@ -1532,13 +1531,17 @@ } /// Print the given dialect symbol to the stream. +/// This unifies the printing API for types and attributes, which +/// different based on a prefix (! v/s #) but little else. static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, - StringRef dialectName, StringRef symString) { + StringRef dialectName, StringRef symString, + bool shouldPrintGenericOp) { os << symPrefix << dialectName; // If this symbol name is simple enough, print it directly in pretty form, // otherwise, we print it as an escaped string. - if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { + if (!shouldPrintGenericOp && + isDialectSymbolSimpleEnoughForPrettyForm(symString)) { os << '.' << symString; return; } @@ -1623,7 +1626,8 @@ auto attrType = attr.getType(); if (auto opaqueAttr = attr.dyn_cast()) { printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), - opaqueAttr.getAttrData()); + opaqueAttr.getAttrData(), + printerFlags.shouldPrintGenericOpForm()); } else if (attr.isa()) { os << "unit"; return; @@ -1911,7 +1915,8 @@ TypeSwitch(type) .Case([&](OpaqueType opaqueTy) { printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), - opaqueTy.getTypeData()); + opaqueTy.getTypeData(), + printerFlags.shouldPrintGenericOpForm()); }) .Case([&](Type) { os << "index"; }) .Case([&](Type) { os << "bf16"; }) @@ -2083,7 +2088,8 @@ DialectAsmPrinter printer(subPrinter); dialect.printAttribute(attr, printer); } - printDialectSymbol(os, "#", dialect.getNamespace(), attrName); + printDialectSymbol(os, "#", dialect.getNamespace(), attrName, + printerFlags.shouldPrintGenericOpForm()); } void AsmPrinter::Impl::printDialectType(Type type) { @@ -2097,7 +2103,8 @@ DialectAsmPrinter printer(subPrinter); dialect.printType(type, printer); } - printDialectSymbol(os, "!", dialect.getNamespace(), typeName); + printDialectSymbol(os, "!", dialect.getNamespace(), typeName, + printerFlags.shouldPrintGenericOpForm()); } //===--------------------------------------------------------------------===//