diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp --- a/llvm/lib/TableGen/Record.cpp +++ b/llvm/lib/TableGen/Record.cpp @@ -2279,34 +2279,6 @@ "Record `" + getName() + "', field `" + FieldName + "' exists but does not have a code initializer!"); } -llvm::Optional -Record::getValueAsOptionalString(StringRef FieldName) const { - const RecordVal *R = getValue(FieldName); - if (!R || !R->getValue()) - return llvm::Optional(); - if (isa(R->getValue())) - return llvm::Optional(); - - if (StringInit *SI = dyn_cast(R->getValue())) - return SI->getValue(); - if (CodeInit *CI = dyn_cast(R->getValue())) - return CI->getValue(); - - PrintFatalError(getLoc(), "Record `" + getName() + "', field `" + FieldName + - "' does not have a string initializer!"); -} -llvm::Optional -Record::getValueAsOptionalCode(StringRef FieldName) const { - const RecordVal *R = getValue(FieldName); - if (!R || !R->getValue()) - return llvm::Optional(); - - if (CodeInit *CI = dyn_cast(R->getValue())) - return CI->getValue(); - - PrintFatalError(getLoc(), "Record `" + getName() + "', field `" + FieldName + - "' does not have a code initializer!"); -} BitsInit *Record::getValueAsBitsInit(StringRef FieldName) const { const RecordVal *R = getValue(FieldName); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2380,15 +2380,16 @@ dag parameters = (ins); // Use the lowercased name as the keyword for parsing/printing. Specify only - // if you want tblgen to automatically generate the printer/parser for this + // if you want tblgen to generate decls and/or defs of printer/parser for this // type. string mnemonic = ?; - // If null, generate just the declarations. - // If an empty code block, generate print/parser methods only if 'mnemonic' is specified. - // If a non-empty code block, just use that code as the definition code. - code printer = [{}]; - code parser = [{}]; + // If 'mnemonic' specified, + // If null, generate just the declarations. + // Error if an empty code block. + // If a non-empty code block, just use that code as the definition code. + code printer = ?; + code parser = ?; // If set, generate accessors for each Type parameter. bit genAccessors = 1; @@ -2414,13 +2415,11 @@ // For StringRefs, which require allocation class StringRefParameter : TypeParameter<"::llvm::StringRef", desc> { let allocator = [{$_dst = $_allocator.copyInto($_self);}]; - let syntax = "\"foo bar\""; } // For standard ArrayRefs, which require allocation class ArrayRefParameter : TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { let allocator = [{$_dst = $_allocator.copyInto($_self);}]; - let syntax = "[ " # arrayOf # ", " # arrayOf # ", ... ]"; } // For classes which require allocation and have their own allocateInto method @@ -2431,13 +2430,12 @@ // For ArrayRefs which contain things which allocate themselves class ArrayRefOfSelfAllocationParameter : TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { let allocator = [{ - llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields($_self.size()); + llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields; for (size_t i = 0; i < $_self.size(); i++) { - tmpFields[i] = $_self[i].allocateInto($_allocator); + tmpFields.push_back($_self[i].allocateInto($_allocator)); } $_dst = $_allocator.copyInto(ArrayRef<}] # arrayOf # [{>(tmpFields)); }]; - let syntax = "[ " # arrayOf # ", " # arrayOf # ", ... ]"; } diff --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/TypeDef.h --- a/mlir/include/mlir/TableGen/TypeDef.h +++ b/mlir/include/mlir/TableGen/TypeDef.h @@ -69,14 +69,12 @@ // supposed to auto-generate them llvm::Optional getMnemonic() const; - // Returns the code to use as the types printer method. If empty, generate - // just the declaration. If null and mnemonic is non-null, generate the - // declaration and definition. + // Returns the code to use as the types printer method. If not specified, + // return a non-value. Otherwise, return the contents of that code block. llvm::Optional getPrinterCode() const; - // Returns the code to use as the types parser method. If empty, generate - // just the declaration. If null and mnemonic is non-null, generate the - // declaration and definition. + // Returns the code to use as the types parser method. If not specified, + // return a non-value. Otherwise, return the contents of that code block. llvm::Optional getParserCode() const; // Should we generate accessors based on the types parameters? @@ -89,6 +87,9 @@ // Returns the dialects extra class declaration code. llvm::Optional getExtraDecls() const; + // Get the code location (for error printing) + llvm::ArrayRef getLoc() const; + // Returns whether two TypeDefs are equal by checking the equality of the // underlying record. bool operator==(const TypeDef &other) const; diff --git a/mlir/include/mlir/TableGen/TypeDefGenHelpers.h b/mlir/include/mlir/TableGen/TypeDefGenHelpers.h deleted file mode 100644 --- a/mlir/include/mlir/TableGen/TypeDefGenHelpers.h +++ /dev/null @@ -1,239 +0,0 @@ -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Accessory functions / templates to assist autogenerated code. The print/parse -// struct templates define standard serializations which can be overridden with -// custom printers/parsers. These structs can be used for temporary stack -// storage also. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TABLEGEN_PARSER_HELPERS_H -#define MLIR_TABLEGEN_PARSER_HELPERS_H - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/DialectImplementation.h" -#include - -namespace mlir { -namespace tblgen { -namespace parser_helpers { - -//===----------------------------------------------------------------------===// -// -// Template enables identify various types for which we have specializations -// -//===----------------------------------------------------------------------===// - -template -using void_t = void; - -template -using remove_constref = - typename std::remove_const::type>::type; - -template -using enable_if_type = typename std::enable_if< - std::is_same, TestType>::value>::type; - -template -using is_not_type = - std::is_same, TestType>::type, - typename std::false_type::type>; - -template -using get_indexable_type = remove_constref()[0])>; - -template -using enable_if_arrayref = - enable_if_type>>; - -//===----------------------------------------------------------------------===// -// -// These structs handle Type parameters' parsing for common types -// -//===----------------------------------------------------------------------===// - -template -struct Parse { - ParseResult go(MLIRContext *ctxt, // The context, should it be needed - DialectAsmParser &parser, // The parser - StringRef parameterName, // Type parameter name, for error - // printing (if necessary) - T &result); // Put the parsed value here -}; - -// Int specialization -template -using enable_if_integral_type = - typename std::enable_if::value && - is_not_type::value>::type; -template -struct Parse> { - ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, - StringRef parameterName, T &result) { - return parser.parseInteger(result); - } -}; - -// Bool specialization -- 'true' / 'false' instead of 0/1 -template -struct Parse> { - ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, - StringRef parameterName, bool &result) { - StringRef boolStr; - if (parser.parseKeyword(&boolStr)) - return mlir::failure(); - if (!boolStr.compare_lower("false")) { - result = false; - return mlir::success(); - } - if (!boolStr.compare_lower("true")) { - result = true; - return mlir::success(); - } - llvm::errs() << "Parser expected true/false, not '" << boolStr << "'\n"; - return mlir::failure(); - } -}; - -// Float specialization -template -using enable_if_float_type = - typename std::enable_if::value>::type; -template -struct Parse> { - ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, - StringRef parameterName, T &result) { - double d; - if (parser.parseFloat(d)) - return mlir::failure(); - result = d; - return mlir::success(); - } -}; - -// mlir::Type specialization -template -using enable_if_mlir_type = - typename std::enable_if::value>::type; -template -struct Parse> { - ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, - StringRef parameterName, T &result) { - Type type; - auto loc = parser.getCurrentLocation(); - if (parser.parseType(type)) - return mlir::failure(); - if ((result = type.dyn_cast_or_null()) == nullptr) { - parser.emitError(loc, "expected type '" + parameterName + "'"); - return mlir::failure(); - } - return mlir::success(); - } -}; - -// StringRef specialization -template -struct Parse> { - ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, - StringRef parameterName, StringRef &result) { - StringAttr a; - if (parser.parseAttribute(a)) - return mlir::failure(); - result = a.getValue(); - return mlir::success(); - } -}; - -// ArrayRef specialization -template -struct Parse> { - using inner_t = get_indexable_type; - Parse innerParser; - llvm::SmallVector parameters; - - ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, - StringRef parameterName, ArrayRef &result) { - if (parser.parseLSquare()) - return mlir::failure(); - if (failed(parser.parseOptionalRSquare())) { - do { - inner_t parameter; // = std::declval(); - innerParser.go(ctxt, parser, parameterName, parameter); - parameters.push_back(parameter); - } while (succeeded(parser.parseOptionalComma())); - if (parser.parseRSquare()) - return mlir::failure(); - } - result = ArrayRef(parameters); - return mlir::success(); - } -}; - -//===----------------------------------------------------------------------===// -// -// These structs handle Type parameters' printing for common types -// -//===----------------------------------------------------------------------===// - -template -struct Print { - static void go(DialectAsmPrinter &printer, const T &obj); -}; - -// Several C++ types can just be piped into the printer -template -using enable_if_trivial_print = - typename std::enable_if::value || - (std::is_integral::value && - is_not_type::value) || - std::is_floating_point::value>::type; -template -struct Print>> { - static void go(DialectAsmPrinter &printer, const T &obj) { printer << obj; } -}; - -// StringRef has to be quoted to match the parse specialization above -template -struct Print> { - static void go(DialectAsmPrinter &printer, const T &obj) { - printer << "\"" << obj << "\""; - } -}; - -// bool specialization -template -struct Print> { - static void go(DialectAsmPrinter &printer, const bool &obj) { - if (obj) - printer << "true"; - else - printer << "false"; - } -}; - -// ArrayRef specialization -template -struct Print> { - static void go(DialectAsmPrinter &printer, - const ArrayRef> &obj) { - printer << "["; - for (size_t i = 0; i < obj.size(); i++) { - Print>::go(printer, obj[i]); - if (i < obj.size() - 1) - printer << ", "; - } - printer << "]"; - } -}; - -} // end namespace parser_helpers -} // end namespace tblgen -} // end namespace mlir - -#endif // MLIR_TABLEGEN_PARSER_HELPERS_H diff --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp --- a/mlir/lib/TableGen/TypeDef.cpp +++ b/mlir/lib/TableGen/TypeDef.cpp @@ -90,12 +90,11 @@ bool TypeDef::genVerifyInvariantsDecl() const { return def->getValueAsBit("genVerifyInvariantsDecl"); } - llvm::Optional TypeDef::getExtraDecls() const { auto value = def->getValueAsString("extraClassDeclaration"); return value.empty() ? llvm::Optional() : value; } - +llvm::ArrayRef TypeDef::getLoc() const { return def->getLoc(); } bool TypeDef::operator==(const TypeDef &other) const { return def == other.def; } diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -16,6 +16,9 @@ def SimpleTypeA : Test_Type<"SimpleA"> { let mnemonic = "smpla"; + + let printer = [{ $_printer << "smpla"; }]; + let parser = [{ return get($_ctxt); }]; } // A more complex parameterized type @@ -27,11 +30,8 @@ let parameters = ( ins "int":$widthOfSomething, - "::mlir::SimpleAType": $exampleTdType, - ArrayRefParameter<"int", "">: $arrayOfInts, - ArrayRefParameter<"Type", "An example of an array of types as a type parameter">: $arrayOfTypes, - "::llvm::StringRef": $simpleString, - ArrayRefParameter<"::llvm::StringRef", "">: $arrayOfStrings + "::mlir::Type":$oneType, + ArrayRefParameter<"int", "An example of an array of ints">: $arrayOfInts ); let extraClassDeclaration = [{ @@ -40,15 +40,32 @@ } def IntegerType : Test_Type<"TestInteger"> { - let mnemonic = "int"; - let genVerifyInvariantsDecl = 1; - let parameters = ( - ins - "::mlir::TestIntegerType::SignednessSemantics":$signedness, - "unsigned":$width - ); - - let extraClassDeclaration = [{ + let mnemonic = "int"; + let genVerifyInvariantsDecl = 1; + let parameters = ( + ins + "::mlir::TestIntegerType::SignednessSemantics":$signedness, + "unsigned":$width + ); + + let printer = [{ + $_printer << "int<"; + Print($_printer, getImpl()->signedness); + $_printer << ", " << getImpl()->width << ">"; + }]; + + let parser = [{ + if (parser.parseLess()) return Type(); + SignednessSemantics signedness; + if (Parse($_parser, signedness)) return mlir::Type(); + if ($_parser.parseComma()) return Type(); + int width; + if ($_parser.parseInteger(width)) return Type(); + if ($_parser.parseGreater()) return Type(); + return get(ctxt, signedness, width); + }]; + + let extraClassDeclaration = [{ /// Signedness semantics. enum SignednessSemantics { Signless, /// No signedness semantics @@ -69,38 +86,38 @@ } class FieldInfo_Type : Test_Type { -let parameters = ( - ins - ArrayRefOfSelfAllocationParameter<"::mlir::FieldInfo", "Models struct fields">: $fields -); - -let printer = [{ - printer << "struct" << "<"; - for (size_t i=0; ifields.size(); i++) { - const auto& field = getImpl()->fields[i]; - printer << "{" << field.name << "," << field.type << "}"; - if (i < getImpl()->fields.size() - 1) - printer << ","; - } - printer << ">"; -}]; - -let parser = [{ - llvm::SmallVector parameters; - if (parser.parseLess()) return Type(); - while (mlir::succeeded(parser.parseOptionalLBrace())) { - StringRef name; - if (parser.parseKeyword(&name)) return Type(); - if (parser.parseComma()) return Type(); - Type type; - if (parser.parseType(type)) return Type(); - if (parser.parseRBrace()) return Type(); - parameters.push_back(FieldInfo {name, type}); - if (parser.parseOptionalComma()) break; - } - if (parser.parseGreater()) return Type(); - return get(ctxt, parameters); -}]; + let parameters = ( + ins + ArrayRefOfSelfAllocationParameter<"::mlir::FieldInfo", "Models struct fields">: $fields + ); + + let printer = [{ + $_printer << "struct" << "<"; + for (size_t i=0; ifields.size(); i++) { + const auto& field = getImpl()->fields[i]; + $_printer << "{" << field.name << "," << field.type << "}"; + if (i < getImpl()->fields.size() - 1) + $_printer << ","; + } + $_printer << ">"; + }]; + + let parser = [{ + llvm::SmallVector parameters; + if ($_parser.parseLess()) return Type(); + while (mlir::succeeded($_parser.parseOptionalLBrace())) { + StringRef name; + if ($_parser.parseKeyword(&name)) return Type(); + if ($_parser.parseComma()) return Type(); + Type type; + if ($_parser.parseType(type)) return Type(); + if ($_parser.parseRBrace()) return Type(); + parameters.push_back(FieldInfo {name, type}); + if ($_parser.parseOptionalComma()) break; + } + if ($_parser.parseGreater()) return Type(); + return get($_ctxt, parameters); + }]; } def StructType : FieldInfo_Type<"Struct"> { diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -14,59 +14,89 @@ #include "TestTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Types.h" -#include "mlir/TableGen/TypeDefGenHelpers.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { -namespace tblgen { -namespace parser_helpers { // Custom parser for SignednessSemantics -template <> -struct Parse { - static ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, - StringRef parameterName, - TestIntegerType::SignednessSemantics &result) { - StringRef signStr; - auto loc = parser.getCurrentLocation(); - if (parser.parseKeyword(&signStr)) - return mlir::failure(); - if (signStr.compare_lower("u") || signStr.compare_lower("unsigned")) - result = TestIntegerType::SignednessSemantics::Unsigned; - else if (signStr.compare_lower("s") || signStr.compare_lower("signed")) - result = TestIntegerType::SignednessSemantics::Signed; - else if (signStr.compare_lower("n") || signStr.compare_lower("none")) - result = TestIntegerType::SignednessSemantics::Signless; - else { - parser.emitError(loc, "expected signed, unsigned, or none"); - return mlir::failure(); - } - return mlir::success(); +static ParseResult Parse(DialectAsmParser &parser, + TestIntegerType::SignednessSemantics &result) { + StringRef signStr; + auto loc = parser.getCurrentLocation(); + if (parser.parseKeyword(&signStr)) + return mlir::failure(); + if (signStr.compare_lower("u") || signStr.compare_lower("unsigned")) + result = TestIntegerType::SignednessSemantics::Unsigned; + else if (signStr.compare_lower("s") || signStr.compare_lower("signed")) + result = TestIntegerType::SignednessSemantics::Signed; + else if (signStr.compare_lower("n") || signStr.compare_lower("none")) + result = TestIntegerType::SignednessSemantics::Signless; + else { + parser.emitError(loc, "expected signed, unsigned, or none"); + return mlir::failure(); } -}; + return mlir::success(); +} // Custom printer for SignednessSemantics -template <> -struct Print { - static void go(DialectAsmPrinter &printer, - const TestIntegerType::SignednessSemantics &ss) { - switch (ss) { - case TestIntegerType::SignednessSemantics::Unsigned: - printer << "unsigned"; - break; - case TestIntegerType::SignednessSemantics::Signed: - printer << "signed"; - break; - case TestIntegerType::SignednessSemantics::Signless: - printer << "none"; +static void Print(DialectAsmPrinter &printer, + const TestIntegerType::SignednessSemantics &ss) { + switch (ss) { + case TestIntegerType::SignednessSemantics::Unsigned: + printer << "unsigned"; + break; + case TestIntegerType::SignednessSemantics::Signed: + printer << "signed"; + break; + case TestIntegerType::SignednessSemantics::Signless: + printer << "none"; + break; + } +} + +Type CompoundAType::parse(::mlir::MLIRContext *ctxt, + ::mlir::DialectAsmParser &parser) { + int widthOfSomething; + Type oneType; + SmallVector arrayOfInts; + if (parser.parseLess()) + return Type(); + if (parser.parseInteger(widthOfSomething)) + return Type(); + if (parser.parseComma()) + return Type(); + if (parser.parseType(oneType)) + return Type(); + if (parser.parseComma()) + return Type(); + + if (parser.parseLSquare()) + return Type(); + int i; + while (!*parser.parseOptionalInteger(i)) { + arrayOfInts.push_back(i); + if (parser.parseOptionalComma()) break; - } } -}; + if (parser.parseRSquare()) + return Type(); + if (parser.parseGreater()) + return Type(); -} // namespace parser_helpers -} // namespace tblgen + return get(ctxt, widthOfSomething, oneType, arrayOfInts); +} +void CompoundAType::print(::mlir::DialectAsmPrinter &printer) const { + printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType() + << ", ["; + auto intArray = getArrayOfInts(); + for (size_t idx = 0; idx < intArray.size(); idx++) { + printer << intArray[idx]; + if (idx < intArray.size() - 1) + printer << ", "; + } + printer << "]>"; +} bool operator==(const FieldInfo &a, const FieldInfo &b) { return a.name == b.name && a.type == b.type; @@ -86,7 +116,6 @@ return mlir::success(); } -struct TestType; } // end namespace mlir #define GET_TYPEDEF_CLASSES diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir --- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir @@ -8,8 +8,8 @@ return } -// CHECK: @compoundA(%arg0: !test.cmpnd_a<1, !test.smpla, [5, 6], [i1, i2], "example str", ["array", "of", "strings"]>) -func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6], [i1, i2], "example str", ["array","of","strings"]>) -> () { +// CHECK: @compoundA(%arg0: !test.cmpnd_a<1, !test.smpla, [5, 6]>) +func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6]>)-> () { return } @@ -20,5 +20,5 @@ // CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int}>) func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int} > ) { - return + return } diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -32,8 +32,6 @@ def A_SimpleTypeA : TestType<"SimpleA"> { // DECL: class SimpleAType: public ::mlir::Type -// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); -// DECL: void print(::mlir::DialectAsmPrinter& printer) const; } // A more complex parameterized type @@ -55,6 +53,8 @@ // DECL: static ::mlir::LogicalResult verifyConstructionInvariants(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims); // DECL: static CompoundAType getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims); // DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; } +// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); +// DECL: void print(::mlir::DialectAsmPrinter& printer) const; // DECL: int getWidthOfSomething() const; // DECL: ::mlir::test::SimpleTypeA getExampleTdType() const; // DECL: SomeCppStruct getExampleCppType() const; @@ -69,9 +69,9 @@ ); // DECL-LABEL: class IndexType: public ::mlir::Type +// DECL: static ::llvm::StringRef getMnemonic() { return "index"; } // DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); // DECL: void print(::mlir::DialectAsmPrinter& printer) const; -// DECL: static ::llvm::StringRef getMnemonic() { return "index"; } } def D_SingleParameterType : TestType<"SingleParameter"> { @@ -82,18 +82,14 @@ // DECL-LABEL: struct SingleParameterTypeStorage; // DECL-LABEL: class SingleParameterType // DECL-NEXT: detail::SingleParameterTypeStorage -// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); -// DECL: void print(::mlir::DialectAsmPrinter& printer) const; } def E_IntegerType : TestType<"Integer"> { - let parser = [{}]; - let printer = [{}]; let mnemonic = "int"; let genVerifyInvariantsDecl = 1; let parameters = ( ins - "SignednessSemantics":$signedness, + "SignednessSemantics":$signedness, TypeParameter<"unsigned", "Bitwdith of integer">:$width ); diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp @@ -132,6 +132,7 @@ /// {0}: The name of the typeDef class. /// {1}: The typeDef storage class namespace. /// {2}: The storage class name +/// {3}: The list of parameters with types static const char *const typeDefDeclParametricBeginStr = R"( namespace {1} { struct {2}; @@ -141,6 +142,7 @@ public: /// Inherit some necessary constructors from 'TypeBase'. using Base::Base; + )"; // snippet for print/parse @@ -176,10 +178,13 @@ if (llvm::Optional extraDecl = typeDef.getExtraDecls()) os << *extraDecl; + // Get the CppType1 param1, CppType2 param2 argument list std::string parameterParameters = constructParameterParameters(typeDef, true); - // parse/print - os << typeDefParsePrint; + os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n", + typeDef.getCppClassName(), parameterParameters); + + // verify invariants if (typeDef.genVerifyInvariantsDecl()) @@ -190,6 +195,9 @@ if (auto mnenomic = typeDef.getMnemonic()) { os << " static ::llvm::StringRef getMnemonic() { return \"" << mnenomic << "\"; }\n"; + + // if mnemonic specified, emit print/parse declarations + os << typeDefParsePrint; } if (typeDef.genAccessors()) { @@ -396,84 +404,6 @@ return mlir::success(); } -/// Emit the body of an autogenerated printer -static mlir::LogicalResult emitPrinterAutogen(TypeDef typeDef, - raw_ostream &os) { - if (auto mnemonic = typeDef.getMnemonic()) { - SmallVector parameters; - typeDef.getParameters(parameters); - - os << " printer << \"" << *mnemonic << "\";\n"; - - // if non-parametric, we're done - if (parameters.size() > 0) { - os << " printer << \"<\";\n"; - - // emit a printer for each parameter separated by ','. - // printer structs for common C++ types are defined in - // TypeDefGenHelpers.h, which must be #included by the consuming code. - for (auto *parameterIter = parameters.begin(); - parameterIter < parameters.end(); parameterIter++) { - // Each printer struct must be put on the stack then 'go' called - os << " ::mlir::tblgen::parser_helpers::Print<" - << parameterIter->getCppType() << ">::go(printer, getImpl()->" - << parameterIter->getName() << ");\n"; - - // emit the comma unless we're the last parameter - if (parameterIter < parameters.end() - 1) { - os << " printer << \", \";\n"; - } - } - os << " printer << \">\";\n"; - } - } - return mlir::success(); -} - -/// Emit the body of an autogenerated parser -static mlir::LogicalResult emitParserAutogen(TypeDef typeDef, raw_ostream &os) { - SmallVector parameters; - typeDef.getParameters(parameters); - - // by the time we get to this function, the mnenomic has already been parsed - if (parameters.size() > 0) { - os << " if (parser.parseLess()) return ::mlir::Type();\n"; - - // emit a parser for each parameter separated by ','. - // parse structs for common C++ types are defined in - // TypeDefGenHelpers.h, which must be #included by the consuming code. - for (auto *parameterIter = parameters.begin(); - parameterIter < parameters.end(); parameterIter++) { - os << " " << parameterIter->getCppType() << " " - << parameterIter->getName() << ";\n"; - os << llvm::formatv( - " ::mlir::tblgen::parser_helpers::Parse<{0}> {1}Parser;\n", - parameterIter->getCppType(), parameterIter->getName()); - os << llvm::formatv(" if ({0}Parser.go(ctxt, parser, \"{1}\", {0})) " - "return ::mlir::Type();\n", - parameterIter->getName(), - parameterIter->getCppType()); - - // parse a comma unless we're the last parameter - if (parameterIter < parameters.end() - 1) { - os << " if (parser.parseComma()) return ::mlir::Type();\n"; - } - } - os << " if (parser.parseGreater()) return ::mlir::Type();\n"; - // done with the parsing - - // all the parameters are now in variables named the same as the parameters - auto parameterNames = - llvm::map_range(parameters, [](TypeParameter parameter) { - return parameter.getName(); - }); - os << " return get(ctxt, " << llvm::join(parameterNames, ", ") << ");\n"; - } else { - os << " return get(ctxt);\n"; - } - return mlir::success(); -} - /// Print all the typedef-specific definition code static mlir::LogicalResult emitTypeDefDef(TypeDef typeDef, raw_ostream &os) { NamespaceEmitter ns(os, typeDef.getDialect()); @@ -485,6 +415,16 @@ if (mlir::failed(emitStorageClass(typeDef, os))) return mlir::failure(); + std::string paramFuncParams = constructParameterParameters(typeDef, true); + SmallVector paramNames; + paramNames.push_back(""); + typeDef.getParametersAs( + paramNames, [](TypeParameter param) { return param.getName(); }); + os << llvm::formatv("{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n" + " return Base::get(ctxt{2});\n" + "}\n", + typeDef.getCppClassName(), paramFuncParams, + llvm::join(paramNames, ",")); // emit the accessors if (typeDef.genAccessors()) { for (auto parameter : parameters) { @@ -497,39 +437,56 @@ } } - // emit the printer code, if appropriate - auto printerCode = typeDef.getPrinterCode(); - if (printerCode && typeDef.getMnemonic()) { - // Both the mnenomic and printerCode must be defined (for parity with - // parserCode) - os << "void " << typeDef.getCppClassName() - << "::print(mlir::DialectAsmPrinter& printer) const {\n"; - if (*printerCode == "") { - // if no code specified, autogenerate a parser - if (mlir::failed(emitPrinterAutogen(typeDef, os))) + // If mnemonic is specified, maybe print a def + if (typeDef.getMnemonic()) { + // emit the printer code, if appropriate + auto printerCode = typeDef.getPrinterCode(); + if (printerCode) { + // Both the mnenomic and printerCode must be defined (for parity with + // parserCode) + os << "void " << typeDef.getCppClassName() + << "::print(mlir::DialectAsmPrinter& printer) const {\n"; + if (*printerCode == "") { + // if no code specified, emit error + llvm::PrintError( + typeDef.getLoc(), + typeDef.getName() + + ": printer (if specified) must have non-empty code"); return mlir::failure(); - } else { - os << *printerCode << "\n"; + } else { + auto fmtCtxt = FmtContext().addSubst("_printer", "printer"); + auto fmtObj = tgfmt(*printerCode, &fmtCtxt); + fmtObj.format(os); + os << "\n"; + } + os << "}\n"; } - os << "}\n"; - } - // emit a parser, if appropriate - auto parserCode = typeDef.getParserCode(); - if (parserCode && typeDef.getMnemonic()) { - // The mnenomic must be defined so the dispatcher knows how to dispatch - os << "::mlir::Type " << typeDef.getCppClassName() - << "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& " - "parser) " - "{\n"; - if (*parserCode == "") { - if (mlir::failed(emitParserAutogen(typeDef, os))) + // emit a parser, if appropriate + auto parserCode = typeDef.getParserCode(); + if (parserCode) { + // The mnenomic must be defined so the dispatcher knows how to dispatch + os << "::mlir::Type " << typeDef.getCppClassName() + << "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& " + "parser) " + "{\n"; + if (*parserCode == "") { + // if no code specified, emit error + llvm::PrintError( + typeDef.getLoc(), + typeDef.getName() + + ": parser (if specified) must have non-empty code"); return mlir::failure(); - } else - os << *parserCode << "\n"; - os << "}\n"; - } + } else { + auto fmtCtxt = FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "ctxt"); + auto fmtObj = tgfmt(*parserCode, &fmtCtxt); + fmtObj.format(os); + os << "\n"; + } + os << "}\n"; + } + } // typeDef.getMnemonic() return mlir::success(); }