diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -1442,9 +1442,12 @@ // Ensure only bits that can be present in the enum are set return static_cast(~static_cast(bits) & static_cast(15u)); } -inline constexpr bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) { +inline constexpr bool bitEnumContainsAll(MyBitEnum bits, MyBitEnum bit) { return (bits & bit) == bit; } +inline constexpr bool bitEnumContainsAny(MyBitEnum bits, MyBitEnum bit) { + return (static_cast(bits) & static_cast(bit)) != 0; +} inline constexpr MyBitEnum bitEnumClear(MyBitEnum bits, MyBitEnum bit) { return bits & ~bit; } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -260,7 +260,8 @@ kMemoryAccessAttrName)) return failure(); - if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) { + if (spirv::bitEnumContainsAll(memoryAccessAttr, + spirv::MemoryAccess::Aligned)) { // Parse integer attribute for alignment. Attribute alignmentAttr; Type i32Type = parser.getBuilder().getIntegerType(32); @@ -290,7 +291,8 @@ kSourceMemoryAccessAttrName)) return failure(); - if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) { + if (spirv::bitEnumContainsAll(memoryAccessAttr, + spirv::MemoryAccess::Aligned)) { // Parse integer attribute for alignment. Attribute alignmentAttr; Type i32Type = parser.getBuilder().getIntegerType(32); @@ -316,7 +318,7 @@ printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; - if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { + if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { // Print integer alignment attribute. if (auto alignment = (alignmentAttrValue ? alignmentAttrValue : memoryOp.alignment())) { @@ -349,7 +351,7 @@ printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; - if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { + if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { // Print integer alignment attribute. if (auto alignment = (alignmentAttrValue ? alignmentAttrValue : memoryOp.alignment())) { @@ -407,7 +409,7 @@ spirv::ImageOperands::MakeTexelVisible | spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend; - if (spirv::bitEnumContains(attr.getValue(), noSupportOperands)) + if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands)) llvm_unreachable("unimplemented operands of Image Operands"); return success(); @@ -491,8 +493,8 @@ << memAccessAttr; } - if (spirv::bitEnumContains(memAccess.getValue(), - spirv::MemoryAccess::Aligned)) { + if (spirv::bitEnumContainsAll(memAccess.getValue(), + spirv::MemoryAccess::Aligned)) { if (!op->getAttr(kAlignmentAttrName)) { return memoryOp.emitOpError("missing alignment value"); } @@ -535,8 +537,8 @@ << memAccess; } - if (spirv::bitEnumContains(memAccess.getValue(), - spirv::MemoryAccess::Aligned)) { + if (spirv::bitEnumContainsAll(memAccess.getValue(), + spirv::MemoryAccess::Aligned)) { if (!op->getAttr(kSourceAlignmentAttrName)) { return memoryOp.emitOpError("missing alignment value"); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -162,7 +162,7 @@ llvm::FastMathFlags ret; auto fmf = op.getFastmathFlags(); for (auto it : handlers) - if (bitEnumContains(fmf, it.first)) + if (bitEnumContainsAll(fmf, it.first)) (ret.*(it.second))(true); return ret; } diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -138,7 +138,8 @@ // inline constexpr operator&( a, b); // inline constexpr operator^( a, b); // inline constexpr operator~( bits); -// inline constexpr bool bitEnumContains( bits, bit); +// inline constexpr bool bitEnumContainsAll( bits, bit); +// inline constexpr bool bitEnumContainsAny( bits, bit); // inline constexpr bitEnumClear( bits, bit); // inline constexpr bitEnumSet( bits, bit, // bool value=true); @@ -161,9 +162,12 @@ // Ensure only bits that can be present in the enum are set return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u)); } -inline constexpr bool bitEnumContains({0} bits, {0} bit) {{ +inline constexpr bool bitEnumContainsAll({0} bits, {0} bit) {{ return (bits & bit) == bit; } +inline constexpr bool bitEnumContainsAny({0} bits, {0} bit) {{ + return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0; +} inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{ return bits & ~bit; } diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -142,10 +142,10 @@ } TEST(EnumsGenTest, GeneratedOperator) { - EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3, - BitEnumWithNone::Bit0)); - EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit0 & BitEnumWithNone::Bit3, - BitEnumWithNone::Bit0)); + EXPECT_TRUE(bitEnumContainsAll(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3, + BitEnumWithNone::Bit0)); + EXPECT_FALSE(bitEnumContainsAll(BitEnumWithNone::Bit0 & BitEnumWithNone::Bit3, + BitEnumWithNone::Bit0)); } TEST(EnumsGenTest, GeneratedSymbolToCustomStringFn) {