diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -21,7 +21,7 @@ let parameters = (ins "linkage::Linkage":$linkage ); - let hasCustomAssemblyFormat = 1; + let assemblyFormat = "`<` $linkage `>`"; } // Attribute definition for the LLVM Linkage enum. @@ -30,7 +30,7 @@ let parameters = (ins "CConv":$CallingConv ); - let hasCustomAssemblyFormat = 1; + let assemblyFormat = "`<` $CallingConv `>`"; } def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2797,54 +2797,6 @@ op->hasTrait(); } -void LinkageAttr::print(AsmPrinter &printer) const { - printer << "<"; - if (static_cast(getLinkage()) <= getMaxEnumValForLinkage()) - printer << stringifyEnum(getLinkage()); - else - printer << static_cast(getLinkage()); - printer << ">"; -} - -Attribute LinkageAttr::parse(AsmParser &parser, Type type) { - StringRef elemName; - if (parser.parseLess() || parser.parseKeyword(&elemName) || - parser.parseGreater()) - return {}; - auto elem = linkage::symbolizeLinkage(elemName); - if (!elem) { - parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName; - return {}; - } - Linkage linkage = *elem; - return LinkageAttr::get(parser.getContext(), linkage); -} - -void CConvAttr::print(AsmPrinter &printer) const { - printer << "<"; - if (static_cast(getCallingConv()) <= cconv::getMaxEnumValForCConv()) - printer << stringifyEnum(getCallingConv()); - else - printer << "INVALID_cc_" << static_cast(getCallingConv()); - printer << ">"; -} - -Attribute CConvAttr::parse(AsmParser &parser, Type type) { - StringRef convName; - - if (parser.parseLess() || parser.parseKeyword(&convName) || - parser.parseGreater()) - return {}; - auto cconv = cconv::symbolizeCConv(convName); - if (!cconv) { - parser.emitError(parser.getNameLoc(), "unknown calling convention: ") - << convName; - return {}; - } - CConv cconvVal = *cconv; - return CConvAttr::get(parser.getContext(), cconvVal); -} - LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr) : options(attr.getOptions().begin(), attr.getOptions().end()) {} diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -273,8 +273,9 @@ // ----- module { - // expected-error@+2 {{unknown calling convention: cc_12}} "llvm.func"() ({ + // expected-error @below {{invalid Calling Conventions specification: cc_12}} + // expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}} }) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv, function_type = !llvm.func} : () -> () } diff --git a/mlir/test/mlir-tblgen/enums-gen.td b/mlir/test/mlir-tblgen/enums-gen.td --- a/mlir/test/mlir-tblgen/enums-gen.td +++ b/mlir/test/mlir-tblgen/enums-gen.td @@ -28,6 +28,24 @@ // DECL: std::string stringifyMyBitEnum(MyBitEnum); // DECL: ::llvm::Optional symbolizeMyBitEnum(::llvm::StringRef); +// DECL: struct FieldParser<::MyBitEnum, ::MyBitEnum> { +// DECL: template +// DECL: static FailureOr<::MyBitEnum> parse(ParserT &parser) { +// DECL: // Parse the keyword/string containing the enum. +// DECL: std::string enumKeyword; +// DECL: auto loc = parser.getCurrentLocation(); +// DECL: if (failed(parser.parseOptionalKeywordOrString(&enumKeyword))) +// DECL: return parser.emitError(loc, "expected keyword for An example bit enum"); +// DECL: // Symbolize the keyword. +// DECL: if (::llvm::Optional<::MyBitEnum> attr = ::symbolizeEnum<::MyBitEnum>(enumKeyword)) +// DECL: return *attr; +// DECL: return parser.emitError(loc, "invalid An example bit enum specification: ") << enumKeyword; +// DECL: } + +// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyBitEnum value) { +// DECL: auto valueStr = stringifyEnum(value); +// DECL: return p << valueStr; + // DEF-LABEL: std::string stringifyMyBitEnum // DEF: auto val = static_cast // DEF: if (val == 0) return "None"; @@ -40,3 +58,34 @@ // DEF: if (str == "None") return MyBitEnum::None; // DEF: .Case("tagged", 1) // DEF: .Case("Bit1", 2) + +// Test enum printer generation for non non-keyword enums. + +def NonKeywordBit: I32BitEnumAttrCaseBit<"Bit0", 0, "tag-ged">; +def MyMixedNonKeywordBitEnum: I32BitEnumAttr<"MyMixedNonKeywordBitEnum", "An example bit enum", [ + NonKeywordBit, + Bit1 + ]> { + let genSpecializedAttr = 0; +} + +def MyNonKeywordBitEnum: I32BitEnumAttr<"MyNonKeywordBitEnum", "An example bit enum", [ + NonKeywordBit + ]> { + let genSpecializedAttr = 0; +} + +// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyMixedNonKeywordBitEnum value) { +// DECL: auto valueStr = stringifyEnum(value); +// DECL: switch (value) { +// DECL: case ::MyMixedNonKeywordBitEnum::Bit1: +// DECL: break; +// DECL: default: +// DECL: return p << '"' << valueStr << '"'; +// DECL: } +// DECL: return p << valueStr; +// DECL: } + +// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyNonKeywordBitEnum value) { +// DECL: auto valueStr = stringifyEnum(value); +// DECL: return p << '"' << valueStr << '"'; diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -10,9 +10,11 @@ // //===----------------------------------------------------------------------===// +#include "FormatGen.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -65,10 +67,92 @@ os << "};\n\n"; } -static void emitDenseMapInfo(StringRef enumName, std::string underlyingType, +static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName, + StringRef cppNamespace, raw_ostream &os) { + if (enumAttr.getUnderlyingType().empty() || + enumAttr.getConstBuilderTemplate().empty()) + return; + auto cases = enumAttr.getAllCases(); + + // Check which cases shouldn't be printed using a keyword. + llvm::BitVector nonKeywordCases(cases.size()); + for (auto [index, caseVal] : llvm::enumerate(cases)) + if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr())) + nonKeywordCases.set(index); + + // If this is a bit enum attribute, don't allow cases that may overlap with + // other cases. For simplicity sake, only allow cases with a single bit value. + if (enumAttr.isBitEnum()) { + for (auto [index, caseVal] : llvm::enumerate(cases)) { + int64_t value = caseVal.getValue(); + if (value < 0 || (value != 0 && !llvm::isPowerOf2_64(value))) + nonKeywordCases.set(index); + } + } + + // Generate the parser and the start of the printer for the enum. + const char *parsedAndPrinterStart = R"( +namespace mlir { +template +struct FieldParser; + +template<> +struct FieldParser<{0}, {0}> {{ + template + static FailureOr<{0}> parse(ParserT &parser) {{ + // Parse the keyword/string containing the enum. + std::string enumKeyword; + auto loc = parser.getCurrentLocation(); + if (failed(parser.parseOptionalKeywordOrString(&enumKeyword))) + return parser.emitError(loc, "expected keyword for {2}"); + + // Symbolize the keyword. + if (::llvm::Optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword)) + return *attr; + return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword; + } +}; +} // namespace mlir + +namespace llvm { +inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{ + auto valueStr = stringifyEnum(value); +)"; + os << formatv(parsedAndPrinterStart, qualName, cppNamespace, + enumAttr.getSummary()); + + // If all cases require a string, always wrap. + if (nonKeywordCases.all()) { + os << " return p << '\"' << valueStr << '\"';\n" + "}\n" + "} // namespace llvm\n"; + return; + } + + // If there are any cases that can't be used with a keyword, switch on the + // case value to determine when to print in the string form. + if (nonKeywordCases.any()) { + os << " switch (value) {\n"; + for (auto &it : llvm::enumerate(cases)) { + if (nonKeywordCases.test(it.index())) + continue; + StringRef symbol = it.value().getSymbol(); + os << llvm::formatv(" case {0}::{1}:\n", qualName, + llvm::isDigit(symbol.front()) ? ("_" + symbol) + : symbol); + } + os << " break;\n" + " default:\n" + " return p << '\"' << valueStr << '\"';\n" + " }\n"; + } + os << " return p << valueStr;\n" + "}\n" + "} // namespace llvm\n"; +} + +static void emitDenseMapInfo(StringRef qualName, std::string underlyingType, StringRef cppNamespace, raw_ostream &os) { - std::string qualName = - std::string(formatv("{0}::{1}", cppNamespace, enumName)); if (underlyingType.empty()) underlyingType = std::string(formatv("std::underlying_type_t<{0}>", qualName)); @@ -529,8 +613,13 @@ for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; + // Generate a generic parser and printer for the enum. + std::string qualName = + std::string(formatv("{0}::{1}", cppNamespace, enumName)); + emitParserPrinter(enumAttr, qualName, cppNamespace, os); + // Emit DenseMapInfo for this enum class - emitDenseMapInfo(enumName, underlyingType, cppNamespace, os); + emitDenseMapInfo(qualName, underlyingType, cppNamespace, os); } static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -444,6 +444,11 @@ bool mlir::tblgen::canFormatStringAsKeyword( StringRef value, function_ref emitError) { + if (value.empty()) { + if (emitError) + emitError("keywords cannot be empty"); + return false; + } if (!isalpha(value.front()) && value.front() != '_') { if (emitError) emitError("valid keyword starts with a letter or '_'");