diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -75,7 +75,8 @@ int width; if ($_parser.parseInteger(width)) return Type(); if ($_parser.parseGreater()) return Type(); - return get(ctxt, signedness, width); + Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc()); + return getChecked(loc, signedness, width); }]; // Any extra code one wants in the type's class declaration. diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -51,7 +51,7 @@ // DECL-LABEL: class CompoundAType: public ::mlir::Type // DECL: static ::mlir::LogicalResult verifyConstructionInvariants(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims); -// DECL: static CompoundAType getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims); +// DECL: static ::mlir::Type getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims); // DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; } // DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); // DECL: void print(::mlir::DialectAsmPrinter& printer) const; diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp @@ -78,25 +78,25 @@ /// [...]". TypeNamePairs, - /// Emit ", parameter1Type parameter1Name, parameter2Type parameter2Name, - /// [...]". - TypeNamePairsPrependComma, - /// Emit "parameter1(parameter1), parameter2(parameter2), [...]". - TypeNameInitializer + TypeNameInitializer, + + /// Emit "param1Name, param2Name, [...]". + JustParams, }; - TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef params) - : emitFormat(emitFormat), params(params) {} + TypeParamCommaFormatter(EmitFormat emitFormat, bool prependComma, + ArrayRef params) + : emitFormat(emitFormat), prependComma(prependComma), params(params) {} /// llvm::formatv will call this function when using an instance as a /// replacement value. void format(raw_ostream &os, StringRef options) { - if (params.size() && emitFormat == EmitFormat::TypeNamePairsPrependComma) + if (params.size() && prependComma) os << ", "; + switch (emitFormat) { case EmitFormat::TypeNamePairs: - case EmitFormat::TypeNamePairsPrependComma: interleaveComma(params, os, [&](const TypeParameter &p) { emitTypeNamePair(p, os); }); break; @@ -105,6 +105,10 @@ emitTypeNameInitializer(p, os); }); break; + case EmitFormat::JustParams: + interleaveComma(params, os, + [&](const TypeParameter &p) { os << p.getName(); }); + break; } } @@ -119,6 +123,7 @@ } EmitFormat emitFormat; + bool prependComma; ArrayRef params; }; @@ -168,10 +173,9 @@ /// The code block for the verifyConstructionInvariants and getChecked. /// /// {0}: List of parameters, parameters style. -/// {1}: C++ type class name. static const char *const typeDefDeclVerifyStr = R"( static ::mlir::LogicalResult verifyConstructionInvariants(Location loc{0}); - static {1} getChecked(Location loc{0}); + static ::mlir::Type getChecked(Location loc{0}); )"; /// Generate the declaration for the given typeDef class. @@ -194,14 +198,13 @@ os << *extraDecl << "\n"; TypeParamCommaFormatter emitTypeNamePairsAfterComma( - TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma, params); + TypeParamCommaFormatter::EmitFormat::TypeNamePairs, true, params); os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n", typeDef.getCppClassName(), emitTypeNamePairsAfterComma); // Emit the verify invariants declaration. if (typeDef.genVerifyInvariantsDecl()) - os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma, - typeDef.getCppClassName()); + os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma); // Emit the mnenomic, if specified. if (auto mnenomic = typeDef.getMnemonic()) { @@ -317,6 +320,17 @@ } )"; +/// The code block for the getChecked definition. +/// +/// {0}: List of parameters, parameters style. +/// {1}: C++ type class name. +/// {2}: Comma separated list of parameter names. +static const char *const typeDefDefGetCheckeStr = R"( + ::mlir::Type {1}::getChecked(Location loc{0}) {{ + return Base::getChecked(loc{2}); + } +)"; + /// Use tgfmt to emit custom allocation code for each parameter, if necessary. static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) { SmallVector parameters; @@ -355,27 +369,28 @@ auto parameterTypeList = join(parameterTypes, ", "); // 1) Emit most of the storage class up until the hashKey body. - os << formatv( - typeDefStorageClassBegin, typeDef.getStorageNamespace(), - typeDef.getStorageClassName(), - TypeParamCommaFormatter( - TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters), - TypeParamCommaFormatter( - TypeParamCommaFormatter::EmitFormat::TypeNameInitializer, parameters), - parameterList, parameterTypeList); + os << formatv(typeDefStorageClassBegin, typeDef.getStorageNamespace(), + typeDef.getStorageClassName(), + TypeParamCommaFormatter( + TypeParamCommaFormatter::EmitFormat::TypeNamePairs, false, + parameters), + TypeParamCommaFormatter( + TypeParamCommaFormatter::EmitFormat::TypeNameInitializer, + false, parameters), + parameterList, parameterTypeList); // 2) Emit the haskKey method. os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n"; // Extract each parameter from the key. for (size_t i = 0, e = parameters.size(); i < e; ++i) - os << formatv(" const auto &{0} = std::get<{1}>(key);\n", - parameters[i].getName(), i); + os << llvm::formatv(" const auto &{0} = std::get<{1}>(key);\n", + parameters[i].getName(), i); // Then combine them all. This requires all the parameters types to have a // hash_value defined. - os << " return ::llvm::hash_combine("; - interleaveComma(parameterNames, os); - os << ");\n"; - os << " }\n"; + os << llvm::formatv( + " return ::llvm::hash_combine({0});\n }\n", + TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams, + false, parameters)); // 3) Emit the construct method. if (typeDef.hasStorageCustomConstructor()) @@ -462,14 +477,12 @@ os << llvm::formatv( "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n" - " return Base::get(ctxt", + " return Base::get(ctxt{2});\n}\n", typeDef.getCppClassName(), TypeParamCommaFormatter( - TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma, - parameters)); - for (TypeParameter ¶m : parameters) - os << ", " << param.getName(); - os << ");\n}\n"; + TypeParamCommaFormatter::EmitFormat::TypeNamePairs, true, parameters), + TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams, + true, parameters)); // Emit the parameter accessors. if (typeDef.genAccessors()) @@ -481,6 +494,17 @@ typeDef.getCppClassName()); } + // Generate getChecked() method. + if (typeDef.genVerifyInvariantsDecl()) + os << llvm::formatv( + typeDefDefGetCheckeStr, + TypeParamCommaFormatter( + TypeParamCommaFormatter::EmitFormat::TypeNamePairs, true, + parameters), + typeDef.getCppClassName(), + TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams, + true, parameters)); + // If mnemonic is specified maybe print definitions for the parser and printer // code, if they're specified. if (typeDef.getMnemonic())