diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -380,6 +380,9 @@ printOptionalAttrDictWithKeyword(ArrayRef attrs, ArrayRef elidedAttrs = {}) = 0; + /// Prints the entire operation with the custom or generic assembly form. + virtual void printOp(Operation *op) = 0; + /// Print the entire operation with the default generic assembly form. /// If `printOpName` is true, then the operation name is printed (the default) /// otherwise it is omitted and the print will start with the operand list. diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -422,7 +422,7 @@ : printerFlags(printerFlags), initializer(initializer) {} /// Print the given operation. - void print(Operation *op) { + void printOp(Operation *op) override { // Visit the operation location. if (printerFlags.shouldPrintDebugInfo()) initializer.visit(op->getLoc(), /*canBeDeferred=*/true); @@ -489,7 +489,7 @@ std::prev(block->end(), (!hasTerminator || printBlockTerminator) ? 0 : 1)); for (Operation &op : range) - print(&op); + printOp(&op); } /// Print the given region. @@ -680,7 +680,7 @@ // attributes/types that will actually be used during printing when // considering aliases. DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); - aliasPrinter.print(op); + aliasPrinter.printOp(op); // Initialize the aliases sorted by name. initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes); @@ -2663,6 +2663,8 @@ void print(Operation *op); /// Print the bare location, not including indentation/location/etc. void printOperation(Operation *op); + /// Print the given operation in the custom or generic form. + void printOp(Operation *op) override; /// Print the given operation in the generic form. void printGenericOp(Operation *op, bool printOpName) override; @@ -2971,34 +2973,7 @@ os << " = "; } - // If requested, always print the generic form. - if (!printerFlags.shouldPrintGenericOpForm()) { - // Check to see if this is a known operation. If so, use the registered - // custom printer hook. - if (auto opInfo = op->getRegisteredInfo()) { - opInfo->printAssembly(op, *this, defaultDialectStack.back()); - return; - } - // Otherwise try to dispatch to the dialect, if available. - if (Dialect *dialect = op->getDialect()) { - if (auto opPrinter = dialect->getOperationPrinter(op)) { - // Print the op name first. - StringRef name = op->getName().getStringRef(); - // Only drop the default dialect prefix when it cannot lead to - // ambiguities. - if (name.count('.') == 1) - name.consume_front((defaultDialectStack.back() + ".").str()); - os << name; - - // Print the rest of the op now. - opPrinter(op, *this); - return; - } - } - } - - // Otherwise print with the generic assembly form. - printGenericOp(op, /*printOpName=*/true); + printOp(op); } void OperationPrinter::printUsersComment(Operation *op) { @@ -3075,6 +3050,37 @@ } } +void OperationPrinter::printOp(Operation *op) { + // If requested, always print the generic form. + if (!printerFlags.shouldPrintGenericOpForm()) { + // Check to see if this is a known operation. If so, use the registered + // custom printer hook. + if (auto opInfo = op->getRegisteredInfo()) { + opInfo->printAssembly(op, *this, defaultDialectStack.back()); + return; + } + // Otherwise try to dispatch to the dialect, if available. + if (Dialect *dialect = op->getDialect()) { + if (auto opPrinter = dialect->getOperationPrinter(op)) { + // Print the op name first. + StringRef name = op->getName().getStringRef(); + // Only drop the default dialect prefix when it cannot lead to + // ambiguities. + if (name.count('.') == 1) + name.consume_front((defaultDialectStack.back() + ".").str()); + os << name; + + // Print the rest of the op now. + opPrinter(op, *this); + return; + } + } + } + + // Otherwise print with the generic assembly form. + printGenericOp(op, /*printOpName=*/true); +} + void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { if (printOpName) printEscapedString(op->getName().getStringRef());