diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td --- a/mlir/include/mlir/IR/OpAsmInterface.td +++ b/mlir/include/mlir/IR/OpAsmInterface.td @@ -47,8 +47,23 @@ %first_result, %middle_results:2, %0 = "my.op" ... ``` }], - "void", "getAsmResultNames", (ins "::mlir::OpAsmSetValueNameFn":$setNameFn) - >, + + "void", "getAsmResultNames", (ins "::mlir::OpAsmSetValueNameFn":$setNameFn)>, + ]; +} + +/// Interface for hooking into the OpAsmPrinter and OpAsmParser. +def DefaultDialectAsmOpInterface : OpInterface<"DefaultDialectAsmOpInterface"> { + let description = [{ + This interface provides hooks to interact with the AsmPrinter and AsmParser + classes. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + StaticInterfaceMethod<[{ + }], + "StringRef", "getDefaultDialect">, ]; } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -961,4 +961,15 @@ /// The OpAsmOpInterface, see OpAsmInterface.td for more details. #include "mlir/IR/OpAsmInterface.h.inc" +namespace mlir { +inline StringRef getDefaultDialectFromEnclosingOps(Operation *op) { + while (op && op->getParentOp()) { + if (auto iface = dyn_cast(op)) + return iface.getDefaultDialect(); + op = op->getParentOp(); + } + return "builtin"; +} +} // end namespace mlir + #endif diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2296,12 +2296,49 @@ //===----------------------------------------------------------------------===// namespace { + +/// A raw_ostream that is intended to optionally filter the dialect name when +/// printing operations. +class operation_name_filtering_stream : public raw_ostream { + std::string filter; + raw_ostream &os; + + /// See raw_ostream::write_impl. + void write_impl(const char *Ptr, size_t Size) override { + StringRef toWrite(Ptr, Size); + if (filter.empty()) { + os.write(Ptr, Size); + return; + } + StringRef subfilter = StringRef(filter).take_front(Size); + StringRef front = toWrite.take_front(subfilter.size()); + // Filter match, skip from the stream. + if (subfilter == front) { + toWrite = toWrite.drop_front(subfilter.size()); + filter = StringRef(filter).drop_front(subfilter.size()).str(); + } + filter = ""; + os.write(toWrite.data(), toWrite.size()); + } + + /// Return the current position within the stream, not counting the bytes + /// currently in the buffer. + uint64_t current_pos() const override { return os.tell(); } + +public: + explicit operation_name_filtering_stream(raw_ostream &os) : os(os) { + SetUnbuffered(); + } + + void setFilter(StringRef dialect) { filter = (dialect + ".").str(); } +}; + /// This class contains the logic for printing operations, regions, and blocks. class OperationPrinter : public ModulePrinter, private OpAsmPrinter { public: explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags, AsmStateImpl &state) - : ModulePrinter(os, flags, &state) {} + : ModulePrinter(filteredOS, flags, &state), filteredOS(os) {} /// Print the given top-level operation. void printTopLevelOperation(Operation *op); @@ -2417,6 +2454,9 @@ } private: + SmallVector defaultDialectStack{"builtin"}; + operation_name_filtering_stream filteredOS; + /// The number of spaces used for indenting nested operations. const static unsigned indentWidth = 2; @@ -2496,6 +2536,7 @@ // Check to see if this is a known operation. If so, use the registered // custom printer hook. if (auto *opInfo = op->getAbstractOperation()) { + filteredOS.setFilter(defaultDialectStack.back()); opInfo->printAssembly(op, *this); return; } @@ -2512,6 +2553,7 @@ void OperationPrinter::printGenericOp(Operation *op) { os << '"'; + filteredOS.setFilter(defaultDialectStack.back()); printEscapedString(op->getName().getStringRef(), os); os << "\"("; interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); @@ -2641,6 +2683,18 @@ bool printEmptyBlock) { os << " {" << newLine; if (!region.empty()) { + struct RestoreDefautDialect { + RestoreDefautDialect(SmallVector &defaulDialectStack) + : defaulDialectStack(defaulDialectStack) {} + ~RestoreDefautDialect() { defaulDialectStack.pop_back(); } + + private: + SmallVector &defaulDialectStack; + }; + defaultDialectStack.push_back( + getDefaultDialectFromEnclosingOps(region.getParentOp()).str()); + RestoreDefautDialect r(defaultDialectStack); + auto *entryBlock = ®ion.front(); // Force printing the block header if printEmptyBlock is set and the block // is empty or if printEntryBlockArgs is set and there are arguments to diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -272,7 +272,7 @@ /// isolated from those above. ParseResult parseRegion(Region ®ion, ArrayRef> entryArguments, - bool isIsolatedNameScope = false); + bool isIsolatedNameScope); /// Parse a region body into 'region'. ParseResult @@ -954,7 +954,8 @@ do { // Create temporary regions with the top level region as parent. result.regions.emplace_back(new Region(topLevelOp)); - if (parseRegion(*result.regions.back(), /*entryArguments=*/{})) + if (parseRegion(*result.regions.back(), /*entryArguments=*/{}, + /*enableNameShadowing=*/false)) return nullptr; } while (consumeIf(Token::comma)); if (parseToken(Token::r_paren, "expected ')' to end region list")) @@ -1660,10 +1661,10 @@ /// Parses a region if present. If the region is present, a new region is /// allocated and placed in `region`. If no region is present, `region` /// remains untouched. - OptionalParseResult - parseOptionalRegion(std::unique_ptr ®ion, - ArrayRef arguments, ArrayRef argTypes, - bool enableNameShadowing = false) override { + OptionalParseResult parseOptionalRegion(std::unique_ptr ®ion, + ArrayRef arguments, + ArrayRef argTypes, + bool enableNameShadowing) override { if (parser.getToken().isNot(Token::l_brace)) return llvm::None; std::unique_ptr newRegion = std::make_unique(); @@ -1860,8 +1861,9 @@ // If the operation name has no namespace prefix we treat it as a builtin // or standard operation and prefix it with "builtin" or "std". // TODO: Remove the special casing here. - opDefinition = AbstractOperation::lookup(Twine("builtin." + opName).str(), - getContext()); + opDefinition = AbstractOperation::lookup( + Twine(getState().defaultDialectStack.back() + "." + opName).str(), + getContext()); if (!opDefinition && getContext()->getOrLoadDialect("std")) { opDefinition = AbstractOperation::lookup(Twine("std." + opName).str(), getContext()); @@ -1877,10 +1879,14 @@ function_ref parseAssemblyFn; bool isIsolatedFromAbove = false; + StringRef defaultDialect = getState().defaultDialectStack.back(); if (opDefinition) { parseAssemblyFn = opDefinition->getParseAssemblyFn(); isIsolatedFromAbove = opDefinition->hasTrait(); + auto *iface = opDefinition->getInterface(); + if (iface) + defaultDialect = iface->getDefaultDialect(); } else { Optional dialectHook; if (dialect) @@ -1891,6 +1897,16 @@ } parseAssemblyFn = *dialectHook; } + struct RestoreDefautDialect { + RestoreDefautDialect(SmallVector &defaultDialectStack) + : defaultDialectStack(defaultDialectStack) {} + ~RestoreDefautDialect() { defaultDialectStack.pop_back(); } + + private: + SmallVector &defaultDialectStack; + }; + getState().defaultDialectStack.push_back(defaultDialect.str()); + RestoreDefautDialect r(getState().defaultDialectStack); consumeToken(); diff --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h --- a/mlir/lib/Parser/ParserState.h +++ b/mlir/lib/Parser/ParserState.h @@ -82,6 +82,8 @@ /// An optional pointer to a struct containing high level parser state to be /// populated during parsing. AsmParserState *asmState; + + SmallVector defaultDialectStack{"builtin"}; }; } // end namespace detail diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1307,6 +1307,25 @@ return } + +// This tests the behavior of "default dialect": +// operations like `test.default_dialect` can define a default dialect +// used in nested region. +// CHECK-LABEL: func @default_dialect +func @default_dialect() { + test.default_dialect { + // The test dialect is the default in this region, the following two + // operations are parsed identically. + // CHECK-NOT: test.parse_integer_literal + // CHECK: parse_integer_literal : 5 + // CHECK: parse_integer_literal : 5 + parse_integer_literal : 5 + test.parse_integer_literal : 5 + "test.terminator"() : ()->() + } + return +} + // CHECK-LABEL: func @unreachable_dominance_violation_ok func @unreachable_dominance_violation_ok() -> i1 { // CHECK: [[VAL:%.*]] = constant false diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -632,6 +632,19 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } +// This is used to test encoding of a string attribute into an SSA name of a +// pretty printed value name. +def DefaultDialectOp : TEST_Op<"default_dialect", [DefaultDialectAsmOpInterface]> { + let regions = (region AnyRegion:$body); + let extraClassDeclaration = [{ + static StringRef getDefaultDialect() { + return "test"; + } + void getAsmResultNames(function_ref setNameFn) {} + }]; + let assemblyFormat = "regions attr-dict-with-keyword"; +} + //===----------------------------------------------------------------------===// // Test Locations //===----------------------------------------------------------------------===//