diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -773,6 +773,9 @@ /// Return if the printer should use local scope when dumping the IR. bool shouldUseLocalScope() const; + /// Return if the printer should print users of a result. + bool shouldPrintOperationUsers() const; + private: /// Elide large elements attributes if the number of elements is larger than /// the upper limit. @@ -790,6 +793,9 @@ /// Print operations with numberings local to the current operation. bool printLocalScope : 1; + + /// Print users of results. + bool printOperationUsers : 1; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -152,6 +152,10 @@ "mlir-print-local-scope", llvm::cl::init(false), llvm::cl::desc("Print with local scope and inline information (eliding " "aliases for attributes, types, and locations")}; + + llvm::cl::opt printOperationUsers{ + "mlir-print-operation-users", llvm::cl::init(false), + llvm::cl::desc("Print users of an operation as a comment")}; }; } // namespace @@ -168,7 +172,7 @@ OpPrintingFlags::OpPrintingFlags() : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), printGenericOpFormFlag(false), assumeVerifiedFlag(false), - printLocalScope(false) { + printLocalScope(false), printOperationUsers(false) { // Initialize based upon command line options, if they are available. if (!clOptions.isConstructed()) return; @@ -179,6 +183,7 @@ printGenericOpFormFlag = clOptions->printGenericOpFormOpt; assumeVerifiedFlag = clOptions->assumeVerifiedOpt; printLocalScope = clOptions->printLocalScopeOpt; + printOperationUsers = clOptions->printOperationUsers; } /// Enable the elision of large elements attributes, by printing a '...' @@ -254,6 +259,11 @@ /// Return if the printer should use local scope when dumping the IR. bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } +/// Return if the printer should print users of a result. +bool OpPrintingFlags::shouldPrintOperationUsers() const { + return printOperationUsers; +} + /// Returns true if an ElementsAttr with the given number of elements should be /// printed with hex. static bool shouldPrintElementsAttrWithHex(int64_t numElements) { @@ -831,6 +841,9 @@ /// of this value. void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; + /// Print the operation identifier. + void printOperationID(Operation *op, raw_ostream &stream) const; + /// Return the result indices for each of the result groups registered by this /// operation, or empty if none exist. ArrayRef getOpResultGroups(Operation *op); @@ -868,6 +881,10 @@ DenseMap valueIDs; DenseMap valueNames; + /// When printing users of results, an operation without a result might + /// be the user. This map holds ids for such operations. + DenseMap operationIDs; + /// This is a map of operations that contain multiple named result groups, /// i.e. there may be multiple names for the results of the operation. The /// value of this map are the result numbers that start a result group. @@ -990,6 +1007,15 @@ stream << '#' << resultNo; } +void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const { + auto it = operationIDs.find(op); + if (it == operationIDs.end()) { + stream << "<>"; + } else { + stream << '%' << it->second; + } +} + ArrayRef SSANameState::getOpResultGroups(Operation *op) { auto it = opResultGroups.find(op); return it == opResultGroups.end() ? ArrayRef() : it->second; @@ -1121,8 +1147,15 @@ } unsigned numResults = op.getNumResults(); - if (numResults == 0) + if (numResults == 0) { + // If users of operations should be printed, + // make sure that operations with no result have an id + if (printerFlags.shouldPrintOperationUsers()) { + if (operationIDs.try_emplace(&op, nextValueID).second) + ++nextValueID; + } return; + } Value resultBegin = op.getResult(0); // If the first result wasn't numbered, give it a default number. @@ -2481,6 +2514,10 @@ void printValueID(Value value, bool printResultNo = true, raw_ostream *streamOverride = nullptr) const; + /// Print the ID of the given operation. + void printOperationID(Operation *op, + raw_ostream *streamOverride = nullptr) const; + //===--------------------------------------------------------------------===// // OpAsmPrinter methods //===--------------------------------------------------------------------===// @@ -2549,6 +2586,19 @@ void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands) override; + /// Print users of this operation or id of this operation if it has no result. + void printUsersComment(Operation *op); + + /// Print users of this block arg. + void printUsersComment(BlockArgument &arg); + + /// Print the users of a value. + void printValueUsers(Value &value); + + /// Print either the ids of the result values or the id of the operation if + /// the operation has no results. + void printUserIDs(Operation *user, bool prefixComma = false); + private: // Contains the stack of default dialects to use when printing regions. // A new dialect is pushed to the stack before parsing regions nested under an @@ -2602,6 +2652,8 @@ os.indent(currentIndent); printOperation(op); printTrailingLocation(op->getLoc()); + if (printerFlags.shouldPrintOperationUsers()) + printUsersComment(op); } void OperationPrinter::printOperation(Operation *op) { @@ -2657,6 +2709,84 @@ printGenericOp(op, /*printOpName=*/true); } +/// Print users of this operation or id of this operation if it has no result. +void OperationPrinter::printUsersComment(Operation *op) { + auto numResults = op->getNumResults(); + if (!numResults && op->getNumOperands()) { + os << " // id: "; + printOperationID(op); + } else if (numResults && op->getUsers().empty()) { + os << " // unused"; + } else if (numResults && !op->getUsers().empty()) { + // Should print "user" if there is one result that is + // used to compute one other result or + // is used in one other operation with no result + unsigned usedInNResults = 0; + unsigned usedInNOperations = 0; + SmallPtrSet userSet; + for (auto user : op->getUsers()) + if (userSet.insert(user).second) { + ++usedInNOperations; + usedInNResults += user->getNumResults(); + } + + // We already know that users is not empty + bool exactlyOneUniqueUse = usedInNResults <= 1 && usedInNOperations <= 1; + os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": "; + bool shouldPrintBrackets = numResults > 1; + auto printOpResult = [&](OpResult opResult) { + if (shouldPrintBrackets) + os << "("; + printValueUsers(opResult); + if (shouldPrintBrackets) + os << ")"; + }; + + interleaveComma(op->getResults(), printOpResult); + } +} + +/// Print users of this block arg. +void OperationPrinter::printUsersComment(BlockArgument &arg) { + os << "// "; + printValueID(arg); + if (arg.getUsers().empty()) { + os << " is unused"; + } else { + os << " is used by "; + printValueUsers(arg); + } + os << newLine; +} + +/// Print the users of a value. +void OperationPrinter::printValueUsers(Value &value) { + if (value.getUsers().empty()) + os << "unused"; + + // One value might be used as the operand of an operation more than once. + // Only print the operations results once in that case. + SmallPtrSet userSet; + for (auto indexedUser : enumerate(value.getUsers())) { + if (userSet.insert(indexedUser.value()).second) + printUserIDs(indexedUser.value(), indexedUser.index()); + } +} + +/// Prints either the ids of the result values or the id of the operation if the +/// operation has no results. +void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) { + if (prefixComma) + os << ", "; + + if (!user->getNumResults()) { + printOperationID(user); + } else { + interleaveComma(user->getResults(), + [&](Value result) { printValueID(result); }); + } +} + void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { if (printOpName) { os << '"'; @@ -2745,6 +2875,14 @@ } currentIndent += indentWidth; + + if (printerFlags.shouldPrintOperationUsers()) { + for (auto arg : block->getArguments()) { + os.indent(currentIndent); + printUsersComment(arg); + } + } + bool hasTerminator = !block->empty() && block->back().hasTrait(); auto range = llvm::make_range( @@ -2764,6 +2902,12 @@ streamOverride ? *streamOverride : os); } +void OperationPrinter::printOperationID(Operation *op, + raw_ostream *streamOverride) const { + state->getSSANameState().printOperationID(op, streamOverride ? *streamOverride + : os); +} + void OperationPrinter::printSuccessor(Block *successor) { printBlockName(successor); } diff --git a/mlir/test/IR/print-op-users.mlir b/mlir/test/IR/print-op-users.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/print-op-users.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-operation-users -split-input-file %s | FileCheck %s + +module { + // CHECK: %[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32 + func @foo(%arg0: i32, %arg1: i32, %arg3: i32) -> i32 { + // CHECK-NEXT: // %[[ARG0]] is used by %[[ARG0U1:.+]], %[[ARG0U2:.+]], %[[ARG0U3:.+]] + // CHECK-NEXT: // %[[ARG1]] is used by %[[ARG1U1:.+]], %[[ARG1U2:.+]] + // CHECK-NEXT: // %[[ARG2]] is unused + // CHECK-NEXT: test.noop + // CHECK-NOT: // unused + "test.noop"() : () -> () + // When no result is produced, an id should be printed + // CHECK-NEXT: // id: %[[ARG0U3]] + "test.no_result"(%arg0) {} : (i32) -> () + // Check for unused result + // CHECK-NEXT: %[[ARG0U2]] = + // CHECK-SAME: // unused + %1 = "test.unused_result"(%arg0, %arg1) {} : (i32, i32) -> i32 + // Check that both users are printed + // CHECK-NEXT: %[[ARG0U1]] = + // CHECK-SAME: // users: %[[A:.+]]#0, %[[A]]#1 + %2 = "test.one_result"(%arg0, %arg1) {} : (i32, i32) -> i32 + // For multiple results, users should be grouped per result + // CHECK-NEXT: %[[A]]:2 = + // CHECK-SAME: // users: (%[[B:.+]], %[[C:.+]]), (%[[B]], %[[D:.+]]) + %3:2 = "test.many_results"(%2) {} : (i32) -> (i32, i32) + // CHECK-NEXT: %[[C]] = + // Result is used twice in next operation but it produces only one result. + // CHECK-SAME: // user: + %4 = "test.foo"(%3#0) {} : (i32) -> i32 + // CHECK-NEXT: %[[D]] = + %5 = "test.foo"(%3#1, %4, %4) {} : (i32, i32, i32) -> i32 + // CHECK-NEXT: %[[B]] = + // Result is not used in any other result but in two operations. + // CHECK-SAME: // users: + %6 = "test.foo"(%3#0, %3#1) {} : (i32, i32) -> i32 + "test.no_result"(%6) {} : (i32) -> () + return %6: i32 + } +} + +// ----- + +module { + // Check with nested operation + // CHECK: %[[CONSTNAME:.+]] = arith.constant + %0 = arith.constant 42 : i32 + %test = "test.outerop"(%0) ({ + // CHECK: "test.innerop"(%[[CONSTNAME]]) : (i32) -> () // id: % + "test.innerop"(%0) : (i32) -> () + // CHECK: (i32) -> i32 // users: %r, %s, %p, %p_0, %q + }): (i32) -> i32 + + // Check named results + // CHECK-NEXT: // users: (%u, %v), (unused), (%u, %v, %r, %s) + %p:2, %q = "test.custom_result_name"(%test) {names = ["p", "p", "q"]} : (i32) -> (i32, i32, i32) + // CHECK-NEXT: // users: (unused), (%u, %v) + %r, %s = "test.custom_result_name"(%q#0, %q#0, %test) {names = ["r", "s"]} : (i32, i32, i32) -> (i32, i32) + // CHECK-NEXT: // unused + %u, %v = "test.custom_result_name"(%s, %q#0, %p) {names = ["u", "v"]} : (i32, i32, i32) -> (i32, i32) +} \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1168,6 +1168,15 @@ setNameFn(getResult(i), str.getValue()); } +void CustomResultsNameOp::getAsmResultNames( + function_ref setNameFn) { + auto value = getNames(); + for (size_t i = 0, e = value.size(); i != e; ++i) + if (auto str = value[i].dyn_cast()) + if (!str.getValue().empty()) + setNameFn(getResult(i), str.getValue()); +} + //===----------------------------------------------------------------------===// // ResultTypeWithTraitOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -732,6 +732,19 @@ let hasCustomAssemblyFormat = 1; } + +// This is used to test encoding of a string attribute into an SSA name of a +// pretty printed value name. +def CustomResultsNameOp + : TEST_Op<"custom_result_name", + [DeclareOpInterfaceMethods]> { + let arguments = (ins + Variadic:$optional, + StrArrayAttr:$names + ); + let results = (outs Variadic:$r); +} + // This is used to test the OpAsmOpInterface::getDefaultDialect() feature: // operations nested in a region under this op will drop the "test." dialect // prefix.