diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -35,6 +35,7 @@ #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Regex.h" +#include "llvm/Support/SaveAndRestore.h" using namespace mlir; void Identifier::print(raw_ostream &os) const { os << str(); } @@ -412,453 +413,803 @@ } //===----------------------------------------------------------------------===// -// ModuleState +// SSANameState //===----------------------------------------------------------------------===// namespace { -class ModuleState { +/// This class manages the state of SSA value names. +class SSANameState { public: - explicit ModuleState(MLIRContext *context) : interfaces(context) {} - - /// Initialize the alias state to enable the printing of aliases. - void initializeAliases(Operation *op) { - aliasState.initialize(op, interfaces); - } - - /// Get an instance of the OpAsmDialectInterface for the given dialect, or - /// null if one wasn't registered. - const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) { - return interfaces.getInterfaceFor(dialect); - } - - /// Get the state used for aliases. - AliasState &getAliasState() { return aliasState; } + /// A sentinal value used for values with names set. + enum : unsigned { NameSentinel = ~0U }; -private: - /// Collection of OpAsm interfaces implemented in the context. - DialectInterfaceCollection interfaces; - - /// The state used for attribute and type aliases. - AliasState aliasState; -}; -} // end anonymous namespace + SSANameState(Operation *op, + DialectInterfaceCollection &interfaces); -//===----------------------------------------------------------------------===// -// ModulePrinter -//===----------------------------------------------------------------------===// - -namespace { -class ModulePrinter { -public: - ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None, - ModuleState *state = nullptr) - : os(os), printerFlags(flags), state(state) {} - explicit ModulePrinter(ModulePrinter &printer) - : os(printer.os), printerFlags(printer.printerFlags), - state(printer.state) {} + /// Print the SSA identifier for the given value to 'stream'. If + /// 'printResultNo' is true, it also presents the result number ('#' number) + /// of this value. + void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; - /// Returns the output stream of the printer. - raw_ostream &getStream() { return os; } + /// Return the result indices for each of the result groups registered by this + /// operation, or empty if none exist. + ArrayRef getOpResultGroups(Operation *op); - template - inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { - mlir::interleaveComma(c, os, each_fn); - } + /// Get the ID for the given block. + unsigned getBlockID(Block *block); - void print(ModuleOp module); + /// Renumber the arguments for the specified region to the same names as the + /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for + /// details. + void shadowRegionArgs(Region ®ion, ValueRange namesToUse); - /// Print the given attribute. If 'mayElideType' is true, some attributes are - /// printed without the type when the type matches the default used in the - /// parser (for example i64 is the default for integer attributes). - void printAttribute(Attribute attr, bool mayElideType = false); +private: + /// Number the SSA values within the given IR unit. + void numberValuesInRegion( + Region ®ion, + DialectInterfaceCollection &interfaces); + void numberValuesInBlock( + Block &block, + DialectInterfaceCollection &interfaces); + void numberValuesInOp( + Operation &op, + DialectInterfaceCollection &interfaces); - void printType(Type type); - void printLocation(LocationAttr loc); + /// Given a result of an operation 'result', find the result group head + /// 'lookupValue' and the result of 'result' within that group in + /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group + /// has more than 1 result. + void getResultIDAndNumber(OpResult result, Value &lookupValue, + Optional &lookupResultNo) const; - void printAffineMap(AffineMap map); - void - printAffineExpr(AffineExpr expr, - function_ref printValueName = nullptr); - void printAffineConstraint(AffineExpr expr, bool isEq); - void printIntegerSet(IntegerSet set); + /// Set a special value name for the given value. + void setValueName(Value value, StringRef name); -protected: - void printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs = {}, - bool withKeyword = false); - void printTrailingLocation(Location loc); - void printLocationInternal(LocationAttr loc, bool pretty = false); - void printDenseElementsAttr(DenseElementsAttr attr); + /// Uniques the given value name within the printer. If the given name + /// conflicts, it is automatically renamed. + StringRef uniqueValueName(StringRef name); - void printDialectAttribute(Attribute attr); - void printDialectType(Type type); + /// This is the value ID for each SSA value. If this returns NameSentinel, + /// then the valueID has an entry in valueNames. + DenseMap valueIDs; + DenseMap valueNames; - /// This enum is used to represent the binding strength of the enclosing - /// context that an AffineExprStorage is being printed in, so we can - /// intelligently produce parens. - enum class BindingStrength { - Weak, // + and - - Strong, // All other binary operators. - }; - void printAffineExprInternal( - AffineExpr expr, BindingStrength enclosingTightness, - function_ref printValueName = nullptr); + /// This is a map of operations that contain multiple named result groups, + /// i.e. there may be multiple names for the results of the operation. The + /// value of this map are the result numbers that start a result group. + DenseMap> opResultGroups; - /// The output stream for the printer. - raw_ostream &os; + /// This is the block ID for each block in the current. + DenseMap blockIDs; - /// A set of flags to control the printer's behavior. - OpPrintingFlags printerFlags; + /// This keeps track of all of the non-numeric names that are in flight, + /// allowing us to check for duplicates. + /// Note: the value of the map is unused. + llvm::ScopedHashTable usedNames; + llvm::BumpPtrAllocator usedNameAllocator; - /// An optional printer state for the module. - ModuleState *state; + /// This is the next value ID to assign in numbering. + unsigned nextValueID = 0; + /// This is the next ID to assign to a region entry block argument. + unsigned nextArgumentID = 0; + /// This is the next ID to assign when a name conflict is detected. + unsigned nextConflictID = 0; }; } // end anonymous namespace -void ModulePrinter::printTrailingLocation(Location loc) { - // Check to see if we are printing debug information. - if (!printerFlags.shouldPrintDebugInfo()) - return; +SSANameState::SSANameState( + Operation *op, + DialectInterfaceCollection &interfaces) { + llvm::ScopedHashTable::ScopeTy usedNamesScope(usedNames); + numberValuesInOp(*op, interfaces); - os << " "; - printLocation(loc); + for (auto ®ion : op->getRegions()) + numberValuesInRegion(region, interfaces); } -void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) { - switch (loc.getKind()) { - case StandardAttributes::OpaqueLocation: - printLocationInternal(loc.cast().getFallbackLocation(), pretty); - break; - case StandardAttributes::UnknownLocation: - if (pretty) - os << "[unknown]"; - else - os << "unknown"; - break; - case StandardAttributes::FileLineColLocation: { - auto fileLoc = loc.cast(); - auto mayQuote = pretty ? "" : "\""; - os << mayQuote << fileLoc.getFilename() << mayQuote << ':' - << fileLoc.getLine() << ':' << fileLoc.getColumn(); - break; - } - case StandardAttributes::NameLocation: { - auto nameLoc = loc.cast(); - os << '\"' << nameLoc.getName() << '\"'; - - // Print the child if it isn't unknown. - auto childLoc = nameLoc.getChildLoc(); - if (!childLoc.isa()) { - os << '('; - printLocationInternal(childLoc, pretty); - os << ')'; - } - break; - } - case StandardAttributes::CallSiteLocation: { - auto callLocation = loc.cast(); - auto caller = callLocation.getCaller(); - auto callee = callLocation.getCallee(); - if (!pretty) - os << "callsite("; - printLocationInternal(callee, pretty); - if (pretty) { - if (callee.isa()) { - if (caller.isa()) { - os << " at "; - } else { - os << "\n at "; - } - } else { - os << "\n at "; - } - } else { - os << " at "; - } - printLocationInternal(caller, pretty); - if (!pretty) - os << ")"; - break; - } - case StandardAttributes::FusedLocation: { - auto fusedLoc = loc.cast(); - if (!pretty) - os << "fused"; - if (auto metadata = fusedLoc.getMetadata()) - os << '<' << metadata << '>'; - os << '['; - interleave( - fusedLoc.getLocations(), - [&](Location loc) { printLocationInternal(loc, pretty); }, - [&]() { os << ", "; }); - os << ']'; - break; - } +void SSANameState::printValueID(Value value, bool printResultNo, + raw_ostream &stream) const { + if (!value) { + stream << "<>"; + return; } -} -/// Print a floating point value in a way that the parser will be able to -/// round-trip losslessly. -static void printFloatValue(const APFloat &apValue, raw_ostream &os) { - // We would like to output the FP constant value in exponential notation, - // but we cannot do this if doing so will lose precision. Check here to - // make sure that we only output it in exponential format if we can parse - // the value back and get the same value. - bool isInf = apValue.isInfinity(); - bool isNaN = apValue.isNaN(); - if (!isInf && !isNaN) { - SmallString<128> strValue; - apValue.toString(strValue, 6, 0, false); + Optional resultNo; + auto lookupValue = value; - // Check to make sure that the stringized number is not some string like - // "Inf" or NaN, that atof will accept, but the lexer will not. Check - // that the string matches the "[-+]?[0-9]" regex. - assert(((strValue[0] >= '0' && strValue[0] <= '9') || - ((strValue[0] == '-' || strValue[0] == '+') && - (strValue[1] >= '0' && strValue[1] <= '9'))) && - "[-+]?[0-9] regex does not match!"); + // If this is an operation result, collect the head lookup value of the result + // group and the result number of 'result' within that group. + if (OpResult result = value.dyn_cast()) + getResultIDAndNumber(result, lookupValue, resultNo); - // Parse back the stringized version and check that the value is equal - // (i.e., there is no precision loss). If it is not, use the default format - // of APFloat instead of the exponential notation. - if (!APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { - strValue.clear(); - apValue.toString(strValue); - } - os << strValue; + auto it = valueIDs.find(lookupValue); + if (it == valueIDs.end()) { + stream << "<>"; return; } - // Print special values in hexadecimal format. The sign bit should be - // included in the literal. - SmallVector str; - APInt apInt = apValue.bitcastToAPInt(); - apInt.toString(str, /*Radix=*/16, /*Signed=*/false, - /*formatAsCLiteral=*/true); - os << str; -} - -void ModulePrinter::printLocation(LocationAttr loc) { - if (printerFlags.shouldPrintDebugInfoPrettyForm()) { - printLocationInternal(loc, /*pretty=*/true); + stream << '%'; + if (it->second != NameSentinel) { + stream << it->second; } else { - os << "loc("; - printLocationInternal(loc); - os << ')'; + auto nameIt = valueNames.find(lookupValue); + assert(nameIt != valueNames.end() && "Didn't have a name entry?"); + stream << nameIt->second; } + + if (resultNo.hasValue() && printResultNo) + stream << '#' << resultNo; } -/// Returns if the given dialect symbol data is simple enough to print in the -/// pretty form, i.e. without the enclosing "". -static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { - // The name must start with an identifier. - if (symName.empty() || !isalpha(symName.front())) - return false; +ArrayRef SSANameState::getOpResultGroups(Operation *op) { + auto it = opResultGroups.find(op); + return it == opResultGroups.end() ? ArrayRef() : it->second; +} - // Ignore all the characters that are valid in an identifier in the symbol - // name. - symName = symName.drop_while( - [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; }); - if (symName.empty()) - return true; - - // If we got to an unexpected character, then it must be a <>. Check those - // recursively. - if (symName.front() != '<' || symName.back() != '>') - return false; - - SmallVector nestedPunctuation; - do { - // If we ran out of characters, then we had a punctuation mismatch. - if (symName.empty()) - return false; +unsigned SSANameState::getBlockID(Block *block) { + auto it = blockIDs.find(block); + return it != blockIDs.end() ? it->second : NameSentinel; +} - auto c = symName.front(); - symName = symName.drop_front(); +void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { + assert(!region.empty() && "cannot shadow arguments of an empty region"); + assert(region.front().getNumArguments() == namesToUse.size() && + "incorrect number of names passed in"); + assert(region.getParentOp()->isKnownIsolatedFromAbove() && + "only KnownIsolatedFromAbove ops can shadow names"); - switch (c) { - // We never allow null characters. This is an EOF indicator for the lexer - // which we could handle, but isn't important for any known dialect. - case '\0': - return false; - case '<': - case '[': - case '(': - case '{': - nestedPunctuation.push_back(c); - continue; - case '-': - // Treat `->` as a special token. - if (!symName.empty() && symName.front() == '>') { - symName = symName.drop_front(); - continue; - } - break; - // Reject types with mismatched brackets. - case '>': - if (nestedPunctuation.pop_back_val() != '<') - return false; - break; - case ']': - if (nestedPunctuation.pop_back_val() != '[') - return false; - break; - case ')': - if (nestedPunctuation.pop_back_val() != '(') - return false; - break; - case '}': - if (nestedPunctuation.pop_back_val() != '{') - return false; - break; - default: + SmallVector nameStr; + for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { + auto nameToUse = namesToUse[i]; + if (nameToUse == nullptr) continue; - } + auto nameToReplace = region.front().getArgument(i); - // We're done when the punctuation is fully matched. - } while (!nestedPunctuation.empty()); + nameStr.clear(); + llvm::raw_svector_ostream nameStream(nameStr); + printValueID(nameToUse, /*printResultNo=*/true, nameStream); - // If there were extra characters, then we failed. - return symName.empty(); -} + // Entry block arguments should already have a pretty "arg" name. + assert(valueIDs[nameToReplace] == NameSentinel); -/// Print the given dialect symbol to the stream. -static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, - StringRef dialectName, StringRef symString) { - os << symPrefix << dialectName; + // Use the name without the leading %. + auto name = StringRef(nameStream.str()).drop_front(); - // If this symbol name is simple enough, print it directly in pretty form, - // otherwise, we print it as an escaped string. - if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { - os << '.' << symString; - return; + // Overwrite the name. + valueNames[nameToReplace] = name.copy(usedNameAllocator); } - - // TODO: escape the symbol name, it could contain " characters. - os << "<\"" << symString << "\">"; } -/// Returns if the given string can be represented as a bare identifier. -static bool isBareIdentifier(StringRef name) { - assert(!name.empty() && "invalid name"); - - // By making this unsigned, the value passed in to isalnum will always be - // in the range 0-255. This is important when building with MSVC because - // its implementation will assert. This situation can arise when dealing - // with UTF-8 multibyte characters. - unsigned char firstChar = static_cast(name[0]); - if (!isalpha(firstChar) && firstChar != '_') - return false; - return llvm::all_of(name.drop_front(), [](unsigned char c) { - return isalnum(c) || c == '_' || c == '$' || c == '.'; - }); -} +void SSANameState::numberValuesInRegion( + Region ®ion, + DialectInterfaceCollection &interfaces) { + // Save the current value ids to allow for numbering values in sibling regions + // the same. + llvm::SaveAndRestore valueIDSaver(nextValueID); + llvm::SaveAndRestore argumentIDSaver(nextArgumentID); + llvm::SaveAndRestore conflictIDSaver(nextConflictID); -/// Print the given string as a symbol reference. A symbol reference is -/// represented as a string prefixed with '@'. The reference is surrounded with -/// ""'s and escaped if it has any special or non-printable characters in it. -static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { - assert(!symbolRef.empty() && "expected valid symbol reference"); + // Push a new used names scope. + llvm::ScopedHashTable::ScopeTy usedNamesScope(usedNames); - // If the symbol can be represented as a bare identifier, write it directly. - if (isBareIdentifier(symbolRef)) { - os << '@' << symbolRef; - return; + // Number the values within this region in a breadth-first order. + unsigned nextBlockID = 0; + for (auto &block : region) { + // Each block gets a unique ID, and all of the operations within it get + // numbered as well. + blockIDs[&block] = nextBlockID++; + numberValuesInBlock(block, interfaces); } - // Otherwise, output the reference wrapped in quotes with proper escaping. - os << "@\""; - printEscapedString(symbolRef, os); - os << '"'; + // After that we traverse the nested regions. + // TODO: Rework this loop to not use recursion. + for (auto &block : region) { + for (auto &op : block) + for (auto &nestedRegion : op.getRegions()) + numberValuesInRegion(nestedRegion, interfaces); + } } -// Print out a valid ElementsAttr that is succinct and can represent any -// potential shape/type, for use when eliding a large ElementsAttr. -// -// We choose to use an opaque ElementsAttr literal with conspicuous content to -// hopefully alert readers to the fact that this has been elided. -// -// Unfortunately, neither of the strings of an opaque ElementsAttr literal will -// accept the string "elided". The first string must be a registered dialect -// name and the latter must be a hex constant. -static void printElidedElementsAttr(raw_ostream &os) { - os << R"(opaque<"", "0xDEADBEEF">)"; -} +void SSANameState::numberValuesInBlock( + Block &block, + DialectInterfaceCollection &interfaces) { + auto setArgNameFn = [&](Value arg, StringRef name) { + assert(!valueIDs.count(arg) && "arg numbered multiple times"); + assert(arg.cast()->getOwner() == &block && + "arg not defined in 'block'"); + setValueName(arg, name); + }; -void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) { - if (!attr) { - os << "<>"; - return; + bool isEntryBlock = block.isEntryBlock(); + if (isEntryBlock) { + if (auto *op = block.getParentOp()) { + if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect())) + asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn); + } } - // Check for an alias for this attribute. - if (state) { - Twine alias = state->getAliasState().getAttributeAlias(attr); - if (!alias.isTriviallyEmpty()) { - os << '#' << alias; - return; + // Number the block arguments. We give entry block arguments a special name + // 'arg'. + SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); + llvm::raw_svector_ostream specialName(specialNameBuffer); + for (auto arg : block.getArguments()) { + if (valueIDs.count(arg)) + continue; + if (isEntryBlock) { + specialNameBuffer.resize(strlen("arg")); + specialName << nextArgumentID++; } + setValueName(arg, specialName.str()); } - switch (attr.getKind()) { - default: - return printDialectAttribute(attr); - - case StandardAttributes::Opaque: { - auto opaqueAttr = attr.cast(); - printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), - opaqueAttr.getAttrData()); - break; - } - case StandardAttributes::Unit: - os << "unit"; - break; - case StandardAttributes::Bool: - os << (attr.cast().getValue() ? "true" : "false"); + // Number the operations in this block. + for (auto &op : block) + numberValuesInOp(op, interfaces); +} - // BoolAttr always elides the type. +void SSANameState::numberValuesInOp( + Operation &op, + DialectInterfaceCollection &interfaces) { + unsigned numResults = op.getNumResults(); + if (numResults == 0) return; - case StandardAttributes::Dictionary: - os << '{'; - interleaveComma(attr.cast().getValue(), - [&](NamedAttribute attr) { - os << attr.first; + Value resultBegin = op.getResult(0); - // The value of a UnitAttr is elided within a dictionary. - if (attr.second.isa()) - return; + // Function used to set the special result names for the operation. + SmallVector resultGroups(/*Size=*/1, /*Value=*/0); + auto setResultNameFn = [&](Value result, StringRef name) { + assert(!valueIDs.count(result) && "result numbered multiple times"); + assert(result->getDefiningOp() == &op && "result not defined by 'op'"); + setValueName(result, name); - os << " = "; - printAttribute(attr.second); - }); - os << '}'; - break; - case StandardAttributes::Integer: { - auto intAttr = attr.cast(); - // Print all integer attributes as signed unless i1. - bool isSigned = intAttr.getType().isIndex() || - intAttr.getType().getIntOrFloatBitWidth() != 1; - intAttr.getValue().print(os, isSigned); + // Record the result number for groups not anchored at 0. + if (int resultNo = result.cast()->getResultNumber()) + resultGroups.push_back(resultNo); + }; + if (OpAsmOpInterface asmInterface = dyn_cast(&op)) + asmInterface.getAsmResultNames(setResultNameFn); + else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect())) + asmInterface->getAsmResultNames(&op, setResultNameFn); - // IntegerAttr elides the type if I64. - if (mayElideType && intAttr.getType().isInteger(64)) - return; - break; - } - case StandardAttributes::Float: { - auto floatAttr = attr.cast(); - printFloatValue(floatAttr.getValue(), os); + // If the first result wasn't numbered, give it a default number. + if (valueIDs.try_emplace(resultBegin, nextValueID).second) + ++nextValueID; - // FloatAttr elides the type if F64. - if (mayElideType && floatAttr.getType().isF64()) - return; - break; + // If this operation has multiple result groups, mark it. + if (resultGroups.size() != 1) { + llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); + opResultGroups.try_emplace(&op, std::move(resultGroups)); } - case StandardAttributes::String: - os << '"'; - printEscapedString(attr.cast().getValue(), os); - os << '"'; - break; +} + +void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue, + Optional &lookupResultNo) const { + Operation *owner = result->getOwner(); + if (owner->getNumResults() == 1) + return; + int resultNo = result->getResultNumber(); + + // If this operation has multiple result groups, we will need to find the + // one corresponding to this result. + auto resultGroupIt = opResultGroups.find(owner); + if (resultGroupIt == opResultGroups.end()) { + // If not, just use the first result. + lookupResultNo = resultNo; + lookupValue = owner->getResult(0); + return; + } + + // Find the correct index using a binary search, as the groups are ordered. + ArrayRef resultGroups = resultGroupIt->second; + auto it = llvm::upper_bound(resultGroups, resultNo); + int groupResultNo = 0, groupSize = 0; + + // If there are no smaller elements, the last result group is the lookup. + if (it == resultGroups.end()) { + groupResultNo = resultGroups.back(); + groupSize = static_cast(owner->getNumResults()) - resultGroups.back(); + } else { + // Otherwise, the previous element is the lookup. + groupResultNo = *std::prev(it); + groupSize = *it - groupResultNo; + } + + // We only record the result number for a group of size greater than 1. + if (groupSize != 1) + lookupResultNo = resultNo - groupResultNo; + lookupValue = owner->getResult(groupResultNo); +} + +void SSANameState::setValueName(Value value, StringRef name) { + // If the name is empty, the value uses the default numbering. + if (name.empty()) { + valueIDs[value] = nextValueID++; + return; + } + + valueIDs[value] = NameSentinel; + valueNames[value] = uniqueValueName(name); +} + +StringRef SSANameState::uniqueValueName(StringRef name) { + // Check to see if this name is already unique. + if (!usedNames.count(name)) { + name = name.copy(usedNameAllocator); + } else { + // Otherwise, we had a conflict - probe until we find a unique name. This + // is guaranteed to terminate (and usually in a single iteration) because it + // generates new names by incrementing nextConflictID. + SmallString<64> probeName(name); + probeName.push_back('_'); + while (true) { + probeName.resize(name.size() + 1); + probeName += llvm::utostr(nextConflictID++); + if (!usedNames.count(probeName)) { + name = StringRef(probeName).copy(usedNameAllocator); + break; + } + } + } + + usedNames.insert(name, char()); + return name; +} + +//===----------------------------------------------------------------------===// +// ModuleState +//===----------------------------------------------------------------------===// + +namespace { +class ModuleState { +public: + explicit ModuleState(Operation *op) + : interfaces(op->getContext()), nameState(op, interfaces) {} + + /// Initialize the alias state to enable the printing of aliases. + void initializeAliases(Operation *op) { + aliasState.initialize(op, interfaces); + } + + /// Get an instance of the OpAsmDialectInterface for the given dialect, or + /// null if one wasn't registered. + const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) { + return interfaces.getInterfaceFor(dialect); + } + + /// Get the state used for aliases. + AliasState &getAliasState() { return aliasState; } + + /// Get the state used for SSA names. + SSANameState &getSSANameState() { return nameState; } + +private: + /// Collection of OpAsm interfaces implemented in the context. + DialectInterfaceCollection interfaces; + + /// The state used for attribute and type aliases. + AliasState aliasState; + + /// The state used for SSA value names. + SSANameState nameState; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ModulePrinter +//===----------------------------------------------------------------------===// + +namespace { +class ModulePrinter { +public: + ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None, + ModuleState *state = nullptr) + : os(os), printerFlags(flags), state(state) {} + explicit ModulePrinter(ModulePrinter &printer) + : os(printer.os), printerFlags(printer.printerFlags), + state(printer.state) {} + + /// Returns the output stream of the printer. + raw_ostream &getStream() { return os; } + + template + inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { + mlir::interleaveComma(c, os, each_fn); + } + + void print(ModuleOp module); + + /// Print the given attribute. If 'mayElideType' is true, some attributes are + /// printed without the type when the type matches the default used in the + /// parser (for example i64 is the default for integer attributes). + void printAttribute(Attribute attr, bool mayElideType = false); + + void printType(Type type); + void printLocation(LocationAttr loc); + + void printAffineMap(AffineMap map); + void + printAffineExpr(AffineExpr expr, + function_ref printValueName = nullptr); + void printAffineConstraint(AffineExpr expr, bool isEq); + void printIntegerSet(IntegerSet set); + +protected: + void printOptionalAttrDict(ArrayRef attrs, + ArrayRef elidedAttrs = {}, + bool withKeyword = false); + void printTrailingLocation(Location loc); + void printLocationInternal(LocationAttr loc, bool pretty = false); + void printDenseElementsAttr(DenseElementsAttr attr); + + void printDialectAttribute(Attribute attr); + void printDialectType(Type type); + + /// This enum is used to represent the binding strength of the enclosing + /// context that an AffineExprStorage is being printed in, so we can + /// intelligently produce parens. + enum class BindingStrength { + Weak, // + and - + Strong, // All other binary operators. + }; + void printAffineExprInternal( + AffineExpr expr, BindingStrength enclosingTightness, + function_ref printValueName = nullptr); + + /// The output stream for the printer. + raw_ostream &os; + + /// A set of flags to control the printer's behavior. + OpPrintingFlags printerFlags; + + /// An optional printer state for the module. + ModuleState *state; +}; +} // end anonymous namespace + +void ModulePrinter::printTrailingLocation(Location loc) { + // Check to see if we are printing debug information. + if (!printerFlags.shouldPrintDebugInfo()) + return; + + os << " "; + printLocation(loc); +} + +void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) { + switch (loc.getKind()) { + case StandardAttributes::OpaqueLocation: + printLocationInternal(loc.cast().getFallbackLocation(), pretty); + break; + case StandardAttributes::UnknownLocation: + if (pretty) + os << "[unknown]"; + else + os << "unknown"; + break; + case StandardAttributes::FileLineColLocation: { + auto fileLoc = loc.cast(); + auto mayQuote = pretty ? "" : "\""; + os << mayQuote << fileLoc.getFilename() << mayQuote << ':' + << fileLoc.getLine() << ':' << fileLoc.getColumn(); + break; + } + case StandardAttributes::NameLocation: { + auto nameLoc = loc.cast(); + os << '\"' << nameLoc.getName() << '\"'; + + // Print the child if it isn't unknown. + auto childLoc = nameLoc.getChildLoc(); + if (!childLoc.isa()) { + os << '('; + printLocationInternal(childLoc, pretty); + os << ')'; + } + break; + } + case StandardAttributes::CallSiteLocation: { + auto callLocation = loc.cast(); + auto caller = callLocation.getCaller(); + auto callee = callLocation.getCallee(); + if (!pretty) + os << "callsite("; + printLocationInternal(callee, pretty); + if (pretty) { + if (callee.isa()) { + if (caller.isa()) { + os << " at "; + } else { + os << "\n at "; + } + } else { + os << "\n at "; + } + } else { + os << " at "; + } + printLocationInternal(caller, pretty); + if (!pretty) + os << ")"; + break; + } + case StandardAttributes::FusedLocation: { + auto fusedLoc = loc.cast(); + if (!pretty) + os << "fused"; + if (auto metadata = fusedLoc.getMetadata()) + os << '<' << metadata << '>'; + os << '['; + interleave( + fusedLoc.getLocations(), + [&](Location loc) { printLocationInternal(loc, pretty); }, + [&]() { os << ", "; }); + os << ']'; + break; + } + } +} + +/// Print a floating point value in a way that the parser will be able to +/// round-trip losslessly. +static void printFloatValue(const APFloat &apValue, raw_ostream &os) { + // We would like to output the FP constant value in exponential notation, + // but we cannot do this if doing so will lose precision. Check here to + // make sure that we only output it in exponential format if we can parse + // the value back and get the same value. + bool isInf = apValue.isInfinity(); + bool isNaN = apValue.isNaN(); + if (!isInf && !isNaN) { + SmallString<128> strValue; + apValue.toString(strValue, 6, 0, false); + + // Check to make sure that the stringized number is not some string like + // "Inf" or NaN, that atof will accept, but the lexer will not. Check + // that the string matches the "[-+]?[0-9]" regex. + assert(((strValue[0] >= '0' && strValue[0] <= '9') || + ((strValue[0] == '-' || strValue[0] == '+') && + (strValue[1] >= '0' && strValue[1] <= '9'))) && + "[-+]?[0-9] regex does not match!"); + + // Parse back the stringized version and check that the value is equal + // (i.e., there is no precision loss). If it is not, use the default format + // of APFloat instead of the exponential notation. + if (!APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { + strValue.clear(); + apValue.toString(strValue); + } + os << strValue; + return; + } + + // Print special values in hexadecimal format. The sign bit should be + // included in the literal. + SmallVector str; + APInt apInt = apValue.bitcastToAPInt(); + apInt.toString(str, /*Radix=*/16, /*Signed=*/false, + /*formatAsCLiteral=*/true); + os << str; +} + +void ModulePrinter::printLocation(LocationAttr loc) { + if (printerFlags.shouldPrintDebugInfoPrettyForm()) { + printLocationInternal(loc, /*pretty=*/true); + } else { + os << "loc("; + printLocationInternal(loc); + os << ')'; + } +} + +/// Returns if the given dialect symbol data is simple enough to print in the +/// pretty form, i.e. without the enclosing "". +static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { + // The name must start with an identifier. + if (symName.empty() || !isalpha(symName.front())) + return false; + + // Ignore all the characters that are valid in an identifier in the symbol + // name. + symName = symName.drop_while( + [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; }); + if (symName.empty()) + return true; + + // If we got to an unexpected character, then it must be a <>. Check those + // recursively. + if (symName.front() != '<' || symName.back() != '>') + return false; + + SmallVector nestedPunctuation; + do { + // If we ran out of characters, then we had a punctuation mismatch. + if (symName.empty()) + return false; + + auto c = symName.front(); + symName = symName.drop_front(); + + switch (c) { + // We never allow null characters. This is an EOF indicator for the lexer + // which we could handle, but isn't important for any known dialect. + case '\0': + return false; + case '<': + case '[': + case '(': + case '{': + nestedPunctuation.push_back(c); + continue; + case '-': + // Treat `->` as a special token. + if (!symName.empty() && symName.front() == '>') { + symName = symName.drop_front(); + continue; + } + break; + // Reject types with mismatched brackets. + case '>': + if (nestedPunctuation.pop_back_val() != '<') + return false; + break; + case ']': + if (nestedPunctuation.pop_back_val() != '[') + return false; + break; + case ')': + if (nestedPunctuation.pop_back_val() != '(') + return false; + break; + case '}': + if (nestedPunctuation.pop_back_val() != '{') + return false; + break; + default: + continue; + } + + // We're done when the punctuation is fully matched. + } while (!nestedPunctuation.empty()); + + // If there were extra characters, then we failed. + return symName.empty(); +} + +/// Print the given dialect symbol to the stream. +static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, + StringRef dialectName, StringRef symString) { + os << symPrefix << dialectName; + + // If this symbol name is simple enough, print it directly in pretty form, + // otherwise, we print it as an escaped string. + if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { + os << '.' << symString; + return; + } + + // TODO: escape the symbol name, it could contain " characters. + os << "<\"" << symString << "\">"; +} + +/// Returns if the given string can be represented as a bare identifier. +static bool isBareIdentifier(StringRef name) { + assert(!name.empty() && "invalid name"); + + // By making this unsigned, the value passed in to isalnum will always be + // in the range 0-255. This is important when building with MSVC because + // its implementation will assert. This situation can arise when dealing + // with UTF-8 multibyte characters. + unsigned char firstChar = static_cast(name[0]); + if (!isalpha(firstChar) && firstChar != '_') + return false; + return llvm::all_of(name.drop_front(), [](unsigned char c) { + return isalnum(c) || c == '_' || c == '$' || c == '.'; + }); +} + +/// Print the given string as a symbol reference. A symbol reference is +/// represented as a string prefixed with '@'. The reference is surrounded with +/// ""'s and escaped if it has any special or non-printable characters in it. +static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { + assert(!symbolRef.empty() && "expected valid symbol reference"); + + // If the symbol can be represented as a bare identifier, write it directly. + if (isBareIdentifier(symbolRef)) { + os << '@' << symbolRef; + return; + } + + // Otherwise, output the reference wrapped in quotes with proper escaping. + os << "@\""; + printEscapedString(symbolRef, os); + os << '"'; +} + +// Print out a valid ElementsAttr that is succinct and can represent any +// potential shape/type, for use when eliding a large ElementsAttr. +// +// We choose to use an opaque ElementsAttr literal with conspicuous content to +// hopefully alert readers to the fact that this has been elided. +// +// Unfortunately, neither of the strings of an opaque ElementsAttr literal will +// accept the string "elided". The first string must be a registered dialect +// name and the latter must be a hex constant. +static void printElidedElementsAttr(raw_ostream &os) { + os << R"(opaque<"", "0xDEADBEEF">)"; +} + +void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) { + if (!attr) { + os << "<>"; + return; + } + + // Check for an alias for this attribute. + if (state) { + Twine alias = state->getAliasState().getAttributeAlias(attr); + if (!alias.isTriviallyEmpty()) { + os << '#' << alias; + return; + } + } + + switch (attr.getKind()) { + default: + return printDialectAttribute(attr); + + case StandardAttributes::Opaque: { + auto opaqueAttr = attr.cast(); + printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), + opaqueAttr.getAttrData()); + break; + } + case StandardAttributes::Unit: + os << "unit"; + break; + case StandardAttributes::Bool: + os << (attr.cast().getValue() ? "true" : "false"); + + // BoolAttr always elides the type. + return; + case StandardAttributes::Dictionary: + os << '{'; + interleaveComma(attr.cast().getValue(), + [&](NamedAttribute attr) { + os << attr.first; + + // The value of a UnitAttr is elided within a dictionary. + if (attr.second.isa()) + return; + + os << " = "; + printAttribute(attr.second); + }); + os << '}'; + break; + case StandardAttributes::Integer: { + auto intAttr = attr.cast(); + // Print all integer attributes as signed unless i1. + bool isSigned = intAttr.getType().isIndex() || + intAttr.getType().getIntOrFloatBitWidth() != 1; + intAttr.getValue().print(os, isSigned); + + // IntegerAttr elides the type if I64. + if (mayElideType && intAttr.getType().isInteger(64)) + return; + break; + } + case StandardAttributes::Float: { + auto floatAttr = attr.cast(); + printFloatValue(floatAttr.getValue(), os); + + // FloatAttr elides the type if F64. + if (mayElideType && floatAttr.getType().isF64()) + return; + break; + } + case StandardAttributes::String: + os << '"'; + printEscapedString(attr.cast().getValue(), os); + os << '"'; + break; case StandardAttributes::Array: os << '['; interleaveComma(attr.cast().getValue(), [&](Attribute attr) { @@ -1112,312 +1463,44 @@ os << '?'; else os << dim; - os << 'x'; - } - printType(v.getElementType()); - for (auto map : v.getAffineMaps()) { - os << ", "; - printAttribute(AffineMapAttr::get(map)); - } - // Only print the memory space if it is the non-default one. - if (v.getMemorySpace()) - os << ", " << v.getMemorySpace(); - os << '>'; - return; - } - case StandardTypes::UnrankedMemRef: { - auto v = type.cast(); - os << "memref<*x"; - printType(v.getElementType()); - os << '>'; - return; - } - case StandardTypes::Complex: - os << "complex<"; - printType(type.cast().getElementType()); - os << '>'; - return; - case StandardTypes::Tuple: { - auto tuple = type.cast(); - os << "tuple<"; - interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); }); - os << '>'; - return; - } - case StandardTypes::None: - os << "none"; - return; - } -} - -//===----------------------------------------------------------------------===// -// CustomDialectAsmPrinter -//===----------------------------------------------------------------------===// - -namespace { -/// This class provides the main specialization of the DialectAsmPrinter that is -/// used to provide support for print attributes and types. This hooks allows -/// for dialects to hook into the main ModulePrinter. -struct CustomDialectAsmPrinter : public DialectAsmPrinter { -public: - CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {} - ~CustomDialectAsmPrinter() override {} - - raw_ostream &getStream() const override { return printer.getStream(); } - - /// Print the given attribute to the stream. - void printAttribute(Attribute attr) override { printer.printAttribute(attr); } - - /// Print the given floating point value in a stablized form. - void printFloat(const APFloat &value) override { - printFloatValue(value, getStream()); - } - - /// Print the given type to the stream. - void printType(Type type) override { printer.printType(type); } - - /// The main module printer. - ModulePrinter &printer; -}; -} // end anonymous namespace - -void ModulePrinter::printDialectAttribute(Attribute attr) { - auto &dialect = attr.getDialect(); - - // Ask the dialect to serialize the attribute to a string. - std::string attrName; - { - llvm::raw_string_ostream attrNameStr(attrName); - ModulePrinter subPrinter(attrNameStr, printerFlags, state); - CustomDialectAsmPrinter printer(subPrinter); - dialect.printAttribute(attr, printer); - } - printDialectSymbol(os, "#", dialect.getNamespace(), attrName); -} - -void ModulePrinter::printDialectType(Type type) { - auto &dialect = type.getDialect(); - - // Ask the dialect to serialize the type to a string. - std::string typeName; - { - llvm::raw_string_ostream typeNameStr(typeName); - ModulePrinter subPrinter(typeNameStr, printerFlags, state); - CustomDialectAsmPrinter printer(subPrinter); - dialect.printType(type, printer); - } - printDialectSymbol(os, "!", dialect.getNamespace(), typeName); -} - -//===----------------------------------------------------------------------===// -// Affine expressions and maps -//===----------------------------------------------------------------------===// - -void ModulePrinter::printAffineExpr( - AffineExpr expr, function_ref printValueName) { - printAffineExprInternal(expr, BindingStrength::Weak, printValueName); -} - -void ModulePrinter::printAffineExprInternal( - AffineExpr expr, BindingStrength enclosingTightness, - function_ref printValueName) { - const char *binopSpelling = nullptr; - switch (expr.getKind()) { - case AffineExprKind::SymbolId: { - unsigned pos = expr.cast().getPosition(); - if (printValueName) - printValueName(pos, /*isSymbol=*/true); - else - os << 's' << pos; - return; - } - case AffineExprKind::DimId: { - unsigned pos = expr.cast().getPosition(); - if (printValueName) - printValueName(pos, /*isSymbol=*/false); - else - os << 'd' << pos; - return; - } - case AffineExprKind::Constant: - os << expr.cast().getValue(); - return; - case AffineExprKind::Add: - binopSpelling = " + "; - break; - case AffineExprKind::Mul: - binopSpelling = " * "; - break; - case AffineExprKind::FloorDiv: - binopSpelling = " floordiv "; - break; - case AffineExprKind::CeilDiv: - binopSpelling = " ceildiv "; - break; - case AffineExprKind::Mod: - binopSpelling = " mod "; - break; - } - - auto binOp = expr.cast(); - AffineExpr lhsExpr = binOp.getLHS(); - AffineExpr rhsExpr = binOp.getRHS(); - - // Handle tightly binding binary operators. - if (binOp.getKind() != AffineExprKind::Add) { - if (enclosingTightness == BindingStrength::Strong) - os << '('; - - // Pretty print multiplication with -1. - auto rhsConst = rhsExpr.dyn_cast(); - if (rhsConst && binOp.getKind() == AffineExprKind::Mul && - rhsConst.getValue() == -1) { - os << "-"; - printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); - if (enclosingTightness == BindingStrength::Strong) - os << ')'; - return; - } - - printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); - - os << binopSpelling; - printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); - - if (enclosingTightness == BindingStrength::Strong) - os << ')'; - return; - } - - // Print out special "pretty" forms for add. - if (enclosingTightness == BindingStrength::Strong) - os << '('; - - // Pretty print addition to a product that has a negative operand as a - // subtraction. - if (auto rhs = rhsExpr.dyn_cast()) { - if (rhs.getKind() == AffineExprKind::Mul) { - AffineExpr rrhsExpr = rhs.getRHS(); - if (auto rrhs = rrhsExpr.dyn_cast()) { - if (rrhs.getValue() == -1) { - printAffineExprInternal(lhsExpr, BindingStrength::Weak, - printValueName); - os << " - "; - if (rhs.getLHS().getKind() == AffineExprKind::Add) { - printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, - printValueName); - } else { - printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, - printValueName); - } - - if (enclosingTightness == BindingStrength::Strong) - os << ')'; - return; - } - - if (rrhs.getValue() < -1) { - printAffineExprInternal(lhsExpr, BindingStrength::Weak, - printValueName); - os << " - "; - printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, - printValueName); - os << " * " << -rrhs.getValue(); - if (enclosingTightness == BindingStrength::Strong) - os << ')'; - return; - } - } + os << 'x'; } - } - - // Pretty print addition to a negative number as a subtraction. - if (auto rhsConst = rhsExpr.dyn_cast()) { - if (rhsConst.getValue() < 0) { - printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); - os << " - " << -rhsConst.getValue(); - if (enclosingTightness == BindingStrength::Strong) - os << ')'; - return; + printType(v.getElementType()); + for (auto map : v.getAffineMaps()) { + os << ", "; + printAttribute(AffineMapAttr::get(map)); } + // Only print the memory space if it is the non-default one. + if (v.getMemorySpace()) + os << ", " << v.getMemorySpace(); + os << '>'; + return; } - - printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); - - os << " + "; - printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); - - if (enclosingTightness == BindingStrength::Strong) - os << ')'; -} - -void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) { - printAffineExprInternal(expr, BindingStrength::Weak); - isEq ? os << " == 0" : os << " >= 0"; -} - -void ModulePrinter::printAffineMap(AffineMap map) { - // Dimension identifiers. - os << '('; - for (int i = 0; i < (int)map.getNumDims() - 1; ++i) - os << 'd' << i << ", "; - if (map.getNumDims() >= 1) - os << 'd' << map.getNumDims() - 1; - os << ')'; - - // Symbolic identifiers. - if (map.getNumSymbols() != 0) { - os << '['; - for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) - os << 's' << i << ", "; - if (map.getNumSymbols() >= 1) - os << 's' << map.getNumSymbols() - 1; - os << ']'; + case StandardTypes::UnrankedMemRef: { + auto v = type.cast(); + os << "memref<*x"; + printType(v.getElementType()); + os << '>'; + return; } - - // Result affine expressions. - os << " -> ("; - interleaveComma(map.getResults(), - [&](AffineExpr expr) { printAffineExpr(expr); }); - os << ')'; -} - -void ModulePrinter::printIntegerSet(IntegerSet set) { - // Dimension identifiers. - os << '('; - for (unsigned i = 1; i < set.getNumDims(); ++i) - os << 'd' << i - 1 << ", "; - if (set.getNumDims() >= 1) - os << 'd' << set.getNumDims() - 1; - os << ')'; - - // Symbolic identifiers. - if (set.getNumSymbols() != 0) { - os << '['; - for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) - os << 's' << i << ", "; - if (set.getNumSymbols() >= 1) - os << 's' << set.getNumSymbols() - 1; - os << ']'; + case StandardTypes::Complex: + os << "complex<"; + printType(type.cast().getElementType()); + os << '>'; + return; + case StandardTypes::Tuple: { + auto tuple = type.cast(); + os << "tuple<"; + interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); }); + os << '>'; + return; } - - // Print constraints. - os << " : ("; - int numConstraints = set.getNumConstraints(); - for (int i = 1; i < numConstraints; ++i) { - printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); - os << ", "; + case StandardTypes::None: + os << "none"; + return; } - if (numConstraints >= 1) - printAffineConstraint(set.getConstraint(numConstraints - 1), - set.isEq(numConstraints - 1)); - os << ')'; } -//===----------------------------------------------------------------------===// -// Operation printing -//===----------------------------------------------------------------------===// - void ModulePrinter::printOptionalAttrDict(ArrayRef attrs, ArrayRef elidedAttrs, bool withKeyword) { @@ -1454,484 +1537,370 @@ os << '}'; } -namespace { +//===----------------------------------------------------------------------===// +// CustomDialectAsmPrinter +//===----------------------------------------------------------------------===// -// OperationPrinter contains common functionality for printing operations. -class OperationPrinter : public ModulePrinter, private OpAsmPrinter { +namespace { +/// This class provides the main specialization of the DialectAsmPrinter that is +/// used to provide support for print attributes and types. This hooks allows +/// for dialects to hook into the main ModulePrinter. +struct CustomDialectAsmPrinter : public DialectAsmPrinter { public: - OperationPrinter(Operation *op, ModulePrinter &other); - OperationPrinter(Region *region, ModulePrinter &other); - - // Methods to print operations. - void print(Operation *op); - void print(Block *block, bool printBlockArgs = true, - bool printBlockTerminator = true); - - void printOperation(Operation *op); - void printGenericOp(Operation *op) override; - - // Implement OpAsmPrinter. - raw_ostream &getStream() const override { return os; } - void printType(Type type) override { ModulePrinter::printType(type); } - void printAttribute(Attribute attr) override { - ModulePrinter::printAttribute(attr); - } - void printOperand(Value value) override { printValueID(value); } - - void printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs = {}) override { - ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); - } - void printOptionalAttrDictWithKeyword( - ArrayRef attrs, - ArrayRef elidedAttrs = {}) override { - ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs, - /*withKeyword=*/true); - } - - enum { nameSentinel = ~0U }; - - void printBlockName(Block *block) { - auto id = getBlockID(block); - if (id != ~0U) - os << "^bb" << id; - else - os << "^INVALIDBLOCK"; - } - - unsigned getBlockID(Block *block) { - auto it = blockIDs.find(block); - return it != blockIDs.end() ? it->second : ~0U; - } - - void printSuccessorAndUseList(Operation *term, unsigned index) override; - - /// Print a region. - void printRegion(Region &blocks, bool printEntryBlockArgs, - bool printBlockTerminators) override { - os << " {\n"; - if (!blocks.empty()) { - auto *entryBlock = &blocks.front(); - print(entryBlock, - printEntryBlockArgs && entryBlock->getNumArguments() != 0, - printBlockTerminators); - for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1)) - print(&b); - } - os.indent(currentIndent) << "}"; - } - - /// Renumber the arguments for the specified region to the same names as the - /// SSA values in namesToUse. This may only be used for IsolatedFromAbove - /// operations. If any entry in namesToUse is null, the corresponding - /// argument name is left alone. - void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override; - - void printAffineMapOfSSAIds(AffineMapAttr mapAttr, - ValueRange operands) override { - AffineMap map = mapAttr.getValue(); - unsigned numDims = map.getNumDims(); - auto printValueName = [&](unsigned pos, bool isSymbol) { - unsigned index = isSymbol ? numDims + pos : pos; - assert(index < operands.size()); - if (isSymbol) - os << "symbol("; - printValueID(operands[index]); - if (isSymbol) - os << ')'; - }; - - interleaveComma(map.getResults(), [&](AffineExpr expr) { - printAffineExpr(expr, printValueName); - }); - } - - /// Print the given string as a symbol reference. - void printSymbolName(StringRef symbolRef) override { - ::printSymbolReference(symbolRef, os); - } - - // Number of spaces used for indenting nested operations. - const static unsigned indentWidth = 2; - -protected: - void numberValuesInRegion(Region ®ion); - void numberValuesInBlock(Block &block); - void numberValuesInOp(Operation &op); - void printValueID(Value value, bool printResultNo = true) const { - printValueIDImpl(value, printResultNo, os); - } - -private: - /// Given a result of an operation 'result', find the result group head - /// 'lookupValue' and the result of 'result' within that group in - /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group - /// has more than 1 result. - void getResultIDAndNumber(OpResult result, Value &lookupValue, - int &lookupResultNo) const; - void printValueIDImpl(Value value, bool printResultNo, - raw_ostream &stream) const; - - /// Set a special value name for the given value. - void setValueName(Value value, StringRef name); - - /// Uniques the given value name within the printer. If the given name - /// conflicts, it is automatically renamed. - StringRef uniqueValueName(StringRef name); - - /// This is the value ID for each SSA value. If this returns ~0, then the - /// valueID has an entry in valueNames. - DenseMap valueIDs; - DenseMap valueNames; - - /// This is a map of operations that contain multiple named result groups, - /// i.e. there may be multiple names for the results of the operation. The key - /// of this map are the result numbers that start a result group. - DenseMap> opResultGroups; - - /// This is the block ID for each block in the current. - DenseMap blockIDs; - - /// This keeps track of all of the non-numeric names that are in flight, - /// allowing us to check for duplicates. - /// Note: the value of the map is unused. - llvm::ScopedHashTable usedNames; - llvm::BumpPtrAllocator usedNameAllocator; - - // This is the current indentation level for nested structures. - unsigned currentIndent = 0; - - /// This is the next value ID to assign in numbering. - unsigned nextValueID = 0; - /// This is the next ID to assign to a region entry block argument. - unsigned nextArgumentID = 0; - /// This is the next ID to assign when a name conflict is detected. - unsigned nextConflictID = 0; -}; -} // end anonymous namespace - -OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other) - : ModulePrinter(other) { - llvm::ScopedHashTable::ScopeTy usedNamesScope(usedNames); - numberValuesInOp(*op); - - for (auto ®ion : op->getRegions()) - numberValuesInRegion(region); -} - -OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other) - : ModulePrinter(other) { - numberValuesInRegion(*region); -} - -void OperationPrinter::numberValuesInRegion(Region ®ion) { - // Save the current value ids to allow for numbering values in sibling regions - // the same. - unsigned curValueID = nextValueID; - unsigned curArgumentID = nextArgumentID; - unsigned curConflictID = nextConflictID; - - // Push a new used names scope. - llvm::ScopedHashTable::ScopeTy usedNamesScope(usedNames); - - // Number the values within this region in a breadth-first order. - unsigned nextBlockID = 0; - for (auto &block : region) { - // Each block gets a unique ID, and all of the operations within it get - // numbered as well. - blockIDs[&block] = nextBlockID++; - numberValuesInBlock(block); - } - - // After that we traverse the nested regions. - // TODO: Rework this loop to not use recursion. - for (auto &block : region) { - for (auto &op : block) - for (auto &nestedRegion : op.getRegions()) - numberValuesInRegion(nestedRegion); - } + CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {} + ~CustomDialectAsmPrinter() override {} - // Restore the original value ids. - nextValueID = curValueID; - nextArgumentID = curArgumentID; - nextConflictID = curConflictID; -} + raw_ostream &getStream() const override { return printer.getStream(); } -void OperationPrinter::numberValuesInBlock(Block &block) { - auto setArgNameFn = [&](Value arg, StringRef name) { - assert(!valueIDs.count(arg) && "arg numbered multiple times"); - assert(arg.cast()->getOwner() == &block && - "arg not defined in 'block'"); - setValueName(arg, name); - }; + /// Print the given attribute to the stream. + void printAttribute(Attribute attr) override { printer.printAttribute(attr); } - bool isEntryBlock = block.isEntryBlock(); - if (isEntryBlock && state) { - if (auto *op = block.getParentOp()) { - if (auto dialectAsmInterface = state->getOpAsmInterface(op->getDialect())) - dialectAsmInterface->getAsmBlockArgumentNames(&block, setArgNameFn); - } + /// Print the given floating point value in a stablized form. + void printFloat(const APFloat &value) override { + printFloatValue(value, getStream()); } - // Number the block arguments. We give entry block arguments a special name - // 'arg'. - SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); - llvm::raw_svector_ostream specialName(specialNameBuffer); - for (auto arg : block.getArguments()) { - if (valueIDs.count(arg)) - continue; - if (isEntryBlock) { - specialNameBuffer.resize(strlen("arg")); - specialName << nextArgumentID++; - } - setValueName(arg, specialName.str()); - } + /// Print the given type to the stream. + void printType(Type type) override { printer.printType(type); } - // Number the operations in this block. - for (auto &op : block) - numberValuesInOp(op); -} + /// The main module printer. + ModulePrinter &printer; +}; +} // end anonymous namespace -void OperationPrinter::numberValuesInOp(Operation &op) { - unsigned numResults = op.getNumResults(); - if (numResults == 0) - return; - Value resultBegin = op.getResult(0); +void ModulePrinter::printDialectAttribute(Attribute attr) { + auto &dialect = attr.getDialect(); - // Function used to set the special result names for the operation. - SmallVector resultGroups(/*Size=*/1, /*Value=*/0); - auto setResultNameFn = [&](Value result, StringRef name) { - assert(!valueIDs.count(result) && "result numbered multiple times"); - assert(result->getDefiningOp() == &op && "result not defined by 'op'"); - setValueName(result, name); + // Ask the dialect to serialize the attribute to a string. + std::string attrName; + { + llvm::raw_string_ostream attrNameStr(attrName); + ModulePrinter subPrinter(attrNameStr, printerFlags, state); + CustomDialectAsmPrinter printer(subPrinter); + dialect.printAttribute(attr, printer); + } + printDialectSymbol(os, "#", dialect.getNamespace(), attrName); +} - // Record the result number for groups not anchored at 0. - if (int resultNo = result.cast()->getResultNumber()) - resultGroups.push_back(resultNo); - }; +void ModulePrinter::printDialectType(Type type) { + auto &dialect = type.getDialect(); - if (OpAsmOpInterface asmInterface = dyn_cast(&op)) { - asmInterface.getAsmResultNames(setResultNameFn); - } else if (auto *dialectAsmInterface = - state ? state->getOpAsmInterface(op.getDialect()) : nullptr) { - dialectAsmInterface->getAsmResultNames(&op, setResultNameFn); + // Ask the dialect to serialize the type to a string. + std::string typeName; + { + llvm::raw_string_ostream typeNameStr(typeName); + ModulePrinter subPrinter(typeNameStr, printerFlags, state); + CustomDialectAsmPrinter printer(subPrinter); + dialect.printType(type, printer); } + printDialectSymbol(os, "!", dialect.getNamespace(), typeName); +} - // If the first result wasn't numbered, give it a default number. - if (valueIDs.try_emplace(resultBegin, nextValueID).second) - ++nextValueID; +//===----------------------------------------------------------------------===// +// Affine expressions and maps +//===----------------------------------------------------------------------===// - // If this operation has multiple result groups, mark it. - if (resultGroups.size() != 1) { - llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); - opResultGroups.try_emplace(&op, std::move(resultGroups)); - } +void ModulePrinter::printAffineExpr( + AffineExpr expr, function_ref printValueName) { + printAffineExprInternal(expr, BindingStrength::Weak, printValueName); } -/// Set a special value name for the given value. -void OperationPrinter::setValueName(Value value, StringRef name) { - // If the name is empty, the value uses the default numbering. - if (name.empty()) { - valueIDs[value] = nextValueID++; +void ModulePrinter::printAffineExprInternal( + AffineExpr expr, BindingStrength enclosingTightness, + function_ref printValueName) { + const char *binopSpelling = nullptr; + switch (expr.getKind()) { + case AffineExprKind::SymbolId: { + unsigned pos = expr.cast().getPosition(); + if (printValueName) + printValueName(pos, /*isSymbol=*/true); + else + os << 's' << pos; return; } + case AffineExprKind::DimId: { + unsigned pos = expr.cast().getPosition(); + if (printValueName) + printValueName(pos, /*isSymbol=*/false); + else + os << 'd' << pos; + return; + } + case AffineExprKind::Constant: + os << expr.cast().getValue(); + return; + case AffineExprKind::Add: + binopSpelling = " + "; + break; + case AffineExprKind::Mul: + binopSpelling = " * "; + break; + case AffineExprKind::FloorDiv: + binopSpelling = " floordiv "; + break; + case AffineExprKind::CeilDiv: + binopSpelling = " ceildiv "; + break; + case AffineExprKind::Mod: + binopSpelling = " mod "; + break; + } - valueIDs[value] = nameSentinel; - valueNames[value] = uniqueValueName(name); -} + auto binOp = expr.cast(); + AffineExpr lhsExpr = binOp.getLHS(); + AffineExpr rhsExpr = binOp.getRHS(); -/// Uniques the given value name within the printer. If the given name -/// conflicts, it is automatically renamed. -StringRef OperationPrinter::uniqueValueName(StringRef name) { - // Check to see if this name is already unique. - if (!usedNames.count(name)) { - name = name.copy(usedNameAllocator); - } else { - // Otherwise, we had a conflict - probe until we find a unique name. This - // is guaranteed to terminate (and usually in a single iteration) because it - // generates new names by incrementing nextConflictID. - SmallString<64> probeName(name); - probeName.push_back('_'); - while (true) { - probeName.resize(name.size() + 1); - probeName += llvm::utostr(nextConflictID++); - if (!usedNames.count(probeName)) { - name = StringRef(probeName).copy(usedNameAllocator); - break; - } + // Handle tightly binding binary operators. + if (binOp.getKind() != AffineExprKind::Add) { + if (enclosingTightness == BindingStrength::Strong) + os << '('; + + // Pretty print multiplication with -1. + auto rhsConst = rhsExpr.dyn_cast(); + if (rhsConst && binOp.getKind() == AffineExprKind::Mul && + rhsConst.getValue() == -1) { + os << "-"; + printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); + if (enclosingTightness == BindingStrength::Strong) + os << ')'; + return; } - } - usedNames.insert(name, char()); - return name; -} + printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); -void OperationPrinter::print(Block *block, bool printBlockArgs, - bool printBlockTerminator) { - // Print the block label and argument list if requested. - if (printBlockArgs) { - os.indent(currentIndent); - printBlockName(block); + os << binopSpelling; + printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); - // Print the argument list if non-empty. - if (!block->args_empty()) { - os << '('; - interleaveComma(block->getArguments(), [&](BlockArgument arg) { - printValueID(arg); - os << ": "; - printType(arg->getType()); - }); + if (enclosingTightness == BindingStrength::Strong) os << ')'; - } - os << ':'; + return; + } - // Print out some context information about the predecessors of this block. - if (!block->getParent()) { - os << "\t// block is not in a region!"; - } else if (block->hasNoPredecessors()) { - os << "\t// no predecessors"; - } else if (auto *pred = block->getSinglePredecessor()) { - os << "\t// pred: "; - printBlockName(pred); - } else { - // We want to print the predecessors in increasing numeric order, not in - // whatever order the use-list is in, so gather and sort them. - SmallVector, 4> predIDs; - for (auto *pred : block->getPredecessors()) - predIDs.push_back({getBlockID(pred), pred}); - llvm::array_pod_sort(predIDs.begin(), predIDs.end()); + // Print out special "pretty" forms for add. + if (enclosingTightness == BindingStrength::Strong) + os << '('; - os << "\t// " << predIDs.size() << " preds: "; + // Pretty print addition to a product that has a negative operand as a + // subtraction. + if (auto rhs = rhsExpr.dyn_cast()) { + if (rhs.getKind() == AffineExprKind::Mul) { + AffineExpr rrhsExpr = rhs.getRHS(); + if (auto rrhs = rrhsExpr.dyn_cast()) { + if (rrhs.getValue() == -1) { + printAffineExprInternal(lhsExpr, BindingStrength::Weak, + printValueName); + os << " - "; + if (rhs.getLHS().getKind() == AffineExprKind::Add) { + printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, + printValueName); + } else { + printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, + printValueName); + } - interleaveComma(predIDs, [&](std::pair pred) { - printBlockName(pred.second); - }); + if (enclosingTightness == BindingStrength::Strong) + os << ')'; + return; + } + + if (rrhs.getValue() < -1) { + printAffineExprInternal(lhsExpr, BindingStrength::Weak, + printValueName); + os << " - "; + printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, + printValueName); + os << " * " << -rrhs.getValue(); + if (enclosingTightness == BindingStrength::Strong) + os << ')'; + return; + } + } } - os << '\n'; } - currentIndent += indentWidth; - auto range = llvm::make_range( - block->getOperations().begin(), - std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1)); - for (auto &op : range) { - print(&op); - os << '\n'; + // Pretty print addition to a negative number as a subtraction. + if (auto rhsConst = rhsExpr.dyn_cast()) { + if (rhsConst.getValue() < 0) { + printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); + os << " - " << -rhsConst.getValue(); + if (enclosingTightness == BindingStrength::Strong) + os << ')'; + return; + } } - currentIndent -= indentWidth; + + printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); + + os << " + "; + printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); + + if (enclosingTightness == BindingStrength::Strong) + os << ')'; } -void OperationPrinter::print(Operation *op) { - os.indent(currentIndent); - printOperation(op); - printTrailingLocation(op->getLoc()); +void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) { + printAffineExprInternal(expr, BindingStrength::Weak); + isEq ? os << " == 0" : os << " >= 0"; } -void OperationPrinter::getResultIDAndNumber(OpResult result, Value &lookupValue, - int &lookupResultNo) const { - Operation *owner = result->getOwner(); - if (owner->getNumResults() == 1) - return; - int resultNo = result->getResultNumber(); +void ModulePrinter::printAffineMap(AffineMap map) { + // Dimension identifiers. + os << '('; + for (int i = 0; i < (int)map.getNumDims() - 1; ++i) + os << 'd' << i << ", "; + if (map.getNumDims() >= 1) + os << 'd' << map.getNumDims() - 1; + os << ')'; - // If this operation has multiple result groups, we will need to find the - // one corresponding to this result. - auto resultGroupIt = opResultGroups.find(owner); - if (resultGroupIt == opResultGroups.end()) { - // If not, just use the first result. - lookupResultNo = resultNo; - lookupValue = owner->getResult(0); - return; + // Symbolic identifiers. + if (map.getNumSymbols() != 0) { + os << '['; + for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) + os << 's' << i << ", "; + if (map.getNumSymbols() >= 1) + os << 's' << map.getNumSymbols() - 1; + os << ']'; } - // Find the correct index using a binary search, as the groups are ordered. - ArrayRef resultGroups = resultGroupIt->second; - auto it = llvm::upper_bound(resultGroups, resultNo); - int groupResultNo = 0, groupSize = 0; + // Result affine expressions. + os << " -> ("; + interleaveComma(map.getResults(), + [&](AffineExpr expr) { printAffineExpr(expr); }); + os << ')'; +} - // If there are no smaller elements, the last result group is the lookup. - if (it == resultGroups.end()) { - groupResultNo = resultGroups.back(); - groupSize = static_cast(owner->getNumResults()) - resultGroups.back(); - } else { - // Otherwise, the previous element is the lookup. - groupResultNo = *std::prev(it); - groupSize = *it - groupResultNo; +void ModulePrinter::printIntegerSet(IntegerSet set) { + // Dimension identifiers. + os << '('; + for (unsigned i = 1; i < set.getNumDims(); ++i) + os << 'd' << i - 1 << ", "; + if (set.getNumDims() >= 1) + os << 'd' << set.getNumDims() - 1; + os << ')'; + + // Symbolic identifiers. + if (set.getNumSymbols() != 0) { + os << '['; + for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) + os << 's' << i << ", "; + if (set.getNumSymbols() >= 1) + os << 's' << set.getNumSymbols() - 1; + os << ']'; } - // We only record the result number for a group of size greater than 1. - if (groupSize != 1) - lookupResultNo = resultNo - groupResultNo; - lookupValue = owner->getResult(groupResultNo); + // Print constraints. + os << " : ("; + int numConstraints = set.getNumConstraints(); + for (int i = 1; i < numConstraints; ++i) { + printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); + os << ", "; + } + if (numConstraints >= 1) + printAffineConstraint(set.getConstraint(numConstraints - 1), + set.isEq(numConstraints - 1)); + os << ')'; } -void OperationPrinter::printValueIDImpl(Value value, bool printResultNo, - raw_ostream &stream) const { - if (!value) { - stream << "<>"; - return; +//===----------------------------------------------------------------------===// +// OperationPrinter +//===----------------------------------------------------------------------===// + +namespace { +/// This class contains the logic for printing operations, regions, and blocks. +class OperationPrinter : public ModulePrinter, private OpAsmPrinter { +public: + explicit OperationPrinter(ModulePrinter &other) : ModulePrinter(other) { + assert(state && "expected valid state when printing operation"); } - int resultNo = -1; - auto lookupValue = value; + /// Print the given operation with its indent and location. + void print(Operation *op); + /// Print the bare location, not including indentation/location/etc. + void printOperation(Operation *op); + /// Print the given operation in the generic form. + void printGenericOp(Operation *op) override; - // If this is a reference to the result of a multi-result operation or - // operation, print out the # identifier and make sure to map our lookup - // to the first result of the operation. - if (OpResult result = value.dyn_cast()) - getResultIDAndNumber(result, lookupValue, resultNo); + /// Print the name of the given block. + void printBlockName(Block *block); - auto it = valueIDs.find(lookupValue); - if (it == valueIDs.end()) { - stream << "<>"; - return; + /// Print the given block. If 'printBlockArgs' is false, the arguments of the + /// block are not printed. If 'printBlockTerminator' is false, the terminator + /// operation of the block is not printed. + void print(Block *block, bool printBlockArgs = true, + bool printBlockTerminator = true); + + /// Print the ID of the given value, optionally with its result number. + void printValueID(Value value, bool printResultNo = true) const; + + //===--------------------------------------------------------------------===// + // OpAsmPrinter methods + //===--------------------------------------------------------------------===// + + /// Return the current stream of the printer. + raw_ostream &getStream() const override { return os; } + + /// Print the given type. + void printType(Type type) override { ModulePrinter::printType(type); } + + /// Print the given attribute. + void printAttribute(Attribute attr) override { + ModulePrinter::printAttribute(attr); } - stream << '%'; - if (it->second != (unsigned)nameSentinel) { - stream << it->second; - } else { - auto nameIt = valueNames.find(lookupValue); - assert(nameIt != valueNames.end() && "Didn't have a name entry?"); - stream << nameIt->second; + /// Print the ID for the given value. + void printOperand(Value value) override { printValueID(value); } + + /// Print an optional attribute dictionary with a given set of elided values. + void printOptionalAttrDict(ArrayRef attrs, + ArrayRef elidedAttrs = {}) override { + ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); + } + void printOptionalAttrDictWithKeyword( + ArrayRef attrs, + ArrayRef elidedAttrs = {}) override { + ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs, + /*withKeyword=*/true); } - if (resultNo != -1 && printResultNo) - stream << '#' << resultNo; -} + /// Print an operation successor with the operands used for the block + /// arguments. + void printSuccessorAndUseList(Operation *term, unsigned index) override; -/// Renumber the arguments for the specified region to the same names as the -/// SSA values in namesToUse. This may only be used for IsolatedFromAbove -/// operations. If any entry in namesToUse is null, the corresponding -/// argument name is left alone. -void OperationPrinter::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { - assert(!region.empty() && "cannot shadow arguments of an empty region"); - assert(region.front().getNumArguments() == namesToUse.size() && - "incorrect number of names passed in"); - assert(region.getParentOp()->isKnownIsolatedFromAbove() && - "only KnownIsolatedFromAbove ops can shadow names"); + /// Print the given region. + void printRegion(Region ®ion, bool printEntryBlockArgs, + bool printBlockTerminators) override; - SmallVector nameStr; - for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { - auto nameToUse = namesToUse[i]; - if (nameToUse == nullptr) - continue; + /// Renumber the arguments for the specified region to the same names as the + /// SSA values in namesToUse. This may only be used for IsolatedFromAbove + /// operations. If any entry in namesToUse is null, the corresponding + /// argument name is left alone. + void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { + state->getSSANameState().shadowRegionArgs(region, namesToUse); + } - auto nameToReplace = region.front().getArgument(i); + /// Print the given affine map with the smybol and dimension operands printed + /// inline with the map. + void printAffineMapOfSSAIds(AffineMapAttr mapAttr, + ValueRange operands) override; - nameStr.clear(); - llvm::raw_svector_ostream nameStream(nameStr); - printValueIDImpl(nameToUse, /*printResultNo=*/true, nameStream); + /// Print the given string as a symbol reference. + void printSymbolName(StringRef symbolRef) override { + ::printSymbolReference(symbolRef, os); + } - // Entry block arguments should already have a pretty "arg" name. - assert(valueIDs[nameToReplace] == (unsigned)nameSentinel); +private: + /// The number of spaces used for indenting nested operations. + const static unsigned indentWidth = 2; - // Use the name without the leading %. - auto name = StringRef(nameStream.str()).drop_front(); + // This is the current indentation level for nested structures. + unsigned currentIndent = 0; +}; +} // end anonymous namespace - // Overwrite the name. - valueNames[nameToReplace] = name.copy(usedNameAllocator); - } +void OperationPrinter::print(Operation *op) { + os.indent(currentIndent); + printOperation(op); + printTrailingLocation(op->getLoc()); } void OperationPrinter::printOperation(Operation *op) { @@ -1943,9 +1912,8 @@ }; // Check to see if this operation has multiple result groups. - auto resultGroupIt = opResultGroups.find(op); - if (resultGroupIt != opResultGroups.end()) { - ArrayRef resultGroups = resultGroupIt->second; + ArrayRef resultGroups = state->getSSANameState().getOpResultGroups(op); + if (!resultGroups.empty()) { // Interleave the groups excluding the last one, this one will be handled // separately. interleaveComma(llvm::seq(0, resultGroups.size() - 1), [&](int i) { @@ -1989,21 +1957,16 @@ for (unsigned i = 0; i < numSuccessors; ++i) totalNumSuccessorOperands += op->getNumSuccessorOperands(i); unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands; - SmallVector properOperands( - op->operand_begin(), std::next(op->operand_begin(), numProperOperands)); - - interleaveComma(properOperands, [&](Value value) { printValueID(value); }); + interleaveComma(op->getOperands().take_front(numProperOperands), + [&](Value value) { printValueID(value); }); os << ')'; // For terminators, print the list of successors and their operands. if (numSuccessors != 0) { os << '['; - for (unsigned i = 0; i < numSuccessors; ++i) { - if (i != 0) - os << ", "; - printSuccessorAndUseList(op, i); - } + interleaveComma(llvm::seq(0, numSuccessors), + [&](unsigned i) { printSuccessorAndUseList(op, i); }); os << ']'; } @@ -2025,6 +1988,73 @@ printFunctionalType(op); } +void OperationPrinter::printBlockName(Block *block) { + auto id = state->getSSANameState().getBlockID(block); + if (id != SSANameState::NameSentinel) + os << "^bb" << id; + else + os << "^INVALIDBLOCK"; +} + +void OperationPrinter::print(Block *block, bool printBlockArgs, + bool printBlockTerminator) { + // Print the block label and argument list if requested. + if (printBlockArgs) { + os.indent(currentIndent); + printBlockName(block); + + // Print the argument list if non-empty. + if (!block->args_empty()) { + os << '('; + interleaveComma(block->getArguments(), [&](BlockArgument arg) { + printValueID(arg); + os << ": "; + printType(arg->getType()); + }); + os << ')'; + } + os << ':'; + + // Print out some context information about the predecessors of this block. + if (!block->getParent()) { + os << "\t// block is not in a region!"; + } else if (block->hasNoPredecessors()) { + os << "\t// no predecessors"; + } else if (auto *pred = block->getSinglePredecessor()) { + os << "\t// pred: "; + printBlockName(pred); + } else { + // We want to print the predecessors in increasing numeric order, not in + // whatever order the use-list is in, so gather and sort them. + SmallVector, 4> predIDs; + for (auto *pred : block->getPredecessors()) + predIDs.push_back({state->getSSANameState().getBlockID(pred), pred}); + llvm::array_pod_sort(predIDs.begin(), predIDs.end()); + + os << "\t// " << predIDs.size() << " preds: "; + + interleaveComma(predIDs, [&](std::pair pred) { + printBlockName(pred.second); + }); + } + os << '\n'; + } + + currentIndent += indentWidth; + auto range = llvm::make_range( + block->getOperations().begin(), + std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1)); + for (auto &op : range) { + print(&op); + os << '\n'; + } + currentIndent -= indentWidth; +} + +void OperationPrinter::printValueID(Value value, bool printResultNo) const { + state->getSSANameState().printValueID(value, printResultNo, os); +} + void OperationPrinter::printSuccessorAndUseList(Operation *term, unsigned index) { printBlockName(term->getSuccessor(index)); @@ -2042,15 +2072,47 @@ os << ')'; } +void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, + bool printBlockTerminators) { + os << " {\n"; + if (!region.empty()) { + auto *entryBlock = ®ion.front(); + print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0, + printBlockTerminators); + for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) + print(&b); + } + os.indent(currentIndent) << "}"; +} + +void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, + ValueRange operands) { + AffineMap map = mapAttr.getValue(); + unsigned numDims = map.getNumDims(); + auto printValueName = [&](unsigned pos, bool isSymbol) { + unsigned index = isSymbol ? numDims + pos : pos; + assert(index < operands.size()); + if (isSymbol) + os << "symbol("; + printValueID(operands[index]); + if (isSymbol) + os << ')'; + }; + + interleaveComma(map.getResults(), [&](AffineExpr expr) { + printAffineExpr(expr, printValueName); + }); +} + void ModulePrinter::print(ModuleOp module) { + assert(state && "expected valid state when printing an operation"); + // Output the aliases at the top level. - if (state) { - state->getAliasState().printAttributeAliases(os); - state->getAliasState().printTypeAliases(os); - } + state->getAliasState().printAttributeAliases(os); + state->getAliasState().printTypeAliases(os); // Print the module. - OperationPrinter(module, *this).print(module); + OperationPrinter(*this).print(module); os << '\n'; } @@ -2122,25 +2184,24 @@ void Operation::print(raw_ostream &os, OpPrintingFlags flags) { // Handle top-level operations or local printing. if (!getParent() || flags.shouldUseLocalScope()) { - ModuleState state(getContext()); + ModuleState state(this); ModulePrinter modulePrinter(os, flags, &state); - OperationPrinter(this, modulePrinter).print(this); + OperationPrinter(modulePrinter).print(this); return; } - auto region = getParentRegion(); - if (!region) { - os << "<>\n"; + Operation *parentOp = getParentOp(); + if (!parentOp) { + os << "<>\n"; return; } + // Get the top-level op. + while (auto *nextOp = parentOp->getParentOp()) + parentOp = nextOp; - // Get the top-level region. - while (auto *nextRegion = region->getParentRegion()) - region = nextRegion; - - ModuleState state(getContext()); + ModuleState state(parentOp); ModulePrinter modulePrinter(os, flags, &state); - OperationPrinter(region, modulePrinter).print(this); + OperationPrinter(modulePrinter).print(this); } void Operation::dump() { @@ -2149,41 +2210,41 @@ } void Block::print(raw_ostream &os) { - auto region = getParent(); - if (!region) { + Operation *parentOp = getParentOp(); + if (!parentOp) { os << "<>\n"; return; } + // Get the top-level op. + while (auto *nextOp = parentOp->getParentOp()) + parentOp = nextOp; - // Get the top-level region. - while (auto *nextRegion = region->getParentRegion()) - region = nextRegion; - - ModuleState state(region->getContext()); + ModuleState state(parentOp); ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state); - OperationPrinter(region, modulePrinter).print(this); + OperationPrinter(modulePrinter).print(this); } void Block::dump() { print(llvm::errs()); } /// Print out the name of the block without printing its body. void Block::printAsOperand(raw_ostream &os, bool printType) { - auto region = getParent(); - if (!region) { + Operation *parentOp = getParentOp(); + if (!parentOp) { os << "<>\n"; return; } + // Get the top-level op. + while (auto *nextOp = parentOp->getParentOp()) + parentOp = nextOp; - // Get the top-level region. - while (auto *nextRegion = region->getParentRegion()) - region = nextRegion; - - ModulePrinter modulePrinter(os); - OperationPrinter(region, modulePrinter).printBlockName(this); + ModuleState state(parentOp); + ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state); + OperationPrinter(modulePrinter).printBlockName(this); } void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) { - ModuleState state(getContext()); + ModuleState state(*this); + // Don't populate aliases when printing at local scope. if (!flags.shouldUseLocalScope()) state.initializeAliases(*this);