diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -324,4 +324,10 @@ let assemblyFormat = "`<` (`(` struct(params)^ `)`) : (`x`)? `>`"; } +def TestTypeSpaces : Test_Type<"TestTypeSpaceS"> { + let parameters = (ins "int":$a, "int":$b); + let mnemonic = "spaces"; + let assemblyFormat = "`<` ` ` $a `\\n` `(` `)` `` `(` `)` $b `>`"; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -42,6 +42,8 @@ // CHECK: !test.optional_group_struct // CHECK: !test.optional_group_struct<(b = 5)> // CHECK: !test.optional_group_struct<(a = 10, b = 5)> +// CHECK: !test.spaces< 5 +// CHECK-NEXT: ()() 6> func private @test_roundtrip_default_parsers_struct( !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4> ) -> ( @@ -67,5 +69,6 @@ !test.optional_group_params<(5, 6)>, !test.optional_group_struct, !test.optional_group_struct<(b = 5)>, - !test.optional_group_struct<(b = 5, a = 10)> + !test.optional_group_struct<(b = 5, a = 10)>, + !test.spaces<5 ()() 6> ) diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -207,6 +207,9 @@ /// Generate the printer code for an optional group. void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, MethodBody &os); + /// Generate a printer (or space eraser) for a whitespace element. + void genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx, + MethodBody &os); /// The ODS definition of the attribute or type whose format is being used to /// generate a parser and printer. @@ -292,6 +295,8 @@ return genStructParser(strct, ctx, os); if (auto *optional = dyn_cast(el)) return genOptionalGroupParser(optional, ctx, os); + if (isa(el)) + return; llvm_unreachable("unknown format element"); } @@ -612,6 +617,8 @@ return genVariablePrinter(var, ctx, os); if (auto *optional = dyn_cast(el)) return genOptionalGroupPrinter(optional, ctx, os); + if (auto *whitespace = dyn_cast(el)) + return genWhitespacePrinter(whitespace, ctx, os); llvm::PrintFatalError("unsupported format element"); } @@ -752,6 +759,20 @@ os.unindent() << "}\n"; } +void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx, + MethodBody &os) { + if (el->getValue() == "\\n") { + // FIXME: The newline should be `printer.printNewLine()`, i.e., handled by + // the printer. + os << tgfmt("$_printer << '\\n';\n", &ctx); + } else if (!el->getValue().empty()) { + os << tgfmt("$_printer << \"$0\";\n", &ctx, el->getValue()); + } else { + lastWasPunctuation = true; + } + shouldEmitSpace = false; +} + //===----------------------------------------------------------------------===// // DefFormatParser //===----------------------------------------------------------------------===//