diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -121,8 +121,8 @@ /// Print an operation registered to this dialect. /// This hook is invoked for registered operation which don't override the /// `print()` method to define their own custom assembly. - virtual LogicalResult printOperation(Operation *op, - OpAsmPrinter &printer) const; + virtual function_ref + getOperationPrinter(Operation *op) const; //===--------------------------------------------------------------------===// // Verification Hooks diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2508,8 +2508,10 @@ } // Otherwise try to dispatch to the dialect, if available. if (Dialect *dialect = op->getDialect()) { - if (succeeded(dialect->printOperation(op, *this))) + if (auto opPrinter = dialect->getOperationPrinter(op)) { + opPrinter(op, *this); return; + } } } diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -172,11 +172,11 @@ return None; } -LogicalResult Dialect::printOperation(Operation *op, - OpAsmPrinter &printer) const { +function_ref +Dialect::getOperationPrinter(Operation *op) const { assert(op->getDialect() == this && "Dialect hook invoked on non-dialect owned operation"); - return failure(); + return nullptr; } /// Utility function that returns if the given string is a valid dialect diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -313,14 +313,15 @@ return None; } -LogicalResult TestDialect::printOperation(Operation *op, - OpAsmPrinter &printer) const { +function_ref +TestDialect::getOperationPrinter(Operation *op) const { StringRef opName = op->getName().getStringRef(); if (opName == "test.dialect_custom_printer") { - printer.getStream() << opName << " custom_format"; - return success(); + return [](Operation *op, OpAsmPrinter &printer) { + printer.getStream() << op->getName().getStringRef() << " custom_format"; + }; } - return failure(); + return {}; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -46,8 +46,10 @@ // Provides a custom printing/parsing for some operations. ::llvm::Optional getParseOperationHook(::llvm::StringRef opName) const override; - ::mlir::LogicalResult printOperation(::mlir::Operation *op, - ::mlir::OpAsmPrinter &printer) const override; + ::llvm::function_ref + getOperationPrinter(::mlir::Operation *op) const override; + private: // Storage for a custom fallback interface. void *fallbackEffectOpInterfaces;