Index: include/llvm/ADT/BitmaskEnum.h =================================================================== --- /dev/null +++ include/llvm/ADT/BitmaskEnum.h @@ -0,0 +1,138 @@ +//===-- llvm/ADT/BitmaskEnum.h ----------------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_BITMASKENUM_H +#define LLVM_ADT_BITMASKENUM_H + +#include +#include +#include + +#include "llvm/Support/MathExtras.h" + +/// \brief LLVM_MARK_AS_BITMASK_ENUM lets you opt in an individual enum type so +/// you can perform bitwise operations on it without putting static_cast +/// everywhere. +/// +/// \code +/// enum MyEnum { +/// E1 = 1, E2 = 2, E3 = 4, E4 = 8, +/// LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ E4) +/// }; +/// +/// void Foo() { +/// MyEnum A = (E1 | E2) & E3 ^ ~E4; // Look, ma: No static_cast! +/// } +/// \endcode +/// +/// Normally when you do a bitwise operation on an enum value, you get back an +/// instance of the underlying type (e.g. int). But using this macro, bitwise +/// ops on your enum will return you back instances of the enum. This is +/// particularly useful for enums which represent a combination of flags. +/// +/// The parameter to LLVM_MARK_AS_BITMASK_ENUM should be the largest individual +/// value in your enum. All of the enum's values must be non-negative. +#define LLVM_MARK_AS_BITMASK_ENUM(LargestValue) \ + LLVM_BITMASK_LARGEST_ENUMERATOR = LargestValue + +namespace llvm { +namespace BitmaskEnumDetail { + +/// Traits class to determine whether an enum has a +/// LLVM_BITMASK_LARGEST_ENUMERATOR enumerator. +template +struct is_bitmask_enum : std::false_type {}; + +template +struct is_bitmask_enum< + E, typename std::enable_if= + 0>::type> : std::true_type {}; + +/// Get a bitmask with 1s in all places up to the high-order bit of E's largest +/// value. +template typename std::underlying_type::type Mask() { + // On overflow, NextPowerOf2 returns zero with the type uint64_t, so + // subtracting 1 gives us the mask with all bits set, like we want. + return NextPowerOf2(static_cast::type>( + E::LLVM_BITMASK_LARGEST_ENUMERATOR)) - + 1; +} + +/// Check that Val is in range for E, and return Val cast to E's underlying +/// type. +template typename std::underlying_type::type Underlying(E Val) { + auto U = static_cast::type>(Val); + assert(U >= 0 && "Negative enum values are not allowed."); + assert(U <= Mask() && "Enum value too large (or largest val too small?)"); + return U; +} + +} // namespace BitmaskEnumDetail +} // namespace llvm + +template ::value>::type> +E operator~(E Val) { + return static_cast(~llvm::BitmaskEnumDetail::Underlying(Val) & + llvm::BitmaskEnumDetail::Mask()); +} + +template ::value>::type> +E operator|(E LHS, E RHS) { + return static_cast(llvm::BitmaskEnumDetail::Underlying(LHS) | + llvm::BitmaskEnumDetail::Underlying(RHS)); +} + +template ::value>::type> +E operator&(E LHS, E RHS) { + return static_cast(llvm::BitmaskEnumDetail::Underlying(LHS) & + llvm::BitmaskEnumDetail::Underlying(RHS)); +} + +template ::value>::type> +E operator^(E LHS, E RHS) { + return static_cast(llvm::BitmaskEnumDetail::Underlying(LHS) ^ + llvm::BitmaskEnumDetail::Underlying(RHS)); +} + +// |=, &=, and ^= return a reference to LHS, to match the behavior of the +// operators on builtin types. + +template ::value>::type> +E &operator|=(E &LHS, E RHS) { + LHS = LHS | RHS; + return LHS; +} + +template ::value>::type> +E &operator&=(E &LHS, E RHS) { + LHS = LHS & RHS; + return LHS; +} + +template ::value>::type> +E &operator^=(E &LHS, E RHS) { + LHS = LHS ^ RHS; + return LHS; +} + +#endif Index: unittests/ADT/BitmaskEnumTest.cpp =================================================================== --- /dev/null +++ unittests/ADT/BitmaskEnumTest.cpp @@ -0,0 +1,122 @@ +//===- llvm/unittest/ADT/BitmaskEnumTest.cpp - BitmaskEnum unit tests -----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/BitmaskEnum.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { +enum Flags { + F0 = 0, + F1 = 1, + F2 = 2, + F3 = 4, + F4 = 8, + LLVM_MARK_AS_BITMASK_ENUM(F4) +}; + +TEST(BitmaskEnumTest, BitwiseOr) { + Flags f = F1 | F2; + EXPECT_EQ(3, f); + + f = f | F3; + EXPECT_EQ(7, f); +} + +TEST(BitmaskEnumTest, BitwiseOrEquals) { + Flags f = F1; + f |= F3; + EXPECT_EQ(5, f); + + // |= should return a reference to the LHS. + f = F2; + (f |= F3) = F1; + EXPECT_EQ(F1, f); +} + +TEST(BitmaskEnumTest, BitwiseAnd) { + Flags f = static_cast(3) & F2; + EXPECT_EQ(F2, f); + + f = (f | F3) & (F1 | F2 | F3); + EXPECT_EQ(6, f); +} + +TEST(BitmaskEnumTest, BitwiseAndEquals) { + Flags f = F1 | F2 | F3; + f &= F1 | F2; + EXPECT_EQ(3, f); + + // &= should return a reference to the LHS. + (f &= F1) = F3; + EXPECT_EQ(F3, f); +} + +TEST(BitmaskEnumTest, BitwiseXor) { + Flags f = (F1 | F2) ^ (F2 | F3); + EXPECT_EQ(5, f); + + f = f ^ F1; + EXPECT_EQ(4, f); +} + +TEST(BitmaskEnumTest, BitwiseXorEquals) { + Flags f = (F1 | F2); + f ^= (F2 | F4); + EXPECT_EQ(9, f); + + // ^= should return a reference to the LHS. + (f ^= F4) = F3; + EXPECT_EQ(F3, f); +} + +TEST(BitmaskEnumTest, BitwiseNot) { + Flags f = ~F1; + EXPECT_EQ(14, f); // Largest value for f is 15. + EXPECT_EQ(15, ~F0); +} + +enum class FlagsClass { + F0 = 0, + F1 = 1, + F2 = 2, + F3 = 4, + LLVM_MARK_AS_BITMASK_ENUM(F3) +}; + +TEST(BitmaskEnumTest, ScopedEnum) { + FlagsClass f = (FlagsClass::F1 & ~FlagsClass::F0) | FlagsClass::F2; + f |= FlagsClass::F3; + EXPECT_EQ(7, static_cast(f)); +} + +} // namespace + +namespace foo { +namespace bar { +namespace { +enum FlagsInNamespace { + F0 = 0, + F1 = 1, + F2 = 2, + F3 = 4, + LLVM_MARK_AS_BITMASK_ENUM(F3) +}; +} // namespace +} // namespace foo +} // namespace bar + +namespace { +TEST(BitmaskEnumTest, EnumInNamespace) { + foo::bar::FlagsInNamespace f = ~foo::bar::F0 & (foo::bar::F1 | foo::bar::F2); + f |= foo::bar::F3; + EXPECT_EQ(7, f); +} +} // namespace Index: unittests/ADT/CMakeLists.txt =================================================================== --- unittests/ADT/CMakeLists.txt +++ unittests/ADT/CMakeLists.txt @@ -7,6 +7,7 @@ APIntTest.cpp APSIntTest.cpp ArrayRefTest.cpp + BitmaskEnumTest.cpp BitVectorTest.cpp DAGDeltaAlgorithmTest.cpp DeltaAlgorithmTest.cpp