Index: llvm/include/llvm/ADT/STLExtras.h =================================================================== --- llvm/include/llvm/ADT/STLExtras.h +++ llvm/include/llvm/ADT/STLExtras.h @@ -1539,35 +1539,47 @@ /// Return true if the sequence [Begin, End) has exactly N items. Runs in O(N) /// time. Not meant for use with random-access iterators. -template +/// Can optionally take a predicate to filer lazily some items. +template ()) &)> bool hasNItems( IterTy &&Begin, IterTy &&End, unsigned N, + Pred &&ShouldBeCounted = + [](const decltype(*std::declval()) &) { return true; }, typename std::enable_if< !std::is_same< typename std::iterator_traits::type>::iterator_category, std::random_access_iterator_tag>::value, void>::type * = nullptr) { - for (; N; --N, ++Begin) + for (; N; ++Begin) { if (Begin == End) return false; // Too few. + N -= ShouldBeCounted(*Begin); + } return Begin == End; } /// Return true if the sequence [Begin, End) has N or more items. Runs in O(N) /// time. Not meant for use with random-access iterators. -template +/// Can optionally take a predicate to filer lazily some items. +template ()) &)> bool hasNItemsOrMore( IterTy &&Begin, IterTy &&End, unsigned N, + Pred &&ShouldBeCounted = + [](const decltype(*std::declval()) &) { return true; }, typename std::enable_if< !std::is_same< typename std::iterator_traits::type>::iterator_category, std::random_access_iterator_tag>::value, void>::type * = nullptr) { - for (; N; --N, ++Begin) + for (; N; ++Begin) { if (Begin == End) return false; // Too few. + N -= ShouldBeCounted(*Begin); + } return true; } Index: llvm/include/llvm/IR/User.h =================================================================== --- llvm/include/llvm/IR/User.h +++ llvm/include/llvm/IR/User.h @@ -218,6 +218,11 @@ NumUserOperands = NumOps; } + /// A droppable user is a user for which uses can be dropped without affecting + /// correctness and should be dropped rather than preventing a transformation + /// from happening. + bool isDroppable() const; + // --------------------------------------------------------------------------- // Operand Iterator interface... // Index: llvm/include/llvm/IR/Value.h =================================================================== --- llvm/include/llvm/IR/Value.h +++ llvm/include/llvm/IR/Value.h @@ -444,6 +444,34 @@ /// This is logically equivalent to getNumUses() >= N. bool hasNUsesOrMore(unsigned N) const; + /// Return true if there is exactly one user of this value that cannot be + /// dropped. + /// + /// This is specialized because it is a common request and does not require + /// traversing the whole use list. + Use *getSingleUndropableUse(); + + /// Return true if there this value. + /// + /// This is specialized because it is a common request and does not require + /// traversing the whole use list. + bool hasNUndropableUses(unsigned N) const; + + /// Return true if this value has N users or more. + /// + /// This is logically equivalent to getNumUses() >= N. + bool hasNUndropableUsesOrMore(unsigned N) const; + + /// Remove every uses that can safely be removed. + /// + /// This will remove for example uses in llvm.assume. + /// This should be used when performing want to perform a tranformation but + /// some Droppable uses pervent it. + /// This function optionally takes a filter to only remove some droppable + /// uses. + void dropDroppableUses(llvm::function_ref ShouldDrop = + [](const Use *) { return true; }); + /// Check if this value is used in the specified basic block. bool isUsedInBasicBlock(const BasicBlock *BB) const; Index: llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h =================================================================== --- llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h +++ llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h @@ -16,6 +16,7 @@ #ifndef LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H #define LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H +#include "llvm/IR/Attributes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/PassManager.h" @@ -26,16 +27,64 @@ /// If no information derived from \p I, this call returns null. /// The returned instruction is not inserted anywhere. CallInst *BuildAssumeFromInst(const Instruction *I, Module *M); -CallInst *BuildAssumeFromInst(Instruction *I) { +inline CallInst *BuildAssumeFromInst(Instruction *I) { return BuildAssumeFromInst(I, I->getModule()); } +/// Query the operand bundle of an llvm.assume to find a single attribute of +/// the specified kind applied on a specified Value. +/// +/// This has a non-constant complexity. It should only be used when a single +/// attribute is going to be queried. +/// +/// Return true iff the queried attribute was found. +/// If ArgVal is set. the argument will be stored to ArgVal. +/// +/// It is possible to have multiple Value for the argument of an attribute in +/// the same llvm.assume on the same llvm::Value. This is rare but need to be +/// dealt with. +enum class AssumeQuery { + Highest, ///< Take the highest value available. + Lowest, ///< Take the lowest value available. +}; +bool hasAttributeInAssume( + CallInst &AssumeCI, Value *IsOn, StringRef AttrName, + uint64_t *ArgVal = nullptr, + AssumeQuery AQR = AssumeQuery::Highest); +inline bool hasAttributeInAssume( + CallInst &AssumeCI, Value *IsOn, Attribute::AttrKind Kind, + uint64_t *ArgVal = nullptr, + AssumeQuery AQR = AssumeQuery::Highest) { + return hasAttributeInAssume( + AssumeCI, IsOn, Attribute::getNameFromAttrKind(Kind), ArgVal, AQR); +} + +/// Check if the assume still holds any information. +/// If this returns false this assume can be dropped without losing anyting. +bool isAssumeWithEmptyBundle(CallInst &Assume); + +/// TODO: Add an function to create/fill a map from the bundle when users intend +/// to make many different queries on the same bundles. to be used for example +/// in the Attributor. + +//===----------------------------------------------------------------------===// +// Utilities for testing +//===----------------------------------------------------------------------===// + /// This pass will try to build an llvm.assume for every instruction in the /// function. Its main purpose is testing. struct AssumeBuilderPass : public PassInfoMixin { PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); }; +/// This struct is used to unitest the ShouldPreserveAllAttributes. +/// It is an RAII struct to change the value of teh flag. +struct SetPreserveAllScope { + bool PreviousValue; + SetPreserveAllScope(bool Value); + ~SetPreserveAllScope(); +}; + } // namespace llvm #endif Index: llvm/lib/IR/AsmWriter.cpp =================================================================== --- llvm/lib/IR/AsmWriter.cpp +++ llvm/lib/IR/AsmWriter.cpp @@ -2552,6 +2552,10 @@ for (unsigned i = 0, e = Call->getNumOperandBundles(); i != e; ++i) { OperandBundleUse BU = Call->getOperandBundleAt(i); + /// Bundles containing dropped values should be ignored. + if (any_of(BU.Inputs, [](const Use &U) { return !U.get(); })) + continue; + if (!FirstBundle) Out << ", "; FirstBundle = false; Index: llvm/lib/IR/User.cpp =================================================================== --- llvm/lib/IR/User.cpp +++ llvm/lib/IR/User.cpp @@ -9,6 +9,7 @@ #include "llvm/IR/User.h" #include "llvm/IR/Constant.h" #include "llvm/IR/GlobalValue.h" +#include "llvm/IR/IntrinsicInst.h" namespace llvm { class BasicBlock; @@ -105,6 +106,13 @@ reinterpret_cast(DI) - DI->SizeInBytes, DI->SizeInBytes); } +/// The definition is in the header. +bool User::isDroppable() const { + if (const auto *Intr = dyn_cast(this)) + return Intr->getIntrinsicID() == Intrinsic::assume; + return false; +} + //===----------------------------------------------------------------------===// // User operator new Implementations //===----------------------------------------------------------------------===// Index: llvm/lib/IR/Value.cpp =================================================================== --- llvm/lib/IR/Value.cpp +++ llvm/lib/IR/Value.cpp @@ -137,6 +137,51 @@ return hasNItemsOrMore(use_begin(), use_end(), N); } +static bool isDroppableUser(const User *U) { return U->isDroppable(); } + +Use *Value::getSingleUndropableUse() { + Use *Result = nullptr; + for (Use &U : uses()) { + if (!U.getUser()->isDroppable()) { + if (Result) + return nullptr; + Result = &U; + } + } + return Result; +} + +bool Value::hasNUndropableUses(unsigned int N) const { + return hasNItems(user_begin(), user_end(), N, isDroppableUser); +} + +bool Value::hasNUndropableUsesOrMore(unsigned int N) const { + return hasNItemsOrMore(user_begin(), user_end(), N, isDroppableUser); +} + +void Value::dropDroppableUses( + llvm::function_ref ShouldDrop) { + SmallVector ToBeEdited; + for (Use &U : uses()) + if (U.getUser()->isDroppable() && ShouldDrop(&U)) + ToBeEdited.push_back(&U); + for (Use *U : ToBeEdited) { + Value *V = U->get(); + (void)V; + if (auto *Assume = dyn_cast(U->getUser())) { + assert(Assume->getIntrinsicID() == Intrinsic::assume); + unsigned OpNo = U->getOperandNo(); + if (OpNo == 0) + Assume->setOperand(0, ConstantInt::getTrue(Assume->getContext())); + else + Assume->setOperand(OpNo, nullptr); + } else + llvm_unreachable("add special handeling here"); + assert(V != U->get() && "the value should have been changed"); + U->removeFromList(); + } +} + bool Value::isUsedInBasicBlock(const BasicBlock *BB) const { // This can be computed either by scanning the instructions in BB, or by // scanning the use list of this Value. Both lists can be very long, but Index: llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -65,6 +65,7 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/Utils/KnowledgeRetention.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SimplifyLibCalls.h" #include @@ -4119,7 +4120,7 @@ // then this one is redundant, and should be removed. KnownBits Known(1); computeKnownBits(IIOperand, Known, 0, II); - if (Known.isAllOnes()) + if (Known.isAllOnes() && isAssumeWithEmptyBundle(*II)) return eraseInstFromFunction(*II); // Update the cache of affected values for this assumption (we might be Index: llvm/lib/Transforms/InstCombine/InstructionCombining.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -3195,7 +3195,7 @@ /// instruction past all of the instructions between it and the end of its /// block. static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { - assert(I->hasOneUse() && "Invariants didn't hold!"); + assert(I->getSingleUndropableUse() && "Invariants didn't hold!"); BasicBlock *SrcBlock = I->getParent(); // Cannot move control-flow-involving, volatile loads, vaarg, etc. @@ -3228,6 +3228,15 @@ if (Scan->mayWriteToMemory()) return false; } + + I->dropDroppableUses([DestBlock](const Use *U) { + if (auto *I = dyn_cast(U->getUser())) + return I->getParent() != DestBlock; + return true; + }); + /// FIXME: We could remove droppable uses that are not dominated by + /// the new position. + BasicBlock::iterator InsertPos = DestBlock->getFirstInsertionPt(); I->moveBefore(&*InsertPos); ++NumSunkInst; @@ -3337,44 +3346,46 @@ } // See if we can trivially sink this instruction to a successor basic block. - if (EnableCodeSinking && I->hasOneUse()) { - BasicBlock *BB = I->getParent(); - Instruction *UserInst = cast(*I->user_begin()); - BasicBlock *UserParent; - - // Get the block the use occurs in. - if (PHINode *PN = dyn_cast(UserInst)) - UserParent = PN->getIncomingBlock(*I->use_begin()); - else - UserParent = UserInst->getParent(); - - if (UserParent != BB) { - bool UserIsSuccessor = false; - // See if the user is one of our successors. - for (succ_iterator SI = succ_begin(BB), E = succ_end(BB); SI != E; ++SI) - if (*SI == UserParent) { - UserIsSuccessor = true; - break; - } + if (EnableCodeSinking) + if (Use *SingleUse = I->getSingleUndropableUse()) { + BasicBlock *BB = I->getParent(); + Instruction *UserInst = cast(SingleUse->getUser()); + BasicBlock *UserParent; + + // Get the block the use occurs in. + if (PHINode *PN = dyn_cast(UserInst)) + UserParent = PN->getIncomingBlock(*I->use_begin()); + else + UserParent = UserInst->getParent(); + + if (UserParent != BB) { + bool UserIsSuccessor = false; + // See if the user is one of our successors. + for (succ_iterator SI = succ_begin(BB), E = succ_end(BB); SI != E; + ++SI) + if (*SI == UserParent) { + UserIsSuccessor = true; + break; + } - // If the user is one of our immediate successors, and if that successor - // only has us as a predecessors (we'd have to split the critical edge - // otherwise), we can keep going. - if (UserIsSuccessor && UserParent->getUniquePredecessor()) { - // Okay, the CFG is simple enough, try to sink this instruction. - if (TryToSinkInstruction(I, UserParent)) { - LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); - MadeIRChange = true; - // We'll add uses of the sunk instruction below, but since sinking - // can expose opportunities for it's *operands* add them to the - // worklist - for (Use &U : I->operands()) - if (Instruction *OpI = dyn_cast(U.get())) - Worklist.Add(OpI); + // If the user is one of our immediate successors, and if that + // successor only has us as a predecessors (we'd have to split the + // critical edge otherwise), we can keep going. + if (UserIsSuccessor && UserParent->getUniquePredecessor()) { + // Okay, the CFG is simple enough, try to sink this instruction. + if (TryToSinkInstruction(I, UserParent)) { + LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); + MadeIRChange = true; + // We'll add uses of the sunk instruction below, but since sinking + // can expose opportunities for it's *operands* add them to the + // worklist + for (Use &U : I->operands()) + if (Instruction *OpI = dyn_cast(U.get())) + Worklist.Add(OpI); + } } } } - } // Now that we have an instruction, try combining it to simplify it. Builder.SetInsertPoint(I); Index: llvm/lib/Transforms/Utils/KnowledgeRetention.cpp =================================================================== --- llvm/lib/Transforms/Utils/KnowledgeRetention.cpp +++ llvm/lib/Transforms/Utils/KnowledgeRetention.cpp @@ -55,10 +55,18 @@ namespace { +/// Index of elements in the operand bundle. +/// If the element exist it is guaranteed to be what is specified in this enum +/// but it may not exist. +enum BundleOpInfoElem { + BOIE_WasOn = 0, + BOIE_Argument = 1, +}; + /// Deterministically compare OperandBundleDef. /// The ordering is: -/// - by the name of the attribute, (doesn't change) -/// - then by the Value of the argument, (doesn't change) +/// - by the attribute's name aka operand bundle tag, (doesn't change) +/// - then by the numeric Value of the argument, (doesn't change) /// - lastly by the Name of the current Value it WasOn. (may change) /// This order is deterministic and allows looking for the right kind of /// attribute with binary search. However finding the right WasOn needs to be @@ -67,10 +75,13 @@ auto getTuple = [](const OperandBundleDef &Op) { return std::make_tuple( Op.getTag(), - Op.input_size() < 2 + Op.input_size() <= BOIE_Argument ? 0 - : cast(*std::next(Op.input_begin()))->getZExtValue(), - Op.input_size() < 1 ? StringRef("") : (*Op.input_begin())->getName()); + : cast(*(Op.input_begin() + BOIE_Argument)) + ->getZExtValue(), + Op.input_size() <= BOIE_WasOn + ? StringRef("") + : (*(Op.input_begin() + BOIE_WasOn))->getName()); }; return getTuple(LHS) < getTuple(RHS); } @@ -156,6 +167,102 @@ return Builder.build(); } +Value *getElemFromBundleOpInfo(CallInst &Assume, + const CallBase::BundleOpInfo &BOI, + unsigned Idx) { + assert(BOI.End - BOI.Begin >= Idx && "index out of range"); + return (Assume.op_begin() + BOI.Begin + Idx)->get(); +}; + +#ifndef NDEBUG + +static bool isExistingAttribute(StringRef Name) { + return StringSwitch(Name) +#define GET_ATTR_NAMES +#define ATTRIBUTE_ALL(ENUM_NAME, DISPLAY_NAME) .Case(#DISPLAY_NAME, true) +#include "llvm/IR/Attributes.inc" + .Default(false); +} + +#endif + +bool llvm::hasAttributeInAssume(CallInst &AssumeCI, Value* IsOn, + StringRef AttrName, uint64_t *ArgVal, + AssumeQuery AQR) { + IntrinsicInst &Assume = cast(AssumeCI); + assert(Assume.getIntrinsicID() == Intrinsic::assume && + "this function is intended to be used on llvm.assume"); + assert(isExistingAttribute(AttrName) && "this attribute doesn't exist"); + assert((ArgVal == nullptr || + Attribute::getAttrKindFromName(AttrName) == Attribute::None || + Attribute::doesAttrKindHaveArgument( + Attribute::getAttrKindFromName(AttrName))) && + "requested value for an attribute that has no argument"); + if (Assume.bundle_op_infos().empty()) + return false; + + CallInst::bundle_op_iterator Lookup; + + /// The right attribute can be found by binary search. After this finding the + /// right WasOn needs to be done via linear search. + /// Element have been ordered by argument value so the first we find is the + /// one we need. + if (AQR == AssumeQuery::Lowest) + Lookup = + llvm::lower_bound(Assume.bundle_op_infos(), AttrName, + [](const CallBase::BundleOpInfo &BOI, StringRef RHS) { + assert(isExistingAttribute(BOI.Tag->getKey()) && + "this attribute doesn't exist"); + return BOI.Tag->getKey() < RHS; + }); + else + Lookup = std::prev( + llvm::upper_bound(Assume.bundle_op_infos(), AttrName, + [](StringRef LHS, const CallBase::BundleOpInfo &BOI) { + assert(isExistingAttribute(BOI.Tag->getKey()) && + "this attribute doesn't exist"); + return LHS < BOI.Tag->getKey(); + })); + + if (Lookup == Assume.bundle_op_info_end() || + Lookup->Tag->getKey() != AttrName) + return false; + if (IsOn) { + if (Lookup->End - Lookup->Begin < BOIE_WasOn) + return false; + while (true) { + if (Lookup == Assume.bundle_op_info_end() || + Lookup->Tag->getKey() != AttrName) + return false; + if (getElemFromBundleOpInfo(Assume, *Lookup, BOIE_WasOn) == IsOn) + break; + if (AQR == AssumeQuery::Highest && + Lookup == Assume.bundle_op_info_begin()) + return false; + Lookup = Lookup + (AQR == AssumeQuery::Lowest ? 1 : -1); + } + } + + if (Lookup->End - Lookup->Begin < BOIE_Argument) + return true; + if (ArgVal) + *ArgVal = cast( + getElemFromBundleOpInfo(Assume, *Lookup, BOIE_Argument)) + ->getZExtValue(); + return true; +} + +bool llvm::isAssumeWithEmptyBundle(CallInst &CI) { + IntrinsicInst &Assume = cast(CI); + assert(Assume.getIntrinsicID() == Intrinsic::assume && + "this function is intended to be used on llvm.assume"); + return none_of( + Assume.bundle_op_infos(), [&Assume](const CallBase::BundleOpInfo &BOI) { + return BOI.Begin == BOI.End || + getElemFromBundleOpInfo(Assume, BOI, BOIE_WasOn) != nullptr; + }); +} + PreservedAnalyses AssumeBuilderPass::run(Function &F, FunctionAnalysisManager &AM) { for (Instruction &I : instructions(F)) @@ -163,3 +270,11 @@ Assume->insertBefore(&I); return PreservedAnalyses::all(); } + +SetPreserveAllScope::SetPreserveAllScope(bool Value) { + PreviousValue = ShouldPreserveAllAttributes; + ShouldPreserveAllAttributes.setValue(Value); +} +SetPreserveAllScope::~SetPreserveAllScope() { + ShouldPreserveAllAttributes.setValue(PreviousValue); +} Index: llvm/lib/Transforms/Utils/Local.cpp =================================================================== --- llvm/lib/Transforms/Utils/Local.cpp +++ llvm/lib/Transforms/Utils/Local.cpp @@ -75,6 +75,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/KnowledgeRetention.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include #include @@ -410,7 +411,8 @@ // true are operationally no-ops. In the future we can consider more // sophisticated tradeoffs for guards considering potential for check // widening, but for now we keep things simple. - if (II->getIntrinsicID() == Intrinsic::assume || + if ((II->getIntrinsicID() == Intrinsic::assume && + isAssumeWithEmptyBundle(*II)) || II->getIntrinsicID() == Intrinsic::experimental_guard) { if (ConstantInt *Cond = dyn_cast(II->getArgOperand(0))) return !Cond->isZero(); Index: llvm/test/Transforms/InstCombine/assume.ll =================================================================== --- llvm/test/Transforms/InstCombine/assume.ll +++ llvm/test/Transforms/InstCombine/assume.ll @@ -269,9 +269,9 @@ ; CHECK-LABEL: @nonnull3( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[LOAD:%.*]] = load i32*, i32** [[A:%.*]], align 8 +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32* [[LOAD]], null ; CHECK-NEXT: br i1 [[CONTROL:%.*]], label [[TAKEN:%.*]], label [[NOT_TAKEN:%.*]] ; CHECK: taken: -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32* [[LOAD]], null ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) ; CHECK-NEXT: ret i1 false ; CHECK: not_taken: @@ -398,6 +398,56 @@ ret i32 %t2 } +define i1 @nonnull3A(i32** %a, i1 %control) { +; CHECK-LABEL: @nonnull3A( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[LOAD:%.*]] = load i32*, i32** [[A:%.*]], align 8 +; CHECK-NEXT: br i1 [[CONTROL:%.*]], label [[TAKEN:%.*]], label [[NOT_TAKEN:%.*]] +; CHECK: taken: +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32* [[LOAD]], null +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) +; CHECK-NEXT: ret i1 true +; CHECK: not_taken: +; CHECK-NEXT: [[RVAL_2:%.*]] = icmp sgt i32* [[LOAD]], null +; CHECK-NEXT: ret i1 [[RVAL_2]] +; +entry: + %load = load i32*, i32** %a + %cmp = icmp ne i32* %load, null + br i1 %control, label %taken, label %not_taken +taken: + tail call void @llvm.assume(i1 %cmp) + ret i1 %cmp +not_taken: + tail call void @llvm.assume(i1 %cmp) + %rval.2 = icmp sgt i32* %load, null + ret i1 %rval.2 +} + +define i1 @nonnull3B(i32** %a, i1 %control) { +; CHECK-LABEL: @nonnull3B( +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 [[CONTROL:%.*]], label [[TAKEN:%.*]], label [[NOT_TAKEN:%.*]] +; CHECK: taken: +; CHECK-NEXT: [[LOAD:%.*]] = load i32*, i32** [[A:%.*]], align 8 +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32* [[LOAD]], null +; CHECK-NEXT: tail call void @llvm.assume(i1 true) [ "nonnull"(i32* [[LOAD]]) ] +; CHECK-NEXT: ret i1 [[CMP]] +; CHECK: not_taken: +; CHECK-NEXT: ret i1 [[CONTROL]] +; +entry: + %load = load i32*, i32** %a + %cmp = icmp ne i32* %load, null + br i1 %control, label %taken, label %not_taken +taken: + tail call void @llvm.assume(i1 true) ["nonnull"(i32* %load)] + ret i1 %cmp +not_taken: + tail call void @llvm.assume(i1 %cmp) + ret i1 %control +} + declare void @llvm.dbg.value(metadata, metadata, metadata) !llvm.dbg.cu = !{!0} Index: llvm/unittests/Transforms/Utils/CMakeLists.txt =================================================================== --- llvm/unittests/Transforms/Utils/CMakeLists.txt +++ llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -14,6 +14,7 @@ CodeMoverUtilsTest.cpp FunctionComparatorTest.cpp IntegerDivisionTest.cpp + KnowledgeRetentionTest.cpp LocalTest.cpp LoopRotationUtilsTest.cpp LoopUtilsTest.cpp Index: llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp =================================================================== --- /dev/null +++ llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp @@ -0,0 +1,240 @@ +//===- KnowledgeRetention.h - utilities to preserve informations *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/KnowledgeRetention.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static void RunTest( + StringRef Head, StringRef Tail, + std::vector>> + &Tests) { + std::string IR; + IR.append(Head.begin(), Head.end()); + for (auto &Elem : Tests) + IR.append(Elem.first.begin(), Elem.first.end()); + IR.append(Tail.begin(), Tail.end()); + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("AssumeQueryAPI", errs()); + unsigned Idx = 0; + for (Instruction &I : (*Mod->getFunction("test")->begin())) { + if (Idx < Tests.size()) + Tests[Idx].second(&I); + Idx++; + } +} + +void AssertMatchesExactlyAttributes(CallInst *Assume, Value *WasOn, + StringRef AttrToMatch) { + Regex Reg(AttrToMatch); + SmallVector Matches; + for (StringRef Attr : { +#define GET_ATTR_NAMES +#define ATTRIBUTE_ALL(ENUM_NAME, DISPLAY_NAME) StringRef(#DISPLAY_NAME), +#include "llvm/IR/Attributes.inc" + }) { + bool ShouldHaveAttr = Reg.match(Attr, &Matches) && Matches[0] == Attr; + if (ShouldHaveAttr != hasAttributeInAssume(*Assume, WasOn, Attr)) + ASSERT_TRUE(false); + } +} + +void AssertHasTheRightValue(CallInst *Assume, Value *WasOn, + Attribute::AttrKind Kind, unsigned Value, bool Both, + AssumeQuery AQ = AssumeQuery::Highest) { + if (!Both) { + uint64_t ArgVal = 0; + ASSERT_TRUE(hasAttributeInAssume(*Assume, WasOn, Kind, &ArgVal, AQ)); + ASSERT_EQ(ArgVal, Value); + return; + } + uint64_t ArgValLow = 0; + uint64_t ArgValHigh = 0; + bool ResultLow = hasAttributeInAssume(*Assume, WasOn, Kind, &ArgValLow, + AssumeQuery::Lowest); + bool ResultHigh = hasAttributeInAssume(*Assume, WasOn, Kind, &ArgValHigh, + AssumeQuery::Highest); + if (ResultLow != ResultHigh) + ASSERT_TRUE(false); + if (ArgValLow != Value || ArgValLow != ArgValHigh) + ASSERT_EQ(ArgValLow, Value); +} + +TEST(AssumeQueryAPI, Basic) { + StringRef Head = + "declare void @llvm.assume(i1)\n" + "declare void @func(i32*, i32*)\n" + "declare void @func1(i32*, i32*, i32*, i32*)\n" + "declare void @func_many(i32*) \"no-jump-tables\" nounwind " + "\"less-precise-fpmad\" willreturn norecurse\n" + "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) {\n"; + StringRef Tail = "ret void\n" + "}"; + std::vector>> + Tests; + Tests.push_back(std::make_pair( + "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align " + "8 noalias %P1)\n", + [](Instruction *I) { + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + AssertMatchesExactlyAttributes(Assume, I->getOperand(0), + "(nonnull|align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(1), + "(noalias|align)"); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Dereferenceable, 16, true); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Alignment, 4, true); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Alignment, 4, true); + })); + Tests.push_back(std::make_pair( + "call void @func1(i32* nonnull align 32 dereferenceable(48) %P, i32* " + "nonnull " + "align 8 dereferenceable(28) %P, i32* nonnull align 64 " + "dereferenceable(4) " + "%P, i32* nonnull align 16 dereferenceable(12) %P)\n", + [](Instruction *I) { + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + AssertMatchesExactlyAttributes(Assume, I->getOperand(0), + "(nonnull|align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(1), + "(nonnull|align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(2), + "(nonnull|align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(3), + "(nonnull|align|dereferenceable)"); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Dereferenceable, 48, false, + AssumeQuery::Highest); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Alignment, 64, false, + AssumeQuery::Highest); + AssertHasTheRightValue(Assume, I->getOperand(1), + Attribute::AttrKind::Alignment, 64, false, + AssumeQuery::Highest); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Dereferenceable, 4, false, + AssumeQuery::Lowest); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Alignment, 8, false, + AssumeQuery::Lowest); + AssertHasTheRightValue(Assume, I->getOperand(1), + Attribute::AttrKind::Alignment, 8, false, + AssumeQuery::Lowest); + })); + Tests.push_back(std::make_pair( + "call void @func_many(i32* align 8 %P1) cold\n", [](Instruction *I) { + SetPreserveAllScope S(true); + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + AssertMatchesExactlyAttributes( + Assume, nullptr, + "(align|no-jump-tables|less-precise-fpmad|" + "nounwind|norecurse|willreturn|cold)"); + })); + Tests.push_back( + std::make_pair("call void @llvm.assume(i1 true)\n", [](Instruction *I) { + CallInst *Assume = cast(I); + AssertMatchesExactlyAttributes(Assume, nullptr, ""); + })); + Tests.push_back(std::make_pair( + "call void @func1(i32* readnone align 32 " + "dereferenceable(48) noalias %P, i32* " + "align 8 dereferenceable(28) %P1, i32* align 64 " + "dereferenceable(4) " + "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n", + [](Instruction *I) { + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + AssertMatchesExactlyAttributes( + Assume, I->getOperand(0), + "(readnone|align|dereferenceable|noalias)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(1), + "(align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(2), + "(align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(3), + "(nonnull|align|dereferenceable)"); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Alignment, 32, true); + AssertHasTheRightValue(Assume, I->getOperand(0), + Attribute::AttrKind::Dereferenceable, 48, true); + AssertHasTheRightValue(Assume, I->getOperand(1), + Attribute::AttrKind::Dereferenceable, 28, true); + AssertHasTheRightValue(Assume, I->getOperand(1), + Attribute::AttrKind::Alignment, 8, true); + AssertHasTheRightValue(Assume, I->getOperand(2), + Attribute::AttrKind::Alignment, 64, true); + AssertHasTheRightValue(Assume, I->getOperand(2), + Attribute::AttrKind::Dereferenceable, 4, true); + AssertHasTheRightValue(Assume, I->getOperand(3), + Attribute::AttrKind::Alignment, 16, true); + AssertHasTheRightValue(Assume, I->getOperand(3), + Attribute::AttrKind::Dereferenceable, 12, true); + })); + + /// The tests below modify the code keep them last to not interfere with + /// others. + Tests.push_back(std::make_pair( + "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align " + "8 noalias %P1)\n", + [](Instruction *I) { + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + Value *Val = I->getOperand(0); + AssertMatchesExactlyAttributes(Assume, Val, + "(nonnull|align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(1), + "(noalias|align)"); + + Val->dropDroppableUses(); + AssertMatchesExactlyAttributes(Assume, Val, ""); + AssertMatchesExactlyAttributes(Assume, I->getOperand(1), + "(noalias|align)"); + })); + Tests.push_back( + std::make_pair("%in_assume = icmp ne i32* %P, null\n", [](Instruction *I) {})); + Tests.push_back( + std::make_pair("call void @llvm.assume(i1 %in_assume)\n", [](Instruction *I) { + I->getOperand(0)->dropDroppableUses(); + ASSERT_EQ(I->getOperand(0), ConstantInt::getTrue(I->getContext())); + })); + Tests.push_back(std::make_pair( + "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align " + "8 noalias %P1)\n", + [](Instruction *I) { + CallInst *Assume = BuildAssumeFromInst(I); + Assume->insertBefore(I); + Value *New = I->getFunction()->getArg(3); + Value *Old = I->getOperand(0); + AssertMatchesExactlyAttributes(Assume, New, ""); + AssertMatchesExactlyAttributes(Assume, Old, + "(nonnull|align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, I->getOperand(1), + "(noalias|align)"); + Old->replaceAllUsesWith(New); + AssertMatchesExactlyAttributes(Assume, New, + "(nonnull|align|dereferenceable)"); + AssertMatchesExactlyAttributes(Assume, Old, ""); + AssertMatchesExactlyAttributes(Assume, I->getOperand(1), + "(noalias|align)"); + })); + RunTest(Head, Tail, Tests); +}