diff --git a/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h b/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h --- a/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h +++ b/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h @@ -16,6 +16,7 @@ #ifndef LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H #define LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H +#include "llvm/IR/Attributes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/PassManager.h" @@ -30,6 +31,41 @@ return BuildAssumeFromInst(I, I->getModule()); } +/// It is possible to have multiple Value for the argument of an attribute in +/// the same llvm.assume on the same llvm::Value. This is rare but need to be +/// dealt with. +enum class AssumeQuery { + Highest, ///< Take the highest value available. + Lowest, ///< Take the lowest value available. +}; + +/// Query the operand bundle of an llvm.assume to find a single attribute of +/// the specified kind applied on a specified Value. +/// +/// This has a non-constant complexity. It should only be used when a single +/// attribute is going to be queried. +/// +/// Return true iff the queried attribute was found. +/// If ArgVal is set. the argument will be stored to ArgVal. +bool hasAttributeInAssume(CallInst &AssumeCI, Value *IsOn, StringRef AttrName, + uint64_t *ArgVal = nullptr, + AssumeQuery AQR = AssumeQuery::Highest); +inline bool hasAttributeInAssume(CallInst &AssumeCI, Value *IsOn, + Attribute::AttrKind Kind, + uint64_t *ArgVal = nullptr, + AssumeQuery AQR = AssumeQuery::Highest) { + return hasAttributeInAssume( + AssumeCI, IsOn, Attribute::getNameFromAttrKind(Kind), ArgVal, AQR); +} + +/// TODO: Add an function to create/fill a map from the bundle when users intend +/// to make many different queries on the same bundles. to be used for example +/// in the Attributor. + +//===----------------------------------------------------------------------===// +// Utilities for testing +//===----------------------------------------------------------------------===// + /// This pass will try to build an llvm.assume for every instruction in the /// function. Its main purpose is testing. struct AssumeBuilderPass : public PassInfoMixin { diff --git a/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp b/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp --- a/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp +++ b/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp @@ -15,13 +15,13 @@ using namespace llvm; -namespace { - cl::opt ShouldPreserveAllAttributes( "assume-preserve-all", cl::init(false), cl::Hidden, cl::desc("enable preservation of all attrbitues. even those that are " "unlikely to be usefull")); +namespace { + struct AssumedKnowledge { const char *Name; Value *Argument; @@ -59,22 +59,33 @@ namespace { +/// Index of elements in the operand bundle. +/// If the element exist it is guaranteed to be what is specified in this enum +/// but it may not exist. +enum BundleOpInfoElem { + BOIE_WasOn = 0, + BOIE_Argument = 1, +}; + /// Deterministically compare OperandBundleDef. /// The ordering is: -/// - by the name of the attribute, (doesn't change) -/// - then by the Value of the argument, (doesn't change) +/// - by the attribute's name aka operand bundle tag, (doesn't change) +/// - then by the numeric Value of the argument, (doesn't change) /// - lastly by the Name of the current Value it WasOn. (may change) /// This order is deterministic and allows looking for the right kind of /// attribute with binary search. However finding the right WasOn needs to be -/// done via linear search because values can get remplaced. +/// done via linear search because values can get replaced. bool isLowerOpBundle(const OperandBundleDef &LHS, const OperandBundleDef &RHS) { auto getTuple = [](const OperandBundleDef &Op) { return std::make_tuple( Op.getTag(), - Op.input_size() < 2 + Op.input_size() <= BOIE_Argument ? 0 - : cast(*std::next(Op.input_begin()))->getZExtValue(), - Op.input_size() < 1 ? StringRef("") : (*Op.input_begin())->getName()); + : cast(*(Op.input_begin() + BOIE_Argument)) + ->getZExtValue(), + Op.input_size() <= BOIE_WasOn + ? StringRef("") + : (*(Op.input_begin() + BOIE_WasOn))->getName()); }; return getTuple(LHS) < getTuple(RHS); } @@ -160,6 +171,88 @@ return Builder.build(); } +#ifndef NDEBUG + +static bool isExistingAttribute(StringRef Name) { + return StringSwitch(Name) +#define GET_ATTR_NAMES +#define ATTRIBUTE_ALL(ENUM_NAME, DISPLAY_NAME) .Case(#DISPLAY_NAME, true) +#include "llvm/IR/Attributes.inc" + .Default(false); +} + +#endif + +bool llvm::hasAttributeInAssume(CallInst &AssumeCI, Value *IsOn, + StringRef AttrName, uint64_t *ArgVal, + AssumeQuery AQR) { + IntrinsicInst &Assume = cast(AssumeCI); + assert(Assume.getIntrinsicID() == Intrinsic::assume && + "this function is intended to be used on llvm.assume"); + assert(isExistingAttribute(AttrName) && "this attribute doesn't exist"); + assert((ArgVal == nullptr || Attribute::doesAttrKindHaveArgument( + Attribute::getAttrKindFromName(AttrName))) && + "requested value for an attribute that has no argument"); + if (Assume.bundle_op_infos().empty()) + return false; + + CallInst::bundle_op_iterator Lookup; + + /// The right attribute can be found by binary search. After this finding the + /// right WasOn needs to be done via linear search. + /// Element have been ordered by argument value so the first we find is the + /// one we need. + if (AQR == AssumeQuery::Lowest) + Lookup = + llvm::lower_bound(Assume.bundle_op_infos(), AttrName, + [](const CallBase::BundleOpInfo &BOI, StringRef RHS) { + assert(isExistingAttribute(BOI.Tag->getKey()) && + "this attribute doesn't exist"); + return BOI.Tag->getKey() < RHS; + }); + else + Lookup = std::prev( + llvm::upper_bound(Assume.bundle_op_infos(), AttrName, + [](StringRef LHS, const CallBase::BundleOpInfo &BOI) { + assert(isExistingAttribute(BOI.Tag->getKey()) && + "this attribute doesn't exist"); + return LHS < BOI.Tag->getKey(); + })); + + auto getValueFromBundleOpInfo = [&Assume](const CallBase::BundleOpInfo &BOI, + unsigned Idx) { + assert(BOI.End - BOI.Begin > Idx && "index out of range"); + return (Assume.op_begin() + BOI.Begin + Idx)->get(); + }; + + if (Lookup == Assume.bundle_op_info_end() || + Lookup->Tag->getKey() != AttrName) + return false; + if (IsOn) { + if (Lookup->End - Lookup->Begin < BOIE_WasOn) + return false; + while (true) { + if (Lookup == Assume.bundle_op_info_end() || + Lookup->Tag->getKey() != AttrName) + return false; + if (getValueFromBundleOpInfo(*Lookup, BOIE_WasOn) == IsOn) + break; + if (AQR == AssumeQuery::Highest && + Lookup == Assume.bundle_op_info_begin()) + return false; + Lookup = Lookup + (AQR == AssumeQuery::Lowest ? 1 : -1); + } + } + + if (Lookup->End - Lookup->Begin < BOIE_Argument) + return true; + if (ArgVal) + *ArgVal = + cast(getValueFromBundleOpInfo(*Lookup, BOIE_Argument)) + ->getZExtValue(); + return true; +} + PreservedAnalyses AssumeBuilderPass::run(Function &F, FunctionAnalysisManager &AM) { for (Instruction &I : instructions(F)) diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt --- a/llvm/unittests/Transforms/Utils/CMakeLists.txt +++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -15,6 +15,7 @@ CodeMoverUtilsTest.cpp FunctionComparatorTest.cpp IntegerDivisionTest.cpp + KnowledgeRetentionTest.cpp LocalTest.cpp LoopRotationUtilsTest.cpp LoopUtilsTest.cpp diff --git a/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp b/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp @@ -0,0 +1,215 @@ +//===- KnowledgeRetention.h - utilities to preserve informations *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/KnowledgeRetention.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/CommandLine.h" +#include "gtest/gtest.h" + +using namespace llvm; + +extern cl::opt ShouldPreserveAllAttributes; + +static void RunTest( + StringRef Head, StringRef Tail, + std::vector>> + &Tests) { + std::string IR; + IR.append(Head.begin(), Head.end()); + for (auto &Elem : Tests) + IR.append(Elem.first.begin(), Elem.first.end()); + IR.append(Tail.begin(), Tail.end()); + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("AssumeQueryAPI", errs()); + unsigned Idx = 0; + for (Instruction &I : (*Mod->getFunction("test")->begin())) { + if (Idx < Tests.size()) + Tests[Idx].second(&I); + Idx++; + } +} + +void AssertMatchesExactlyAttributes(CallInst *Assume, Value *WasOn, + StringRef AttrToMatch) { + Regex Reg(AttrToMatch); + SmallVector Matches; + for (StringRef Attr : { +#define GET_ATTR_NAMES +#define ATTRIBUTE_ALL(ENUM_NAME, DISPLAY_NAME) StringRef(#DISPLAY_NAME), +#include "llvm/IR/Attributes.inc" + }) { + bool ShouldHaveAttr = Reg.match(Attr, &Matches) && Matches[0] == Attr; + if (ShouldHaveAttr != hasAttributeInAssume(*Assume, WasOn, Attr)) + ASSERT_TRUE(false); + } +} + +void AssertHasTheRightValue(CallInst *Assume, Value *WasOn, + Attribute::AttrKind Kind, unsigned Value, bool Both, + AssumeQuery AQ = AssumeQuery::Highest) { + if (!Both) { + uint64_t ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume(*Assume, WasOn, Kind, &ArgVal, AQ)); + ASSERT_EQ(ArgVal, Value); + return; + } + uint64_t ArgValLow = 0; + uint64_t ArgValHigh = 0; + bool ResultLow = hasAttributeInAssume(*Assume, WasOn, Kind, &ArgValLow, + AssumeQuery::Lowest); + bool ResultHigh = hasAttributeInAssume(*Assume, WasOn, Kind, &ArgValHigh, + AssumeQuery::Highest); + if (ResultLow != ResultHigh) + ASSERT_TRUE(false); + if (ArgValLow != Value || ArgValLow != ArgValHigh) + ASSERT_EQ(ArgValLow, Value); +} + +TEST(AssumeQueryAPI, Basic) { + StringRef Head = + "declare void @llvm.assume(i1)\n" + "declare void @func(i32*, i32*)\n" + "declare void @func1(i32*, i32*, i32*, i32*)\n" + "declare void @func_many(i32*) \"no-jump-tables\" nounwind " + "\"less-precise-fpmad\" willreturn norecurse\n" + "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) {\n"; + StringRef Tail = "ret void\n" + "}"; + std::vector>> + Tests; + Tests.push_back(std::make_pair( + "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align " + "8 noalias %P1)\n", + [](Instruction *I) { + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + AssertMatchesExactlyAttributes(Assume, I->getOperand(0), + "(nonnull|align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(1), + "(noalias|align)"); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Dereferenceable, 16, true); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Alignment, 4, true); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Alignment, 4, true); + })); + Tests.push_back(std::make_pair( + "call void @func1(i32* nonnull align 32 dereferenceable(48) %P, i32* " + "nonnull " + "align 8 dereferenceable(28) %P, i32* nonnull align 64 " + "dereferenceable(4) " + "%P, i32* nonnull align 16 dereferenceable(12) %P)\n", + [](Instruction *I) { + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + AssertMatchesExactlyAttributes(Assume, I->getOperand(0), + "(nonnull|align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(1), + "(nonnull|align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(2), + "(nonnull|align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(3), + "(nonnull|align|dereferenceable)"); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Dereferenceable, 48, false, + AssumeQuery::Highest); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Alignment, 64, false, + AssumeQuery::Highest); + AssertHasTheRightValue(Assume, I->getOperand(1), + Attribute::AttrKind::Alignment, 64, false, + AssumeQuery::Highest); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Dereferenceable, 4, false, + AssumeQuery::Lowest); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Alignment, 8, false, + AssumeQuery::Lowest); + AssertHasTheRightValue(Assume, I->getOperand(1), + Attribute::AttrKind::Alignment, 8, false, + AssumeQuery::Lowest); + })); + Tests.push_back(std::make_pair( + "call void @func_many(i32* align 8 %P1) cold\n", [](Instruction *I) { + ShouldPreserveAllAttributes.setValue(true); + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + AssertMatchesExactlyAttributes( + Assume, nullptr, + "(align|no-jump-tables|less-precise-fpmad|" + "nounwind|norecurse|willreturn|cold)"); + ShouldPreserveAllAttributes.setValue(false); + })); + Tests.push_back( + std::make_pair("call void @llvm.assume(i1 true)\n", [](Instruction *I) { + CallInst *Assume = cast(I); + AssertMatchesExactlyAttributes(Assume, nullptr, ""); + })); + Tests.push_back(std::make_pair( + "call void @func1(i32* readnone align 32 " + "dereferenceable(48) noalias %P, i32* " + "align 8 dereferenceable(28) %P1, i32* align 64 " + "dereferenceable(4) " + "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n", + [](Instruction *I) { + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + AssertMatchesExactlyAttributes( + Assume, I->getOperand(0), + "(readnone|align|dereferenceable|noalias)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(1), + "(align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(2), + "(align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(3), + "(nonnull|align|dereferenceable)"); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Alignment, 32, true); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Dereferenceable, 48, true); + AssertHasTheRightValue(Assume, I->getOperand(1), + Attribute::AttrKind::Dereferenceable, 28, true); + AssertHasTheRightValue(Assume, I->getOperand(1), + Attribute::AttrKind::Alignment, 8, true); + AssertHasTheRightValue(Assume, I->getOperand(2), + Attribute::AttrKind::Alignment, 64, true); + AssertHasTheRightValue(Assume, I->getOperand(2), + Attribute::AttrKind::Dereferenceable, 4, true); + AssertHasTheRightValue(Assume, I->getOperand(3), + Attribute::AttrKind::Alignment, 16, true); + AssertHasTheRightValue(Assume, I->getOperand(3), + Attribute::AttrKind::Dereferenceable, 12, true); + })); + + /// Keep this test last as it modifies the function. + Tests.push_back(std::make_pair( + "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align " + "8 noalias %P1)\n", + [](Instruction *I) { + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + Value *New = I->getFunction()->getArg(3); + Value *Old = I->getOperand(0); + AssertMatchesExactlyAttributes(Assume, New, ""); + AssertMatchesExactlyAttributes(Assume, Old, + "(nonnull|align|dereferenceable)"); + Old->replaceAllUsesWith(New); + AssertMatchesExactlyAttributes(Assume, New, + "(nonnull|align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, Old, ""); + })); + RunTest(Head, Tail, Tests); +}