diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -39,6 +39,12 @@ /// class Dialect { public: + /// Type for a callback provided by the dialect to parse a custom operation. + /// This is used for the dialect to provide an alternative way to parse custom + /// operations, including unregistered ones. + using ParseOpHook = + std::function; + virtual ~Dialect(); /// Utility function that returns if the given string is a valid dialect @@ -97,6 +103,18 @@ llvm_unreachable("dialect has no registered type printing hook"); } + /// Return the hook to parse an operation registered to this dialect, if any. + /// By default this will lookup for registered operations and return the + /// `parse()` method registered on the AbstractOperation. Dialects can + /// override this behavior and handle unregistered operations as well. + virtual Optional getParseOperationHook(StringRef opName) const; + + /// Print an operation registered to this dialect. + /// This hook is invoked for registered operation which don't override the + /// `print()` method to define their own custom assembly. + virtual LogicalResult printOperation(Operation *op, + OpAsmPrinter &printer) const; + //===--------------------------------------------------------------------===// // Verification Hooks //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -86,6 +86,9 @@ /// Use the specified object to parse this ops custom assembly format. ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const; + /// Return the static hook for parsing this operation assembly. + ParseAssemblyFn getParseAssemblyFn() const { return parseAssemblyFn; } + /// This hook implements the AsmPrinter for this operation. void printAssembly(Operation *op, OpAsmPrinter &p) const { return printAssemblyFn(op, p); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2409,6 +2409,11 @@ opInfo->printAssembly(op, *this); return; } + // Otherwise try to dispatch to the dialect, if available. + if (Dialect *dialect = op->getDialect()) { + if (succeeded(dialect->printOperation(op, *this))) + return; + } } // Otherwise print with the generic assembly form. diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -136,6 +136,18 @@ return Type(); } +Optional +Dialect::getParseOperationHook(StringRef opName) const { + return None; +} + +LogicalResult Dialect::printOperation(Operation *op, + OpAsmPrinter &printer) const { + assert(op->getDialect() == this && + "Dialect hook invoked on non-dialect owned operation"); + return failure(); +} + /// Utility function that returns if the given string is a valid dialect /// namespace. bool Dialect::isValidNamespace(StringRef str) { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -890,17 +890,19 @@ namespace { class CustomOpAsmParser : public OpAsmParser { public: - CustomOpAsmParser(SMLoc nameLoc, - ArrayRef resultIDs, - const AbstractOperation *opDefinition, - OperationParser &parser) - : nameLoc(nameLoc), resultIDs(resultIDs), opDefinition(opDefinition), + CustomOpAsmParser( + SMLoc nameLoc, ArrayRef resultIDs, + llvm::function_ref + parseAssembly, + bool isIsolatedFromAbove, StringRef opName, OperationParser &parser) + : nameLoc(nameLoc), resultIDs(resultIDs), parseAssembly(parseAssembly), + isIsolatedFromAbove(isIsolatedFromAbove), opName(opName), parser(parser) {} /// Parse an instance of the operation described by 'opDefinition' into the /// provided operation state. ParseResult parseOperation(OperationState &opState) { - if (opDefinition->parseAssembly(*this, opState)) + if (parseAssembly(*this, opState)) return failure(); // Verify that the parsed attributes does not have duplicate attributes. // This can happen if an attribute set during parsing is also specified in @@ -929,8 +931,7 @@ /// Emit a diagnostic at the specified location and return failure. InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { emittedError = true; - return parser.emitError(loc, "custom op '" + opDefinition->name.strref() + - "' " + message); + return parser.emitError(loc, "custom op '" + opName + "' " + message); } llvm::SMLoc getCurrentLocation() override { @@ -1455,8 +1456,7 @@ } // Try to parse the region. - assert((!enableNameShadowing || - opDefinition->hasTrait()) && + assert((!enableNameShadowing || isIsolatedFromAbove) && "name shadowing is only allowed on isolated regions"); if (parser.parseRegion(region, regionArguments, enableNameShadowing)) return failure(); @@ -1621,7 +1621,9 @@ ArrayRef resultIDs; /// The abstract information of the operation. - const AbstractOperation *opDefinition; + function_ref parseAssembly; + bool isIsolatedFromAbove; + StringRef opName; /// The main operation parser. OperationParser &parser; @@ -1635,31 +1637,51 @@ OperationParser::parseCustomOperation(ArrayRef resultIDs) { llvm::SMLoc opLoc = getToken().getLoc(); StringRef opName = getTokenSpelling(); - auto *opDefinition = AbstractOperation::lookup(opName, getContext()); - if (!opDefinition) { + Dialect *dialect = nullptr; + if (opDefinition) { + dialect = &opDefinition->dialect; + } else { if (opName.contains('.')) { // This op has a dialect, we try to check if we can register it in the // context on the fly. StringRef dialectName = opName.split('.').first; - if (!getContext()->getLoadedDialect(dialectName) && - getContext()->getOrLoadDialect(dialectName)) { + dialect = getContext()->getLoadedDialect(dialectName); + if (!dialect && (dialect = getContext()->getOrLoadDialect(dialectName))) opDefinition = AbstractOperation::lookup(opName, getContext()); - } } else { // If the operation name has no namespace prefix we treat it as a standard // operation and prefix it with "std". // TODO: Would it be better to just build a mapping of the registered // operations in the standard dialect? - if (getContext()->getOrLoadDialect("std")) + if (getContext()->getOrLoadDialect("std")) { opDefinition = AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); + if (opDefinition) + opName = opDefinition->name.strref(); + } } } - if (!opDefinition) { - emitError(opLoc) << "custom op '" << opName << "' is unknown"; - return nullptr; + // This is the actual hook for the custom op parsing, usually implemented by + // the op itself (`Op::parse()`). We retrieve it either from the + // AbstractOperation or from the Dialect. + std::function parseAssemblyFn; + bool isIsolatedFromAbove = false; + + if (opDefinition) { + parseAssemblyFn = opDefinition->getParseAssemblyFn(); + isIsolatedFromAbove = + opDefinition->hasTrait(); + } else { + Optional dialectHook; + if (dialect) + dialectHook = dialect->getParseOperationHook(opName); + if (!dialectHook.hasValue()) { + emitError(opLoc) << "custom op '" << opName << "' is unknown"; + return nullptr; + } + parseAssemblyFn = *dialectHook; } consumeToken(); @@ -1674,9 +1696,10 @@ auto srcLocation = getEncodedSourceLocation(opLoc); // Have the op implementation take a crack and parsing this. - OperationState opState(srcLocation, opDefinition->name); + OperationState opState(srcLocation, opName); CleanupOpStateRegions guard{opState}; - CustomOpAsmParser opAsmParser(opLoc, resultIDs, opDefinition, *this); + CustomOpAsmParser opAsmParser(opLoc, resultIDs, parseAssemblyFn, + isIsolatedFromAbove, opName, *this); if (opAsmParser.parseOperation(opState)) return nullptr; diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1411,3 +1411,7 @@ %2 = "bar"(%1) : (i64) -> i64 "unregistered_terminator"() : () -> () }) {sym_name = "unregistered_op_dominance_violation_ok", type = () -> i1} : () -> () + +// This is an unregister operation, the printing/parsing is handled by the dialect. +// CHECK: test.dialect_custom_printer custom_format +test.dialect_custom_printer custom_format diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -207,6 +207,25 @@ return success(); } +Optional +TestDialect::getParseOperationHook(StringRef opName) const { + if (opName == "test.dialect_custom_printer") + return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { + return parser.parseKeyword("custom_format"); + }}; + return None; +} + +LogicalResult TestDialect::printOperation(Operation *op, + OpAsmPrinter &printer) const { + StringRef opName = op->getName().getStringRef(); + if (opName == "test.dialect_custom_printer") { + printer.getStream() << opName << " custom_format"; + return success(); + } + return failure(); +} + //===----------------------------------------------------------------------===// // TestBranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -39,6 +39,12 @@ Type type) const override; void printAttribute(Attribute attr, DialectAsmPrinter &printer) const override; + + // Provides a custom printing/parsing for some operations. + Optional + getParseOperationHook(StringRef opName) const override; + LogicalResult printOperation(Operation *op, + OpAsmPrinter &printer) const override; }]; }