diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -380,6 +380,10 @@ printOptionalAttrDictWithKeyword(ArrayRef attrs, ArrayRef elidedAttrs = {}) = 0; + /// Prints the entire operation with the custom assembly form, if available, + /// or the generic assembly form, otherwise. + virtual void printCustomOrGenericOp(Operation *op) = 0; + /// Print the entire operation with the default generic assembly form. /// If `printOpName` is true, then the operation name is printed (the default) /// otherwise it is omitted and the print will start with the operand list. diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -421,8 +421,9 @@ AliasInitializer &initializer) : printerFlags(printerFlags), initializer(initializer) {} - /// Print the given operation. - void print(Operation *op) { + /// Prints the entire operation with the custom assembly form, if available, + /// or the generic assembly form, otherwise. + void printCustomOrGenericOp(Operation *op) override { // Visit the operation location. if (printerFlags.shouldPrintDebugInfo()) initializer.visit(op->getLoc(), /*canBeDeferred=*/true); @@ -489,7 +490,7 @@ std::prev(block->end(), (!hasTerminator || printBlockTerminator) ? 0 : 1)); for (Operation &op : range) - print(&op); + printCustomOrGenericOp(&op); } /// Print the given region. @@ -680,7 +681,7 @@ // attributes/types that will actually be used during printing when // considering aliases. DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); - aliasPrinter.print(op); + aliasPrinter.printCustomOrGenericOp(op); // Initialize the aliases sorted by name. initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes); @@ -2660,11 +2661,16 @@ /// Print the given top-level operation. void printTopLevelOperation(Operation *op); - /// Print the given operation with its indent and location. - void print(Operation *op); - /// Print the bare location, not including indentation/location/etc. - void printOperation(Operation *op); - /// Print the given operation in the generic form. + /// Print the given operation, including its left-hand side and its right-hand + /// side, with its indent and location. + void printFullOpWithIndentAndLoc(Operation *op); + /// Print the given operation, including its left-hand side and its right-hand + /// side, but not including indentation and location. + void printFullOp(Operation *op); + /// Print the right-hand size of the given operation in the custom or generic + /// form. + void printCustomOrGenericOp(Operation *op) override; + /// Print the right-hand side of the given operation in the generic form. void printGenericOp(Operation *op, bool printOpName) override; /// Print the name of the given block. @@ -2838,7 +2844,7 @@ state.getAliasState().printNonDeferredAliases(os, newLine); // Print the module. - print(op); + printFullOpWithIndentAndLoc(op); os << newLine; // Output the aliases at the top level that can be deferred. @@ -2934,18 +2940,18 @@ printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); } -void OperationPrinter::print(Operation *op) { +void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) { // Track the location of this operation. state.registerOperationLocation(op, newLine.curLine, currentIndent); os.indent(currentIndent); - printOperation(op); + printFullOp(op); printTrailingLocation(op->getLoc()); if (printerFlags.shouldPrintValueUsers()) printUsersComment(op); } -void OperationPrinter::printOperation(Operation *op) { +void OperationPrinter::printFullOp(Operation *op) { if (size_t numResults = op->getNumResults()) { auto printResultGroup = [&](size_t resultNo, size_t resultCount) { printValueID(op->getResult(resultNo), /*printResultNo=*/false); @@ -2972,34 +2978,7 @@ os << " = "; } - // If requested, always print the generic form. - if (!printerFlags.shouldPrintGenericOpForm()) { - // Check to see if this is a known operation. If so, use the registered - // custom printer hook. - if (auto opInfo = op->getRegisteredInfo()) { - opInfo->printAssembly(op, *this, defaultDialectStack.back()); - return; - } - // Otherwise try to dispatch to the dialect, if available. - if (Dialect *dialect = op->getDialect()) { - if (auto opPrinter = dialect->getOperationPrinter(op)) { - // Print the op name first. - StringRef name = op->getName().getStringRef(); - // Only drop the default dialect prefix when it cannot lead to - // ambiguities. - if (name.count('.') == 1) - name.consume_front((defaultDialectStack.back() + ".").str()); - os << name; - - // Print the rest of the op now. - opPrinter(op, *this); - return; - } - } - } - - // Otherwise print with the generic assembly form. - printGenericOp(op, /*printOpName=*/true); + printCustomOrGenericOp(op); } void OperationPrinter::printUsersComment(Operation *op) { @@ -3076,6 +3055,37 @@ } } +void OperationPrinter::printCustomOrGenericOp(Operation *op) { + // If requested, always print the generic form. + if (!printerFlags.shouldPrintGenericOpForm()) { + // Check to see if this is a known operation. If so, use the registered + // custom printer hook. + if (auto opInfo = op->getRegisteredInfo()) { + opInfo->printAssembly(op, *this, defaultDialectStack.back()); + return; + } + // Otherwise try to dispatch to the dialect, if available. + if (Dialect *dialect = op->getDialect()) { + if (auto opPrinter = dialect->getOperationPrinter(op)) { + // Print the op name first. + StringRef name = op->getName().getStringRef(); + // Only drop the default dialect prefix when it cannot lead to + // ambiguities. + if (name.count('.') == 1) + name.consume_front((defaultDialectStack.back() + ".").str()); + os << name; + + // Print the rest of the op now. + opPrinter(op, *this); + return; + } + } + } + + // Otherwise print with the generic assembly form. + printGenericOp(op, /*printOpName=*/true); +} + void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { if (printOpName) printEscapedString(op->getName().getStringRef()); @@ -3176,7 +3186,7 @@ std::prev(block->end(), (!hasTerminator || printBlockTerminator) ? 0 : 1)); for (auto &op : range) { - print(&op); + printFullOpWithIndentAndLoc(&op); os << newLine; } currentIndent -= indentWidth; @@ -3418,7 +3428,7 @@ state.getImpl().initializeAliases(this); printer.printTopLevelOperation(this); } else { - printer.print(this); + printer.printFullOpWithIndentAndLoc(this); } } diff --git a/mlir/test/IR/print-op-custom-or-generic.mlir b/mlir/test/IR/print-op-custom-or-generic.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/print-op-custom-or-generic.mlir @@ -0,0 +1,28 @@ +// # RUN: mlir-opt %s -split-input-file | FileCheck %s +// # RUN: mlir-opt %s -mlir-print-op-generic -split-input-file | FileCheck %s --check-prefix=GENERIC + +// Check that `printCustomOrGenericOp` and `printGenericOp` print the right +// assembly format. For operations without custom format, both should print the +// generic format. + +// CHECK-LABEL: func @op_with_custom_printer +// CHECK-GENERIC-LABEL: "func"() +func.func @op_with_custom_printer() { + %x = test.string_attr_pretty_name + // CHECK: %x = test.string_attr_pretty_name + // GENERIC: %0 = "test.string_attr_pretty_name"() + return + // CHECK: return + // GENERIC: "func.return"() +} + +// ----- + +// CHECK-LABEL: func @op_without_custom_printer +// CHECK-GENERIC: "func"() +func.func @op_without_custom_printer() { + // CHECK: "test.result_type_with_trait"() : () -> !test.test_type_with_trait + // GENERIC: "test.result_type_with_trait"() : () -> !test.test_type_with_trait + "test.result_type_with_trait"() : () -> !test.test_type_with_trait + return +} diff --git a/mlir/test/IR/print-op-generic.mlir b/mlir/test/IR/print-op-generic.mlir deleted file mode 100644 --- a/mlir/test/IR/print-op-generic.mlir +++ /dev/null @@ -1,13 +0,0 @@ -// # RUN: mlir-opt %s | FileCheck %s -// # RUN: mlir-opt %s --mlir-print-op-generic | FileCheck %s --check-prefix=GENERIC - -// CHECK-LABEL: func @pretty_names -// CHECK-GENERIC: "func"() -func.func @pretty_names() { - %x = test.string_attr_pretty_name - // CHECK: %x = test.string_attr_pretty_name - // GENERIC: %0 = "test.string_attr_pretty_name"() - return - // CHECK: return - // GENERIC: "func.return"() -}