Index: include/llvm/ADT/FlagsEnum.h =================================================================== --- /dev/null +++ include/llvm/ADT/FlagsEnum.h @@ -0,0 +1,137 @@ +//===-- llvm/ADT/FlagsEnum.h ------------------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_FLAGSENUM_H +#define LLVM_ADT_FLAGSENUM_H + +#include +#include +#include + +#include "llvm/Support/MathExtras.h" + +namespace llvm { + +/// \brief ENABLE_FLAGS_ENUM lets you opt in an individual enum type so you can +/// perform bitwise operations on it without putting static_cast everywhere. +/// +/// \code +/// enum MyEnum { E1 = 1, E2 = 2, E3 = 4, E4 = 8}; +/// ENABLE_FLAGS_ENUM(MyEnum, /* LargestValue = */ E4); +/// +/// void Foo() { +/// MyEnum A = (E1 | E2) & E3 ^ ~E4; // Look, ma: No static_cast! +/// } +/// \endcode +/// +/// Normally when you do a bitwise operation on an enum value, you get back an +/// instance of the underlying type (e.g. int). But using this macro, bitwise +/// ops on your enum will return you back instances of the enum. This is +/// particularly useful for enums which represent a combination of flags. +/// +/// The second parameter to ENABLE_FLAGS_ENUM should be the largest individual +/// value in your enum. All of the enum's values must be non-negative. +/// +/// ENABLE_FLAGS_ENUM must appear in the same namespace as the enum itself; +/// otherwise, the "using" statements inside the macro won't pull the operators +/// into the right namespace. +#define ENABLE_FLAGS_ENUM(EnumType, LargestValue) \ + static_assert(std::is_enum::value, "EnumType must be an enum."); \ + inline EnumType llvm_FlagsEnum_GetLargestValue(EnumType) { \ + return LargestValue; \ + } \ + using ::llvm::FlagsEnumDetail::operator|; \ + using ::llvm::FlagsEnumDetail::operator&; \ + using ::llvm::FlagsEnumDetail::operator^; \ + using ::llvm::FlagsEnumDetail::operator~; \ + using ::llvm::FlagsEnumDetail::operator|=; \ + using ::llvm::FlagsEnumDetail::operator&=; \ + using ::llvm::FlagsEnumDetail::operator^=; + +namespace FlagsEnumDetail { + +// Traits class to determine whether we can find an +// llvm_FlagsEnum_GetLargestValue overload for type E via ADL. +template +struct is_flags_enum : std::false_type {}; + +template +struct is_flags_enum< + E, typename std::enable_if())) >= 0>::type> + : std::true_type {}; + +// Get a bitmask with 1s in all places up to the high-order bit of E's largest +// value. +template typename std::underlying_type::type Mask() { + // On overflow, NextPowerOf2 returns uint64_t{0}, so subtracting 1 gives us + // the mask with all bits set, like we want. + return NextPowerOf2(llvm_FlagsEnum_GetLargestValue(E{})) - 1; +} + +// Check that Val is in range for E, and return Val cast to E's underlying type. +template typename std::underlying_type::type Underlying(E Val) { + assert(Val >= 0 && "Negative enum values are not allowed."); + assert(Val <= Mask() && + "Enum value too large (or LargestValue too small?)"); + return static_cast::type>(Val); +} + +template ::value>::type> +E operator~(E Val) { + return static_cast(~Underlying(Val) & Mask()); +} + +template ::value>::type> +E operator|(E Lhs, E Rhs) { + return static_cast(Underlying(Lhs) | Underlying(Rhs)); +} + +template ::value>::type> +E operator&(E Lhs, E Rhs) { + return static_cast(Underlying(Lhs) & Underlying(Rhs)); +} + +template ::value>::type> +E operator^(E Lhs, E Rhs) { + return static_cast(Underlying(Lhs) ^ Underlying(Rhs)); +} + +// |=, &=, and ^= return a reference to LHS, to match the behavior of the +// operators on builtin types. + +template ::value>::type> +E &operator|=(E &Lhs, E Rhs) { + Lhs = Lhs | Rhs; + return Lhs; +} + +template ::value>::type> +E &operator&=(E &Lhs, E Rhs) { + Lhs = Lhs & Rhs; + return Lhs; +} + +template ::value>::type> +E &operator^=(E &Lhs, E Rhs) { + Lhs = Lhs ^ Rhs; + return Lhs; +} + +} // namespace FlagsEnumDetail +} // namespace llvm + +#endif Index: unittests/ADT/CMakeLists.txt =================================================================== --- unittests/ADT/CMakeLists.txt +++ unittests/ADT/CMakeLists.txt @@ -12,6 +12,7 @@ DeltaAlgorithmTest.cpp DenseMapTest.cpp DenseSetTest.cpp + FlagsEnumTest.cpp FoldingSet.cpp FunctionRefTest.cpp HashingTest.cpp Index: unittests/ADT/FlagsEnumTest.cpp =================================================================== --- /dev/null +++ unittests/ADT/FlagsEnumTest.cpp @@ -0,0 +1,92 @@ +//===- llvm/unittest/ADT/FlagsEnumTest.cpp - FlagsEnum unit tests ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/FlagsEnum.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace llvm { +namespace { +enum Flags { + F0 = 0, + F1 = 1, + F2 = 2, + F3 = 4, + F4 = 8, +}; + +ENABLE_FLAGS_ENUM(Flags, F4); +} // namespace + +namespace { + +TEST(FlagsEnumTest, BitwiseOr) { + Flags f = F1 | F2; + EXPECT_EQ(3, f); + + f = f | F3; + EXPECT_EQ(7, f); +} + +TEST(FlagsEnumTest, BitwiseOrEquals) { + Flags f = F1; + f |= F3; + EXPECT_EQ(5, f); + + // |= should return a reference to the LHS. + f = F2; + (f |= F3) = F1; + EXPECT_EQ(F1, f); +} + +TEST(FlagsEnumTest, BitwiseAnd) { + Flags f = static_cast(3) & F2; + EXPECT_EQ(F2, f); + + f = (f | F3) & (F1 | F2 | F3); + EXPECT_EQ(6, f); +} + +TEST(FlagsEnumTest, BitwiseAndEquals) { + Flags f = F1 | F2 | F3; + f &= F1 | F2; + EXPECT_EQ(3, f); + + // &= should return a reference to the LHS. + (f &= F1) = F3; + EXPECT_EQ(F3, f); +} + +TEST(FlagsEnumTest, BitwiseXor) { + Flags f = (F1 | F2) ^ (F2 | F3); + EXPECT_EQ(5, f); + + f = f ^ F1; + EXPECT_EQ(4, f); +} + +TEST(FlagsEnumTest, BitwiseXorEquals) { + Flags f = (F1 | F2); + f ^= (F2 | F4); + EXPECT_EQ(9, f); + + // ^= should return a reference to the LHS. + (f ^= F4) = F3; + EXPECT_EQ(F3, f); +} + +TEST(FlagsEnumTest, BitwiseNot) { + Flags f = ~F1; + EXPECT_EQ(14, f); // Largest value for f is 15. + EXPECT_EQ(15, ~F0); +} + +} // namespace +} // namespace llvm