diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -173,13 +173,19 @@ /// back to this one which accepts everything. LogicalResult verify() { return success(); } - /// Unless overridden, the custom assembly form of an op is always rejected. - /// Op implementations should implement this to return failure. - /// On success, they should fill in result with the fields to use. + /// Parse the custom form of an operation. Unless overridden, this method will + /// first try to get an operation parser from the op's dialect. Otherwise the + /// custom assembly form of an op is always rejected. Op implementations + /// should implement this to return failure. On success, they should fill in + /// result with the fields to use. static ParseResult parse(OpAsmParser &parser, OperationState &result); - // The fallback for the printer is to print it the generic assembly form. - static void print(Operation *op, OpAsmPrinter &p); + /// Print the operation. Unless overridden, this method will first try to get + /// an operation printer from the dialect. Otherwise, it prints the operation + /// in generic form. + static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect); + + /// Print an operation name, eliding the dialect prefix if necessary. static void printOpName(Operation *op, OpAsmPrinter &p, StringRef defaultDialect); @@ -1781,7 +1787,7 @@ OperationName::PrintAssemblyFn> getPrintAssemblyFnImpl() { return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) { - return OpState::print(op, printer); + return OpState::print(op, printer, defaultDialect); }; } /// The internal implementation of `getPrintAssemblyFn` that is invoked when diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -580,14 +580,27 @@ // OpState trait class. //===----------------------------------------------------------------------===// -// The fallback for the parser is to reject the custom assembly form. +// The fallback for the parser is to try for a dialect operation parser. +// Otherwise, reject the custom assembly form. ParseResult OpState::parse(OpAsmParser &parser, OperationState &result) { + if (auto parseFn = result.name.getDialect()->getParseOperationHook( + result.name.getStringRef())) + return (*parseFn)(parser, result); return parser.emitError(parser.getNameLoc(), "has no custom assembly form"); } -// The fallback for the printer is to print in the generic assembly form. -void OpState::print(Operation *op, OpAsmPrinter &p) { p.printGenericOp(op); } -// The fallback for the printer is to print in the generic assembly form. +// The fallback for the printer is to try for a dialect operation printer. +// Otherwise, it prints the generic form. +void OpState::print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { + if (auto printFn = op->getDialect()->getOperationPrinter(op)) { + printOpName(op, p, defaultDialect); + printFn(op, p); + } else { + p.printGenericOp(op); + } +} + +/// Print an operation name, eliding the dialect prefix if necessary. void OpState::printOpName(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) { StringRef name = op->getName().getStringRef(); diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1425,3 +1425,8 @@ // This is an unregister operation, the printing/parsing is handled by the dialect. // CHECK: test.dialect_custom_printer custom_format test.dialect_custom_printer custom_format + +// This is a registered operation with no custom parser and printer, and should +// be handled by the dialect. +// CHECK: test.dialect_custom_format_fallback custom_format_fallback +test.dialect_custom_format_fallback custom_format_fallback diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -318,6 +318,11 @@ return parser.parseKeyword("custom_format"); }}; } + if (opName == "test.dialect_custom_format_fallback") { + return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { + return parser.parseKeyword("custom_format_fallback"); + }}; + } return None; } @@ -329,6 +334,11 @@ printer.getStream() << " custom_format"; }; } + if (opName == "test.dialect_custom_format_fallback") { + return [](Operation *op, OpAsmPrinter &printer) { + printer.getStream() << " custom_format_fallback"; + }; + } return {}; } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -625,6 +625,10 @@ ); } +// This is used to test that the fallback for a custom op's parser and printer +// is the dialect parser and printer hooks. +def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">; + // This is used to test encoding of a string attribute into an SSA name of a // pretty printed value name. def StringAttrPrettyNameOp