diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td --- a/mlir/include/mlir/IR/OpAsmInterface.td +++ b/mlir/include/mlir/IR/OpAsmInterface.td @@ -52,4 +52,25 @@ ]; } +/// Interface for hooking a default dialect into the OpAsmPrinter and OpAsmParser. +def DefaultDialectAsmOpInterface : OpInterface<"DefaultDialectAsmOpInterface"> { + let description = [{ + This interface provides hooks to plug a default dialect into the AsmPrinter and + AsmParser classes. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + StaticInterfaceMethod<[{ + Return the default dialect used when printing/parsing operations in regions + nested under a given operation. This allows to elide the dialect from the + operation name, for example it would be possible to omit the `spv.` prefix + from all operations in a SpirV module if it was implementing this interface + to return `spv`. + }], + "StringRef", "getDefaultDialect" + >, + ]; +} + #endif // MLIR_OPASMINTERFACE diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -961,4 +961,18 @@ /// The OpAsmOpInterface, see OpAsmInterface.td for more details. #include "mlir/IR/OpAsmInterface.h.inc" +namespace mlir { +/// Return the default dialect to use when printing/parsing the regions nested +/// under a given operation. The default is "builtin" is no enclosing operation +/// implements `DefaultDialectAsmOpInterface`. +inline StringRef getDefaultDialectFromEnclosingOps(Operation *op) { + while (op && op->getParentOp()) { + if (auto iface = dyn_cast(op)) + return iface.getDefaultDialect(); + op = op->getParentOp(); + } + return "builtin"; +} +} // end namespace mlir + #endif diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2296,12 +2296,73 @@ //===----------------------------------------------------------------------===// namespace { + +/// A raw_ostream that is intended to optionally filter the dialect name when +/// printing operations. +class operation_name_filtering_stream : public raw_ostream { +public: + /// Construct a new stream that wraps the provided stream. + explicit operation_name_filtering_stream(raw_ostream &os) : os(os) {} + ~operation_name_filtering_stream() { flush(); } + + /// This sets the current dialect that we intend to filter. + /// If the next data written to this stream after this call matches the + /// dialect name it gets stripped. It is to be set each time we're about to + /// print an operation, since the first thing printed by the custom printer + /// will be the op name. + void setDialectFilter(StringRef dialect) { + // Setup a buffer that is the size of the filter: this guarantee that the + // call to write_impl() is at least as large as the filter and allow to + // decide if it matches or not. An previously existing buffer is flushed + // implicitly and processed here, as we intend to match only whatever is + // written to the stream *after* the filter is set. + flush(); + SetBufferSize(dialect.size() + 1); + flush(); + // Modify the filter after setting the buffer size, in case there was a + // previous filter set and data left to flush during the call above. + filter = (dialect + ".").str(); + } + +private: + /// See raw_ostream::write_impl. + void write_impl(const char *Ptr, size_t Size) override { + // No filter set (usual case), passthrough mode. + if (LLVM_LIKELY(filter.empty())) { + os.write(Ptr, Size); + return; + } + // If the input is smaller than the current filter, it means the stream is + // explicitly flushed, usually on destruction, we will assume that this is + // the end of the input and we can't have a match (the input is smaller than + // the filter). + StringRef toWrite(Ptr, Size); + if (toWrite.startswith(filter)) + toWrite = toWrite.drop_front(filter.size()); + os << toWrite; + filter.clear(); + // At this point we don't need to bufferize anymore, however we can't + // disable the buffering by calling `SetUnbuffered();` here without breaking + // some invariants of the stream. It does not really hurt to keep it though. + return; + } + + /// Return the current position within the stream, not counting the bytes + /// currently in the buffer. + uint64_t current_pos() const override { return os.tell(); } + + /// The underlying stream that is wrapped here. + raw_ostream &os; + /// The filter to apply to the next data written to the stream. + std::string filter; +}; + /// This class contains the logic for printing operations, regions, and blocks. class OperationPrinter : public ModulePrinter, private OpAsmPrinter { public: explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags, AsmStateImpl &state) - : ModulePrinter(os, flags, &state) {} + : ModulePrinter(filteredOS, flags, &state), filteredOS(os) {} /// Print the given top-level operation. void printTopLevelOperation(Operation *op); @@ -2417,6 +2478,16 @@ } private: + // Contains the stack of default dialect to use when printing regions. + // A new dialect get pushed to the stack before parsing regions nested + // under an operation implementing `DefaultDialectAsmOpInterface`, and + // popped when done. At the top-level we start with "builtin" as the + // default, so that the top-level `module` operation prints as-is. + SmallVector defaultDialectStack{"builtin"}; + // This stream is used to intercept the operation names while printing + // and filtering the default dialect if it matches. + operation_name_filtering_stream filteredOS; + /// The number of spaces used for indenting nested operations. const static unsigned indentWidth = 2; @@ -2496,6 +2567,7 @@ // Check to see if this is a known operation. If so, use the registered // custom printer hook. if (auto *opInfo = op->getAbstractOperation()) { + filteredOS.setDialectFilter(defaultDialectStack.back()); opInfo->printAssembly(op, *this); return; } @@ -2512,6 +2584,7 @@ void OperationPrinter::printGenericOp(Operation *op) { os << '"'; + filteredOS.setDialectFilter(defaultDialectStack.back()); printEscapedString(op->getName().getStringRef(), os); os << "\"("; interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); @@ -2641,6 +2714,20 @@ bool printEmptyBlock) { os << " {" << newLine; if (!region.empty()) { + // RAII struct to pop the default dialect from the `defaulDialectStack` when + // we're done with printing this region. + struct RestoreDefautDialect { + RestoreDefautDialect(SmallVector &defaulDialectStack) + : defaulDialectStack(defaulDialectStack) {} + ~RestoreDefautDialect() { defaulDialectStack.pop_back(); } + + private: + SmallVector &defaulDialectStack; + }; + defaultDialectStack.push_back( + getDefaultDialectFromEnclosingOps(region.getParentOp()).str()); + RestoreDefautDialect r(defaultDialectStack); + auto *entryBlock = ®ion.front(); // Force printing the block header if printEmptyBlock is set and the block // is empty or if printEntryBlockArgs is set and there are arguments to diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1860,8 +1860,9 @@ // If the operation name has no namespace prefix we treat it as a builtin // or standard operation and prefix it with "builtin" or "std". // TODO: Remove the special casing here. - opDefinition = AbstractOperation::lookup(Twine("builtin." + opName).str(), - getContext()); + opDefinition = AbstractOperation::lookup( + Twine(getState().defaultDialectStack.back() + "." + opName).str(), + getContext()); if (!opDefinition && getContext()->getOrLoadDialect("std")) { opDefinition = AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); @@ -1877,10 +1878,14 @@ function_ref parseAssemblyFn; bool isIsolatedFromAbove = false; + StringRef defaultDialect = getState().defaultDialectStack.back(); if (opDefinition) { parseAssemblyFn = opDefinition->getParseAssemblyFn(); isIsolatedFromAbove = opDefinition->hasTrait(); + auto *iface = opDefinition->getInterface(); + if (iface) + defaultDialect = iface->getDefaultDialect(); } else { Optional dialectHook; if (dialect) @@ -1891,6 +1896,16 @@ } parseAssemblyFn = *dialectHook; } + struct RestoreDefautDialect { + RestoreDefautDialect(SmallVector &defaultDialectStack) + : defaultDialectStack(defaultDialectStack) {} + ~RestoreDefautDialect() { defaultDialectStack.pop_back(); } + + private: + SmallVector &defaultDialectStack; + }; + getState().defaultDialectStack.push_back(defaultDialect.str()); + RestoreDefautDialect r(getState().defaultDialectStack); consumeToken(); diff --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h --- a/mlir/lib/Parser/ParserState.h +++ b/mlir/lib/Parser/ParserState.h @@ -82,6 +82,13 @@ /// An optional pointer to a struct containing high level parser state to be /// populated during parsing. AsmParserState *asmState; + + // Contains the stack of default dialect to use when parsing regions. + // A new dialect get pushed to the stack before parsing regions nested + // under an operation implementing `DefaultDialectAsmOpInterface`, and + // popped when done. At the top-level we start with "builtin" as the + // default, so that the top-level `module` operation parses as-is. + SmallVector defaultDialectStack{"builtin"}; }; } // end namespace detail diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1307,6 +1307,25 @@ return } + +// This tests the behavior of "default dialect": +// operations like `test.default_dialect` can define a default dialect +// used in nested region. +// CHECK-LABEL: func @default_dialect +func @default_dialect() { + test.default_dialect { + // The test dialect is the default in this region, the following two + // operations are parsed identically. + // CHECK-NOT: test.parse_integer_literal + // CHECK: parse_integer_literal : 5 + // CHECK: parse_integer_literal : 5 + parse_integer_literal : 5 + test.parse_integer_literal : 5 + "test.terminator"() : ()->() + } + return +} + // CHECK-LABEL: func @unreachable_dominance_violation_ok func @unreachable_dominance_violation_ok() -> i1 { // CHECK: [[VAL:%.*]] = constant false diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -632,6 +632,19 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } +// This is used to test encoding of a string attribute into an SSA name of a +// pretty printed value name. +def DefaultDialectOp : TEST_Op<"default_dialect", [DefaultDialectAsmOpInterface]> { + let regions = (region AnyRegion:$body); + let extraClassDeclaration = [{ + static StringRef getDefaultDialect() { + return "test"; + } + void getAsmResultNames(function_ref setNameFn) {} + }]; + let assemblyFormat = "regions attr-dict-with-keyword"; +} + //===----------------------------------------------------------------------===// // Test Locations //===----------------------------------------------------------------------===//