diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -42,7 +42,7 @@ let parameters = (ins "Type":$elementType, "unsigned":$numElements); let assemblyFormat = [{ - `<` $numElements `x` ` ` custom($elementType) `>` + `<` $numElements `x` custom($elementType) `>` }]; let genVerifyDecl = 1; @@ -182,7 +182,7 @@ let parameters = (ins "Type":$elementType, "unsigned":$numElements); let assemblyFormat = [{ - `<` $numElements `x` ` ` custom($elementType) `>` + `<` $numElements `x` custom($elementType) `>` }]; let genVerifyDecl = 1; diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -312,7 +312,7 @@ def TestCustomAnchor : Test_Attr<"TestCustomAnchor"> { let parameters = (ins "int":$a, OptionalParameter<"mlir::Optional">:$b); let mnemonic = "custom_anchor"; - let assemblyFormat = "`<` $a (`>`) : (`,` ` ` custom($b)^ `>`)?"; + let assemblyFormat = "`<` $a (`>`) : (`,` custom($b)^ `>`)?"; } def Test_IteratorTypeEnum diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -680,7 +680,10 @@ } return success(); } - +static ParseResult parseCustomDirectiveSpacing(OpAsmParser &parser, + mlir::StringAttr &attr) { + return parser.parseAttribute(attr); +} static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, NamedAttrList &attrs) { return parser.parseOptionalAttrDict(attrs); @@ -759,7 +762,10 @@ if (optAttribute) printer << ", " << optAttribute; } - +static void printCustomDirectiveSpacing(OpAsmPrinter &printer, Operation *op, + Attribute attribute) { + printer << attribute; +} static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, DictionaryAttr attrs) { printer.printOptionalAttrDict(attrs.getValue()); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2315,6 +2315,16 @@ }]; } +def FormatCustomDirectiveSpacing + : TEST_Op<"format_custom_directive_spacing"> { + let arguments = (ins StrAttr:$attr1, StrAttr:$attr2); + let assemblyFormat = [{ + custom($attr1) + custom($attr2) + attr-dict + }]; +} + def FormatCustomDirectiveAttrDict : TEST_Op<"format_custom_directive_attrdict"> { let arguments = (ins I64Attr:$attr, OptionalAttr:$optAttr); diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -319,10 +319,17 @@ def TestTypeCustom : Test_Type<"TestTypeCustom"> { let parameters = (ins "int":$a, OptionalParameter<"mlir::Optional">:$b); let mnemonic = "custom_type"; - let assemblyFormat = [{ `<` custom($a) + let assemblyFormat = [{ `<` custom($a) `` custom(ref($a), $b) `>` }]; } +def TestTypeCustomSpacing : Test_Type<"TestTypeCustomSpacing"> { + let parameters = (ins "int":$a, "int":$b); + let mnemonic = "custom_type_spacing"; + let assemblyFormat = [{ `<` custom($a) + custom($b) `>` }]; +} + def TestTypeCustomString : Test_Type<"TestTypeCustomString"> { let parameters = (ins StringRefParameter<>:$foo); let mnemonic = "custom_type_string"; diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -134,7 +134,7 @@ } static void printBarString(AsmPrinter &printer, StringRef foo) { - printer << ' ' << foo; + printer << foo; } //===----------------------------------------------------------------------===// // Tablegen Generated Definitions diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -62,6 +62,7 @@ // CHECK: !test.default_valued_type<> // CHECK: !test.custom_type<-5> // CHECK: !test.custom_type<2 0 1 5> +// CHECK: !test.custom_type_spacing<1 2> // CHECK: !test.custom_type_string<"foo" foo> // CHECK: !test.custom_type_string<"bar" bar> @@ -98,6 +99,7 @@ !test.default_valued_type<>, !test.custom_type<-5>, !test.custom_type<2 9 9 5>, + !test.custom_type_spacing<1 2>, !test.custom_type_string<"foo" foo>, !test.custom_type_string<"bar" bar> ) diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -398,6 +398,9 @@ return } +// CHECK: test.format_custom_directive_spacing "a" "b" +test.format_custom_directive_spacing "a" "b" + // CHECK: test.format_literal_following_optional_group(5 : i32) : i32 {a} test.format_literal_following_optional_group(5 : i32) : i32 {a} diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -826,6 +826,12 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os) { + // Insert a space before the custom directive, if necessary. + if (shouldEmitSpace || !lastWasPunctuation) + os << tgfmt("$_printer << ' ';\n", &ctx); + shouldEmitSpace = true; + lastWasPunctuation = false; + os << tgfmt("print$0($_printer", &ctx, el->getName()); os.indent(); for (FormatElement *arg : el->getArguments()) {