diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h @@ -39,9 +39,6 @@ // `EnumClass`. // template StringRef attributeName(); // -// Get the function that can be used to symbolize an enum value. -// template -// Optional (*)(StringRef) symbolizeEnum(); #include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc" } // end namespace spirv diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -311,7 +311,7 @@ return llvm::None; } - auto val = spirv::symbolizeEnum()(enumSpec); + auto val = spirv::symbolizeEnum(enumSpec); if (!val) parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'"; return val; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -116,7 +116,7 @@ << attrName << " attribute specified as string"; } auto attrOptional = - spirv::symbolizeEnum()(attrVal.cast().getValue()); + spirv::symbolizeEnum(attrVal.cast().getValue()); if (!attrOptional) { return parser.emitError(loc, "invalid ") << attrName << " attribute specification: " << attrVal; @@ -151,7 +151,7 @@ auto loc = parser.getCurrentLocation(); if (parser.parseKeyword(&keyword)) return failure(); - if (Optional attr = spirv::symbolizeEnum()(keyword)) { + if (Optional attr = spirv::symbolizeEnum(keyword)) { value = attr.getValue(); return success(); } diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -370,6 +370,28 @@ emitMaxValueFn(enumDef, os); } + // Generate a generic `stringifyEnum` function that forwards to the method + // specified by the user. + const char *const stringifyEnumStr = R"( +inline {0} stringifyEnum({1} enumValue) {{ + return {2}(enumValue); +} +)"; + os << formatv(stringifyEnumStr, symToStrFnRetType, enumName, symToStrFnName); + + // Generate a generic `symbolizeEnum` function that forwards to the method + // specified by the user. + const char *const symbolizeEnumStr = R"( +template +llvm::Optional symbolizeEnum(llvm::StringRef); + +template <> +inline llvm::Optional<{0}> symbolizeEnum<{0}>(llvm::StringRef str) { + return {1}(str); +} +)"; + os << formatv(symbolizeEnumStr, enumName, strToSymFnName); + for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -1105,13 +1105,6 @@ "attributeName();\n"); } -static void emitEnumGetSymbolizeFnDecl(raw_ostream &os) { - os << "template using SymbolizeFnTy = " - "llvm::Optional (*)(StringRef);\n"; - os << "template inline constexpr " - "SymbolizeFnTy symbolizeEnum();\n"; -} - static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr, raw_ostream &os) { auto enumName = enumAttr.getEnumClassName(); @@ -1124,17 +1117,6 @@ os << "}\n"; } -static void emitEnumGetSymbolizeFnDefn(const EnumAttr &enumAttr, - raw_ostream &os) { - auto enumName = enumAttr.getEnumClassName(); - auto strToSymFnName = enumAttr.getStringToSymbolFnName(); - os << formatv( - "template <> inline SymbolizeFnTy<{0}> symbolizeEnum<{0}>() {{\n", - enumName); - os << " return " << strToSymFnName << ";\n"; - os << "}\n"; -} - static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) { llvm::emitSourceFileHeader("SPIR-V Op Utilities", os); @@ -1142,11 +1124,9 @@ os << "#ifndef SPIRV_OP_UTILS_H_\n"; os << "#define SPIRV_OP_UTILS_H_\n"; emitEnumGetAttrNameFnDecl(os); - emitEnumGetSymbolizeFnDecl(os); for (const auto *def : defs) { EnumAttr enumAttr(*def); emitEnumGetAttrNameFnDefn(enumAttr, os); - emitEnumGetSymbolizeFnDefn(enumAttr, os); } os << "#endif // SPIRV_OP_UTILS_H\n"; return false;