Index: include/llvm/Analysis/SparsePropagation.h =================================================================== --- include/llvm/Analysis/SparsePropagation.h +++ include/llvm/Analysis/SparsePropagation.h @@ -48,17 +48,40 @@ LatticeVal getOverdefinedVal() const { return OverdefinedVal; } LatticeVal getUntrackedVal() const { return UntrackedVal; } - /// IsUntrackedValue - If the specified Value is something that is obviously - /// uninteresting to the analysis (and would always return UntrackedVal), - /// this function can return true to avoid pointless work. + /// IsUntrackedValue - If the specified value is obviously uninteresting to + /// the analysis (i.e., it would always return UntrackedVal), this function + /// can return true to avoid pointless work. virtual bool IsUntrackedValue(Value *V) { return false; } + /// IsUntrackedGlobalVariable - If the specified global variable holds values + /// that are obviously uninteresting to the analysis (i.e., they would always + /// return UntrackedVal), this function can return true to avoid pointless + /// work. + virtual bool IsUntrackedGlobalVariable(GlobalVariable *V) { return false; } + + /// IsUntrackedFunction - If the specified function returns values that are + /// obviously uninteresting to the analysis (i.e., they would always return + /// UntrackedVal), this function can return true to avoid pointless work. + virtual bool IsUntrackedFunction(Function *V) { return false; } + /// ComputeConstant - Given a constant value, compute and return a lattice /// value corresponding to the specified constant. virtual LatticeVal ComputeConstant(Constant *C) { return getOverdefinedVal(); // always safe } + /// ComputeGlobalVariable - Given a global variable, compute and return a + /// lattice value corresponding to the value stored in the global. + virtual LatticeVal ComputeGlobalVariable(GlobalVariable *GV) { + return getOverdefinedVal(); // always safe + } + + /// ComputeFunction - Given a function, compute and return a lattice value + /// corresponding to the function's return value. + virtual LatticeVal ComputeFunction(Function *F) { + return getOverdefinedVal(); // always safe + } + /// IsSpecialCasedPHI - Given a PHI node, determine whether this PHI node is /// one that the we want to handle through ComputeInstructionState. virtual bool IsSpecialCasedPHI(PHINode *PN) { return false; } @@ -84,11 +107,14 @@ return getOverdefinedVal(); // always safe, never useful. } - /// ComputeInstructionState - Given an instruction and a vector of its operand - /// values, compute the result value of the instruction. - virtual LatticeVal ComputeInstructionState(Instruction &I, - SparseSolver &SS) { - return getOverdefinedVal(); // always safe, never useful. + /// ComputeInstructionState - Compute the lattice values that change as a + /// result of executing instruction \p I. The changed values are store in \p + /// ChangedValues. + virtual void + ComputeInstructionState(Instruction &I, + DenseMap &ChangedValues, + SparseSolver &SS) { + ChangedValues[&I] = getOverdefinedVal(); // always safe, never useful. } /// PrintValue - Render the specified lattice value to the specified stream. @@ -112,12 +138,34 @@ /// compute transfer functions. AbstractLatticeFunction *LatticeFunc; - DenseMap ValueState; // The state each value is in. - SmallPtrSet BBExecutable; // The bbs that are executable. + /// ValueState - Holds the lattice state associated with LLVM values. + DenseMap ValueState; + + /// FunctionState - Holds the lattice state for function return values. We + /// maintain a separate map for returns because functions can return more + /// than one value (i.e., there's no key with which to associate the return + /// lattice state other than the function itself). LLVM functions can also be + /// tracked independently in the ValueState map, so we need a separate one + /// for returns. + DenseMap FunctionState; + + /// GlobalVariableState - Holds the lattice state for values stored in global + /// variables. Similar to function returns, we maintain a separate map for + /// global variables because a global variable is the only key with which to + /// associate the values stored at its location. Additionally, LLVM global + /// variables can also be tracked independently in the ValueState map. This + /// map maintains the state of in-memory values, not their addresses. + DenseMap GlobalVariableState; + + /// BBWorkList - Holds basic blocks that should be processed. + SmallVector BBWorkList; - std::vector InstWorkList; // Worklist of insts to process. + /// ValueWorkList - Holds values (i.e., instructions and global variables) + /// that should be processed. + SmallVector ValueWorkList; - std::vector BBWorkList; // The BasicBlock work list + /// BBExecutable - Holds the basic blocks that are executable. + SmallPtrSet BBExecutable; /// KnownFeasibleEdges - Entries in this set are edges which have already had /// PHI nodes retriggered. @@ -131,25 +179,23 @@ SparseSolver &operator=(const SparseSolver &) = delete; /// Solve - Solve for constants and executable blocks. - void Solve(Function &F) { - MarkBlockExecutable(&F.getEntryBlock()); + void Solve() { // Process the work lists until they are empty! - while (!BBWorkList.empty() || !InstWorkList.empty()) { - // Process the instruction work list. - while (!InstWorkList.empty()) { - Instruction *I = InstWorkList.back(); - InstWorkList.pop_back(); + while (!BBWorkList.empty() || !ValueWorkList.empty()) { + // Process the value work list. + while (!ValueWorkList.empty()) { + Value *I = ValueWorkList.back(); + ValueWorkList.pop_back(); DEBUG(dbgs() << "\nPopped off I-WL: " << *I << "\n"); // "I" got into the work list because it made a transition. See if any // users are both live and in need of updating. - for (User *U : I->users()) { - Instruction *UI = cast(U); - if (BBExecutable.count(UI->getParent())) // Inst is executable? - visitInst(*UI); - } + for (User *U : I->users()) + if (Instruction *UI = dyn_cast(U)) + if (BBExecutable.count(UI->getParent())) // Inst is executable? + visitInst(*UI); } // Process the basic block work list. @@ -167,38 +213,38 @@ } } - void Print(Function &F, raw_ostream &OS) const { - OS << "\nFUNCTION: " << F.getName() << "\n"; - for (auto &BB : F) { - if (!BBExecutable.count(&BB)) - OS << "INFEASIBLE: "; - OS << "\t"; - if (BB.hasName()) - OS << BB.getName() << ":\n"; - else - OS << "; anon bb\n"; - for (auto &I : BB) { - LatticeFunc->PrintValue(getLatticeState(&I), OS); - OS << I << "\n"; - } + /// Print - Print the contents of each of the state maps. + void Print(raw_ostream &OS) const { + printStateMap(FunctionState, OS); + printStateMap(GlobalVariableState, OS); + printStateMap(ValueState, OS); + } - OS << "\n"; - } + /// getExistingValueState - Return the LatticeVal object corresponding to the + /// given value from the ValueState map. If the value is not in the map, + /// UntrackedVal is returned, unlike the getValueState method. + LatticeVal getExistingValueState(Value *V) const { + return findOrGetUntracked(ValueState, V); } - /// getLatticeState - Return the LatticeVal object that corresponds to the - /// value. If an value is not in the map, it is returned as untracked, - /// unlike the getValueState method. - LatticeVal getLatticeState(Value *V) const { - auto I = ValueState.find(V); - return I != ValueState.end() ? I->second : LatticeFunc->getUntrackedVal(); + /// getExistingGlobalVariableState - Return the LatticeVal object + /// corresponding to the given global variable from the GlobalVariableState + /// map. If the global variable is not in the map, UntrackedVal is returned, + /// unlike the getGlobalVariableState method. + LatticeVal getExistingGlobalVariableState(GlobalVariable *GV) const { + return findOrGetUntracked(GlobalVariableState, GV); } - /// getValueState - Return the LatticeVal object that corresponds to the - /// value, initializing the value's state if it hasn't been entered into the - /// map yet. This function is necessary because not all values should start - /// out in the underdefined state... Arguments should be overdefined, and - /// constants should be marked as constants. + /// getExistingFunctionState - Return the LatticeVal object corresponding to + /// the given function from the FunctionState map. If the function is not in + /// the map, UntrackedVal is returned, unlike the getFunctionState method. + LatticeVal getExistingFunctionState(Function *F) const { + return findOrGetUntracked(FunctionState, F); + } + + /// getValueState - Return the LatticeVal object corresponding to the given + /// value from the ValueState map. If the value is not in the map, its state + /// is initialized. LatticeVal getValueState(Value *V) { auto I = ValueState.find(V); if (I != ValueState.end()) @@ -224,6 +270,30 @@ return ValueState[V] = LV; } + /// getGlobalVariableState - Return the LatticeVal object corresponding to + /// the given global variable from the GlobalVariableState map. If the global + /// variable is not in the map, its state is initialized. + LatticeVal getGlobalVariableState(GlobalVariable *GV) { + auto I = GlobalVariableState.find(GV); + if (I != GlobalVariableState.end()) + return I->second; + if (LatticeFunc->IsUntrackedGlobalVariable(GV)) + return LatticeFunc->getUntrackedVal(); + return GlobalVariableState[GV] = LatticeFunc->ComputeGlobalVariable(GV); + } + + /// getFunctionState - Return the LatticeVal object corresponding to the + /// given function from the FunctionState map. If the function is not in the + /// map, its state is initialized. + LatticeVal getFunctionState(Function *F) { + auto I = FunctionState.find(F); + if (I != FunctionState.end()) + return I->second; + if (LatticeFunc->IsUntrackedFunction(F)) + return LatticeFunc->getUntrackedVal(); + return FunctionState[F] = LatticeFunc->ComputeFunction(F); + } + /// isEdgeFeasible - Return true if the control flow edge from the 'From' /// basic block to the 'To' basic block is currently feasible. If /// AggressiveUndef is true, then this treats values with unknown lattice @@ -249,25 +319,73 @@ return BBExecutable.count(BB); } + /// MarkBlockExecutable - This method can be used by clients to mark all of + /// the blocks that are known to be intrinsically live in the processed unit. + void MarkBlockExecutable(BasicBlock *BB) { + if (!BBExecutable.insert(BB).second) + return; + DEBUG(dbgs() << "Marking Block Executable: " << BB->getName() << "\n"); + BBWorkList.push_back(BB); // Add the block to the work list! + } + private: - /// UpdateState - When the state for some instruction is potentially updated, - /// this function notices and adds I to the worklist if needed. - void UpdateState(Instruction &Inst, LatticeVal V) { - auto I = ValueState.find(&Inst); - if (I != ValueState.end() && I->second == V) + /// UpdateState - When the state for some value is potentially updated, this + /// function notices and adds \p V to the worklist if needed. + void UpdateState(Value &V, LatticeVal LV, + DenseMap &StateMap) { + auto I = StateMap.find(&V); + if (I != StateMap.end() && I->second == LV) return; // No change. // An update. Visit uses of I. - ValueState[&Inst] = V; - InstWorkList.push_back(&Inst); + StateMap[&V] = LV; + ValueWorkList.push_back(&V); } - /// MarkBlockExecutable - This method can be used by clients to mark all of - /// the blocks that are known to be intrinsically live in the processed unit. - void MarkBlockExecutable(BasicBlock *BB) { - DEBUG(dbgs() << "Marking Block Executable: " << BB->getName() << "\n"); - BBExecutable.insert(BB); // Basic block is executable! - BBWorkList.push_back(BB); // Add the block to the work list! + /// getStateMapForInstruction - Return a reference to the state map + /// appropriate for the given instruction. + DenseMap &getStateMapForInstruction(Instruction &I) { + switch (I.getOpcode()) { + case Instruction::Ret: + return FunctionState; + case Instruction::Store: + return GlobalVariableState; + default: + return ValueState; + } + } + + /// printStateMap - Print the contents of the given state map. + void printStateMap(const DenseMap &StateMap, + raw_ostream &OS) const { + if (StateMap.empty()) + return; + + LatticeVal LV; + Value *V; + + if (&StateMap == &FunctionState) + OS << "FunctionState:\n"; + else if (&StateMap == &GlobalVariableState) + OS << "GlobalVariableState:\n"; + else if (&StateMap == &ValueState) + OS << "ValueState:\n"; + else + llvm_unreachable("Unknown state map"); + + for (auto &StateEntry : StateMap) { + std::tie(V, LV) = StateEntry; + if (LV == LatticeFunc->getUntrackedVal()) + continue; + OS << "\t"; + LatticeFunc->PrintValue(LV, OS); + OS << ": "; + if (isa(V)) + OS << V->getName() << "\n"; + else + OS << *V << "\n"; + } + OS << "\n"; } /// markEdgeExecutable - Mark a basic block as executable, adding it to the BB @@ -308,7 +426,7 @@ if (AggressiveUndef) BCValue = getValueState(BI->getCondition()); else - BCValue = getLatticeState(BI->getCondition()); + BCValue = getExistingValueState(BI->getCondition()); if (BCValue == LatticeFunc->getOverdefinedVal() || BCValue == LatticeFunc->getUntrackedVal()) { @@ -334,10 +452,8 @@ return; } - if (isa(TI)) { - // Invoke instructions successors are always executable. - // TODO: Could ask the lattice function if the value can throw. - Succs[0] = Succs[1] = true; + if (TI.isExceptional()) { + Succs.assign(Succs.size(), true); return; } @@ -351,7 +467,7 @@ if (AggressiveUndef) SCValue = getValueState(SI.getCondition()); else - SCValue = getLatticeState(SI.getCondition()); + SCValue = getExistingValueState(SI.getCondition()); if (SCValue == LatticeFunc->getOverdefinedVal() || SCValue == LatticeFunc->getUntrackedVal()) { @@ -382,9 +498,13 @@ // Otherwise, ask the transfer function what the result is. If this is // something that we care about, remember it. - LatticeVal IV = LatticeFunc->ComputeInstructionState(I, *this); - if (IV != LatticeFunc->getUntrackedVal()) - UpdateState(I, IV); + DenseMap ChangedValues; + LatticeFunc->ComputeInstructionState(I, ChangedValues, *this); + for (auto &ChangedValue : ChangedValues) { + if (ChangedValue.second != LatticeFunc->getUntrackedVal()) + UpdateState(*ChangedValue.first, ChangedValue.second, + getStateMapForInstruction(I)); + } if (TerminatorInst *TI = dyn_cast(&I)) visitTerminatorInst(*TI); @@ -395,9 +515,12 @@ // be computed from its incoming values. For example, SSI form stores its // sigma functions as PHINodes with a single incoming value. if (LatticeFunc->IsSpecialCasedPHI(&PN)) { - LatticeVal IV = LatticeFunc->ComputeInstructionState(PN, *this); - if (IV != LatticeFunc->getUntrackedVal()) - UpdateState(PN, IV); + DenseMap ChangedValues; + LatticeFunc->ComputeInstructionState(PN, ChangedValues, *this); + for (auto &ChangedValue : ChangedValues) { + if (ChangedValue.second != LatticeFunc->getUntrackedVal()) + UpdateState(*ChangedValue.first, ChangedValue.second, ValueState); + } return; } @@ -411,7 +534,7 @@ // Super-extra-high-degree PHI nodes are unlikely to ever be interesting, // and slow us down a lot. Just mark them overdefined. if (PN.getNumIncomingValues() > 64) { - UpdateState(PN, Overdefined); + UpdateState(PN, Overdefined, ValueState); return; } @@ -433,7 +556,7 @@ } // Update the PHI with the compute value, which is the merge of the inputs. - UpdateState(PN, PNIV); + UpdateState(PN, PNIV, ValueState); } void visitTerminatorInst(TerminatorInst &TI) { @@ -447,8 +570,16 @@ if (SuccFeasible[i]) markEdgeExecutable(BB, TI.getSuccessor(i)); } + + LatticeVal findOrGetUntracked(const DenseMap &StateMap, + Value *V) const { + auto I = StateMap.find(V); + return I != StateMap.end() ? I->second : LatticeFunc->getUntrackedVal(); + } }; } // end namespace llvm +#undef DEBUG_TYPE + #endif // LLVM_ANALYSIS_SPARSEPROPAGATION_H Index: unittests/Analysis/CMakeLists.txt =================================================================== --- unittests/Analysis/CMakeLists.txt +++ unittests/Analysis/CMakeLists.txt @@ -22,6 +22,7 @@ OrderedBasicBlockTest.cpp ProfileSummaryInfoTest.cpp ScalarEvolutionTest.cpp + SparsePropagation.cpp TargetLibraryInfoTest.cpp TBAATest.cpp UnrollAnalyzer.cpp Index: unittests/Analysis/SparsePropagation.cpp =================================================================== --- /dev/null +++ unittests/Analysis/SparsePropagation.cpp @@ -0,0 +1,528 @@ +//===- SparsePropagation.cpp - Unit tests for the generic solver ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/SparsePropagation.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/IRBuilder.h" +#include "gtest/gtest.h" +using namespace llvm; + +namespace { + +/// This class defines a simple test lattice value that could be used for +/// solving problems similar to constant propagation. The value is maintained +/// as a PointerIntPair. +class TestLatticeVal { +public: + /// The states of the lattices value. Only the ConstantVal state is + /// interesting; the rest are special states used by the generic solver. The + /// UntrackedVal state differs from the other three in that the generic + /// solver uses it to avoid doing unnecessary work. In particular, when a + /// value moves to the UntrackedVal state, it's users are not notified. + enum TestLatticeStateTy { + UndefinedVal, + ConstantVal, + OverdefinedVal, + UntrackedVal + }; + + TestLatticeVal() : LatticeVal(nullptr, UndefinedVal) {} + TestLatticeVal(Constant *C, TestLatticeStateTy State) + : LatticeVal(C, State) {} + + /// Return true if this lattice value is in the Constant state. This is used + /// for checking the solver results. + bool isConstant() const { return LatticeVal.getInt() == ConstantVal; } + + /// Return true if this lattice value is in the Overdefined state. This is + /// used for checking the solver results. + bool isOverdefined() const { return LatticeVal.getInt() == OverdefinedVal; } + + /// Return true if this lattice value is in the Untracked state. This is used + /// for checking the solver results. + bool isUntracked() const { return LatticeVal.getInt() == UntrackedVal; } + + bool operator==(const TestLatticeVal &RHS) const { + return LatticeVal == RHS.LatticeVal; + } + + bool operator!=(const TestLatticeVal &RHS) const { + return LatticeVal != RHS.LatticeVal; + } + +private: + /// A simple lattice value type for problems similar to constant propagation. + /// It holds the constant value and the lattice state. + PointerIntPair LatticeVal; +}; + +/// This class defines a simple test lattice function that could be used for +/// solving problems similar to constant propagation. The test lattice differs +/// from a "real" lattice in a few ways. First, it initializes all return +/// values, values stored in global variables, and arguments in the undefined +/// state. This means that there are no limitations on what we can track +/// interprocedurally. For simplicity, all global values in the tests will be +/// given internal linkage, since this is not something this lattice function +/// tracks. Second, it only handles the few instructions necessary for the +/// tests. +class TestLatticeFunc : public AbstractLatticeFunction { +public: + /// Construct a new test lattice function with special values for the + /// Undefined, Overdefined, and Untracked states. + TestLatticeFunc() + : AbstractLatticeFunction( + TestLatticeVal(nullptr, TestLatticeVal::UndefinedVal), + TestLatticeVal(nullptr, TestLatticeVal::OverdefinedVal), + TestLatticeVal(nullptr, TestLatticeVal::UntrackedVal)) {} + + /// Compute a new lattice value for the given argument. We assume all + /// arguments can be tracked. + TestLatticeVal ComputeArgument(Argument *A) override { return getUndefVal(); } + + /// Compute a new lattice value to track the values stored in the given + /// global variable. We assume all global variables can be tracked. + TestLatticeVal ComputeGlobalVariable(GlobalVariable *GV) override { + return getUndefVal(); + } + + /// Compute a new lattice value to track the return values of the given + /// function. We assume all functions can be tracked. + TestLatticeVal ComputeFunction(Function *F) override { return getUndefVal(); } + + /// Compute a new lattice value for the given constant. + TestLatticeVal ComputeConstant(Constant *C) override { + return TestLatticeVal(C, TestLatticeVal::ConstantVal); + } + + /// Merge the two given lattice values. This merge should be equivalent to + /// what is done for constant propagation. That is, the resulting lattice + /// value is constant only if the two given lattice values are constant and + /// hold the same value. + TestLatticeVal MergeValues(TestLatticeVal X, TestLatticeVal Y) override { + if (X == getUntrackedVal() || Y == getUntrackedVal()) + return getUntrackedVal(); + if (X == getOverdefinedVal() || Y == getOverdefinedVal()) + return getOverdefinedVal(); + if (X == getUndefVal() && Y == getUndefVal()) + return getUndefVal(); + if (X == getUndefVal()) + return Y; + if (Y == getUndefVal()) + return X; + if (X == Y) + return X; + return getOverdefinedVal(); + } + + /// Compute the lattice values that change as a result of executing the given + /// instruction. We only handle the few instructions needed for the tests. + void ComputeInstructionState(Instruction &I, + DenseMap &ChangedValues, + SparseSolver &SS) override { + switch (I.getOpcode()) { + case Instruction::Call: + return visitCallSite(cast(&I), ChangedValues, SS); + case Instruction::Ret: + return visitReturn(*cast(&I), ChangedValues, SS); + case Instruction::Store: + return visitStore(*cast(&I), ChangedValues, SS); + default: + return visitInst(I, ChangedValues, SS); + } + } + +private: + /// Handle call sites. The state of a called function's argument is the merge + /// of the current formal argument state with the call site's corresponding + /// actual argument state. The call site state is the merge of the call site + /// state with the returned value state of the called function. + void visitCallSite(CallSite CS, + DenseMap &ChangedValues, + SparseSolver &SS) { + Function *F = CS.getCalledFunction(); + Instruction *I = CS.getInstruction(); + if (!F) { + ChangedValues[I] = getOverdefinedVal(); + return; + } + SS.MarkBlockExecutable(&F->front()); + for (Argument &A : F->args()) + ChangedValues[&A] = MergeValues( + SS.getValueState(&A), SS.getValueState(CS.getArgument(A.getArgNo()))); + ChangedValues[I] = MergeValues(SS.getValueState(I), SS.getFunctionState(F)); + } + + /// Handle return instructions. The function's return state is the merge of + /// the returned value state and the function's current return state. + void visitReturn(ReturnInst &I, + DenseMap &ChangedValues, + SparseSolver &SS) { + Function *F = I.getParent()->getParent(); + if (!F->getReturnType()->isVoidTy()) + ChangedValues[F] = MergeValues(SS.getValueState(I.getReturnValue()), + SS.getFunctionState(F)); + } + + /// Handle store instructions. If the pointer operand of the store is a + /// global variable, we attempt to track the value. The global variable state + /// is the merge of the stored value state with the current global variable + /// state. + void visitStore(StoreInst &I, + DenseMap &ChangedValues, + SparseSolver &SS) { + auto *GV = dyn_cast(I.getPointerOperand()); + if (!GV) + return; + ChangedValues[GV] = MergeValues(SS.getValueState(I.getValueOperand()), + SS.getGlobalVariableState(GV)); + } + + /// Handle all other instructions. All other instructions are marked + /// overdefined. + void visitInst(Instruction &I, + DenseMap &ChangedValues, + SparseSolver &SS) { + ChangedValues[&I] = getOverdefinedVal(); + } +}; + +/// This class defines the common data used for all of the tests. The tests +/// should add code to the module and then run the solver. +class SparsePropagationTest : public testing::Test { +protected: + LLVMContext Context; + Module M; + IRBuilder<> Builder; + TestLatticeFunc Lattice; + SparseSolver Solver; + +public: + SparsePropagationTest() + : M("", Context), Builder(Context), Solver(&Lattice) {} +}; +} // namespace + +/// Test that we mark discovered functions executable. +/// +/// define internal void @f() { +/// call void @g() +/// ret void +/// } +/// +/// define internal void @g() { +/// call void @f() +/// ret void +/// } +/// +/// For this test, we initially mark "f" executable, and the solver discovers +/// "g" because of the call in "f". The mutually recursive call in "g" also +/// tests that we don't add a block to the basic block work list if it is +/// already executable. Doing so would put the solver into an infinite loop. +TEST_F(SparsePropagationTest, MarkBlockExecutable) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateCall(G); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateCall(F); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.Solve(); + + EXPECT_TRUE(Solver.isBlockExecutable(GEntry)); +} + +/// Test that we propagate information through global variables. +/// +/// @gv = internal global i64 +/// +/// define internal void @f() { +/// store i64 1, i64* @gv +/// ret void +/// } +/// +/// define internal void @g() { +/// store i64 1, i64* @gv +/// ret void +/// } +/// +/// For this test, we initially mark both "f" and "g" executable. The solver +/// computes the lattice state of the global variable as constant. The test +/// ensures that the solver uses the correct state map (i.e., the one for +/// in-memory values) for the store instruction. Thus, "gv" should be constant +/// in the global variable state map but untracked in the value state map. +TEST_F(SparsePropagationTest, GlobalVariableConstant) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + GlobalVariable *GV = + new GlobalVariable(M, Builder.getInt64Ty(), false, + GlobalValue::InternalLinkage, nullptr, "gv"); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateStore(Builder.getInt64(1), GV); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateStore(Builder.getInt64(1), GV); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.MarkBlockExecutable(GEntry); + Solver.Solve(); + + EXPECT_TRUE(Solver.getExistingValueState(GV).isUntracked()); + EXPECT_TRUE(Solver.getExistingGlobalVariableState(GV).isConstant()); +} + +/// Test that we propagate information through global variables. +/// +/// @gv = internal global i64 +/// +/// define internal void @f() { +/// store i64 0, i64* @gv +/// ret void +/// } +/// +/// define internal void @g() { +/// store i64 1, i64* @gv +/// ret void +/// } +/// +/// For this test, we initially mark both "f" and "g" executable. The solver +/// computes the lattice state of the global variable as overdefined. The test +/// ensures that the solver uses the correct state map (i.e., the one for +/// in-memory values) for the store instruction. Thus, "gv" should be +/// overdefined in the global variable state map but untracked in the value +/// state map. +TEST_F(SparsePropagationTest, GlobalVariableOverDefined) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + GlobalVariable *GV = + new GlobalVariable(M, Builder.getInt64Ty(), false, + GlobalValue::InternalLinkage, nullptr, "gv"); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateStore(Builder.getInt64(0), GV); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateStore(Builder.getInt64(1), GV); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.MarkBlockExecutable(GEntry); + Solver.Solve(); + + EXPECT_TRUE(Solver.getExistingValueState(GV).isUntracked()); + EXPECT_TRUE(Solver.getExistingGlobalVariableState(GV).isOverdefined()); +} + +/// Test that we propagate information through function returns. +/// +/// define internal i64 @f(i1* %cond) { +/// if: +/// %0 = load i1, i1* %cond +/// br i1 %0, label %then, label %else +/// +/// then: +/// ret i64 0 +/// +/// else: +/// ret i64 1 +/// } +/// +/// For this test, we initially mark "f" executable. The solver computes the +/// return value of the function as constant. The test ensures that the solver +/// uses the correct state map (i.e., the one for functions) for the return +/// instructions. Thus, "f" should be constant in the function state map, but +/// untracked in the value state map. +TEST_F(SparsePropagationTest, FunctionDefined) { + Function *F = + Function::Create(FunctionType::get(Builder.getInt64Ty(), + {Type::getInt1PtrTy(Context)}, false), + GlobalValue::InternalLinkage, "f", &M); + BasicBlock *If = BasicBlock::Create(Context, "if", F); + BasicBlock *Then = BasicBlock::Create(Context, "then", F); + BasicBlock *Else = BasicBlock::Create(Context, "else", F); + F->arg_begin()->setName("cond"); + Builder.SetInsertPoint(If); + LoadInst *Cond = Builder.CreateLoad(F->arg_begin()); + Builder.CreateCondBr(Cond, Then, Else); + Builder.SetInsertPoint(Then); + Builder.CreateRet(Builder.getInt64(1)); + Builder.SetInsertPoint(Else); + Builder.CreateRet(Builder.getInt64(1)); + + Solver.MarkBlockExecutable(If); + Solver.Solve(); + + EXPECT_TRUE(Solver.getExistingValueState(F).isUntracked()); + EXPECT_TRUE(Solver.getExistingFunctionState(F).isConstant()); +} + +/// Test that we propagate information through function returns. +/// +/// define internal i64 @f(i1* %cond) { +/// if: +/// %0 = load i1, i1* %cond +/// br i1 %0, label %then, label %else +/// +/// then: +/// ret i64 0 +/// +/// else: +/// ret i64 1 +/// } +/// +/// For this test, we initially mark "f" executable. The solver computes the +/// return value of the function as overdefined. The test ensures that the +/// solver uses the correct state map (i.e., the one for functions) for the +/// return instructions. Thus, "f" should be overdefined in the function state +/// map, but untracked in the value state map. +TEST_F(SparsePropagationTest, FunctionOverDefined) { + Function *F = + Function::Create(FunctionType::get(Builder.getInt64Ty(), + {Type::getInt1PtrTy(Context)}, false), + GlobalValue::InternalLinkage, "f", &M); + BasicBlock *If = BasicBlock::Create(Context, "if", F); + BasicBlock *Then = BasicBlock::Create(Context, "then", F); + BasicBlock *Else = BasicBlock::Create(Context, "else", F); + F->arg_begin()->setName("cond"); + Builder.SetInsertPoint(If); + LoadInst *Cond = Builder.CreateLoad(F->arg_begin()); + Builder.CreateCondBr(Cond, Then, Else); + Builder.SetInsertPoint(Then); + Builder.CreateRet(Builder.getInt64(0)); + Builder.SetInsertPoint(Else); + Builder.CreateRet(Builder.getInt64(1)); + + Solver.MarkBlockExecutable(If); + Solver.Solve(); + + EXPECT_TRUE(Solver.getExistingValueState(F).isUntracked()); + EXPECT_TRUE(Solver.getExistingFunctionState(F).isOverdefined()); +} + +/// Test that we propagate information through arguments. +/// +/// define internal void @f() { +/// call void @g(i64 0, i64 1) +/// call void @g(i64 1, i64 1) +/// ret void +/// } +/// +/// define internal void @g(i64 %a, i64 %b) { +/// ret void +/// } +/// +/// For this test, we initially mark "f" executable, and the solver discovers +/// "g" because of the calls in "f". The solver computes the state of argument +/// "a" as overdefined and the state of "b" as constant. The test ensures that +/// the solver uses the value state map for the arguments, instead of the +/// function or global variable maps. +/// +/// In addition, this test demonstrates that ComputeInstructionState can alter +/// the state of multiple lattice values, in addition to the one associated +/// with the instruction definition. Each call instruction in this test updates +/// the state of arguments "a" and "b". +TEST_F(SparsePropagationTest, ComputeInstructionState) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create( + FunctionType::get(Builder.getVoidTy(), + {Builder.getInt64Ty(), Builder.getInt64Ty()}, false), + GlobalValue::InternalLinkage, "g", &M); + Argument *A = G->arg_begin(); + Argument *B = std::next(G->arg_begin()); + A->setName("a"); + B->setName("b"); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateCall(G, {Builder.getInt64(0), Builder.getInt64(1)}); + Builder.CreateCall(G, {Builder.getInt64(1), Builder.getInt64(1)}); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.Solve(); + + EXPECT_TRUE(Solver.getExistingValueState(A).isOverdefined()); + EXPECT_TRUE(Solver.getExistingValueState(B).isConstant()); +} + +/// Test that we can handle exceptional terminator instructions. +/// +/// declare internal void @p() +/// +/// declare internal void @g() +/// +/// define internal void @f() personality i8* bitcast (void ()* @p to i8*) { +/// entry: +/// invoke void @g() +/// to label %exit unwind label %catch.pad +/// +/// catch.pad: +/// %0 = catchswitch within none [label %catch.body] unwind to caller +/// +/// catch.body: +/// %1 = catchpad within %0 [] +/// catchret from %1 to label %exit +/// +/// exit: +/// ret void +/// } +/// +/// For this test, we initially mark the entry block executable. The solver +/// then discovers the rest of the blocks in the function are executable. +TEST_F(SparsePropagationTest, ExceptionalTerminatorInsts) { + Function *P = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "p", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Constant *C = + ConstantExpr::getCast(Instruction::BitCast, P, Builder.getInt8PtrTy()); + F->setPersonalityFn(C); + BasicBlock *Entry = BasicBlock::Create(Context, "entry", F); + BasicBlock *Pad = BasicBlock::Create(Context, "catch.pad", F); + BasicBlock *Body = BasicBlock::Create(Context, "catch.body", F); + BasicBlock *Exit = BasicBlock::Create(Context, "exit", F); + Builder.SetInsertPoint(Entry); + Builder.CreateInvoke(G, Exit, Pad); + Builder.SetInsertPoint(Pad); + CatchSwitchInst *CatchSwitch = + Builder.CreateCatchSwitch(ConstantTokenNone::get(Context), nullptr, 1); + CatchSwitch->addHandler(Body); + Builder.SetInsertPoint(Body); + CatchPadInst *CatchPad = Builder.CreateCatchPad(CatchSwitch, {}); + Builder.CreateCatchRet(CatchPad, Exit); + Builder.SetInsertPoint(Exit); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(Entry); + Solver.Solve(); + + EXPECT_TRUE(Solver.isBlockExecutable(Pad)); + EXPECT_TRUE(Solver.isBlockExecutable(Body)); + EXPECT_TRUE(Solver.isBlockExecutable(Exit)); +}