diff --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp --- a/mlir/lib/TableGen/TypeDef.cpp +++ b/mlir/lib/TableGen/TypeDef.cpp @@ -112,6 +112,8 @@ if (auto *typeParameter = dyn_cast(parameterType)) { llvm::RecordVal *code = typeParameter->getDef()->getValue("allocator"); + if (!code) + return llvm::Optional(); if (llvm::CodeInit *ci = dyn_cast(code->getValue())) return ci->getValue(); if (isa(code->getValue())) diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -6,6 +6,11 @@ // DECL: #ifdef GET_TYPEDEF_CLASSES // DECL: #undef GET_TYPEDEF_CLASSES +// DECL: namespace mlir { +// DECL: class DialectAsmParser; +// DECL: class DialectAsmPrinter; +// DECL: } // namespace mlir + // DECL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic); // DECL: ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, ::mlir::DialectAsmPrinter& printer); @@ -34,6 +39,10 @@ // DECL: class SimpleAType: public ::mlir::Type } +def RTLValueType : Type, "Type"> { + string cppType = "::mlir::Type"; +} + // A more complex parameterized type def B_CompoundTypeA : TestType<"CompoundA"> { let summary = "A more complex parameterized type"; @@ -44,14 +53,15 @@ "int":$widthOfSomething, "::mlir::test::SimpleTypeA": $exampleTdType, "SomeCppStruct": $exampleCppType, - ArrayRefParameter<"int", "Matrix dimensions">:$dims + ArrayRefParameter<"int", "Matrix dimensions">:$dims, + RTLValueType:$inner ); let genVerifyInvariantsDecl = 1; // DECL-LABEL: class CompoundAType: public ::mlir::Type -// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims); -// DECL: static ::mlir::Type getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims); +// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); +// DECL: static ::mlir::Type getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); // DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; } // DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); // DECL: void print(::mlir::DialectAsmPrinter& printer) const; diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp @@ -133,6 +133,15 @@ // GEN: TypeDef declarations //===----------------------------------------------------------------------===// +/// Print this above all the other declarations. Contains type declarations used +/// later on. +static const char *const typeDefDeclHeader = R"( +namespace mlir { +class DialectAsmParser; +class DialectAsmPrinter; +} // namespace mlir +)"; + /// The code block for the start of a typeDef class declaration -- singleton /// case. /// @@ -174,8 +183,8 @@ /// /// {0}: List of parameters, parameters style. static const char *const typeDefDeclVerifyStr = R"( - static ::mlir::LogicalResult verifyConstructionInvariants(Location loc{0}); - static ::mlir::Type getChecked(Location loc{0}); + static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{0}); + static ::mlir::Type getChecked(::mlir::Location loc{0}); )"; /// Generate the declaration for the given typeDef class. @@ -239,6 +248,10 @@ findAllTypeDefs(recordKeeper, typeDefs); IfDefScope scope("GET_TYPEDEF_CLASSES", os); + + // Output the common "header". + os << typeDefDeclHeader; + if (typeDefs.size() > 0) { NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect());