diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -773,6 +773,9 @@ /// Return if the printer should use local scope when dumping the IR. bool shouldUseLocalScope() const; + /// Return if the printer should print users of a result. + bool shouldPrintOperationUsers() const; + private: /// Elide large elements attributes if the number of elements is larger than /// the upper limit. @@ -790,6 +793,9 @@ /// Print operations with numberings local to the current operation. bool printLocalScope : 1; + + /// Print users of results. + bool printOperationUsers : 1; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -152,6 +152,10 @@ "mlir-print-local-scope", llvm::cl::init(false), llvm::cl::desc("Print with local scope and inline information (eliding " "aliases for attributes, types, and locations")}; + + llvm::cl::opt printOperationUsers{ + "mlir-print-operation-users", llvm::cl::init(false), + llvm::cl::desc("Print users of an operation as a comment")}; }; } // namespace @@ -168,7 +172,7 @@ OpPrintingFlags::OpPrintingFlags() : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), printGenericOpFormFlag(false), assumeVerifiedFlag(false), - printLocalScope(false) { + printLocalScope(false), printOperationUsers(false) { // Initialize based upon command line options, if they are available. if (!clOptions.isConstructed()) return; @@ -179,6 +183,7 @@ printGenericOpFormFlag = clOptions->printGenericOpFormOpt; assumeVerifiedFlag = clOptions->assumeVerifiedOpt; printLocalScope = clOptions->printLocalScopeOpt; + printOperationUsers = clOptions->printOperationUsers; } /// Enable the elision of large elements attributes, by printing a '...' @@ -254,6 +259,11 @@ /// Return if the printer should use local scope when dumping the IR. bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } +/// Return if the printer should print users of a result. +bool OpPrintingFlags::shouldPrintOperationUsers() const { + return printOperationUsers; +} + /// Returns true if an ElementsAttr with the given number of elements should be /// printed with hex. static bool shouldPrintElementsAttrWithHex(int64_t numElements) { @@ -831,6 +841,9 @@ /// of this value. void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; + /// Print the operation identifier. + void printOperationID(Operation *op, raw_ostream &stream) const; + /// Return the result indices for each of the result groups registered by this /// operation, or empty if none exist. ArrayRef getOpResultGroups(Operation *op); @@ -868,6 +881,10 @@ DenseMap valueIDs; DenseMap valueNames; + /// When printing users of results, an operation without a result might + /// be the user. This map holds ids for such operations. + DenseMap operationIDs; + /// This is a map of operations that contain multiple named result groups, /// i.e. there may be multiple names for the results of the operation. The /// value of this map are the result numbers that start a result group. @@ -990,6 +1007,15 @@ stream << '#' << resultNo; } +void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const { + auto it = operationIDs.find(op); + if (it == operationIDs.end()) { + stream << "<>"; + } else { + stream << '%' << it->second; + } +} + ArrayRef SSANameState::getOpResultGroups(Operation *op) { auto it = opResultGroups.find(op); return it == opResultGroups.end() ? ArrayRef() : it->second; @@ -1121,8 +1147,15 @@ } unsigned numResults = op.getNumResults(); - if (numResults == 0) + if (numResults == 0) { + // If users of operations should be printed, + // make sure that operations with no result have an id + if (printerFlags.shouldPrintOperationUsers()) { + if (operationIDs.try_emplace(&op, nextValueID).second) + ++nextValueID; + } return; + } Value resultBegin = op.getResult(0); // If the first result wasn't numbered, give it a default number. @@ -2481,6 +2514,10 @@ void printValueID(Value value, bool printResultNo = true, raw_ostream *streamOverride = nullptr) const; + /// Print the ID of the given operation. + void printOperationID(Operation *op, + raw_ostream *streamOverride = nullptr) const; + //===--------------------------------------------------------------------===// // OpAsmPrinter methods //===--------------------------------------------------------------------===// @@ -2549,6 +2586,16 @@ void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands) override; + /// Print users of this operation or id of this operation if it has no result. + void printUsersComment(Operation *op); + + /// Print users of this block arg. + void printUsersComment(BlockArgument &arg); + + /// Print either the ids of the result values or the id of the operation if + /// the operation has no results. + void printUserIDs(Operation *user, bool prefixComma = false); + private: // Contains the stack of default dialects to use when printing regions. // A new dialect is pushed to the stack before parsing regions nested under an @@ -2602,6 +2649,8 @@ os.indent(currentIndent); printOperation(op); printTrailingLocation(op->getLoc()); + if (printerFlags.shouldPrintOperationUsers()) + printUsersComment(op); } void OperationPrinter::printOperation(Operation *op) { @@ -2657,6 +2706,53 @@ printGenericOp(op, /*printOpName=*/true); } +/// Print users of this operation or id of this operation if it has no result. +void OperationPrinter::printUsersComment(Operation *op) { + if (!op->getNumResults() && op->getNumOperands() && !op->getNumRegions()) { + // Should print ids only for "last level" operations which have no result + // and are not used anywhere else. + os << " // id: "; + printOperationID(op); + } else if (!op->getUsers().empty()) { + os << " // " << (op->hasOneUse() ? "user" : "users") << ": "; + SmallPtrSet userSet; + for (auto indexedUser : llvm::enumerate(op->getUsers())) { + auto user = indexedUser.value(); + if (userSet.contains(user)) + continue; + userSet.insert(user); + printUserIDs(user, indexedUser.index()); + } + } +} + +/// Print users of this block arg. +void OperationPrinter::printUsersComment(BlockArgument &arg) { + SmallPtrSet userSet; + os << " // " << (arg.hasOneUse() ? "user" : "users") << ": "; + for (auto indexedUser : llvm::enumerate(arg.getUsers())) { + auto user = indexedUser.value(); + if (userSet.contains(user)) + continue; + userSet.insert(user); + printUserIDs(user, indexedUser.index()); + } +} + +/// Prints either the ids of the result values or the id of the operation if the +/// operation has no results. +void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) { + if (prefixComma) + os << ", "; + + if (!user->getNumResults()) + printOperationID(user); + else + // just print the id once from the first result (instead of: %1#0, %1#1, + // ...) + printValueID(user->getResult(0), false); +} + void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { if (printOpName) { os << '"'; @@ -2745,6 +2841,19 @@ } currentIndent += indentWidth; + + if (printerFlags.shouldPrintOperationUsers()) { + for (auto arg : block->getArguments()) { + if (arg.getUsers().empty()) + continue; + os.indent(currentIndent); + os << "// "; + printValueID(arg); + printUsersComment(arg); + os << newLine; + } + } + bool hasTerminator = !block->empty() && block->back().hasTrait(); auto range = llvm::make_range( @@ -2764,6 +2873,12 @@ streamOverride ? *streamOverride : os); } +void OperationPrinter::printOperationID(Operation *op, + raw_ostream *streamOverride) const { + state->getSSANameState().printOperationID(op, streamOverride ? *streamOverride + : os); +} + void OperationPrinter::printSuccessor(Block *successor) { printBlockName(successor); }