diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -226,6 +226,18 @@ /// return undef. Value *isBytewiseValue(Value *V, const DataLayout &DL); + // Result type for |getMostFrequestByte|. + struct MostFrequentByte { + MostFrequentByte(Value *Value, size_t Count, size_t Other) + : Value(Value), Count(Count), Other(Other) {} + Value *Value = nullptr; + size_t Count = 0; + size_t Other = 0; + }; + /// Find most frequent byte in specified value, approximately. The current + /// version ignores some types and non-splat integers. + MostFrequentByte getMostFrequentByte(Value *V, const DataLayout &DL); + /// Given an aggregrate and an sequence of indices, see if the scalar value /// indexed is already around as a register, for example if it were inserted /// directly into the aggregrate. diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -3166,26 +3166,29 @@ return true; } -Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) { +namespace { - // All byte-wide stores are splatable, even of arbitrary variables. - if (V->getType()->isIntegerTy(8)) - return V; +struct ValueHistogram { + uint64_t Values[256] = {}; + uint64_t Other = 0; +}; +static void addToValueHistogram(ValueHistogram &H, Value *V, + const DataLayout &DL) { LLVMContext &Ctx = V->getContext(); // Undef don't care. - auto *UndefInt8 = UndefValue::get(Type::getInt8Ty(Ctx)); if (isa(V)) - return UndefInt8; + return; if (auto *T = dyn_cast(V->getType())) if (!T->getNumElements()) - return UndefInt8; + return; if (auto *T = dyn_cast(V->getType())) if (!T->getNumElements()) - return UndefInt8; + return; + const uint64_t Size = DL.getTypeStoreSize(V->getType()); Constant *C = dyn_cast(V); if (!C) { // Conceptually, we could handle things like: @@ -3194,12 +3197,15 @@ // %c = or i16 %a, %b // but until there is an example that actually needs this, it doesn't seem // worth worrying about. - return nullptr; + H.Other += Size; + return; } // Handle 'null' ConstantArrayZero etc. - if (C->isNullValue()) - return Constant::getNullValue(Type::getInt8Ty(Ctx)); + if (C->isNullValue()) { + H.Values[0] += Size; + return; + } // Constant floating-point values can be handled as integer values if the // corresponding integer value is "byteable". An important case is 0.0. @@ -3212,61 +3218,83 @@ else if (CFP->getType()->isDoubleTy()) Ty = Type::getInt64Ty(Ctx); // Don't handle long double formats, which have strange constraints. - return Ty ? isBytewiseValue(ConstantExpr::getBitCast(CFP, Ty), DL) - : nullptr; + if (Ty) + addToValueHistogram(H, ConstantExpr::getBitCast(CFP, Ty), DL); + else + H.Other += Size; + return; } // We can handle constant integers that are multiple of 8 bits. if (ConstantInt *CI = dyn_cast(C)) { - if (CI->getBitWidth() % 8 == 0) { - assert(CI->getBitWidth() > 8 && "8 bits should be handled above!"); - if (!CI->getValue().isSplat(8)) - return nullptr; - return ConstantInt::get(Ctx, CI->getValue().trunc(8)); - } + if (CI->getBitWidth() % 8 == 0 && CI->getValue().isSplat(8)) + H.Values[CI->getValue().getRawData()[0] % 256] += Size; + else + H.Other += Size; + return; } if (auto *CE = dyn_cast(C)) { if (CE->getOpcode() == Instruction::IntToPtr) { auto PS = DL.getPointerSizeInBits( cast(CE->getType())->getAddressSpace()); - return isBytewiseValue( - ConstantExpr::getIntegerCast(CE->getOperand(0), - Type::getIntNTy(Ctx, PS), false), - DL); + addToValueHistogram(H, + ConstantExpr::getIntegerCast(CE->getOperand(0), + Type::getIntNTy(Ctx, PS), + false), + DL); + } else { + H.Other += Size; } + return; } - auto Merge = [&](Value *LHS, Value *RHS) -> Value * { - if (LHS == RHS) - return LHS; - if (!LHS || !RHS) - return nullptr; - if (LHS == UndefInt8) - return RHS; - if (RHS == UndefInt8) - return LHS; - return nullptr; - }; - if (ConstantDataSequential *CA = dyn_cast(C)) { - Value *Val = UndefInt8; + assert(CA->getNumElements()); for (unsigned I = 0, E = CA->getNumElements(); I != E; ++I) - if (!(Val = Merge(Val, isBytewiseValue(CA->getElementAsConstant(I), DL)))) - return nullptr; - return Val; + addToValueHistogram(H, CA->getElementAsConstant(I), DL); + return; } if (isa(C)) { - Value *Val = UndefInt8; + assert(C->getNumOperands()); for (unsigned I = 0, E = C->getNumOperands(); I != E; ++I) - if (!(Val = Merge(Val, isBytewiseValue(C->getOperand(I), DL)))) - return nullptr; - return Val; + addToValueHistogram(H, C->getOperand(I), DL); + return; } // Don't try to handle the handful of other constants. - return nullptr; + H.Other += Size; +} + +} // namespace + +Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) { + MostFrequentByte MFB = getMostFrequentByte(V, DL); + return MFB.Other ? nullptr : MFB.Value; +} + +MostFrequentByte llvm::getMostFrequentByte(Value *V, const DataLayout &DL) { + if (V->getType()->isIntegerTy(8)) + return {V, 1, 0}; + + ValueHistogram H; + addToValueHistogram(H, V, DL); + + size_t Total = + std::accumulate(std::begin(H.Values), std::end(H.Values), H.Other); + auto I = std::max_element(std::begin(H.Values), std::end(H.Values)); + auto Ty = Type::getInt8Ty(V->getContext()); + if (*I) { + return { + ConstantInt::get(Ty, std::distance(std::begin(H.Values), I)), + *I, + Total - *I, + }; + } + if (!Total) + return {UndefValue::get(Ty), 0, 0}; + return {nullptr, 0, Total}; } // This is the recursive version of BuildSubAggregate. It takes a few different diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -694,239 +694,329 @@ expectKnownBits(/*zero*/ 2u, /*one*/ 0u); } -class IsBytewiseValueTest : public ValueTrackingTest, - public ::testing::WithParamInterface< - std::pair> { +struct BytewiseValueTestInput { + const char *Result; + struct MFBResultType { + MFBResultType(uint64_t Count = 0, uint64_t Other = 0, + const char *AltResult = nullptr) + : Count(Count), Other(Other), AltResult(AltResult) {} + uint64_t Count; + uint64_t Other; + const char *AltResult; // use |Result| if nullptr. + } MFBResult; + const char *Input; +}; + +std::ostream &operator<<(std::ostream &os, + const BytewiseValueTestInput &Input) { + return os << '"' << Input.Input << '"'; +} + +class BytewiseValueTest + : public ValueTrackingTest, + public ::testing::WithParamInterface { protected: + void SetUp() override { + M = parseModule(std::string("@test = global ") + GetParam().Input); + GV = dyn_cast(M->getNamedValue("test")); + } + + static std::string toString(Value *V) { + if (!V) + return ""; + std::string Buff; + raw_string_ostream S(Buff); + S << *V; + return S.str(); + } + + std::unique_ptr M; + GlobalVariable *GV = nullptr; }; -const std::pair IsBytewiseValueTests[] = { +const BytewiseValueTestInput BytewiseValueTests[] = { { "i8 0", + {8}, "i48* null", }, { "i8 undef", + {0}, "i48* undef", }, { "i8 0", + {1}, "i8 zeroinitializer", }, { "i8 0", + {1}, "i8 0", }, { "i8 -86", + {1}, "i8 -86", }, { "i8 -1", + {1}, "i8 -1", }, { "i8 undef", + {0}, "i16 undef", }, { "i8 0", + {2}, "i16 0", }, { "", + {0, 2}, "i16 7", }, { "i8 -86", + {2}, "i16 -21846", }, { "i8 -1", + {2}, "i16 -1", }, { "i8 0", + {6}, "i48 0", }, { "i8 -1", + {6}, "i48 -1", }, { "i8 0", + {7}, "i49 0", }, { "", + {0, 7}, "i49 -1", }, { "i8 0", + {2}, "half 0xH0000", }, { "i8 -85", + {2}, "half 0xHABAB", }, { "i8 0", + {4}, "float 0.0", }, { "i8 -1", + {4}, "float 0xFFFFFFFFE0000000", }, { "i8 0", + {8}, "double 0.0", }, { "i8 -15", + {8}, "double 0xF1F1F1F1F1F1F1F1", }, { "i8 undef", + {}, "i16* undef", }, { "i8 0", + {8}, "i16* inttoptr (i64 0 to i16*)", }, { "i8 -1", + {8}, "i16* inttoptr (i64 -1 to i16*)", }, { "i8 -86", + {8}, "i16* inttoptr (i64 -6148914691236517206 to i16*)", }, { "", + {0, 8}, "i16* inttoptr (i48 -1 to i16*)", }, { "i8 -1", + {8}, "i16* inttoptr (i96 -1 to i16*)", }, { "i8 undef", + {}, "[0 x i8] zeroinitializer", }, { "i8 undef", + {}, "[0 x i8] undef", }, { "i8 0", + {6}, "[6 x i8] zeroinitializer", }, { "i8 undef", + {}, "[6 x i8] undef", }, { "i8 1", + {5}, "[5 x i8] [i8 1, i8 1, i8 1, i8 1, i8 1]", }, { "", + {0, 40}, "[5 x i64] [i64 1, i64 1, i64 1, i64 1, i64 1]", }, { "i8 -1", + {40}, "[5 x i64] [i64 -1, i64 -1, i64 -1, i64 -1, i64 -1]", }, { "", + {3, 1, "i8 1"}, "[4 x i8] [i8 1, i8 2, i8 1, i8 1]", }, { "i8 1", + {3}, "[4 x i8] [i8 1, i8 undef, i8 1, i8 1]", }, { "i8 0", + {6}, "<6 x i8> zeroinitializer", }, { "i8 undef", + {}, "<6 x i8> undef", }, { "i8 1", + {5}, "<5 x i8> ", }, { "", + {0, 40}, "<5 x i64> ", }, { "i8 -1", + {40}, "<5 x i64> ", }, { "", + {3, 1, "i8 1"}, "<4 x i8> ", }, { "i8 5", + {1}, "<2 x i8> < i8 5, i8 undef >", }, { "i8 0", + {8}, "[2 x [2 x i16]] zeroinitializer", }, { "i8 undef", + {}, "[2 x [2 x i16]] undef", }, { "i8 -86", + {8}, "[2 x [2 x i16]] [[2 x i16] [i16 -21846, i16 -21846], " "[2 x i16] [i16 -21846, i16 -21846]]", }, { "", + {6, 2, "i8 -86"}, "[2 x [2 x i16]] [[2 x i16] [i16 -21846, i16 -21846], " "[2 x i16] [i16 -21836, i16 -21846]]", }, { "i8 undef", + {}, "{ } zeroinitializer", }, { "i8 undef", + {}, "{ } undef", }, { "i8 0", + {24}, "{i8, i64, i16*} zeroinitializer", }, { "i8 undef", + {}, "{i8, i64, i16*} undef", }, { "i8 -86", + {9}, "{i8, i64, i16*} {i8 -86, i64 -6148914691236517206, i16* undef}", }, { "", + {8, 1, "i8 -86"}, "{i8, i64, i16*} {i8 86, i64 -6148914691236517206, i16* undef}", }, }; -INSTANTIATE_TEST_CASE_P(IsBytewiseValueParamTests, IsBytewiseValueTest, - ::testing::ValuesIn(IsBytewiseValueTests)); - -TEST_P(IsBytewiseValueTest, IsBytewiseValue) { - auto M = parseModule(std::string("@test = global ") + GetParam().second); - GlobalVariable *GV = dyn_cast(M->getNamedValue("test")); - Value *Actual = isBytewiseValue(GV->getInitializer(), M->getDataLayout()); - std::string Buff; - raw_string_ostream S(Buff); - if (Actual) - S << *Actual; - EXPECT_EQ(GetParam().first, S.str()); +INSTANTIATE_TEST_CASE_P(BytewiseValueParamTests, BytewiseValueTest, + ::testing::ValuesIn(BytewiseValueTests)); + +TEST_P(BytewiseValueTest, IsBytewiseValue) { + EXPECT_EQ(GetParam().Result, toString(isBytewiseValue(GV->getInitializer(), + M->getDataLayout()))); +} + +TEST_P(BytewiseValueTest, GetMostFrequentByte) { + auto Actual = getMostFrequentByte(GV->getInitializer(), M->getDataLayout()); + auto Expected = GetParam().MFBResult; + EXPECT_EQ(Expected.Count, Actual.Count); + EXPECT_EQ(Expected.Other, Actual.Other); + EXPECT_EQ(Expected.AltResult ? Expected.AltResult : GetParam().Result, + toString(Actual.Value)); }