diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td --- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td +++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td @@ -39,6 +39,7 @@ let parameters = (ins "KindTy":$kind); let genAccessors = 1; + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ using KindTy = unsigned; @@ -62,6 +63,7 @@ let parameters = (ins "mlir::Type":$eleTy); let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; } def fir_BoxType : FIR_Type<"Box", "box"> { @@ -91,6 +93,7 @@ }]; let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; } def fir_CharacterType : FIR_Type<"Character", "char"> { @@ -104,6 +107,7 @@ }]; let parameters = (ins "KindTy":$FKind, "CharacterType::LenType":$len); + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ using KindTy = unsigned; @@ -143,6 +147,7 @@ }]; let parameters = (ins "KindTy":$fKind); + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ using KindTy = unsigned; @@ -174,6 +179,7 @@ let parameters = (ins "mlir::Type":$eleTy); let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; @@ -194,6 +200,7 @@ }]; let parameters = (ins "KindTy":$fKind); + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ using KindTy = unsigned; @@ -219,6 +226,7 @@ }]; let parameters = (ins "KindTy":$fKind); + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ using KindTy = unsigned; @@ -259,6 +267,7 @@ let parameters = (ins "mlir::Type":$eleTy); let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; @@ -283,6 +292,7 @@ }]; let parameters = (ins "KindTy":$fKind); + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ using KindTy = unsigned; @@ -304,6 +314,7 @@ let genVerifyDecl = 1; let genStorageClass = 0; + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ using TypePair = std::pair; @@ -351,6 +362,7 @@ }]; let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; } def fir_ShapeType : FIR_Type<"Shape", "shape"> { @@ -363,6 +375,7 @@ }]; let parameters = (ins "unsigned":$rank); + let hasCustomAssemblyFormat = 1; } def fir_ShapeShiftType : FIR_Type<"ShapeShift", "shapeshift"> { @@ -376,6 +389,7 @@ }]; let parameters = (ins "unsigned":$rank); + let hasCustomAssemblyFormat = 1; } def fir_ShiftType : FIR_Type<"Shift", "shift"> { @@ -388,6 +402,7 @@ }]; let parameters = (ins "unsigned":$rank); + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ using KindTy = unsigned; @@ -417,6 +432,7 @@ ); let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; let builders = [ TypeBuilderWithInferredContext<(ins @@ -470,6 +486,7 @@ }]; let parameters = (ins "unsigned":$rank); + let hasCustomAssemblyFormat = 1; } def fir_TypeDescType : FIR_Type<"TypeDesc", "tdesc"> { @@ -483,6 +500,7 @@ let parameters = (ins "mlir::Type":$ofTy); let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; @@ -505,6 +523,7 @@ let parameters = (ins "uint64_t":$len, "mlir::Type":$eleTy); let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ static bool isValidElementType(mlir::Type t); diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td @@ -47,6 +47,7 @@ return $_get(valueType.getContext(), valueType); }]> ]; + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; } diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td @@ -41,6 +41,8 @@ }]; let parameters = (ins StringRefParameter<"the opaque value">:$value); + + let hasCustomAssemblyFormat = 1; } #endif // MLIR_DIALECT_EMITC_IR_EMITCATTRIBUTES diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td @@ -42,6 +42,7 @@ }]; let parameters = (ins StringRefParameter<"the opaque value">:$value); + let hasCustomAssemblyFormat = 1; } def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> { diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -23,6 +23,7 @@ let parameters = (ins "FastmathFlags":$flags ); + let hasCustomAssemblyFormat = 1; } // Attribute definition for the LLVM Linkage enum. @@ -31,6 +32,7 @@ let parameters = (ins "linkage::Linkage":$linkage ); + let hasCustomAssemblyFormat = 1; } def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> { @@ -63,6 +65,7 @@ AttrBuilder<(ins "ArrayRef>":$sortedOptions)>, AttrBuilder<(ins "LoopOptionsAttrBuilder &":$optionBuilders)> ]; + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; } diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td @@ -67,6 +67,7 @@ }]>, ]; let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; } diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -80,6 +80,7 @@ ); let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ // Dimension level types that define sparse tensors: diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -188,20 +188,20 @@ // Use the lowercased name as the keyword for parsing/printing. Specify only // if you want tblgen to generate declarations and/or definitions of - // the printer/parser. + // the printer/parser. If specified and the Attribute or Type contains + // parameters, `assemblyFormat` or `hasCustomAssemblyFormat` must also be + // specified. string mnemonic = ?; - // If 'mnemonic' specified, - // If null, generate just the declarations. - // If a non-empty code block, just use that code as the definition code. - // Error if an empty code block. - code printer = ?; - code parser = ?; - // Custom assembly format. Requires 'mnemonic' to be specified. Cannot be - // specified at the same time as either 'printer' or 'parser'. The generated + // specified at the same time as 'hasCustomAssemblyFormat'. The generated // printer requires 'genAccessors' to be true. string assemblyFormat = ?; + /// This field indicates that the attribute or type has a custom assembly format + /// implemented in C++. When set to `1` a `parse` and `print` method are generated + /// on the generated class. The attribute or type should implement these methods to + /// support the custom format. + bit hasCustomAssemblyFormat = 0; // If set, generate accessors for each parameter. bit genAccessors = 1; diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -175,30 +175,13 @@ /// supposed to auto-generate them. Optional getMnemonic() const; - /// Returns the code to use as the types printer method. If not specified, - /// return a non-value. Otherwise, return the contents of that code block. - Optional getPrinterCode() const; - - /// Returns the code to use as the parser method. If not specified, returns - /// None. Otherwise, returns the contents of that code block. - Optional getParserCode() const; + /// Returns if the attribute or type has a custom assembly format implemented + /// in C++. Corresponds to the `hasCustomAssemblyFormat` field. + bool hasCustomAssemblyFormat() const; /// Returns the custom assembly format, if one was specified. Optional getAssemblyFormat() const; - /// An attribute or type with parameters needs a parser. - bool needsParserPrinter() const { return getNumParameters() != 0; } - - /// Returns true if this attribute or type has a generated parser. - bool hasGeneratedParser() const { - return getParserCode() || getAssemblyFormat(); - } - - /// Returns true if this attribute or type has a generated printer. - bool hasGeneratedPrinter() const { - return getPrinterCode() || getAssemblyFormat(); - } - /// Returns true if the accessors based on the parameters should be generated. bool genAccessors() const; diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -62,6 +62,30 @@ for (unsigned i = 0, e = parametersDag->getNumArgs(); i < e; ++i) parameters.push_back(AttrOrTypeParameter(parametersDag, i)); } + + // Verify the use of the mnemonic field. + bool hasCppFormat = hasCustomAssemblyFormat(); + bool hasDeclarativeFormat = getAssemblyFormat().hasValue(); + if (getMnemonic()) { + if (hasCppFormat && hasDeclarativeFormat) { + PrintFatalError(getLoc(), "cannot specify both 'assemblyFormat' " + "and 'hasCustomAssemblyFormat'"); + } + if (!parameters.empty() && !hasCppFormat && !hasDeclarativeFormat) { + PrintFatalError(getLoc(), + "must specify either 'assemblyFormat' or " + "'hasCustomAssemblyFormat' when 'mnemonic' is set"); + } + } else if (hasCppFormat || hasDeclarativeFormat) { + PrintFatalError(getLoc(), + "'assemblyFormat' or 'hasCustomAssemblyFormat' can only be " + "used when 'mnemonic' is set"); + } + // Assembly format requires accessors to be generated. + if (hasDeclarativeFormat && !genAccessors()) { + PrintFatalError(getLoc(), + "'assemblyFormat' requires 'genAccessors' to be true"); + } } Dialect AttrOrTypeDef::getDialect() const { @@ -122,12 +146,8 @@ return def->getValueAsOptionalString("mnemonic"); } -Optional AttrOrTypeDef::getPrinterCode() const { - return def->getValueAsOptionalString("printer"); -} - -Optional AttrOrTypeDef::getParserCode() const { - return def->getValueAsOptionalString("parser"); +bool AttrOrTypeDef::hasCustomAssemblyFormat() const { + return def->getValueAsBit("hasCustomAssemblyFormat"); } Optional AttrOrTypeDef::getAssemblyFormat() const { diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -43,6 +43,7 @@ "An example of an array of ints" // Parameter description. >: $arrayOfInts ); + let hasCustomAssemblyFormat = 1; } def CompoundAttrNested : Test_Attr<"CompoundAttrNested"> { let mnemonic = "cmpnd_nested"; @@ -54,6 +55,7 @@ def AttrWithSelfTypeParam : Test_Attr<"AttrWithSelfTypeParam"> { let mnemonic = "attr_with_self_type_param"; let parameters = (ins AttributeSelfTypeParameter<"">:$type); + let hasCustomAssemblyFormat = 1; } // An attribute testing AttributeSelfTypeParameter. @@ -61,6 +63,7 @@ let mnemonic = "attr_with_type_builder"; let parameters = (ins "::mlir::IntegerAttr":$attr); let typeBuilder = "$_attr.getType()"; + let hasCustomAssemblyFormat = 1; } def TestAttrTrait : NativeAttrTrait<"TestAttrTrait">; @@ -68,7 +71,6 @@ // The definition of a singleton attribute that has a trait. def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> { let mnemonic = "attr_with_trait"; - let parameters = (ins ); } // Test support for ElementsAttrInterface. @@ -106,6 +108,7 @@ } }]; let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; } def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [ @@ -120,6 +123,7 @@ "::mlir::Attribute":$second, "::mlir::Attribute":$third ); + let hasCustomAssemblyFormat = 1; } // A more complex parameterized attribute with multiple level of nesting. diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -48,6 +48,7 @@ let extraClassDeclaration = [{ struct SomeCppStruct {}; }]; + let hasCustomAssemblyFormat = 1; } // A more complex and nested parameterized type. @@ -92,12 +93,8 @@ "::test::TestIntegerType::SignednessSemantics":$signedness ); - // We define the printer inline. - let printer = [{ - $_printer << "<"; - printSignedness($_printer, getImpl()->signedness); - $_printer << ", " << getImpl()->width << ">"; - }]; + // Indicate we use a custom format. + let hasCustomAssemblyFormat = 1; // Define custom builder methods. let builders = [ @@ -108,19 +105,6 @@ ]; let skipDefaultBuilders = 1; - // The parser is defined here also. - let parser = [{ - if ($_parser.parseLess()) return Type(); - SignednessSemantics signedness; - if (parseSignedness($_parser, signedness)) return Type(); - if ($_parser.parseComma()) return Type(); - int width; - if ($_parser.parseInteger(width)) return Type(); - if ($_parser.parseGreater()) return Type(); - Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc()); - return getChecked(loc, loc.getContext(), width, signedness); - }]; - // Any extra code one wants in the type's class declaration. let extraClassDeclaration = [{ /// Signedness semantics. @@ -150,37 +134,7 @@ "::test::FieldInfo", // FieldInfo is defined/declared in TestTypes.h. "Models struct fields">: $fields ); - - // Prints the type in this format: - // struct<[{field1Name, field1Type}, {field2Name, field2Type}] - let printer = [{ - $_printer << "<"; - for (size_t i=0, e = getImpl()->fields.size(); i < e; i++) { - const auto& field = getImpl()->fields[i]; - $_printer << "{" << field.name << "," << field.type << "}"; - if (i < getImpl()->fields.size() - 1) - $_printer << ","; - } - $_printer << ">"; - }]; - - // Parses the above format - let parser = [{ - llvm::SmallVector parameters; - if ($_parser.parseLess()) return Type(); - while (mlir::succeeded($_parser.parseOptionalLBrace())) { - llvm::StringRef name; - if ($_parser.parseKeyword(&name)) return Type(); - if ($_parser.parseComma()) return Type(); - Type type; - if ($_parser.parseType(type)) return Type(); - if ($_parser.parseRBrace()) return Type(); - parameters.push_back(FieldInfo {name, type}); - if ($_parser.parseOptionalComma()) break; - } - if ($_parser.parseGreater()) return Type(); - return get($_ctxt, parameters); - }]; + let hasCustomAssemblyFormat = 1; } def StructType : FieldInfo_Type<"Struct"> { @@ -208,6 +162,7 @@ public: }]; + let hasCustomAssemblyFormat = 1; } def TestMemRefElementType : Test_Type<"TestMemRefElementType", diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -83,6 +83,65 @@ return llvm::hash_combine(fi.name, fi.type); } +//===----------------------------------------------------------------------===// +// TestCustomType +//===----------------------------------------------------------------------===// + +static LogicalResult parseCustomTypeA(AsmParser &parser, + FailureOr &a_result) { + a_result.emplace(); + return parser.parseInteger(*a_result); +} + +static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; } + +static LogicalResult parseCustomTypeB(AsmParser &parser, int a, + FailureOr> &b_result) { + if (a < 0) + return success(); + for (int i : llvm::seq(0, a)) + if (failed(parser.parseInteger(i))) + return failure(); + b_result.emplace(0); + return parser.parseInteger(**b_result); +} + +static void printCustomTypeB(AsmPrinter &printer, int a, Optional b) { + if (a < 0) + return; + printer << ' '; + for (int i : llvm::seq(0, a)) + printer << i << ' '; + printer << *b; +} + +static LogicalResult parseFooString(AsmParser &parser, + FailureOr &foo) { + std::string result; + if (parser.parseString(&result)) + return failure(); + foo = std::move(result); + return success(); +} + +static void printFooString(AsmPrinter &printer, StringRef foo) { + printer << '"' << foo << '"'; +} + +static LogicalResult parseBarString(AsmParser &parser, StringRef foo) { + return parser.parseKeyword(foo); +} + +static void printBarString(AsmPrinter &printer, StringRef foo) { + printer << ' ' << foo; +} +//===----------------------------------------------------------------------===// +// Tablegen Generated Definitions +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "TestTypeDefs.cpp.inc" + //===----------------------------------------------------------------------===// // CompoundAType //===----------------------------------------------------------------------===// @@ -129,6 +188,54 @@ return success(); } +Type TestIntegerType::parse(AsmParser &parser) { + SignednessSemantics signedness; + int width; + if (parser.parseLess() || parseSignedness(parser, signedness) || + parser.parseComma() || parser.parseInteger(width) || + parser.parseGreater()) + return Type(); + Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); + return getChecked(loc, loc.getContext(), width, signedness); +} + +void TestIntegerType::print(AsmPrinter &p) const { + p << "<"; + printSignedness(p, getSignedness()); + p << ", " << getWidth() << ">"; +} + +//===----------------------------------------------------------------------===// +// TestStructType +//===----------------------------------------------------------------------===// + +Type StructType::parse(AsmParser &p) { + SmallVector parameters; + if (p.parseLess()) + return Type(); + while (succeeded(p.parseOptionalLBrace())) { + Type type; + StringRef name; + if (p.parseKeyword(&name) || p.parseComma() || p.parseType(type) || + p.parseRBrace()) + return Type(); + parameters.push_back(FieldInfo{name, type}); + if (p.parseOptionalComma()) + break; + } + if (p.parseGreater()) + return Type(); + return get(p.getContext(), parameters); +} + +void StructType::print(AsmPrinter &p) const { + p << "<"; + llvm::interleaveComma(getFields(), p, [&](const FieldInfo &field) { + p << "{" << field.name << "," << field.type << "}"; + }); + p << ">"; +} + //===----------------------------------------------------------------------===// // TestType //===----------------------------------------------------------------------===// @@ -208,66 +315,6 @@ return 1; } -//===----------------------------------------------------------------------===// -// TestCustomType -//===----------------------------------------------------------------------===// - -static LogicalResult parseCustomTypeA(AsmParser &parser, - FailureOr &a_result) { - a_result.emplace(); - return parser.parseInteger(*a_result); -} - -static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; } - -static LogicalResult parseCustomTypeB(AsmParser &parser, int a, - FailureOr> &b_result) { - if (a < 0) - return success(); - for (int i : llvm::seq(0, a)) - if (failed(parser.parseInteger(i))) - return failure(); - b_result.emplace(0); - return parser.parseInteger(**b_result); -} - -static void printCustomTypeB(AsmPrinter &printer, int a, Optional b) { - if (a < 0) - return; - printer << ' '; - for (int i : llvm::seq(0, a)) - printer << i << ' '; - printer << *b; -} - -static LogicalResult parseFooString(AsmParser &parser, - FailureOr &foo) { - std::string result; - if (parser.parseString(&result)) - return failure(); - foo = std::move(result); - return success(); -} - -static void printFooString(AsmPrinter &printer, StringRef foo) { - printer << '"' << foo << '"'; -} - -static LogicalResult parseBarString(AsmParser &parser, StringRef foo) { - return parser.parseKeyword(foo); -} - -static void printBarString(AsmPrinter &printer, StringRef foo) { - printer << ' ' << foo; -} - -//===----------------------------------------------------------------------===// -// Tablegen Generated Definitions -//===----------------------------------------------------------------------===// - -#define GET_TYPEDEF_CLASSES -#include "TestTypeDefs.cpp.inc" - //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -60,6 +60,7 @@ ); let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; // DECL-LABEL: class CompoundAAttr : public ::mlir::Attribute // DECL: static CompoundAAttr getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef dims, ::mlir::Type inner); @@ -102,6 +103,7 @@ ins StringRefParameter<"Label for index">:$label ); + let hasCustomAssemblyFormat = 1; // DECL-LABEL: class IndexAttr : public ::mlir::Attribute // DECL: static constexpr ::llvm::StringLiteral getMnemonic() { @@ -127,6 +129,7 @@ let mnemonic = "attr_with_type_builder"; let parameters = (ins "::mlir::IntegerAttr":$attr); let typeBuilder = "$_attr.getType()"; + let hasCustomAssemblyFormat = 1; } // DEF-LABEL: struct AttrWithTypeBuilderAttrStorage diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir --- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir @@ -38,7 +38,7 @@ return } -// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int}>) +// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla}, {field2,!test.int}>) func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int} > ) { return } diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -35,8 +35,8 @@ def Test_Dialect: Dialect { // DECL-NOT: TestDialect - let name = "TestDialect"; - let cppNamespace = "::test"; + let name = "TestDialect"; + let cppNamespace = "::test"; } class TestType : TypeDef { } @@ -54,16 +54,16 @@ let summary = "A more complex parameterized type"; let description = "This type is to test a reasonably complex type"; let mnemonic = "cmpnd_a"; - let parameters = ( - ins - "int":$widthOfSomething, - "::test::SimpleTypeA": $exampleTdType, - "SomeCppStruct": $exampleCppType, - ArrayRefParameter<"int", "Matrix dimensions">:$dims, - RTLValueType:$inner + let parameters = (ins + "int":$widthOfSomething, + "::test::SimpleTypeA": $exampleTdType, + "SomeCppStruct": $exampleCppType, + ArrayRefParameter<"int", "Matrix dimensions">:$dims, + RTLValueType:$inner ); let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; // DECL-LABEL: class CompoundAType : public ::mlir::Type // DECL: static CompoundAType getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); @@ -79,12 +79,12 @@ } def C_IndexType : TestType<"Index"> { - let mnemonic = "index"; + let mnemonic = "index"; - let parameters = ( - ins - StringRefParameter<"Label for index">:$label - ); + let parameters = (ins + StringRefParameter<"Label for index">:$label + ); + let hasCustomAssemblyFormat = 1; // DECL-LABEL: class IndexType : public ::mlir::Type // DECL: static constexpr ::llvm::StringLiteral getMnemonic() { @@ -95,8 +95,7 @@ } def D_SingleParameterType : TestType<"SingleParameter"> { - let parameters = ( - ins + let parameters = (ins "int": $num ); // DECL-LABEL: struct SingleParameterTypeStorage; @@ -105,17 +104,17 @@ } def E_IntegerType : TestType<"Integer"> { - let mnemonic = "int"; - let genVerifyDecl = 1; - let parameters = ( - ins - "SignednessSemantics":$signedness, - TypeParameter<"unsigned", "Bitwidth of integer">:$width - ); + let mnemonic = "int"; + let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; + let parameters = (ins + "SignednessSemantics":$signedness, + TypeParameter<"unsigned", "Bitwidth of integer">:$width + ); // DECL-LABEL: IntegerType : public ::mlir::Type - let extraClassDeclaration = [{ + let extraClassDeclaration = [{ /// Signedness semantics. enum SignednessSemantics { Signless, /// No signedness semantics @@ -132,7 +131,7 @@ bool isSigned() const { return getSignedness() == Signed; } /// Return true if this is an unsigned integer type. bool isUnsigned() const { return getSignedness() == Unsigned; } - }]; + }]; // DECL: /// Signedness semantics. // DECL-NEXT: enum SignednessSemantics { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -121,10 +121,6 @@ /// Emit a checked custom builder. void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder); - //===--------------------------------------------------------------------===// - // Parser and Printer Emission - void emitParserPrinterBody(MethodBody &parser, MethodBody &printer); - //===--------------------------------------------------------------------===// // Interface Method Emission @@ -264,9 +260,10 @@ auto *mnemonic = defCls.addStaticMethod( "::llvm::StringLiteral", "getMnemonic"); mnemonic->body().indent() << strfmt("return {\"{0}\"};", *def.getMnemonic()); + // Declare the parser and printer, if needed. - if (!def.needsParserPrinter() && !def.hasGeneratedParser() && - !def.hasGeneratedPrinter()) + bool hasAssemblyFormat = def.getAssemblyFormat().hasValue(); + if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat) return; // Declare the parser. @@ -274,18 +271,18 @@ parserParams.emplace_back("::mlir::AsmParser &", "odsParser"); if (isa(&def)) parserParams.emplace_back("::mlir::Type", "odsType"); - auto *parser = defCls.addMethod( - strfmt("::mlir::{0}", valueType), "parse", - def.hasGeneratedParser() ? Method::Static : Method::StaticDeclaration, - std::move(parserParams)); + auto *parser = defCls.addMethod(strfmt("::mlir::{0}", valueType), "parse", + hasAssemblyFormat ? Method::Static + : Method::StaticDeclaration, + std::move(parserParams)); // Declare the printer. - auto props = - def.hasGeneratedPrinter() ? Method::Const : Method::ConstDeclaration; + auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration; Method *printer = defCls.addMethod("void", "print", props, MethodParameter("::mlir::AsmPrinter &", "odsPrinter")); - // Emit the bodies. - emitParserPrinterBody(parser->body(), printer->body()); + // Emit the bodies if we are using the declarative format. + if (hasAssemblyFormat) + return generateAttrOrTypeFormat(def, parser->body(), printer->body()); } void DefGen::emitAccessors() { @@ -406,50 +403,6 @@ m->body().indent().getStream().printReindented(bodyStr); } -//===----------------------------------------------------------------------===// -// Parser and Printer Emission - -void DefGen::emitParserPrinterBody(MethodBody &parser, MethodBody &printer) { - Optional parserCode = def.getParserCode(); - Optional printerCode = def.getPrinterCode(); - Optional asmFormat = def.getAssemblyFormat(); - // Verify the parser-printer specification first. - if (asmFormat && (parserCode || printerCode)) { - PrintFatalError(def.getLoc(), - def.getName() + ": assembly format cannot be specified at " - "the same time as printer or parser code"); - } - // Specified code cannot be empty. - if (parserCode && parserCode->empty()) - PrintFatalError(def.getLoc(), def.getName() + ": parser cannot be empty"); - if (printerCode && printerCode->empty()) - PrintFatalError(def.getLoc(), def.getName() + ": printer cannot be empty"); - // Assembly format requires accessors to be generated. - if (asmFormat && !def.genAccessors()) { - PrintFatalError(def.getLoc(), - def.getName() + - ": the generated printer from 'assemblyFormat' " - "requires 'genAccessors' to be true"); - } - - // Generate the parser and printer bodies. - if (asmFormat) - return generateAttrOrTypeFormat(def, parser, printer); - - FmtContext ctx = FmtContext({{"_parser", "odsParser"}, - {"_printer", "odsPrinter"}, - {"_type", "odsType"}}); - if (parserCode) { - ctx.addSubst("_ctxt", "odsParser.getContext()"); - parser.indent().getStream().printReindented(tgfmt(*parserCode, &ctx).str()); - } - if (printerCode) { - ctx.addSubst("_ctxt", "odsPrinter.getContext()"); - printer.indent().getStream().printReindented( - tgfmt(*printerCode, &ctx).str()); - } -} - //===----------------------------------------------------------------------===// // Interface Method Emission @@ -829,18 +782,21 @@ for (auto &def : defs) { if (!def.getMnemonic()) continue; + bool hasParserPrinterDecl = + def.hasCustomAssemblyFormat() || def.getAssemblyFormat(); std::string defClass = strfmt( "{0}::{1}", def.getDialect().getCppNamespace(), def.getCppClassName()); + // If the def has no parameters or parser code, invoke a normal `get`. std::string parseOrGet = - def.needsParserPrinter() || def.hasGeneratedParser() + hasParserPrinterDecl ? strfmt("parse(parser{0})", isAttrGenerator ? ", type" : "") : "get(parser.getContext())"; parse.body() << llvm::formatv(getValueForMnemonic, defClass, parseOrGet); // If the def has no parameters and no printer, just print the mnemonic. StringRef printDef = ""; - if (def.needsParserPrinter() || def.hasGeneratedPrinter()) + if (hasParserPrinterDecl) printDef = "\nt.print(printer);"; printer.body() << llvm::formatv(printValue, defClass, printDef); } diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -253,8 +253,7 @@ os << "\n" << def.getSummary() << "\n"; // Emit the syntax if present. - if (def.getMnemonic() && def.getPrinterCode() == StringRef() && - def.getParserCode() == StringRef()) + if (def.getMnemonic() && !def.hasCustomAssemblyFormat()) emitAttrOrTypeDefAssemblyFormat(def, os); // Emit the description if present.