diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -121,8 +121,8 @@ /// Print an operation registered to this dialect. /// This hook is invoked for registered operation which don't override the /// `print()` method to define their own custom assembly. - virtual LogicalResult printOperation(Operation *op, - OpAsmPrinter &printer) const; + virtual llvm::unique_function + getOperationPrinter(Operation *op) const; //===--------------------------------------------------------------------===// // Verification Hooks @@ -297,8 +297,7 @@ public: explicit DialectRegistry(); - template - void insert() { + template void insert() { insert(TypeID::get(), ConcreteDialect::getDialectNamespace(), static_cast(([](MLIRContext *ctx) { @@ -364,8 +363,7 @@ /// Add an external op interface model for an op that belongs to a dialect, /// both provided as template parameters. The dialect must be present in the /// registry. - template - void addOpInterface() { + template void addOpInterface() { StringRef opName = OpTy::getOperationName(); StringRef dialectName = opName.split('.').first; addObjectInterface(dialectName, TypeID::get(), @@ -426,8 +424,7 @@ namespace llvm { /// Provide isa functionality for Dialects. -template -struct isa_impl { +template struct isa_impl { static inline bool doit(const ::mlir::Dialect &dialect) { return mlir::TypeID::get() == dialect.getTypeID(); } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2508,8 +2508,10 @@ } // Otherwise try to dispatch to the dialect, if available. if (Dialect *dialect = op->getDialect()) { - if (succeeded(dialect->printOperation(op, *this))) + if (auto opPrinter = dialect->getOperationPrinter(op)) { + opPrinter(op, *this); return; + } } } diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -172,11 +172,11 @@ return None; } -LogicalResult Dialect::printOperation(Operation *op, - OpAsmPrinter &printer) const { +llvm::unique_function +Dialect::getOperationPrinter(Operation *op) const { assert(op->getDialect() == this && "Dialect hook invoked on non-dialect owned operation"); - return failure(); + return nullptr; } /// Utility function that returns if the given string is a valid dialect diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -313,14 +313,15 @@ return None; } -LogicalResult TestDialect::printOperation(Operation *op, - OpAsmPrinter &printer) const { +llvm::unique_function +TestDialect::getOperationPrinter(Operation *op) const { StringRef opName = op->getName().getStringRef(); if (opName == "test.dialect_custom_printer") { - printer.getStream() << opName << " custom_format"; - return success(); + return [](Operation *op, OpAsmPrinter &printer) { + printer.getStream() << op->getName().getStringRef() << " custom_format"; + }; } - return failure(); + return {}; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -39,15 +39,17 @@ void registerTypes(); ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser, - ::mlir::Type type) const override; + ::mlir::Type type) const override; void printAttribute(::mlir::Attribute attr, ::mlir::DialectAsmPrinter &printer) const override; // Provides a custom printing/parsing for some operations. ::llvm::Optional getParseOperationHook(::llvm::StringRef opName) const override; - ::mlir::LogicalResult printOperation(::mlir::Operation *op, - ::mlir::OpAsmPrinter &printer) const override; + ::llvm::unique_function + getOperationPrinter(::mlir::Operation *op) const override; + private: // Storage for a custom fallback interface. void *fallbackEffectOpInterfaces;