diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -2097,16 +2097,14 @@ op_iterator populateBundleOperandInfos(ArrayRef Bundles, const unsigned BeginIndex); +public: /// Return the BundleOpInfo for the operand at index OpIdx. /// /// It is an error to call this with an OpIdx that does not correspond to an /// bundle operand. + BundleOpInfo &getBundleOpInfoForOperand(unsigned OpIdx); const BundleOpInfo &getBundleOpInfoForOperand(unsigned OpIdx) const { - for (auto &BOI : bundle_op_infos()) - if (BOI.Begin <= OpIdx && OpIdx < BOI.End) - return BOI; - - llvm_unreachable("Did not find operand bundle for operand!"); + return const_cast(this)->getBundleOpInfoForOperand(OpIdx); } protected: diff --git a/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h b/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h --- a/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h +++ b/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h @@ -93,6 +93,28 @@ /// If the IR changes the map will be outdated. void fillMapFromAssume(CallInst &AssumeCI, RetainedKnowledgeMap &Result); +/// Represent one information held inside an operand bundle of an llvm.assume. +/// AttrKind is the property that hold. +/// WasOn if not null is that Value for which AttrKind holds. +/// ArgValue is optionally an argument. +struct RetainedKnowledge { + Attribute::AttrKind AttrKind = Attribute::None; + Value *WasOn = nullptr; + unsigned ArgValue = 0; +}; + +/// Retreive the information help by Assume on the operand at index Idx. +/// Assume should be an llvm.assume and Idx should be in the operand bundle. +RetainedKnowledge getKnowledgeFromOperandInAssume(CallInst &Assume, + unsigned Idx); + +/// Retreive the information help by the Use U of an llvm.assume. the use should +/// be in the operand bundle. +inline RetainedKnowledge getKnowledgeFromUseInAssume(const Use *U) { + return getKnowledgeFromOperandInAssume(*cast(U->getUser()), + U->getOperandNo()); +} + //===----------------------------------------------------------------------===// // Utilities for testing //===----------------------------------------------------------------------===// diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -384,6 +384,53 @@ return It; } +CallBase::BundleOpInfo &CallBase::getBundleOpInfoForOperand(unsigned OpIdx) { + /// When there isn't many bundles, we do a simple linear search. + /// Else fallback to a binary-search that use the fact that bundles usually + /// have similar number of argument to get faster convergence. + if (bundle_op_info_end() - bundle_op_info_begin() < 8) { + for (auto &BOI : bundle_op_infos()) + if (BOI.Begin <= OpIdx && OpIdx < BOI.End) + return BOI; + + llvm_unreachable("Did not find operand bundle for operand!"); + } + + assert(OpIdx >= arg_size() && "the Idx is not in the operand bundles"); + assert(bundle_op_info_end() - bundle_op_info_begin() > 0 && + OpIdx < std::prev(bundle_op_info_end())->End && + "The Idx isn't in the operand bundle"); + + /// We need a decimal number below and to prevent using floating point numbers + /// we use an intergal value multiplied by this constant. + constexpr unsigned NumberScaling = 1024; + + bundle_op_iterator Begin = bundle_op_info_begin(); + bundle_op_iterator End = bundle_op_info_end(); + bundle_op_iterator Current; + + while (Begin != End) { + unsigned ScaledOperandPerBundle = + NumberScaling * (std::prev(End)->End - Begin->Begin) / (End - Begin); + Current = Begin + (((OpIdx - Begin->Begin) * NumberScaling) / + ScaledOperandPerBundle); + if (Current >= End) + Current = std::prev(End); + assert(Current < End && Current >= Begin && + "the operand bundle doesn't cover every value in the range"); + if (OpIdx >= Current->Begin && OpIdx < Current->End) + break; + if (OpIdx >= Current->End) + Begin = Current + 1; + else + End = Current; + } + + assert(OpIdx >= Current->Begin && OpIdx < Current->End && + "the operand bundle doesn't cover every value in the range"); + return *Current; +} + //===----------------------------------------------------------------------===// // CallInst Implementation //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp b/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp --- a/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp +++ b/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp @@ -288,6 +288,23 @@ } } +RetainedKnowledge llvm::getKnowledgeFromOperandInAssume(CallInst &AssumeCI, + unsigned Idx) { + IntrinsicInst &Assume = cast(AssumeCI); + assert(Assume.getIntrinsicID() == Intrinsic::assume && + "this function is intended to be used on llvm.assume"); + CallBase::BundleOpInfo BOI = Assume.getBundleOpInfoForOperand(Idx); + RetainedKnowledge Result; + Result.AttrKind = Attribute::getAttrKindFromName(BOI.Tag->getKey()); + Result.WasOn = getValueFromBundleOpInfo(Assume, BOI, BOIE_WasOn); + if (BOI.End - BOI.Begin > BOIE_Argument) + Result.ArgValue = + cast(getValueFromBundleOpInfo(Assume, BOI, BOIE_Argument)) + ->getZExtValue(); + + return Result; +} + PreservedAnalyses AssumeBuilderPass::run(Function &F, FunctionAnalysisManager &AM) { for (Instruction &I : instructions(F)) diff --git a/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp b/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp --- a/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp +++ b/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp @@ -10,10 +10,12 @@ #include "llvm/AsmParser/Parser.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/Support/Regex.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/CommandLine.h" #include "gtest/gtest.h" +#include using namespace llvm; @@ -387,3 +389,102 @@ })); RunTest(Head, Tail, Tests); } + +static void RunRandTest(uint64_t Seed, int Size, int MinCount, int MaxCount, + unsigned MaxValue) { + LLVMContext C; + SMDiagnostic Err; + + std::random_device dev; + std::mt19937 Rng(Seed); + std::uniform_int_distribution DistCount(MinCount, MaxCount); + std::uniform_int_distribution DistValue(0, MaxValue); + std::uniform_int_distribution DistAttr(0, + Attribute::EndAttrKinds - 1); + + std::unique_ptr Mod = std::make_unique("AssumeQueryAPI", C); + if (!Mod) + Err.print("AssumeQueryAPI", errs()); + + std::vector TypeArgs; + for (int i = 0; i < (Size * 2); i++) + TypeArgs.push_back(Type::getInt32PtrTy(C)); + FunctionType *FuncType = + FunctionType::get(Type::getVoidTy(C), TypeArgs, false); + + Function *F = + Function::Create(FuncType, GlobalValue::ExternalLinkage, "test", &*Mod); + BasicBlock *BB = BasicBlock::Create(C); + BB->insertInto(F); + Instruction *Ret = ReturnInst::Create(C); + BB->getInstList().insert(BB->begin(), Ret); + Function *FnAssume = Intrinsic::getDeclaration(Mod.get(), Intrinsic::assume); + + std::vector ShuffledArgs; + std::vector HasArg; + for (auto &Arg : F->args()) { + ShuffledArgs.push_back(&Arg); + HasArg.push_back(false); + } + + std::shuffle(ShuffledArgs.begin(), ShuffledArgs.end(), Rng); + + std::vector OpBundle; + OpBundle.reserve(Size); + std::vector Args; + Args.reserve(2); + for (int i = 0; i < Size; i++) { + int count = DistCount(Rng); + int value = DistValue(Rng); + int attr = DistAttr(Rng); + std::string str; + raw_string_ostream ss(str); + ss << Attribute::getNameFromAttrKind( + static_cast(attr)); + Args.clear(); + + if (count > 0) { + Args.push_back(ShuffledArgs[i]); + HasArg[i] = true; + } + if (count > 1) + Args.push_back(ConstantInt::get(Type::getInt32Ty(C), value)); + + OpBundle.push_back(OperandBundleDef{ss.str().c_str(), std::move(Args)}); + } + + Instruction *Assume = + CallInst::Create(FnAssume, ArrayRef({ConstantInt::getTrue(C)}), + std::move(OpBundle)); + Assume->insertBefore(&F->begin()->front()); + RetainedKnowledgeMap Map; + fillMapFromAssume(*cast(Assume), Map); + for (int i = 0; i < (Size * 2); i++) { + if (!HasArg[i]) + continue; + RetainedKnowledge K = + getKnowledgeFromUseInAssume(&*ShuffledArgs[i]->use_begin()); + auto LookupIt = Map.find(RetainedKnowledgeKey{K.WasOn, K.AttrKind}); + ASSERT_TRUE(LookupIt != Map.end()); + MinMax MM = LookupIt->second; + ASSERT_TRUE(MM.Min == MM.Max); + ASSERT_TRUE(MM.Min == K.ArgValue); + } +} + +TEST(AssumeQueryAPI, getKnowledgeFromUseInAssume) { + // // For Fuzzing + // std::random_device dev; + // std::mt19937 Rng(dev()); + // while (true) { + // unsigned Seed = Rng(); + // dbgs() << Seed << "\n"; + // RunRandTest(Seed, 100000, 0, 2, 100); + // } + RunRandTest(23456, 4, 0, 2, 100); + RunRandTest(560987, 25, -3, 2, 100); + + // Large bundles can lead to special cases. this is why this test is soo + // large. + RunRandTest(9876789, 100000, -0, 7, 100); +}