Index: llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h =================================================================== --- llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h +++ llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h @@ -16,6 +16,7 @@ #ifndef LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H #define LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H +#include "llvm/IR/Attributes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/PassManager.h" @@ -26,16 +27,58 @@ /// If no information derived from \p I, this call returns null. /// The returned instruction is not inserted anywhere. CallInst *BuildAssumeFromInst(const Instruction *I, Module *M); -CallInst *BuildAssumeFromInst(Instruction *I) { +inline CallInst *BuildAssumeFromInst(Instruction *I) { return BuildAssumeFromInst(I, I->getModule()); } +/// Query the operand bundle of an llvm.assume to find a single attribute of +/// the specified kind applied on a specified Value. +/// +/// This has a log(n) complexity n being the number of operand bundles. It +/// should only be used when a single attribute is going to be queried. +/// +/// Return true iff the queried attribute was found. +/// If ArgVal is set. the argument will be stored to ArgVal. +/// +/// It is possible to have multiple Value for the argument of an attribute in +/// the same llvm.assume on the same llvm::Value. This is rare but need to be +/// dealt with. +enum class AssumeQuery { + Highest, ///< Take the highest value available. + Lowest, ///< Take the lowest value available. +}; +bool hasAttributeInAssume( + CallInst &AssumeCI, Value *IsOn, StringRef AttrName, + uint64_t *ArgVal = nullptr, + AssumeQuery AQR = AssumeQuery::Highest); +inline bool hasAttributeInAssume( + CallInst &AssumeCI, Value *IsOn, Attribute::AttrKind Kind, + uint64_t *ArgVal = nullptr, + AssumeQuery AQR = AssumeQuery::Highest) { + return hasAttributeInAssume( + AssumeCI, IsOn, Attribute::getNameFromAttrKind(Kind), ArgVal, AQR); +} + +/// TODO: Add an function to create/fill a map from the bundle when users intend +/// to make many different queries on the same bundles. to be used for example +/// in the Attributor. + +//===----------------------------------------------------------------------===// +// Utilities for testing +//===----------------------------------------------------------------------===// + /// This pass will try to build an llvm.assume for every instruction in the /// function. Its main purpose is testing. struct AssumeBuilderPass : public PassInfoMixin { PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); }; +struct SetPreserveAllScope { + bool PreviousValue; + SetPreserveAllScope(bool Value); + ~SetPreserveAllScope(); +}; + } // namespace llvm #endif Index: llvm/lib/Transforms/Utils/KnowledgeRetention.cpp =================================================================== --- llvm/lib/Transforms/Utils/KnowledgeRetention.cpp +++ llvm/lib/Transforms/Utils/KnowledgeRetention.cpp @@ -149,6 +149,79 @@ return Builder.build(); } +#ifndef NDEBUG + +static bool isExistingAttribute(StringRef Name) { + return StringSwitch(Name) +#define GET_ATTR_NAMES +#define ATTRIBUTE_ALL(ENUM_NAME, DISPLAY_NAME) .Case(#DISPLAY_NAME, true) +#include "llvm/IR/Attributes.inc" + .Default(false); +} + +#endif + +bool llvm::hasAttributeInAssume(CallInst &AssumeCI, Value* IsOn, + StringRef AttrName, uint64_t *ArgVal, + AssumeQuery AQR) { + IntrinsicInst &Assume = cast(AssumeCI); + assert(Assume.getIntrinsicID() == Intrinsic::assume); + assert(isExistingAttribute(AttrName)); + assert((ArgVal == nullptr || + Attribute::getAttrKindFromName(AttrName) == Attribute::None || + Attribute::doesAttrKindHaveArgument( + Attribute::getAttrKindFromName(AttrName))) && + "requested value for an attribute that has no argument"); + if (Assume.bundle_op_infos().empty()) + return false; + + using Pair = std::pair; + Pair ToFind = + std::make_pair(AttrName, IsOn ? IsOn->getName() : StringRef("")); + auto getPair = [&Assume](const CallBase::BundleOpInfo &BOI) { + assert(BOI.Begin <= BOI.End); + assert(isExistingAttribute(BOI.Tag->getKey())); + return std::make_pair( + BOI.Tag->getKey(), + BOI.Begin == BOI.End + ? StringRef("") + : (Assume.op_begin() + BOI.Begin)->get()->getName()); + }; + + CallInst::bundle_op_iterator Lookup; + + /// The elements in the operand bundle are sorted by Tag then the Name of the + /// value they are on then by increasing order of argument value. + if (AQR == AssumeQuery::Lowest) + Lookup = llvm::lower_bound( + Assume.bundle_op_infos(), ToFind, + [&getPair](const CallBase::BundleOpInfo &BOI, const Pair &RHS) { + auto Other = getPair(BOI); + return Other < RHS; + }); + else + Lookup = std::prev(llvm::upper_bound( + Assume.bundle_op_infos(), ToFind, + [&getPair](const Pair &LHS, const CallBase::BundleOpInfo &BOI) { + auto Other = getPair(BOI); + return LHS < Other; + })); + + if (Lookup == Assume.bundle_op_info_end() || + Lookup->Tag->getKey() != AttrName || + (Lookup->Begin == Lookup->End && IsOn) || + (IsOn && (Assume.op_begin() + Lookup->Begin)->get() != IsOn)) + return false; + assert(Lookup->Begin <= Lookup->End); + if (Lookup->Begin + 2 > Lookup->End) + return true; + assert(Lookup->Begin + 2 == Lookup->End); + if (ArgVal) + *ArgVal = cast((Assume.op_begin() + Lookup->Begin + 1)->get()) + ->getZExtValue(); + return true; +} + PreservedAnalyses AssumeBuilderPass::run(Function &F, FunctionAnalysisManager &AM) { for (Instruction &I : instructions(F)) @@ -156,3 +229,11 @@ Assume->insertBefore(&I); return PreservedAnalyses::all(); } + +SetPreserveAllScope::SetPreserveAllScope(bool Value) { + PreviousValue = ShouldPreserveAllAttributes; + ShouldPreserveAllAttributes.setValue(Value); +} +SetPreserveAllScope::~SetPreserveAllScope() { + ShouldPreserveAllAttributes.setValue(PreviousValue); +} Index: llvm/unittests/Transforms/Utils/CMakeLists.txt =================================================================== --- llvm/unittests/Transforms/Utils/CMakeLists.txt +++ llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -14,6 +14,7 @@ CodeMoverUtilsTest.cpp FunctionComparatorTest.cpp IntegerDivisionTest.cpp + KnowledgeRetentionTest.cpp LocalTest.cpp LoopUtilsTest.cpp SizeOptsTest.cpp Index: llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp =================================================================== --- /dev/null +++ llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp @@ -0,0 +1,235 @@ +//===- KnowledgeRetention.h - utilities to preserve informations *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/KnowledgeRetention.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static void RunTest( + StringRef Head, StringRef Tail, + std::vector>> + &Tests) { + std::string IR; + IR.append(Head.begin(), Head.end()); + for (auto &Elem : Tests) + IR.append(Elem.first.begin(), Elem.first.end()); + IR.append(Tail.begin(), Tail.end()); + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("AssumeQueryAPI", errs()); + unsigned Idx = 0; + for (Instruction &I : (*Mod->getFunction("test")->begin())) { + if (Idx < Tests.size()) + Tests[Idx].second(&I); + Idx++; + } +} + +TEST(AssumeQueryAPI, Basic) { + StringRef Head = + "declare void @llvm.assume(i1)\n" + "declare void @func(i32*, i32*)\n" + "declare void @func1(i32*, i32*, i32*, i32*)\n" + "declare void @func_many(i32*) \"no-jump-tables\" nounwind " + "\"less-precise-fpmad\" willreturn norecurse\n" + "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) {\n"; + StringRef Tail = "ret void\n" + "}"; + std::vector>> + Tests; + Tests.push_back(std::make_pair( + "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align 8 noalias %P1)\n", + [](Instruction *I) { + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(0), + Attribute::AttrKind::NonNull)); + uint64_t ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(0), Attribute::AttrKind::Dereferenceable)); + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(0), Attribute::AttrKind::Dereferenceable, + &ArgVal, AssumeQuery::Highest)); + ASSERT_EQ(ArgVal, 16U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(0), Attribute::AttrKind::Alignment)); + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(0), Attribute::AttrKind::Alignment, + &ArgVal, AssumeQuery::Highest)); + ASSERT_EQ(ArgVal, 4U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(1), Attribute::AttrKind::Alignment)); + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(1), Attribute::AttrKind::Alignment, + &ArgVal, AssumeQuery::Highest)); + ASSERT_EQ(ArgVal, 8U); + ASSERT_FALSE(hasAttributeInAssume( + *Assume, I->getOperand(1), Attribute::AttrKind::NonNull)); + ASSERT_FALSE(hasAttributeInAssume( + *Assume, I->getOperand(0), Attribute::AttrKind::NoAlias)); + ASSERT_FALSE(hasAttributeInAssume( + *Assume, I->getOperand(1), Attribute::AttrKind::Dereferenceable)); + })); + Tests.push_back(std::make_pair( + "call void @func1(i32* nonnull align 32 dereferenceable(48) %P, i32* nonnull " + "align 8 dereferenceable(28) %P, i32* nonnull align 64 dereferenceable(4) " + "%P, i32* nonnull align 16 dereferenceable(12) %P)\n", + [](Instruction *I) { + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + uint64_t ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(0), + Attribute::AttrKind::NonNull)); + ASSERT_EQ(I->getOperand(0), I->getOperand(1)); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(0), Attribute::AttrKind::Dereferenceable, + &ArgVal, AssumeQuery::Highest)); + ASSERT_EQ(ArgVal, 48U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(0), Attribute::AttrKind::Alignment, + &ArgVal, AssumeQuery::Highest)); + ASSERT_EQ(ArgVal, 64U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(1), Attribute::AttrKind::Alignment, + &ArgVal, AssumeQuery::Highest)); + ASSERT_EQ(ArgVal, 64U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(0), Attribute::AttrKind::Dereferenceable, + &ArgVal, AssumeQuery::Lowest)); + ASSERT_EQ(ArgVal, 4U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(0), Attribute::AttrKind::Alignment, + &ArgVal, AssumeQuery::Lowest)); + ASSERT_EQ(ArgVal, 8U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(1), Attribute::AttrKind::Alignment, + &ArgVal, AssumeQuery::Lowest)); + ASSERT_EQ(ArgVal, 8U); + })); + Tests.push_back(std::make_pair( + "call void @func_many(i32* align 8 %P1) cold\n", [](Instruction *I) { + SetPreserveAllScope S(true); + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + ASSERT_TRUE(hasAttributeInAssume(*Assume, nullptr, + Attribute::AttrKind::NoUnwind)); + ASSERT_TRUE(hasAttributeInAssume(*Assume, nullptr, + Attribute::AttrKind::NoRecurse)); + ASSERT_TRUE(hasAttributeInAssume(*Assume, nullptr, + Attribute::AttrKind::WillReturn)); + ASSERT_TRUE( + hasAttributeInAssume(*Assume, nullptr, Attribute::AttrKind::Cold)); + ASSERT_TRUE(hasAttributeInAssume(*Assume, nullptr, "no-jump-tables")); + ASSERT_TRUE(hasAttributeInAssume(*Assume, nullptr, "less-precise-fpmad")); + })); + Tests.push_back( + std::make_pair("call void @llvm.assume(i1 true)\n", [](Instruction *I) { + CallInst *Assume = cast(I); + ASSERT_FALSE(hasAttributeInAssume(*Assume, nullptr, + Attribute::AttrKind::NoUnwind)); + ASSERT_FALSE(hasAttributeInAssume(*Assume, nullptr, + Attribute::AttrKind::NoRecurse)); + ASSERT_FALSE(hasAttributeInAssume(*Assume, nullptr, + Attribute::AttrKind::WillReturn)); + ASSERT_FALSE( + hasAttributeInAssume(*Assume, nullptr, Attribute::AttrKind::Cold)); + })); + Tests.push_back(std::make_pair( + "call void @func1(i32* readnone align 32 dereferenceable(48) noalias %P, i32* " + "align 8 dereferenceable(28) %P1, i32* align 64 " + "dereferenceable(4) " + "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n", + [](Instruction *I) { + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + uint64_t ArgVal; + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(0), Attribute::AttrKind::Alignment, &ArgVal)); + ASSERT_EQ(ArgVal, 32U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume( + *Assume, I->getOperand(0), Attribute::AttrKind::Dereferenceable, &ArgVal)); + ASSERT_EQ(ArgVal, 48U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(0), + Attribute::AttrKind::Dereferenceable, + &ArgVal)); + ASSERT_EQ(ArgVal, 48U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(1), + Attribute::AttrKind::Dereferenceable, + &ArgVal)); + ASSERT_EQ(ArgVal, 28U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(1), + Attribute::AttrKind::Alignment, + &ArgVal)); + ASSERT_EQ(ArgVal, 8U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(2), + Attribute::AttrKind::Alignment, + &ArgVal)); + ASSERT_EQ(ArgVal, 64U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(2), + Attribute::AttrKind::Dereferenceable, + &ArgVal)); + ASSERT_EQ(ArgVal, 4U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(3), + Attribute::AttrKind::Alignment, + &ArgVal)); + ASSERT_EQ(ArgVal, 16U); + ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(3), + Attribute::AttrKind::Dereferenceable, + &ArgVal)); + ASSERT_EQ(ArgVal, 12U); + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(0), + Attribute::AttrKind::ReadNone)); + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(0), + Attribute::AttrKind::NoAlias)); + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(0), + Attribute::AttrKind::Dereferenceable)); + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(1), + Attribute::AttrKind::Dereferenceable)); + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(2), + Attribute::AttrKind::Dereferenceable)); + ASSERT_TRUE(hasAttributeInAssume(*Assume, I->getOperand(3), + Attribute::AttrKind::Dereferenceable)); + ASSERT_FALSE(hasAttributeInAssume(*Assume, I->getOperand(1), + Attribute::AttrKind::ReadNone)); + ASSERT_FALSE(hasAttributeInAssume(*Assume, I->getOperand(1), + Attribute::AttrKind::NoAlias)); + ASSERT_FALSE(hasAttributeInAssume(*Assume, I->getOperand(2), + Attribute::AttrKind::ReadNone)); + ASSERT_FALSE(hasAttributeInAssume(*Assume, I->getOperand(2), + Attribute::AttrKind::NoAlias)); + ASSERT_FALSE(hasAttributeInAssume(*Assume, I->getOperand(3), + Attribute::AttrKind::ReadNone)); + ASSERT_FALSE(hasAttributeInAssume(*Assume, I->getOperand(3), + Attribute::AttrKind::NoAlias)); + })); + RunTest(Head, Tail, Tests); +}