diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -28,14 +28,17 @@ def FMFcontract : I32BitEnumAttrCaseBit<"contract", 4>; def FMFafn : I32BitEnumAttrCaseBit<"afn", 5>; def FMFreassoc : I32BitEnumAttrCaseBit<"reassoc", 6>; -def FMFfast : I32BitEnumAttrCaseBit<"fast", 7>; +def FMFfast : I32BitEnumAttrCaseGroup<"fast", + [ FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc]>; def FastmathFlags : I32BitEnumAttr< "FastmathFlags", "LLVM fastmath flags", [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast ]> { + let separator = ", "; let cppNamespace = "::mlir::LLVM"; + let printBitEnumPrimaryGroups = 1; } def LLVM_FMFAttr : DialectAttr< diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td --- a/mlir/include/mlir/IR/EnumAttr.td +++ b/mlir/include/mlir/IR/EnumAttr.td @@ -267,13 +267,20 @@ // bits together. let symbolToStringFnRetType = "std::string"; - // The delimiter used to separate bit enum cases in strings. + // The delimiter used to separate bit enum cases in strings. Only "|" and + // "," (along with optional spaces) are supported due to the use of the + // parseSeparatorFn in parameterParser below. + // Spaces in the separator string are used for printing, but will be optional + // for parsing. string separator = "|"; + assert !or(!ge(!find(separator, "|"), 0), !ge(!find(separator, ","), 0)), + "separator must contain '|' or ',' for parameter parsing"; // Parsing function that corresponds to the enum separator. Only // "," and "|" are supported by this definition. - string parseSeparatorFn = !if(!eq(separator,"|"),"parseOptionalVerticalBar", - "parseOptionalComma"); + string parseSeparatorFn = !if(!ge(!find(separator, "|"), 0), + "parseOptionalVerticalBar", + "parseOptionalComma"); // Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the // symbol is not valid. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2838,26 +2838,9 @@ op->hasTrait(); } -static constexpr const FastmathFlags fastmathFlagsList[] = { - // clang-format off - FastmathFlags::nnan, - FastmathFlags::ninf, - FastmathFlags::nsz, - FastmathFlags::arcp, - FastmathFlags::contract, - FastmathFlags::afn, - FastmathFlags::reassoc, - FastmathFlags::fast, - // clang-format on -}; - void FMFAttr::print(AsmPrinter &printer) const { printer << "<"; - auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) { - return bitEnumContains(this->getFlags(), flag); - }); - llvm::interleaveComma(flags, printer, - [&](auto flag) { printer << stringifyEnum(flag); }); + printer << stringifyFastmathFlags(this->getFlags()); printer << ">"; } diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -157,7 +157,6 @@ {FastmathFlags::contract, &llvmFMF::setAllowContract}, {FastmathFlags::afn, &llvmFMF::setApproxFunc}, {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc}, - {FastmathFlags::fast, &llvmFMF::setFast}, // clang-format on }; llvm::FastMathFlags ret; diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -445,7 +445,7 @@ // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32 %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : f32 // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 - %9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 + %9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 // CHECK: {{.*}} = llvm.fneg %arg0 : f32 %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : f32 diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -413,9 +413,9 @@ // CHECK-LABEL: func @allowed_cases_pass func.func @allowed_cases_pass() { - // CHECK: test.op_with_bit_enum + // CHECK: test.op_with_bit_enum "test.op_with_bit_enum"() {value = #test.bit_enum} : () -> () - // CHECK: test.op_with_bit_enum + // CHECK: test.op_with_bit_enum test.op_with_bit_enum return } @@ -424,11 +424,11 @@ // CHECK-LABEL: func @allowed_cases_pass func.func @allowed_cases_pass() { - // CHECK: test.op_with_bit_enum_vbar + // CHECK: test.op_with_bit_enum_vbar "test.op_with_bit_enum_vbar"() { value = #test.bit_enum_vbar } : () -> () - // CHECK: test.op_with_bit_enum_vbar + // CHECK: test.op_with_bit_enum_vbar test.op_with_bit_enum_vbar return } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -324,7 +324,7 @@ ]> { let genSpecializedAttr = 0; let cppNamespace = "test"; - let separator = ","; + let separator = ", "; } // Define the enum attribute. @@ -347,7 +347,7 @@ ]> { let genSpecializedAttr = 0; let cppNamespace = "test"; - let separator = "|"; + let separator = " | "; } def TestBitEnumVerticalBarAttr diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -277,6 +277,7 @@ std::string underlyingType = std::string(enumAttr.getUnderlyingType()); StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); StringRef separator = enumDef.getValueAsString("separator"); + StringRef separatorTrimmed = separator.trim(); auto enumerants = enumAttr.getAllCases(); auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants); @@ -292,15 +293,16 @@ // Split the string to get symbols for all the bits. os << " ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n"; - os << formatv(" str.split(symbols, \"{0}\");\n\n", separator); + // Remove whitespace from the separator string when parsing. + os << formatv(" str.split(symbols, \"{0}\");\n\n", separatorTrimmed); os << formatv(" {0} val = 0;\n", underlyingType); os << " for (auto symbol : symbols) {\n"; // Convert each symbol to the bit ordinal and set the corresponding bit. - os << formatv( - " auto bit = llvm::StringSwitch<::llvm::Optional<{0}>>(symbol)\n", - underlyingType); + os << formatv(" auto bit = " + "llvm::StringSwitch<::llvm::Optional<{0}>>(symbol.trim())\n", + underlyingType); for (const auto &enumerant : enumerants) { // Skip the special enumerant for None. if (auto val = enumerant.getValue()) diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -80,7 +80,7 @@ EXPECT_EQ(stringifyBitEnumWithNone(BitEnumWithNone::Bit3), "Bit3"); EXPECT_EQ( stringifyBitEnumWithNone(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3), - "Bit0|Bit3"); + "Bit0 | Bit3"); EXPECT_EQ(stringifyBitEnum64_Test(BitEnum64_Test::Bit1), "Bit1"); EXPECT_EQ( @@ -96,7 +96,7 @@ BitEnumWithNone::Bit3 | BitEnumWithNone::Bit0); EXPECT_EQ(symbolizeBitEnumWithNone("Bit2"), llvm::None); - EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit4"), llvm::None); + EXPECT_EQ(symbolizeBitEnumWithNone("Bit3 | Bit4"), llvm::None); EXPECT_EQ(symbolizeBitEnumWithoutNone("None"), llvm::None); } @@ -129,11 +129,11 @@ EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3), - "Bit0,Bit2,Bit3"); + "Bit0, Bit2, Bit3"); EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit4 | BitEnumPrimaryGroup::Bit5), - "Bits4And5,Bit0"); + "Bits4And5, Bit0"); EXPECT_EQ(stringifyBitEnumPrimaryGroup( BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 | BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3 | diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td --- a/mlir/unittests/TableGen/enums.td +++ b/mlir/unittests/TableGen/enums.td @@ -33,7 +33,9 @@ def Bit5 : I32BitEnumAttrCaseBit<"Bit5", 5>; def BitEnumWithNone : I32BitEnumAttr<"BitEnumWithNone", "A test enum", - [NoBits, Bit0, Bit3]>; + [NoBits, Bit0, Bit3]> { + let separator = " | "; +} def BitEnumWithoutNone : I32BitEnumAttr<"BitEnumWithoutNone", "A test enum", [Bit0, Bit3]>; @@ -46,12 +48,14 @@ [Bits0To3, Bits4And5]>; def BitEnumWithGroup : I32BitEnumAttr<"BitEnumWithGroup", "A test enum", - [Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]>; + [Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]> { + let separator = "|"; +} def BitEnumPrimaryGroup : I32BitEnumAttr<"BitEnumPrimaryGroup", "test enum", [Bit0, Bit1, Bit2, Bit3, Bit4, Bit5, Bits0To3, Bits4And5, Bits0To5]> { - let separator = ","; + let separator = ", "; let printBitEnumPrimaryGroups = 1; }