diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -411,160 +411,511 @@ } //===----------------------------------------------------------------------===// -// ModuleState +// SSANameState //===----------------------------------------------------------------------===// namespace { -class ModuleState { +/// This class manages the state of SSA value names. +class SSANameState { public: - explicit ModuleState(MLIRContext *context) : interfaces(context) {} + /// A sentinal value used for values without names set. + enum : unsigned { NameSentinel = ~0U }; - /// Initialize the alias state to enable the printing of aliases. - void initializeAliases(Operation *op); + SSANameState(Operation *op, + DialectInterfaceCollection &interfaces); - /// Get an instance of the OpAsmDialectInterface for the given dialect, or - /// null if one wasn't registered. - const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) { - return interfaces.getInterfaceFor(dialect); - } + /// Print the ID for the given value to 'stream'. + void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; - /// Get the state used for aliases. - AliasState &getAliasState() { return aliasState; } + /// Return the result groups registered by this operation, or empty if none + /// exist. + ArrayRef getOpResultGroups(Operation *op); + + /// Get the ID for the given block. + unsigned getBlockID(Block *block); + + /// Renumber the arguments for the specified region to the same names as the + /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for + /// details. + void shadowRegionArgs(Region ®ion, ValueRange namesToUse); private: - /// Collection of OpAsm interfaces implemented in the context. - DialectInterfaceCollection interfaces; + /// Number the SSA values within the given IR unit. + void numberValuesInRegion( + Region ®ion, + DialectInterfaceCollection &interfaces); + void numberValuesInBlock( + Block &block, + DialectInterfaceCollection &interfaces); + void numberValuesInOp( + Operation &op, + DialectInterfaceCollection &interfaces); - /// The state used for attribute and type aliases. - AliasState aliasState; + /// Given a result of an operation 'result', find the result group head + /// 'lookupValue' and the result of 'result' within that group in + /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group + /// has more than 1 result. + void getResultIDAndNumber(OpResult result, Value &lookupValue, + int &lookupResultNo) const; + + /// Set a special value name for the given value. + void setValueName(Value value, StringRef name); + + /// Uniques the given value name within the printer. If the given name + /// conflicts, it is automatically renamed. + StringRef uniqueValueName(StringRef name); + + /// This is the value ID for each SSA value. If this returns ~0, then the + /// valueID has an entry in valueNames. + DenseMap valueIDs; + DenseMap valueNames; + + /// This is a map of operations that contain multiple named result groups, + /// i.e. there may be multiple names for the results of the operation. The key + /// of this map are the result numbers that start a result group. + DenseMap> opResultGroups; + + /// This is the block ID for each block in the current. + DenseMap blockIDs; + + /// This keeps track of all of the non-numeric names that are in flight, + /// allowing us to check for duplicates. + /// Note: the value of the map is unused. + llvm::ScopedHashTable usedNames; + llvm::BumpPtrAllocator usedNameAllocator; + + /// This is the next value ID to assign in numbering. + unsigned nextValueID = 0; + /// This is the next ID to assign to a region entry block argument. + unsigned nextArgumentID = 0; + /// This is the next ID to assign when a name conflict is detected. + unsigned nextConflictID = 0; }; } // end anonymous namespace -/// Initialize the alias state to enable the printing of aliases. -void ModuleState::initializeAliases(Operation *op) { - aliasState.initialize(op, interfaces); +SSANameState::SSANameState( + Operation *op, + DialectInterfaceCollection &interfaces) { + llvm::ScopedHashTable::ScopeTy usedNamesScope(usedNames); + numberValuesInOp(*op, interfaces); + + for (auto ®ion : op->getRegions()) + numberValuesInRegion(region, interfaces); } -//===----------------------------------------------------------------------===// -// ModulePrinter -//===----------------------------------------------------------------------===// +void SSANameState::printValueID(Value value, bool printResultNo, + raw_ostream &stream) const { + if (!value) { + stream << "<>"; + return; + } -namespace { -class ModulePrinter { -public: - ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None, - ModuleState *state = nullptr) - : os(os), printerFlags(flags), state(state) {} - explicit ModulePrinter(ModulePrinter &printer) - : os(printer.os), printerFlags(printer.printerFlags), - state(printer.state) {} + int resultNo = -1; + auto lookupValue = value; - /// Returns the output stream of the printer. - raw_ostream &getStream() { return os; } + // If this is a reference to the result of a multi-result operation or + // operation, print out the # identifier and make sure to map our lookup + // to the first result of the operation. + if (OpResult result = value.dyn_cast()) + getResultIDAndNumber(result, lookupValue, resultNo); - template - inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { - mlir::interleaveComma(c, os, each_fn); + auto it = valueIDs.find(lookupValue); + if (it == valueIDs.end()) { + stream << "<>"; + return; } - void print(ModuleOp module); + stream << '%'; + if (it->second != NameSentinel) { + stream << it->second; + } else { + auto nameIt = valueNames.find(lookupValue); + assert(nameIt != valueNames.end() && "Didn't have a name entry?"); + stream << nameIt->second; + } - /// Print the given attribute. If 'mayElideType' is true, some attributes are - /// printed without the type when the type matches the default used in the - /// parser (for example i64 is the default for integer attributes). - void printAttribute(Attribute attr, bool mayElideType = false); + if (resultNo != -1 && printResultNo) + stream << '#' << resultNo; +} - void printType(Type type); - void printLocation(LocationAttr loc); +ArrayRef SSANameState::getOpResultGroups(Operation *op) { + auto it = opResultGroups.find(op); + return it == opResultGroups.end() ? ArrayRef() : it->second; +} - void printAffineMap(AffineMap map); - void - printAffineExpr(AffineExpr expr, - function_ref printValueName = nullptr); - void printAffineConstraint(AffineExpr expr, bool isEq); - void printIntegerSet(IntegerSet set); +unsigned SSANameState::getBlockID(Block *block) { + auto it = blockIDs.find(block); + return it != blockIDs.end() ? it->second : NameSentinel; +} -protected: - void printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs = {}, - bool withKeyword = false); - void printTrailingLocation(Location loc); - void printLocationInternal(LocationAttr loc, bool pretty = false); - void printDenseElementsAttr(DenseElementsAttr attr); +void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { + assert(!region.empty() && "cannot shadow arguments of an empty region"); + assert(region.front().getNumArguments() == namesToUse.size() && + "incorrect number of names passed in"); + assert(region.getParentOp()->isKnownIsolatedFromAbove() && + "only KnownIsolatedFromAbove ops can shadow names"); - void printDialectAttribute(Attribute attr); - void printDialectType(Type type); + SmallVector nameStr; + for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { + auto nameToUse = namesToUse[i]; + if (nameToUse == nullptr) + continue; + auto nameToReplace = region.front().getArgument(i); - /// This enum is used to represent the binding strength of the enclosing - /// context that an AffineExprStorage is being printed in, so we can - /// intelligently produce parens. - enum class BindingStrength { - Weak, // + and - - Strong, // All other binary operators. - }; - void printAffineExprInternal( - AffineExpr expr, BindingStrength enclosingTightness, - function_ref printValueName = nullptr); + nameStr.clear(); + llvm::raw_svector_ostream nameStream(nameStr); + printValueID(nameToUse, /*printResultNo=*/true, nameStream); - /// The output stream for the printer. - raw_ostream &os; + // Entry block arguments should already have a pretty "arg" name. + assert(valueIDs[nameToReplace] == NameSentinel); - /// A set of flags to control the printer's behavior. - OpPrintingFlags printerFlags; + // Use the name without the leading %. + auto name = StringRef(nameStream.str()).drop_front(); - /// An optional printer state for the module. - ModuleState *state; -}; -} // end anonymous namespace + // Overwrite the name. + valueNames[nameToReplace] = name.copy(usedNameAllocator); + } +} -void ModulePrinter::printTrailingLocation(Location loc) { - // Check to see if we are printing debug information. - if (!printerFlags.shouldPrintDebugInfo()) - return; +void SSANameState::numberValuesInRegion( + Region ®ion, + DialectInterfaceCollection &interfaces) { + // Save the current value ids to allow for numbering values in sibling regions + // the same. + unsigned curValueID = nextValueID; + unsigned curArgumentID = nextArgumentID; + unsigned curConflictID = nextConflictID; - os << " "; - printLocation(loc); + // Push a new used names scope. + llvm::ScopedHashTable::ScopeTy usedNamesScope(usedNames); + + // Number the values within this region in a breadth-first order. + unsigned nextBlockID = 0; + for (auto &block : region) { + // Each block gets a unique ID, and all of the operations within it get + // numbered as well. + blockIDs[&block] = nextBlockID++; + numberValuesInBlock(block, interfaces); + } + + // After that we traverse the nested regions. + // TODO: Rework this loop to not use recursion. + for (auto &block : region) { + for (auto &op : block) + for (auto &nestedRegion : op.getRegions()) + numberValuesInRegion(nestedRegion, interfaces); + } + + // Restore the original value ids. + nextValueID = curValueID; + nextArgumentID = curArgumentID; + nextConflictID = curConflictID; } -void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) { - switch (loc.getKind()) { - case StandardAttributes::OpaqueLocation: - printLocationInternal(loc.cast().getFallbackLocation(), pretty); - break; - case StandardAttributes::UnknownLocation: - if (pretty) - os << "[unknown]"; - else - os << "unknown"; - break; - case StandardAttributes::FileLineColLocation: { - auto fileLoc = loc.cast(); - auto mayQuote = pretty ? "" : "\""; - os << mayQuote << fileLoc.getFilename() << mayQuote << ':' - << fileLoc.getLine() << ':' << fileLoc.getColumn(); - break; +void SSANameState::numberValuesInBlock( + Block &block, + DialectInterfaceCollection &interfaces) { + auto setArgNameFn = [&](Value arg, StringRef name) { + assert(!valueIDs.count(arg) && "arg numbered multiple times"); + assert(arg.cast()->getOwner() == &block && + "arg not defined in 'block'"); + setValueName(arg, name); + }; + + bool isEntryBlock = block.isEntryBlock(); + if (isEntryBlock) { + if (auto *op = block.getParentOp()) { + if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect())) + asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn); + } } - case StandardAttributes::NameLocation: { - auto nameLoc = loc.cast(); - os << '\"' << nameLoc.getName() << '\"'; - // Print the child if it isn't unknown. - auto childLoc = nameLoc.getChildLoc(); - if (!childLoc.isa()) { - os << '('; - printLocationInternal(childLoc, pretty); - os << ')'; + // Number the block arguments. We give entry block arguments a special name + // 'arg'. + SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); + llvm::raw_svector_ostream specialName(specialNameBuffer); + for (auto arg : block.getArguments()) { + if (valueIDs.count(arg)) + continue; + if (isEntryBlock) { + specialNameBuffer.resize(strlen("arg")); + specialName << nextArgumentID++; } - break; + setValueName(arg, specialName.str()); } - case StandardAttributes::CallSiteLocation: { - auto callLocation = loc.cast(); - auto caller = callLocation.getCaller(); - auto callee = callLocation.getCallee(); - if (!pretty) - os << "callsite("; - printLocationInternal(callee, pretty); - if (pretty) { + + // Number the operations in this block. + for (auto &op : block) + numberValuesInOp(op, interfaces); +} + +void SSANameState::numberValuesInOp( + Operation &op, + DialectInterfaceCollection &interfaces) { + unsigned numResults = op.getNumResults(); + if (numResults == 0) + return; + Value resultBegin = op.getResult(0); + + // Function used to set the special result names for the operation. + SmallVector resultGroups(/*Size=*/1, /*Value=*/0); + auto setResultNameFn = [&](Value result, StringRef name) { + assert(!valueIDs.count(result) && "result numbered multiple times"); + assert(result->getDefiningOp() == &op && "result not defined by 'op'"); + setValueName(result, name); + + // Record the result number for groups not anchored at 0. + if (int resultNo = result.cast()->getResultNumber()) + resultGroups.push_back(resultNo); + }; + if (OpAsmOpInterface asmInterface = dyn_cast(&op)) + asmInterface.getAsmResultNames(setResultNameFn); + else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect())) + asmInterface->getAsmResultNames(&op, setResultNameFn); + + // If the first result wasn't numbered, give it a default number. + if (valueIDs.try_emplace(resultBegin, nextValueID).second) + ++nextValueID; + + // If this operation has multiple result groups, mark it. + if (resultGroups.size() != 1) { + llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); + opResultGroups.try_emplace(&op, std::move(resultGroups)); + } +} + +void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue, + int &lookupResultNo) const { + Operation *owner = result->getOwner(); + if (owner->getNumResults() == 1) + return; + int resultNo = result->getResultNumber(); + + // If this operation has multiple result groups, we will need to find the + // one corresponding to this result. + auto resultGroupIt = opResultGroups.find(owner); + if (resultGroupIt == opResultGroups.end()) { + // If not, just use the first result. + lookupResultNo = resultNo; + lookupValue = owner->getResult(0); + return; + } + + // Find the correct index using a binary search, as the groups are ordered. + ArrayRef resultGroups = resultGroupIt->second; + auto it = llvm::upper_bound(resultGroups, resultNo); + int groupResultNo = 0, groupSize = 0; + + // If there are no smaller elements, the last result group is the lookup. + if (it == resultGroups.end()) { + groupResultNo = resultGroups.back(); + groupSize = static_cast(owner->getNumResults()) - resultGroups.back(); + } else { + // Otherwise, the previous element is the lookup. + groupResultNo = *std::prev(it); + groupSize = *it - groupResultNo; + } + + // We only record the result number for a group of size greater than 1. + if (groupSize != 1) + lookupResultNo = resultNo - groupResultNo; + lookupValue = owner->getResult(groupResultNo); +} + +void SSANameState::setValueName(Value value, StringRef name) { + // If the name is empty, the value uses the default numbering. + if (name.empty()) { + valueIDs[value] = nextValueID++; + return; + } + + valueIDs[value] = NameSentinel; + valueNames[value] = uniqueValueName(name); +} + +StringRef SSANameState::uniqueValueName(StringRef name) { + // Check to see if this name is already unique. + if (!usedNames.count(name)) { + name = name.copy(usedNameAllocator); + } else { + // Otherwise, we had a conflict - probe until we find a unique name. This + // is guaranteed to terminate (and usually in a single iteration) because it + // generates new names by incrementing nextConflictID. + SmallString<64> probeName(name); + probeName.push_back('_'); + while (true) { + probeName.resize(name.size() + 1); + probeName += llvm::utostr(nextConflictID++); + if (!usedNames.count(probeName)) { + name = StringRef(probeName).copy(usedNameAllocator); + break; + } + } + } + + usedNames.insert(name, char()); + return name; +} + +//===----------------------------------------------------------------------===// +// ModuleState +//===----------------------------------------------------------------------===// + +namespace { +class ModuleState { +public: + explicit ModuleState(Operation *op) + : interfaces(op->getContext()), nameState(op, interfaces) {} + + /// Initialize the alias state to enable the printing of aliases. + void initializeAliases(Operation *op) { + aliasState.initialize(op, interfaces); + } + + /// Get an instance of the OpAsmDialectInterface for the given dialect, or + /// null if one wasn't registered. + const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) { + return interfaces.getInterfaceFor(dialect); + } + + /// Get the state used for aliases. + AliasState &getAliasState() { return aliasState; } + + /// Get the state used for SSA names. + SSANameState &getSSANameState() { return nameState; } + +private: + /// Collection of OpAsm interfaces implemented in the context. + DialectInterfaceCollection interfaces; + + /// The state used for attribute and type aliases. + AliasState aliasState; + + /// The state used for SSA value names. + SSANameState nameState; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ModulePrinter +//===----------------------------------------------------------------------===// + +namespace { +class ModulePrinter { +public: + ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None, + ModuleState *state = nullptr) + : os(os), printerFlags(flags), state(state) {} + explicit ModulePrinter(ModulePrinter &printer) + : os(printer.os), printerFlags(printer.printerFlags), + state(printer.state) {} + + /// Returns the output stream of the printer. + raw_ostream &getStream() { return os; } + + template + inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { + mlir::interleaveComma(c, os, each_fn); + } + + void print(ModuleOp module); + + /// Print the given attribute. If 'mayElideType' is true, some attributes are + /// printed without the type when the type matches the default used in the + /// parser (for example i64 is the default for integer attributes). + void printAttribute(Attribute attr, bool mayElideType = false); + + void printType(Type type); + void printLocation(LocationAttr loc); + + void printAffineMap(AffineMap map); + void + printAffineExpr(AffineExpr expr, + function_ref printValueName = nullptr); + void printAffineConstraint(AffineExpr expr, bool isEq); + void printIntegerSet(IntegerSet set); + +protected: + void printOptionalAttrDict(ArrayRef attrs, + ArrayRef elidedAttrs = {}, + bool withKeyword = false); + void printTrailingLocation(Location loc); + void printLocationInternal(LocationAttr loc, bool pretty = false); + void printDenseElementsAttr(DenseElementsAttr attr); + + void printDialectAttribute(Attribute attr); + void printDialectType(Type type); + + /// This enum is used to represent the binding strength of the enclosing + /// context that an AffineExprStorage is being printed in, so we can + /// intelligently produce parens. + enum class BindingStrength { + Weak, // + and - + Strong, // All other binary operators. + }; + void printAffineExprInternal( + AffineExpr expr, BindingStrength enclosingTightness, + function_ref printValueName = nullptr); + + /// The output stream for the printer. + raw_ostream &os; + + /// A set of flags to control the printer's behavior. + OpPrintingFlags printerFlags; + + /// An optional printer state for the module. + ModuleState *state; +}; +} // end anonymous namespace + +void ModulePrinter::printTrailingLocation(Location loc) { + // Check to see if we are printing debug information. + if (!printerFlags.shouldPrintDebugInfo()) + return; + + os << " "; + printLocation(loc); +} + +void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) { + switch (loc.getKind()) { + case StandardAttributes::OpaqueLocation: + printLocationInternal(loc.cast().getFallbackLocation(), pretty); + break; + case StandardAttributes::UnknownLocation: + if (pretty) + os << "[unknown]"; + else + os << "unknown"; + break; + case StandardAttributes::FileLineColLocation: { + auto fileLoc = loc.cast(); + auto mayQuote = pretty ? "" : "\""; + os << mayQuote << fileLoc.getFilename() << mayQuote << ':' + << fileLoc.getLine() << ':' << fileLoc.getColumn(); + break; + } + case StandardAttributes::NameLocation: { + auto nameLoc = loc.cast(); + os << '\"' << nameLoc.getName() << '\"'; + + // Print the child if it isn't unknown. + auto childLoc = nameLoc.getChildLoc(); + if (!childLoc.isa()) { + os << '('; + printLocationInternal(childLoc, pretty); + os << ')'; + } + break; + } + case StandardAttributes::CallSiteLocation: { + auto callLocation = loc.cast(); + auto caller = callLocation.getCaller(); + auto callee = callLocation.getCallee(); + if (!pretty) + os << "callsite("; + printLocationInternal(callee, pretty); + if (pretty) { if (callee.isa()) { if (caller.isa()) { os << " at "; @@ -1152,8 +1503,44 @@ } } -//===----------------------------------------------------------------------===// -// CustomDialectAsmPrinter +void ModulePrinter::printOptionalAttrDict(ArrayRef attrs, + ArrayRef elidedAttrs, + bool withKeyword) { + // If there are no attributes, then there is nothing to be done. + if (attrs.empty()) + return; + + // Filter out any attributes that shouldn't be included. + SmallVector filteredAttrs( + llvm::make_filter_range(attrs, [&](NamedAttribute attr) { + return !llvm::is_contained(elidedAttrs, attr.first.strref()); + })); + + // If there are no attributes left to print after filtering, then we're done. + if (filteredAttrs.empty()) + return; + + // Print the 'attributes' keyword if necessary. + if (withKeyword) + os << " attributes"; + + // Otherwise, print them all out in braces. + os << " {"; + interleaveComma(filteredAttrs, [&](NamedAttribute attr) { + os << attr.first; + + // Pretty printing elides the attribute value for unit attributes. + if (attr.second.isa()) + return; + + os << " = "; + printAttribute(attr.second); + }); + os << '}'; +} + +//===----------------------------------------------------------------------===// +// CustomDialectAsmPrinter //===----------------------------------------------------------------------===// namespace { @@ -1347,593 +1734,175 @@ printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); os << " + "; - printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); - - if (enclosingTightness == BindingStrength::Strong) - os << ')'; -} - -void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) { - printAffineExprInternal(expr, BindingStrength::Weak); - isEq ? os << " == 0" : os << " >= 0"; -} - -void ModulePrinter::printAffineMap(AffineMap map) { - // Dimension identifiers. - os << '('; - for (int i = 0; i < (int)map.getNumDims() - 1; ++i) - os << 'd' << i << ", "; - if (map.getNumDims() >= 1) - os << 'd' << map.getNumDims() - 1; - os << ')'; - - // Symbolic identifiers. - if (map.getNumSymbols() != 0) { - os << '['; - for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) - os << 's' << i << ", "; - if (map.getNumSymbols() >= 1) - os << 's' << map.getNumSymbols() - 1; - os << ']'; - } - - // Result affine expressions. - os << " -> ("; - interleaveComma(map.getResults(), - [&](AffineExpr expr) { printAffineExpr(expr); }); - os << ')'; -} - -void ModulePrinter::printIntegerSet(IntegerSet set) { - // Dimension identifiers. - os << '('; - for (unsigned i = 1; i < set.getNumDims(); ++i) - os << 'd' << i - 1 << ", "; - if (set.getNumDims() >= 1) - os << 'd' << set.getNumDims() - 1; - os << ')'; - - // Symbolic identifiers. - if (set.getNumSymbols() != 0) { - os << '['; - for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) - os << 's' << i << ", "; - if (set.getNumSymbols() >= 1) - os << 's' << set.getNumSymbols() - 1; - os << ']'; - } - - // Print constraints. - os << " : ("; - int numConstraints = set.getNumConstraints(); - for (int i = 1; i < numConstraints; ++i) { - printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); - os << ", "; - } - if (numConstraints >= 1) - printAffineConstraint(set.getConstraint(numConstraints - 1), - set.isEq(numConstraints - 1)); - os << ')'; -} - -//===----------------------------------------------------------------------===// -// Operation printing -//===----------------------------------------------------------------------===// - -void ModulePrinter::printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs, - bool withKeyword) { - // If there are no attributes, then there is nothing to be done. - if (attrs.empty()) - return; - - // Filter out any attributes that shouldn't be included. - SmallVector filteredAttrs( - llvm::make_filter_range(attrs, [&](NamedAttribute attr) { - return !llvm::is_contained(elidedAttrs, attr.first.strref()); - })); - - // If there are no attributes left to print after filtering, then we're done. - if (filteredAttrs.empty()) - return; - - // Print the 'attributes' keyword if necessary. - if (withKeyword) - os << " attributes"; - - // Otherwise, print them all out in braces. - os << " {"; - interleaveComma(filteredAttrs, [&](NamedAttribute attr) { - os << attr.first; - - // Pretty printing elides the attribute value for unit attributes. - if (attr.second.isa()) - return; - - os << " = "; - printAttribute(attr.second); - }); - os << '}'; -} - -namespace { - -// OperationPrinter contains common functionality for printing operations. -class OperationPrinter : public ModulePrinter, private OpAsmPrinter { -public: - OperationPrinter(Operation *op, ModulePrinter &other); - OperationPrinter(Region *region, ModulePrinter &other); - - // Methods to print operations. - void print(Operation *op); - void print(Block *block, bool printBlockArgs = true, - bool printBlockTerminator = true); - - void printOperation(Operation *op); - void printGenericOp(Operation *op) override; - - // Implement OpAsmPrinter. - raw_ostream &getStream() const override { return os; } - void printType(Type type) override { ModulePrinter::printType(type); } - void printAttribute(Attribute attr) override { - ModulePrinter::printAttribute(attr); - } - void printOperand(Value value) override { printValueID(value); } - - void printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs = {}) override { - ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); - } - void printOptionalAttrDictWithKeyword( - ArrayRef attrs, - ArrayRef elidedAttrs = {}) override { - ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs, - /*withKeyword=*/true); - } - - enum { nameSentinel = ~0U }; - - void printBlockName(Block *block) { - auto id = getBlockID(block); - if (id != ~0U) - os << "^bb" << id; - else - os << "^INVALIDBLOCK"; - } - - unsigned getBlockID(Block *block) { - auto it = blockIDs.find(block); - return it != blockIDs.end() ? it->second : ~0U; - } - - void printSuccessorAndUseList(Operation *term, unsigned index) override; - - /// Print a region. - void printRegion(Region &blocks, bool printEntryBlockArgs, - bool printBlockTerminators) override { - os << " {\n"; - if (!blocks.empty()) { - auto *entryBlock = &blocks.front(); - print(entryBlock, - printEntryBlockArgs && entryBlock->getNumArguments() != 0, - printBlockTerminators); - for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1)) - print(&b); - } - os.indent(currentIndent) << "}"; - } - - /// Renumber the arguments for the specified region to the same names as the - /// SSA values in namesToUse. This may only be used for IsolatedFromAbove - /// operations. If any entry in namesToUse is null, the corresponding - /// argument name is left alone. - void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override; - - void printAffineMapOfSSAIds(AffineMapAttr mapAttr, - ValueRange operands) override { - AffineMap map = mapAttr.getValue(); - unsigned numDims = map.getNumDims(); - auto printValueName = [&](unsigned pos, bool isSymbol) { - unsigned index = isSymbol ? numDims + pos : pos; - assert(index < operands.size()); - if (isSymbol) - os << "symbol("; - printValueID(operands[index]); - if (isSymbol) - os << ')'; - }; - - interleaveComma(map.getResults(), [&](AffineExpr expr) { - printAffineExpr(expr, printValueName); - }); - } - - /// Print the given string as a symbol reference. - void printSymbolName(StringRef symbolRef) override { - ::printSymbolReference(symbolRef, os); - } - - // Number of spaces used for indenting nested operations. - const static unsigned indentWidth = 2; - -protected: - void numberValuesInRegion(Region ®ion); - void numberValuesInBlock(Block &block); - void numberValuesInOp(Operation &op); - void printValueID(Value value, bool printResultNo = true) const { - printValueIDImpl(value, printResultNo, os); - } - -private: - /// Given a result of an operation 'result', find the result group head - /// 'lookupValue' and the result of 'result' within that group in - /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group - /// has more than 1 result. - void getResultIDAndNumber(OpResult result, Value &lookupValue, - int &lookupResultNo) const; - void printValueIDImpl(Value value, bool printResultNo, - raw_ostream &stream) const; - - /// Set a special value name for the given value. - void setValueName(Value value, StringRef name); - - /// Uniques the given value name within the printer. If the given name - /// conflicts, it is automatically renamed. - StringRef uniqueValueName(StringRef name); - - /// This is the value ID for each SSA value. If this returns ~0, then the - /// valueID has an entry in valueNames. - DenseMap valueIDs; - DenseMap valueNames; - - /// This is a map of operations that contain multiple named result groups, - /// i.e. there may be multiple names for the results of the operation. The key - /// of this map are the result numbers that start a result group. - DenseMap> opResultGroups; - - /// This is the block ID for each block in the current. - DenseMap blockIDs; - - /// This keeps track of all of the non-numeric names that are in flight, - /// allowing us to check for duplicates. - /// Note: the value of the map is unused. - llvm::ScopedHashTable usedNames; - llvm::BumpPtrAllocator usedNameAllocator; - - // This is the current indentation level for nested structures. - unsigned currentIndent = 0; - - /// This is the next value ID to assign in numbering. - unsigned nextValueID = 0; - /// This is the next ID to assign to a region entry block argument. - unsigned nextArgumentID = 0; - /// This is the next ID to assign when a name conflict is detected. - unsigned nextConflictID = 0; -}; -} // end anonymous namespace - -OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other) - : ModulePrinter(other) { - llvm::ScopedHashTable::ScopeTy usedNamesScope(usedNames); - numberValuesInOp(*op); - - for (auto ®ion : op->getRegions()) - numberValuesInRegion(region); -} - -OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other) - : ModulePrinter(other) { - numberValuesInRegion(*region); -} - -void OperationPrinter::numberValuesInRegion(Region ®ion) { - // Save the current value ids to allow for numbering values in sibling regions - // the same. - unsigned curValueID = nextValueID; - unsigned curArgumentID = nextArgumentID; - unsigned curConflictID = nextConflictID; - - // Push a new used names scope. - llvm::ScopedHashTable::ScopeTy usedNamesScope(usedNames); - - // Number the values within this region in a breadth-first order. - unsigned nextBlockID = 0; - for (auto &block : region) { - // Each block gets a unique ID, and all of the operations within it get - // numbered as well. - blockIDs[&block] = nextBlockID++; - numberValuesInBlock(block); - } - - // After that we traverse the nested regions. - // TODO: Rework this loop to not use recursion. - for (auto &block : region) { - for (auto &op : block) - for (auto &nestedRegion : op.getRegions()) - numberValuesInRegion(nestedRegion); - } - - // Restore the original value ids. - nextValueID = curValueID; - nextArgumentID = curArgumentID; - nextConflictID = curConflictID; -} - -void OperationPrinter::numberValuesInBlock(Block &block) { - auto setArgNameFn = [&](Value arg, StringRef name) { - assert(!valueIDs.count(arg) && "arg numbered multiple times"); - assert(arg.cast()->getOwner() == &block && - "arg not defined in 'block'"); - setValueName(arg, name); - }; - - bool isEntryBlock = block.isEntryBlock(); - if (isEntryBlock && state) { - if (auto *op = block.getParentOp()) { - if (auto dialectAsmInterface = state->getOpAsmInterface(op->getDialect())) - dialectAsmInterface->getAsmBlockArgumentNames(&block, setArgNameFn); - } - } - - // Number the block arguments. We give entry block arguments a special name - // 'arg'. - SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); - llvm::raw_svector_ostream specialName(specialNameBuffer); - for (auto arg : block.getArguments()) { - if (valueIDs.count(arg)) - continue; - if (isEntryBlock) { - specialNameBuffer.resize(strlen("arg")); - specialName << nextArgumentID++; - } - setValueName(arg, specialName.str()); - } - - // Number the operations in this block. - for (auto &op : block) - numberValuesInOp(op); -} - -void OperationPrinter::numberValuesInOp(Operation &op) { - unsigned numResults = op.getNumResults(); - if (numResults == 0) - return; - Value resultBegin = op.getResult(0); - - // Function used to set the special result names for the operation. - SmallVector resultGroups(/*Size=*/1, /*Value=*/0); - auto setResultNameFn = [&](Value result, StringRef name) { - assert(!valueIDs.count(result) && "result numbered multiple times"); - assert(result->getDefiningOp() == &op && "result not defined by 'op'"); - setValueName(result, name); - - // Record the result number for groups not anchored at 0. - if (int resultNo = result.cast()->getResultNumber()) - resultGroups.push_back(resultNo); - }; - - if (OpAsmOpInterface asmInterface = dyn_cast(&op)) { - asmInterface.getAsmResultNames(setResultNameFn); - } else if (auto *dialectAsmInterface = - state ? state->getOpAsmInterface(op.getDialect()) : nullptr) { - dialectAsmInterface->getAsmResultNames(&op, setResultNameFn); - } - - // If the first result wasn't numbered, give it a default number. - if (valueIDs.try_emplace(resultBegin, nextValueID).second) - ++nextValueID; - - // If this operation has multiple result groups, mark it. - if (resultGroups.size() != 1) { - llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); - opResultGroups.try_emplace(&op, std::move(resultGroups)); - } -} - -/// Set a special value name for the given value. -void OperationPrinter::setValueName(Value value, StringRef name) { - // If the name is empty, the value uses the default numbering. - if (name.empty()) { - valueIDs[value] = nextValueID++; - return; - } - - valueIDs[value] = nameSentinel; - valueNames[value] = uniqueValueName(name); -} - -/// Uniques the given value name within the printer. If the given name -/// conflicts, it is automatically renamed. -StringRef OperationPrinter::uniqueValueName(StringRef name) { - // Check to see if this name is already unique. - if (!usedNames.count(name)) { - name = name.copy(usedNameAllocator); - } else { - // Otherwise, we had a conflict - probe until we find a unique name. This - // is guaranteed to terminate (and usually in a single iteration) because it - // generates new names by incrementing nextConflictID. - SmallString<64> probeName(name); - probeName.push_back('_'); - while (true) { - probeName.resize(name.size() + 1); - probeName += llvm::utostr(nextConflictID++); - if (!usedNames.count(probeName)) { - name = StringRef(probeName).copy(usedNameAllocator); - break; - } - } - } - - usedNames.insert(name, char()); - return name; -} - -void OperationPrinter::print(Block *block, bool printBlockArgs, - bool printBlockTerminator) { - // Print the block label and argument list if requested. - if (printBlockArgs) { - os.indent(currentIndent); - printBlockName(block); - - // Print the argument list if non-empty. - if (!block->args_empty()) { - os << '('; - interleaveComma(block->getArguments(), [&](BlockArgument arg) { - printValueID(arg); - os << ": "; - printType(arg->getType()); - }); - os << ')'; - } - os << ':'; - - // Print out some context information about the predecessors of this block. - if (!block->getParent()) { - os << "\t// block is not in a region!"; - } else if (block->hasNoPredecessors()) { - os << "\t// no predecessors"; - } else if (auto *pred = block->getSinglePredecessor()) { - os << "\t// pred: "; - printBlockName(pred); - } else { - // We want to print the predecessors in increasing numeric order, not in - // whatever order the use-list is in, so gather and sort them. - SmallVector, 4> predIDs; - for (auto *pred : block->getPredecessors()) - predIDs.push_back({getBlockID(pred), pred}); - llvm::array_pod_sort(predIDs.begin(), predIDs.end()); - - os << "\t// " << predIDs.size() << " preds: "; - - interleaveComma(predIDs, [&](std::pair pred) { - printBlockName(pred.second); - }); - } - os << '\n'; - } - - currentIndent += indentWidth; - auto range = llvm::make_range( - block->getOperations().begin(), - std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1)); - for (auto &op : range) { - print(&op); - os << '\n'; - } - currentIndent -= indentWidth; + printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); + + if (enclosingTightness == BindingStrength::Strong) + os << ')'; } -void OperationPrinter::print(Operation *op) { - os.indent(currentIndent); - printOperation(op); - printTrailingLocation(op->getLoc()); +void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) { + printAffineExprInternal(expr, BindingStrength::Weak); + isEq ? os << " == 0" : os << " >= 0"; } -void OperationPrinter::getResultIDAndNumber(OpResult result, Value &lookupValue, - int &lookupResultNo) const { - Operation *owner = result->getOwner(); - if (owner->getNumResults() == 1) - return; - int resultNo = result->getResultNumber(); +void ModulePrinter::printAffineMap(AffineMap map) { + // Dimension identifiers. + os << '('; + for (int i = 0; i < (int)map.getNumDims() - 1; ++i) + os << 'd' << i << ", "; + if (map.getNumDims() >= 1) + os << 'd' << map.getNumDims() - 1; + os << ')'; - // If this operation has multiple result groups, we will need to find the - // one corresponding to this result. - auto resultGroupIt = opResultGroups.find(owner); - if (resultGroupIt == opResultGroups.end()) { - // If not, just use the first result. - lookupResultNo = resultNo; - lookupValue = owner->getResult(0); - return; + // Symbolic identifiers. + if (map.getNumSymbols() != 0) { + os << '['; + for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) + os << 's' << i << ", "; + if (map.getNumSymbols() >= 1) + os << 's' << map.getNumSymbols() - 1; + os << ']'; } - // Find the correct index using a binary search, as the groups are ordered. - ArrayRef resultGroups = resultGroupIt->second; - auto it = llvm::upper_bound(resultGroups, resultNo); - int groupResultNo = 0, groupSize = 0; + // Result affine expressions. + os << " -> ("; + interleaveComma(map.getResults(), + [&](AffineExpr expr) { printAffineExpr(expr); }); + os << ')'; +} - // If there are no smaller elements, the last result group is the lookup. - if (it == resultGroups.end()) { - groupResultNo = resultGroups.back(); - groupSize = static_cast(owner->getNumResults()) - resultGroups.back(); - } else { - // Otherwise, the previous element is the lookup. - groupResultNo = *std::prev(it); - groupSize = *it - groupResultNo; +void ModulePrinter::printIntegerSet(IntegerSet set) { + // Dimension identifiers. + os << '('; + for (unsigned i = 1; i < set.getNumDims(); ++i) + os << 'd' << i - 1 << ", "; + if (set.getNumDims() >= 1) + os << 'd' << set.getNumDims() - 1; + os << ')'; + + // Symbolic identifiers. + if (set.getNumSymbols() != 0) { + os << '['; + for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) + os << 's' << i << ", "; + if (set.getNumSymbols() >= 1) + os << 's' << set.getNumSymbols() - 1; + os << ']'; } - // We only record the result number for a group of size greater than 1. - if (groupSize != 1) - lookupResultNo = resultNo - groupResultNo; - lookupValue = owner->getResult(groupResultNo); + // Print constraints. + os << " : ("; + int numConstraints = set.getNumConstraints(); + for (int i = 1; i < numConstraints; ++i) { + printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); + os << ", "; + } + if (numConstraints >= 1) + printAffineConstraint(set.getConstraint(numConstraints - 1), + set.isEq(numConstraints - 1)); + os << ')'; } -void OperationPrinter::printValueIDImpl(Value value, bool printResultNo, - raw_ostream &stream) const { - if (!value) { - stream << "<>"; - return; +//===----------------------------------------------------------------------===// +// OperationPrinter +//===----------------------------------------------------------------------===// + +namespace { +/// This class contains the logic for printing operations, regions, and blocks. +class OperationPrinter : public ModulePrinter, private OpAsmPrinter { +public: + explicit OperationPrinter(ModulePrinter &other) : ModulePrinter(other) { + assert(state && "expected valid state when printing operation"); } - int resultNo = -1; - auto lookupValue = value; + /// Print the given operation with its indent and location. + void print(Operation *op); + /// Print the bare location, not including indentation/location/etc. + void printOperation(Operation *op); + /// Print the given operation in the generic form. + void printGenericOp(Operation *op) override; - // If this is a reference to the result of a multi-result operation or - // operation, print out the # identifier and make sure to map our lookup - // to the first result of the operation. - if (OpResult result = value.dyn_cast()) - getResultIDAndNumber(result, lookupValue, resultNo); + /// Print the name of the given block. + void printBlockName(Block *block); - auto it = valueIDs.find(lookupValue); - if (it == valueIDs.end()) { - stream << "<>"; - return; + /// Print the given block. If 'printBlockArgs' is false, the arguments of the + /// block are not printed. If 'printBlockTerminator' is false, the terminator + /// operation of the block is not printed. + void print(Block *block, bool printBlockArgs = true, + bool printBlockTerminator = true); + + /// Print the ID of the given value, optionally with its result number. + void printValueID(Value value, bool printResultNo = true) const; + + //===--------------------------------------------------------------------===// + // OpAsmPrinter methods + //===--------------------------------------------------------------------===// + + /// Return the current stream of the printer. + raw_ostream &getStream() const override { return os; } + + /// Print the given type. + void printType(Type type) override { ModulePrinter::printType(type); } + + /// Print the given attribute. + void printAttribute(Attribute attr) override { + ModulePrinter::printAttribute(attr); } - stream << '%'; - if (it->second != nameSentinel) { - stream << it->second; - } else { - auto nameIt = valueNames.find(lookupValue); - assert(nameIt != valueNames.end() && "Didn't have a name entry?"); - stream << nameIt->second; + /// Print the ID for the given value. + void printOperand(Value value) override { printValueID(value); } + + /// Print an optional attribute dictionary with a given set of elided values. + void printOptionalAttrDict(ArrayRef attrs, + ArrayRef elidedAttrs = {}) override { + ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); + } + void printOptionalAttrDictWithKeyword( + ArrayRef attrs, + ArrayRef elidedAttrs = {}) override { + ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs, + /*withKeyword=*/true); } - if (resultNo != -1 && printResultNo) - stream << '#' << resultNo; -} + /// Print an operation successor with the operands used for the block + /// arguments. + void printSuccessorAndUseList(Operation *term, unsigned index) override; -/// Renumber the arguments for the specified region to the same names as the -/// SSA values in namesToUse. This may only be used for IsolatedFromAbove -/// operations. If any entry in namesToUse is null, the corresponding -/// argument name is left alone. -void OperationPrinter::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { - assert(!region.empty() && "cannot shadow arguments of an empty region"); - assert(region.front().getNumArguments() == namesToUse.size() && - "incorrect number of names passed in"); - assert(region.getParentOp()->isKnownIsolatedFromAbove() && - "only KnownIsolatedFromAbove ops can shadow names"); + /// Print the given region. + void printRegion(Region ®ion, bool printEntryBlockArgs, + bool printBlockTerminators) override; - SmallVector nameStr; - for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { - auto nameToUse = namesToUse[i]; - if (nameToUse == nullptr) - continue; + /// Renumber the arguments for the specified region to the same names as the + /// SSA values in namesToUse. This may only be used for IsolatedFromAbove + /// operations. If any entry in namesToUse is null, the corresponding + /// argument name is left alone. + void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { + state->getSSANameState().shadowRegionArgs(region, namesToUse); + } - auto nameToReplace = region.front().getArgument(i); + /// Print the given affine map with the smybol and dimension operands printed + /// inline with the map. + void printAffineMapOfSSAIds(AffineMapAttr mapAttr, + ValueRange operands) override; - nameStr.clear(); - llvm::raw_svector_ostream nameStream(nameStr); - printValueIDImpl(nameToUse, /*printResultNo=*/true, nameStream); + /// Print the given string as a symbol reference. + void printSymbolName(StringRef symbolRef) override { + ::printSymbolReference(symbolRef, os); + } - // Entry block arguments should already have a pretty "arg" name. - assert(valueIDs[nameToReplace] == nameSentinel); +private: + /// The number of spaces used for indenting nested operations. + const static unsigned indentWidth = 2; - // Use the name without the leading %. - auto name = StringRef(nameStream.str()).drop_front(); + // This is the current indentation level for nested structures. + unsigned currentIndent = 0; +}; +} // end anonymous namespace - // Overwrite the name. - valueNames[nameToReplace] = name.copy(usedNameAllocator); - } +void OperationPrinter::print(Operation *op) { + os.indent(currentIndent); + printOperation(op); + printTrailingLocation(op->getLoc()); } void OperationPrinter::printOperation(Operation *op) { @@ -1945,9 +1914,8 @@ }; // Check to see if this operation has multiple result groups. - auto resultGroupIt = opResultGroups.find(op); - if (resultGroupIt != opResultGroups.end()) { - ArrayRef resultGroups = resultGroupIt->second; + ArrayRef resultGroups = state->getSSANameState().getOpResultGroups(op); + if (!resultGroups.empty()) { // Interleave the groups excluding the last one, this one will be handled // separately. interleaveComma(llvm::seq(0, resultGroups.size() - 1), [&](int i) { @@ -1991,21 +1959,16 @@ for (unsigned i = 0; i < numSuccessors; ++i) totalNumSuccessorOperands += op->getNumSuccessorOperands(i); unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands; - SmallVector properOperands( - op->operand_begin(), std::next(op->operand_begin(), numProperOperands)); - - interleaveComma(properOperands, [&](Value value) { printValueID(value); }); + interleaveComma(op->getOperands().take_front(numProperOperands), + [&](Value value) { printValueID(value); }); os << ')'; // For terminators, print the list of successors and their operands. if (numSuccessors != 0) { os << '['; - for (unsigned i = 0; i < numSuccessors; ++i) { - if (i != 0) - os << ", "; - printSuccessorAndUseList(op, i); - } + interleaveComma(llvm::seq(0, numSuccessors), + [&](unsigned i) { printSuccessorAndUseList(op, i); }); os << ']'; } @@ -2027,6 +1990,73 @@ printFunctionalType(op); } +void OperationPrinter::printBlockName(Block *block) { + auto id = state->getSSANameState().getBlockID(block); + if (id != SSANameState::NameSentinel) + os << "^bb" << id; + else + os << "^INVALIDBLOCK"; +} + +void OperationPrinter::print(Block *block, bool printBlockArgs, + bool printBlockTerminator) { + // Print the block label and argument list if requested. + if (printBlockArgs) { + os.indent(currentIndent); + printBlockName(block); + + // Print the argument list if non-empty. + if (!block->args_empty()) { + os << '('; + interleaveComma(block->getArguments(), [&](BlockArgument arg) { + printValueID(arg); + os << ": "; + printType(arg->getType()); + }); + os << ')'; + } + os << ':'; + + // Print out some context information about the predecessors of this block. + if (!block->getParent()) { + os << "\t// block is not in a region!"; + } else if (block->hasNoPredecessors()) { + os << "\t// no predecessors"; + } else if (auto *pred = block->getSinglePredecessor()) { + os << "\t// pred: "; + printBlockName(pred); + } else { + // We want to print the predecessors in increasing numeric order, not in + // whatever order the use-list is in, so gather and sort them. + SmallVector, 4> predIDs; + for (auto *pred : block->getPredecessors()) + predIDs.push_back({state->getSSANameState().getBlockID(pred), pred}); + llvm::array_pod_sort(predIDs.begin(), predIDs.end()); + + os << "\t// " << predIDs.size() << " preds: "; + + interleaveComma(predIDs, [&](std::pair pred) { + printBlockName(pred.second); + }); + } + os << '\n'; + } + + currentIndent += indentWidth; + auto range = llvm::make_range( + block->getOperations().begin(), + std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1)); + for (auto &op : range) { + print(&op); + os << '\n'; + } + currentIndent -= indentWidth; +} + +void OperationPrinter::printValueID(Value value, bool printResultNo) const { + state->getSSANameState().printValueID(value, printResultNo, os); +} + void OperationPrinter::printSuccessorAndUseList(Operation *term, unsigned index) { printBlockName(term->getSuccessor(index)); @@ -2044,15 +2074,47 @@ os << ')'; } +void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, + bool printBlockTerminators) { + os << " {\n"; + if (!region.empty()) { + auto *entryBlock = ®ion.front(); + print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0, + printBlockTerminators); + for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) + print(&b); + } + os.indent(currentIndent) << "}"; +} + +void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, + ValueRange operands) { + AffineMap map = mapAttr.getValue(); + unsigned numDims = map.getNumDims(); + auto printValueName = [&](unsigned pos, bool isSymbol) { + unsigned index = isSymbol ? numDims + pos : pos; + assert(index < operands.size()); + if (isSymbol) + os << "symbol("; + printValueID(operands[index]); + if (isSymbol) + os << ')'; + }; + + interleaveComma(map.getResults(), [&](AffineExpr expr) { + printAffineExpr(expr, printValueName); + }); +} + void ModulePrinter::print(ModuleOp module) { + assert(state && "expected valid state when printing an operation"); + // Output the aliases at the top level. - if (state) { - state->getAliasState().printAttributeAliases(os); - state->getAliasState().printTypeAliases(os); - } + state->getAliasState().printAttributeAliases(os); + state->getAliasState().printTypeAliases(os); // Print the module. - OperationPrinter(module, *this).print(module); + OperationPrinter(*this).print(module); os << '\n'; } @@ -2124,25 +2186,24 @@ void Operation::print(raw_ostream &os, OpPrintingFlags flags) { // Handle top-level operations or local printing. if (!getParent() || flags.shouldUseLocalScope()) { - ModuleState state(getContext()); + ModuleState state(this); ModulePrinter modulePrinter(os, flags, &state); - OperationPrinter(this, modulePrinter).print(this); + OperationPrinter(modulePrinter).print(this); return; } - auto region = getParentRegion(); - if (!region) { - os << "<>\n"; + Operation *parentOp = getParentOp(); + if (!parentOp) { + os << "<>\n"; return; } + // Get the top-level op. + while (auto *nextOp = parentOp->getParentOp()) + parentOp = nextOp; - // Get the top-level region. - while (auto *nextRegion = region->getParentRegion()) - region = nextRegion; - - ModuleState state(getContext()); + ModuleState state(parentOp); ModulePrinter modulePrinter(os, flags, &state); - OperationPrinter(region, modulePrinter).print(this); + OperationPrinter(modulePrinter).print(this); } void Operation::dump() { @@ -2151,41 +2212,41 @@ } void Block::print(raw_ostream &os) { - auto region = getParent(); - if (!region) { + Operation *parentOp = getParentOp(); + if (!parentOp) { os << "<>\n"; return; } + // Get the top-level op. + while (auto *nextOp = parentOp->getParentOp()) + parentOp = nextOp; - // Get the top-level region. - while (auto *nextRegion = region->getParentRegion()) - region = nextRegion; - - ModuleState state(region->getContext()); + ModuleState state(parentOp); ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state); - OperationPrinter(region, modulePrinter).print(this); + OperationPrinter(modulePrinter).print(this); } void Block::dump() { print(llvm::errs()); } /// Print out the name of the block without printing its body. void Block::printAsOperand(raw_ostream &os, bool printType) { - auto region = getParent(); - if (!region) { + Operation *parentOp = getParentOp(); + if (!parentOp) { os << "<>\n"; return; } + // Get the top-level op. + while (auto *nextOp = parentOp->getParentOp()) + parentOp = nextOp; - // Get the top-level region. - while (auto *nextRegion = region->getParentRegion()) - region = nextRegion; - - ModulePrinter modulePrinter(os); - OperationPrinter(region, modulePrinter).printBlockName(this); + ModuleState state(parentOp); + ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state); + OperationPrinter(modulePrinter).printBlockName(this); } void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) { - ModuleState state(getContext()); + ModuleState state(*this); + // Don't populate aliases when printing at local scope. if (!flags.shouldUseLocalScope()) state.initializeAliases(*this);