Index: include/llvm/Analysis/SparsePropagation.h =================================================================== --- include/llvm/Analysis/SparsePropagation.h +++ include/llvm/Analysis/SparsePropagation.h @@ -17,6 +17,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/IR/Instruction.h" #include #include #include @@ -27,7 +28,6 @@ class BasicBlock; class Constant; class Function; -class Instruction; class PHINode; class raw_ostream; class SparseSolver; @@ -44,7 +44,8 @@ /// class AbstractLatticeFunction { public: - using LatticeVal = void *; + using LatticeVal = const void *; + using StateTy = DenseMap; private: LatticeVal UndefVal, OverdefinedVal, UntrackedVal; @@ -63,11 +64,22 @@ LatticeVal getOverdefinedVal() const { return OverdefinedVal; } LatticeVal getUntrackedVal() const { return UntrackedVal; } - /// IsUntrackedValue - If the specified Value is something that is obviously - /// uninteresting to the analysis (and would always return UntrackedVal), - /// this function can return true to avoid pointless work. + /// IsUntrackedValue - If the specified value is obviously uninteresting to + /// the analysis (i.e., it would always return UntrackedVal), this function + /// can return true to avoid pointless work. virtual bool IsUntrackedValue(Value *V) { return false; } + /// IsUntrackedGlobalVariable - If the specified global variable holds values + /// that are obviously uninteresting to the analysis (i.e., they would always + /// return UntrackedVal), this function can return true to avoid pointless + /// work. + virtual bool IsUntrackedGlobalVariable(GlobalVariable *V) { return false; } + + /// IsUntrackedFunction - If the specified function returns values that are + /// obviously uninteresting to the analysis (i.e., they would always return + /// UntrackedVal), this function can return true to avoid pointless work. + virtual bool IsUntrackedFunction(Function *V) { return false; } + /// ComputeConstant - Given a constant value, compute and return a lattice /// value corresponding to the specified constant. virtual LatticeVal ComputeConstant(Constant *C) { @@ -91,6 +103,18 @@ return getOverdefinedVal(); // always safe } + /// ComputeGlobalVariable - Given a global variable, compute and return a + /// lattice value corresponding to the value stored in the global. + virtual LatticeVal ComputeGlobalVariable(GlobalVariable *GV) { + return getOverdefinedVal(); // always safe + } + + /// ComputeFunction - Given a function, compute and return a lattice value + /// corresponding to the function's return value. + virtual LatticeVal ComputeFunction(Function *F) { + return getOverdefinedVal(); // always safe + } + /// MergeValues - Compute and return the merge of the two specified lattice /// values. Merging should only move one direction down the lattice to /// guarantee convergence (toward overdefined). @@ -98,10 +122,12 @@ return getOverdefinedVal(); // always safe, never useful. } - /// ComputeInstructionState - Given an instruction and a vector of its operand - /// values, compute the result value of the instruction. - virtual LatticeVal ComputeInstructionState(Instruction &I, SparseSolver &SS) { - return getOverdefinedVal(); // always safe, never useful. + /// ComputeInstructionState - Compute the lattice values that change as a + /// result of executing instruction \p I. The changed values are store in \p + /// CV. + virtual void ComputeInstructionState(Instruction &I, StateTy &CV, + SparseSolver &S) { + CV[&I] = getOverdefinedVal(); // always safe, never useful. } /// PrintValue - Render the specified lattice value to the specified stream. @@ -112,17 +138,40 @@ /// Propagation with a programmable lattice function. class SparseSolver { using LatticeVal = AbstractLatticeFunction::LatticeVal; + using StateTy = AbstractLatticeFunction::StateTy; /// LatticeFunc - This is the object that knows the lattice and how to do /// compute transfer functions. AbstractLatticeFunction *LatticeFunc; - DenseMap ValueState; // The state each value is in. - SmallPtrSet BBExecutable; // The bbs that are executable. + /// ValueState - Holds the lattice state associated with LLVM values. + StateTy ValueState; + + /// FunctionState - Holds the lattice state for function return values. We + /// maintain a separate map for returns because functions can return more + /// than one value (i.e., there's no key with which to associate the return + /// lattice state other than the function itself). LLVM functions can also be + /// tracked independently in the ValueState map, so we need a separate one + /// for returns. + StateTy FunctionState; + + /// GlobalVariableState - Holds the lattice state for values stored in global + /// variables. Similar to function returns, we maintain a separate map for + /// global variables because a global variable is the only key with which to + /// associate the values stored at its location. Additionally, LLVM global + /// variables can also be tracked independently in the ValueState map. This + /// map maintains the state of in-memory values, not their addresses. + StateTy GlobalVariableState; + + /// BBWorkList - Holds basic blocks that should be processed. + std::vector BBWorkList; - std::vector InstWorkList; // Worklist of insts to process. + /// ValueWorkList - Holds values (i.e., instructions and global variables) + /// that should be processed. + std::vector ValueWorkList; - std::vector BBWorkList; // The BasicBlock work list + /// BBExecutable - Holds the basic blocks that are executable. + SmallPtrSet BBExecutable; /// KnownFeasibleEdges - Entries in this set are edges which have already had /// PHI nodes retriggered. @@ -134,28 +183,43 @@ : LatticeFunc(Lattice) {} SparseSolver(const SparseSolver &) = delete; SparseSolver &operator=(const SparseSolver &) = delete; - ~SparseSolver() { delete LatticeFunc; } /// Solve - Solve for constants and executable blocks. - void Solve(Function &F); + void Solve(); - void Print(Function &F, raw_ostream &OS) const; + /// Print - Print the contents of each of the states maps. + void Print(raw_ostream &OS) const; - /// getLatticeState - Return the LatticeVal object that corresponds to the - /// value. If an value is not in the map, it is returned as untracked, - /// unlike the getOrInitValueState method. - LatticeVal getLatticeState(Value *V) const { - DenseMap::const_iterator I = ValueState.find(V); - return I != ValueState.end() ? I->second : LatticeFunc->getUntrackedVal(); - } + /// getValueState - Return the LatticeVal object corresponding to the given + /// value from the ValueState map. If the value is not in the map, + /// UntrackedVal is returned. + LatticeVal getValueState(Value *V) const; + + /// getGlobalVariableState - Return the LatticeVal object corresponding to + /// the given global variable from the GlobalVariableState map. If the global + /// variable is not in the map, UntrackedVal is returned. + LatticeVal getGlobalVariableState(GlobalVariable *GV) const; + + /// getFunctionState - Return the LatticeVal object corresponding to the + /// given function from the FunctionState map. If the function is not in the + /// map, UntrackedVal is returned. + LatticeVal getFunctionState(Function *F) const; - /// getOrInitValueState - Return the LatticeVal object that corresponds to the - /// value, initializing the value's state if it hasn't been entered into the - /// map yet. This function is necessary because not all values should start - /// out in the underdefined state... Arguments should be overdefined, and - /// constants should be marked as constants. + /// getOrInitValueState - Return the LatticeVal object corresponding to the + /// given value from the ValueState map. If the value is not in the map, its + /// state is initialized. LatticeVal getOrInitValueState(Value *V); + /// getOrInitGlobalVariableState - Return the LatticeVal object corresponding + /// to the given global variable from the GlobalVariableState map. If the + /// global variable is not in the map, its state is initialized. + LatticeVal getOrInitGlobalVariableState(GlobalVariable *GV); + + /// getOrInitFunctionState - Return the LatticeVal object corresponding to + /// the given function from the FunctionState map. If the function is not in + /// the map, its state is initialized. + LatticeVal getOrInitFunctionState(Function *F); + /// isEdgeFeasible - Return true if the control flow edge from the 'From' /// basic block to the 'To' basic block is currently feasible. If /// AggressiveUndef is true, then this treats values with unknown lattice @@ -171,15 +235,22 @@ return BBExecutable.count(BB); } -private: - /// UpdateState - When the state for some instruction is potentially updated, - /// this function notices and adds I to the worklist if needed. - void UpdateState(Instruction &Inst, LatticeVal V); - /// MarkBlockExecutable - This method can be used by clients to mark all of /// the blocks that are known to be intrinsically live in the processed unit. void MarkBlockExecutable(BasicBlock *BB); +private: + /// UpdateState - When the state for some value is potentially updated, this + /// function notices and adds V to the worklist if needed. + void UpdateState(Value &V, LatticeVal LV, StateTy &State); + + /// getStateMapForInstruction - Return a reference to the state map + /// appropriate for the given instruction. + StateTy &getStateMapForInstruction(Instruction &I); + + /// printStateMap - Print the contents of the given state map. + void printStateMap(const StateTy &State, raw_ostream &OS) const; + /// markEdgeExecutable - Mark a basic block as executable, adding it to the BB /// work list if it is not already executable. void markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest); Index: lib/Analysis/SparsePropagation.cpp =================================================================== --- lib/Analysis/SparsePropagation.cpp +++ lib/Analysis/SparsePropagation.cpp @@ -20,8 +20,8 @@ #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/User.h" #include "llvm/Support/Casting.h" @@ -54,13 +54,25 @@ // SparseSolver Implementation //===----------------------------------------------------------------------===// -/// getOrInitValueState - Return the LatticeVal object that corresponds to the -/// value, initializing the value's state if it hasn't been entered into the -/// map yet. This function is necessary because not all values should start -/// out in the underdefined state... Arguments should be overdefined, and -/// constants should be marked as constants. +SparseSolver::LatticeVal SparseSolver::getValueState(Value *V) const { + StateTy::const_iterator I = ValueState.find(V); + return I != ValueState.end() ? I->second : LatticeFunc->getUntrackedVal(); +} + +SparseSolver::LatticeVal +SparseSolver::getGlobalVariableState(GlobalVariable *GV) const { + StateTy::const_iterator I = GlobalVariableState.find(GV); + return I != GlobalVariableState.end() ? I->second + : LatticeFunc->getUntrackedVal(); +} + +SparseSolver::LatticeVal SparseSolver::getFunctionState(Function *F) const { + StateTy::const_iterator I = FunctionState.find(F); + return I != FunctionState.end() ? I->second : LatticeFunc->getUntrackedVal(); +} + SparseSolver::LatticeVal SparseSolver::getOrInitValueState(Value *V) { - DenseMap::iterator I = ValueState.find(V); + StateTy::iterator I = ValueState.find(V); if (I != ValueState.end()) return I->second; // Common case, in the map LatticeVal LV; @@ -83,23 +95,43 @@ return ValueState[V] = LV; } +SparseSolver::LatticeVal +SparseSolver::getOrInitGlobalVariableState(GlobalVariable *GV) { + StateTy::iterator I = GlobalVariableState.find(GV); + if (I != GlobalVariableState.end()) + return I->second; + if (LatticeFunc->IsUntrackedGlobalVariable(GV)) + return LatticeFunc->getUntrackedVal(); + return GlobalVariableState[GV] = LatticeFunc->ComputeGlobalVariable(GV); +} + +SparseSolver::LatticeVal SparseSolver::getOrInitFunctionState(Function *F) { + StateTy::iterator I = FunctionState.find(F); + if (I != FunctionState.end()) + return I->second; + if (LatticeFunc->IsUntrackedFunction(F)) + return LatticeFunc->getUntrackedVal(); + return FunctionState[F] = LatticeFunc->ComputeFunction(F); +} + /// UpdateState - When the state for some instruction is potentially updated, /// this function notices and adds I to the worklist if needed. -void SparseSolver::UpdateState(Instruction &Inst, LatticeVal V) { - DenseMap::iterator I = ValueState.find(&Inst); - if (I != ValueState.end() && I->second == V) +void SparseSolver::UpdateState(Value &V, LatticeVal LV, StateTy &State) { + StateTy::iterator I = State.find(&V); + if (I != State.end() && I->second == LV) return; // No change. // An update. Visit uses of I. - ValueState[&Inst] = V; - InstWorkList.push_back(&Inst); + State[&V] = LV; + ValueWorkList.push_back(&V); } /// MarkBlockExecutable - This method can be used by clients to mark all of /// the blocks that are known to be intrinsically live in the processed unit. void SparseSolver::MarkBlockExecutable(BasicBlock *BB) { + if (!BBExecutable.insert(BB).second) + return; DEBUG(dbgs() << "Marking Block Executable: " << BB->getName() << "\n"); - BBExecutable.insert(BB); // Basic block is executable! BBWorkList.push_back(BB); // Add the block to the work list! } @@ -141,7 +173,7 @@ if (AggressiveUndef) BCValue = getOrInitValueState(BI->getCondition()); else - BCValue = getLatticeState(BI->getCondition()); + BCValue = getValueState(BI->getCondition()); if (BCValue == LatticeFunc->getOverdefinedVal() || BCValue == LatticeFunc->getUntrackedVal()) { @@ -166,10 +198,8 @@ return; } - if (isa(TI)) { - // Invoke instructions successors are always executable. - // TODO: Could ask the lattice function if the value can throw. - Succs[0] = Succs[1] = true; + if (TI.isExceptional()) { + Succs.assign(Succs.size(), true); return; } @@ -183,7 +213,7 @@ if (AggressiveUndef) SCValue = getOrInitValueState(SI.getCondition()); else - SCValue = getLatticeState(SI.getCondition()); + SCValue = getValueState(SI.getCondition()); if (SCValue == LatticeFunc->getOverdefinedVal() || SCValue == LatticeFunc->getUntrackedVal()) { @@ -238,9 +268,12 @@ // computed from its incoming values. For example, SSI form stores its sigma // functions as PHINodes with a single incoming value. if (LatticeFunc->IsSpecialCasedPHI(&PN)) { - LatticeVal IV = LatticeFunc->ComputeInstructionState(PN, *this); - if (IV != LatticeFunc->getUntrackedVal()) - UpdateState(PN, IV); + StateTy ChangedValues; + LatticeFunc->ComputeInstructionState(PN, ChangedValues, *this); + for (auto &ChangedValue : ChangedValues) { + if (ChangedValue.second != LatticeFunc->getUntrackedVal()) + UpdateState(*ChangedValue.first, ChangedValue.second, ValueState); + } return; } @@ -254,7 +287,7 @@ // Super-extra-high-degree PHI nodes are unlikely to ever be interesting, // and slow us down a lot. Just mark them overdefined. if (PN.getNumIncomingValues() > 64) { - UpdateState(PN, Overdefined); + UpdateState(PN, Overdefined, ValueState); return; } @@ -276,7 +309,15 @@ } // Update the PHI with the compute value, which is the merge of the inputs. - UpdateState(PN, PNIV); + UpdateState(PN, PNIV, ValueState); +} + +SparseSolver::StateTy &SparseSolver::getStateMapForInstruction(Instruction &I) { + switch (I.getOpcode()) { + case Instruction::Ret: return FunctionState; + case Instruction::Store: return GlobalVariableState; + default: return ValueState; + } } void SparseSolver::visitInst(Instruction &I) { @@ -287,32 +328,35 @@ // Otherwise, ask the transfer function what the result is. If this is // something that we care about, remember it. - LatticeVal IV = LatticeFunc->ComputeInstructionState(I, *this); - if (IV != LatticeFunc->getUntrackedVal()) - UpdateState(I, IV); - + StateTy ChangedValues; + LatticeFunc->ComputeInstructionState(I, ChangedValues, *this); + for (auto &ChangedValue : ChangedValues) { + if (ChangedValue.second != LatticeFunc->getUntrackedVal()) + UpdateState(*ChangedValue.first, ChangedValue.second, + getStateMapForInstruction(I)); + } + if (TerminatorInst *TI = dyn_cast(&I)) visitTerminatorInst(*TI); } -void SparseSolver::Solve(Function &F) { - MarkBlockExecutable(&F.getEntryBlock()); - +void SparseSolver::Solve() { + // Process the work lists until they are empty! - while (!BBWorkList.empty() || !InstWorkList.empty()) { - // Process the instruction work list. - while (!InstWorkList.empty()) { - Instruction *I = InstWorkList.back(); - InstWorkList.pop_back(); + while (!BBWorkList.empty() || !ValueWorkList.empty()) { + // Process the value work list. + while (!ValueWorkList.empty()) { + Value *I = ValueWorkList.back(); + ValueWorkList.pop_back(); DEBUG(dbgs() << "\nPopped off I-WL: " << *I << "\n"); // "I" got into the work list because it made a transition. See if any // users are both live and in need of updating. for (User *U : I->users()) { - Instruction *UI = cast(U); - if (BBExecutable.count(UI->getParent())) // Inst is executable? - visitInst(*UI); + if (Instruction *UI = dyn_cast(U)) + if (BBExecutable.count(UI->getParent())) // Inst is executable? + visitInst(*UI); } } @@ -331,21 +375,39 @@ } } -void SparseSolver::Print(Function &F, raw_ostream &OS) const { - OS << "\nFUNCTION: " << F.getName() << "\n"; - for (auto &BB : F) { - if (!BBExecutable.count(&BB)) - OS << "INFEASIBLE: "; +void SparseSolver::printStateMap(const StateTy &State, raw_ostream &OS) const { + if (State.empty()) + return; + + LatticeVal LV; + Value *V; + + if (&State == &FunctionState) + OS << "FunctionState:\n"; + else if (&State == &GlobalVariableState) + OS << "GlobalVariableState:\n"; + else if (&State == &ValueState) + OS << "ValueState:\n"; + else + llvm_unreachable("Unknown state map"); + + for (auto &StateEntry : State) { + std::tie(V, LV) = StateEntry; + if (LV == LatticeFunc->getUntrackedVal()) + continue; OS << "\t"; - if (BB.hasName()) - OS << BB.getName() << ":\n"; + LatticeFunc->PrintValue(LV, OS); + OS << ": "; + if (isa(V)) + OS << V->getName() << "\n"; else - OS << "; anon bb\n"; - for (auto &I : BB) { - LatticeFunc->PrintValue(getLatticeState(&I), OS); - OS << I << "\n"; - } - - OS << "\n"; + OS << *V << "\n"; } + OS << "\n"; +} + +void SparseSolver::Print(raw_ostream &OS) const { + printStateMap(FunctionState, OS); + printStateMap(GlobalVariableState, OS); + printStateMap(ValueState, OS); } Index: unittests/Analysis/CMakeLists.txt =================================================================== --- unittests/Analysis/CMakeLists.txt +++ unittests/Analysis/CMakeLists.txt @@ -21,6 +21,7 @@ OrderedBasicBlockTest.cpp ProfileSummaryInfoTest.cpp ScalarEvolutionTest.cpp + SparsePropagation.cpp TargetLibraryInfoTest.cpp TBAATest.cpp UnrollAnalyzer.cpp Index: unittests/Analysis/SparsePropagation.cpp =================================================================== --- /dev/null +++ unittests/Analysis/SparsePropagation.cpp @@ -0,0 +1,509 @@ +//===- SparsePropagation.cpp - Unit tests for the generic solver ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/SparsePropagation.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/IRBuilder.h" +#include "gtest/gtest.h" +using namespace llvm; + +namespace { + +/// This class defines a simple test lattice function that could be used for +/// solving problems similar to constant propagation. The test lattice differs +/// from a "real" lattice in a few ways. First, it initializes all return +/// values, values stored in global variables, and arguments in the undefined +/// state. This means that there are no limitations on what we can track +/// interprocedurally. For simplicity, all global values in the tests will be +/// given internal linkage, since this is not something this lattice function +/// tracks. Second, it only handles the few instructions necessary for the +/// tests. +class TestLatticeFunc : public AbstractLatticeFunction { + + /// The states of the lattices values. Only the Constant state is + /// interesting; the rest are special states used by the generic solver. The + /// Untracked state differs from the other three in that the generic solver + /// uses it to avoid doing unnecessary work. In particular, when a value + /// moves to the Untracked state, it's users are not be notified. + enum class TestLatticeValTy : unsigned { + Undefined, + Constant, + Overdefined, + Untracked + }; + + /// A simple lattice value type for problems similar to constant propagation. + /// It holds the constant value and the lattice state. + using TestLatticeVal = PointerIntPair; + +public: + /// Construct a new test lattice function with special values for the + /// Undefined, Overdefined, and Untracked states. + TestLatticeFunc() + : AbstractLatticeFunction( + TestLatticeVal(nullptr, TestLatticeValTy::Undefined) + .getOpaqueValue(), + TestLatticeVal(nullptr, TestLatticeValTy::Overdefined) + .getOpaqueValue(), + TestLatticeVal(nullptr, TestLatticeValTy::Untracked) + .getOpaqueValue()) {} + + /// Return true if the given lattice value is in the Constant state. This is + /// used for checking the solver results. + bool isConstant(LatticeVal LV) const { + return TestLatticeVal::getFromOpaqueValue(LV).getInt() == + TestLatticeValTy::Constant; + } + + /// Return true if the given lattice value is in the Overdefined state. This + /// is used for checking the solver results. + bool isOverdefined(LatticeVal LV) const { + return TestLatticeVal::getFromOpaqueValue(LV).getInt() == + TestLatticeValTy::Overdefined; + } + + /// Return true if the given lattice value is in the Untracked state. This is + /// used for checking the solver results. + bool isUntracked(LatticeVal LV) const { + return TestLatticeVal::getFromOpaqueValue(LV).getInt() == + TestLatticeValTy::Untracked; + } + + /// Compute a new lattice value for the given argument. We assume all + /// arguments can be tracked. + LatticeVal ComputeArgument(Argument *A) override { return getUndefVal(); } + + /// Compute a new lattice value to track the values stored in the given + /// global variable. We assume all global variables can be tracked. + LatticeVal ComputeGlobalVariable(GlobalVariable *GV) override { + return getUndefVal(); + } + + /// Compute a new lattice value to track the return values of the given + /// function. We assume all functions can be tracked. + LatticeVal ComputeFunction(Function *F) override { return getUndefVal(); } + + /// Compute a new lattice value for the given constant. + LatticeVal ComputeConstant(Constant *C) override { + return TestLatticeVal(C, TestLatticeValTy::Constant).getOpaqueValue(); + } + + /// Merge the two given lattice values. This merge should be equivalent to + /// what is done for constant propagation. That is, the resulting lattice + /// value is constant only if the two given lattice values are constant and + /// hold the same value. + LatticeVal MergeValues(LatticeVal X, LatticeVal Y) override { + if (X == getUntrackedVal() || Y == getUntrackedVal()) + return getUntrackedVal(); + if (X == getOverdefinedVal() || Y == getOverdefinedVal()) + return getOverdefinedVal(); + if (X == getUndefVal() && Y == getUndefVal()) + return getUndefVal(); + if (X == getUndefVal()) + return Y; + if (Y == getUndefVal()) + return X; + if (X == Y) + return X; + return getOverdefinedVal(); + } + + /// Compute the lattice values that change as a result of executing the given + /// instruction. We only handle the few instructions needed for the tests. + void ComputeInstructionState(Instruction &I, StateTy &CV, + SparseSolver &S) override { + switch (I.getOpcode()) { + case Instruction::Call: return visitCallSite(cast(&I), CV, S); + case Instruction::Ret: return visitReturn(*cast(&I), CV, S); + case Instruction::Store: return visitStore(*cast(&I), CV, S); + default: return visitInst(I, CV, S); + } + } + +private: + /// Handle call sites. The state of a called function's argument is the merge + /// of the current formal argument state with the call site's corresponding + /// actual argument state. The call site state is the merge of the call site + /// state with the returned value state of the called function. + void visitCallSite(CallSite CS, StateTy &CV, SparseSolver &S) { + Function *F = CS.getCalledFunction(); + Instruction *I = CS.getInstruction(); + if (!F) { + CV[I] = getOverdefinedVal(); + return; + } + S.MarkBlockExecutable(&F->front()); + for (Argument &A : F->args()) + CV[&A] = MergeValues(S.getOrInitValueState(&A), + S.getOrInitValueState(CS.getArgument(A.getArgNo()))); + CV[I] = MergeValues(S.getOrInitValueState(I), S.getOrInitFunctionState(F)); + } + + /// Handle return instructions. The function's return state is the merge of + /// the returned value state and the function's current return state. + void visitReturn(ReturnInst &I, StateTy &CV, SparseSolver &S) { + Function *F = I.getParent()->getParent(); + if (!F->getReturnType()->isVoidTy()) + CV[F] = MergeValues(S.getOrInitValueState(I.getReturnValue()), + S.getOrInitFunctionState(F)); + } + + /// Handle store instructions. If the pointer operand of the store is a + /// global variable, we attempt to track the value. The global variable state + /// is the merge of the stored value state with the current global variable + /// state. + void visitStore(StoreInst &I, StateTy &CV, SparseSolver &S) { + auto *GV = dyn_cast(I.getPointerOperand()); + if (!GV) + return; + CV[GV] = MergeValues(S.getOrInitValueState(I.getValueOperand()), + S.getOrInitGlobalVariableState(GV)); + } + + /// Handle all other instructions. All other instructions are marked + /// overdefined. + void visitInst(Instruction &I, StateTy &CV, SparseSolver &S) { + CV[&I] = getOverdefinedVal(); + } +}; + +/// This class defines the common data used for all of the tests. The tests +/// should add code to the module and then run the solver. +class SparsePropagationTest : public testing::Test { +protected: + LLVMContext Context; + Module M; + IRBuilder<> Builder; + TestLatticeFunc Lattice; + SparseSolver Solver; + +public: + SparsePropagationTest() + : M("", Context), Builder(Context), Solver(&Lattice) {} +}; +} + +/// Test that we mark discovered functions executable. +/// +/// define internal void @f() { +/// call void @g() +/// ret void +/// } +/// +/// define internal void @g() { +/// call void @f() +/// ret void +/// } +/// +/// For this test, we initially mark "f" executable, and the solver discovers +/// "g" because of the call in "f". The mutually recursive call in "g" also +/// tests that we don't add a block to the basic block work list if it is +/// already executable. Doing so would put the solver into an infinite loop. +TEST_F(SparsePropagationTest, MarkBlockExecutable) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateCall(G); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateCall(F); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.Solve(); + + EXPECT_TRUE(Solver.isBlockExecutable(GEntry)); +} + +/// Test that we propagate information through global variables. +/// +/// @gv = internal global i64 +/// +/// define internal void @f() { +/// store i64 1, i64* @gv +/// ret void +/// } +/// +/// define internal void @g() { +/// store i64 1, i64* @gv +/// ret void +/// } +/// +/// For this test, we initially mark both "f" and "g" executable. The solver +/// computes the lattice state of the global variable as constant. The test +/// ensures that the solver uses the correct state map (i.e., the one for +/// in-memory values) for the store instruction. Thus, "gv" should be constant +/// in the global variable state map but untracked in the value state map. +TEST_F(SparsePropagationTest, GlobalVariableConstant) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + GlobalVariable *GV = + new GlobalVariable(M, Builder.getInt64Ty(), false, + GlobalValue::InternalLinkage, nullptr, "gv"); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateStore(Builder.getInt64(1), GV); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateStore(Builder.getInt64(1), GV); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.MarkBlockExecutable(GEntry); + Solver.Solve(); + + EXPECT_TRUE(Lattice.isUntracked(Solver.getValueState(GV))); + EXPECT_TRUE(Lattice.isConstant(Solver.getGlobalVariableState(GV))); +} + +/// Test that we propagate information through global variables. +/// +/// @gv = internal global i64 +/// +/// define internal void @f() { +/// store i64 0, i64* @gv +/// ret void +/// } +/// +/// define internal void @g() { +/// store i64 1, i64* @gv +/// ret void +/// } +/// +/// For this test, we initially mark both "f" and "g" executable. The solver +/// computes the lattice state of the global variable as overdefined. The test +/// ensures that the solver uses the correct state map (i.e., the one for +/// in-memory values) for the store instruction. Thus, "gv" should be +/// overdefined in the global variable state map but untracked in the value +/// state map. +TEST_F(SparsePropagationTest, GlobalVariableOverDefined) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + GlobalVariable *GV = + new GlobalVariable(M, Builder.getInt64Ty(), false, + GlobalValue::InternalLinkage, nullptr, "gv"); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateStore(Builder.getInt64(0), GV); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateStore(Builder.getInt64(1), GV); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.MarkBlockExecutable(GEntry); + Solver.Solve(); + + EXPECT_TRUE(Lattice.isUntracked(Solver.getValueState(GV))); + EXPECT_TRUE(Lattice.isOverdefined(Solver.getGlobalVariableState(GV))); +} + +/// Test that we propagate information through function returns. +/// +/// define internal i64 @f(i1* %cond) { +/// if: +/// %0 = load i1, i1* %cond +/// br i1 %0, label %then, label %else +/// +/// then: +/// ret i64 0 +/// +/// else: +/// ret i64 1 +/// } +/// +/// For this test, we initially mark "f" executable. The solver computes the +/// return value of the function as constant. The test ensures that the solver +/// uses the correct state map (i.e., the one for functions) for the return +/// instructions. Thus, "f" should be constant in the function state map, but +/// untracked in the value state map. +TEST_F(SparsePropagationTest, FunctionDefined) { + Function *F = + Function::Create(FunctionType::get(Builder.getInt64Ty(), + {Type::getInt1PtrTy(Context)}, false), + GlobalValue::InternalLinkage, "f", &M); + BasicBlock *If = BasicBlock::Create(Context, "if", F); + BasicBlock *Then = BasicBlock::Create(Context, "then", F); + BasicBlock *Else = BasicBlock::Create(Context, "else", F); + F->arg_begin()->setName("cond"); + Builder.SetInsertPoint(If); + LoadInst *Cond = Builder.CreateLoad(F->arg_begin()); + Builder.CreateCondBr(Cond, Then, Else); + Builder.SetInsertPoint(Then); + Builder.CreateRet(Builder.getInt64(1)); + Builder.SetInsertPoint(Else); + Builder.CreateRet(Builder.getInt64(1)); + + Solver.MarkBlockExecutable(If); + Solver.Solve(); + + EXPECT_TRUE(Lattice.isUntracked(Solver.getValueState(F))); + EXPECT_TRUE(Lattice.isConstant(Solver.getFunctionState(F))); +} + +/// Test that we propagate information through function returns. +/// +/// define internal i64 @f(i1* %cond) { +/// if: +/// %0 = load i1, i1* %cond +/// br i1 %0, label %then, label %else +/// +/// then: +/// ret i64 0 +/// +/// else: +/// ret i64 1 +/// } +/// +/// For this test, we initially mark "f" executable. The solver computes the +/// return value of the function as overdefined. The test ensures that the +/// solver uses the correct state map (i.e., the one for functions) for the +/// return instructions. Thus, "f" should be overdefined in the function state +/// map, but untracked in the value state map. +TEST_F(SparsePropagationTest, FunctionOverDefined) { + Function *F = + Function::Create(FunctionType::get(Builder.getInt64Ty(), + {Type::getInt1PtrTy(Context)}, false), + GlobalValue::InternalLinkage, "f", &M); + BasicBlock *If = BasicBlock::Create(Context, "if", F); + BasicBlock *Then = BasicBlock::Create(Context, "then", F); + BasicBlock *Else = BasicBlock::Create(Context, "else", F); + F->arg_begin()->setName("cond"); + Builder.SetInsertPoint(If); + LoadInst *Cond = Builder.CreateLoad(F->arg_begin()); + Builder.CreateCondBr(Cond, Then, Else); + Builder.SetInsertPoint(Then); + Builder.CreateRet(Builder.getInt64(0)); + Builder.SetInsertPoint(Else); + Builder.CreateRet(Builder.getInt64(1)); + + Solver.MarkBlockExecutable(If); + Solver.Solve(); + + EXPECT_TRUE(Lattice.isUntracked(Solver.getValueState(F))); + EXPECT_TRUE(Lattice.isOverdefined(Solver.getFunctionState(F))); +} + +/// Test that we propagate information through arguments. +/// +/// define internal void @f() { +/// call void @g(i64 0, i64 1) +/// call void @g(i64 1, i64 1) +/// ret void +/// } +/// +/// define internal void @g(i64 %a, i64 %b) { +/// ret void +/// } +/// +/// For this test, we initially mark "f" executable, and the solver discovers +/// "g" because of the calls in "f". The solver computes the state of argument +/// "a" as overdefined and the state of "b" as constant. The test ensures that +/// the solver uses the value state map for the arguments, instead of the +/// function or global variable maps. +/// +/// In addition, this test demonstrates that ComputeInstructionState can alter +/// the state of multiple lattice values, in addition to the one associated +/// with the instruction definition. Each call instruction in this test updates +/// the state of arguments "a" and "b". +TEST_F(SparsePropagationTest, ComputeInstructionState) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create( + FunctionType::get(Builder.getVoidTy(), + {Builder.getInt64Ty(), Builder.getInt64Ty()}, false), + GlobalValue::InternalLinkage, "g", &M); + Argument *A = G->arg_begin(); + Argument *B = std::next(G->arg_begin()); + A->setName("a"); + B->setName("b"); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateCall(G, {Builder.getInt64(0), Builder.getInt64(1)}); + Builder.CreateCall(G, {Builder.getInt64(1), Builder.getInt64(1)}); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.Solve(); + + EXPECT_TRUE(Lattice.isOverdefined(Solver.getValueState(A))); + EXPECT_TRUE(Lattice.isConstant(Solver.getValueState(B))); +} + +/// Test that we can handle exceptional terminator instructions. +/// +/// declare internal void @p() +/// +/// declare internal void @g() +/// +/// define internal void @f() personality i8* bitcast (void ()* @p to i8*) { +/// entry: +/// invoke void @g() +/// to label %exit unwind label %catch.pad +/// +/// catch.pad: +/// %0 = catchswitch within none [label %catch.body] unwind to caller +/// +/// catch.body: +/// %1 = catchpad within %0 [] +/// catchret from %1 to label %exit +/// +/// exit: +/// ret void +/// } +/// +/// For this test, we initially mark the entry block executable. The solver +/// then discovers the rest of the blocks in the function are executable. +TEST_F(SparsePropagationTest, ExceptionalTerminatorInsts) { + Function *P = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "p", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Constant *C = + ConstantExpr::getCast(Instruction::BitCast, P, Builder.getInt8PtrTy()); + F->setPersonalityFn(C); + BasicBlock *Entry = BasicBlock::Create(Context, "entry", F); + BasicBlock *Pad = BasicBlock::Create(Context, "catch.pad", F); + BasicBlock *Body = BasicBlock::Create(Context, "catch.body", F); + BasicBlock *Exit = BasicBlock::Create(Context, "exit", F); + Builder.SetInsertPoint(Entry); + Builder.CreateInvoke(G, Exit, Pad); + Builder.SetInsertPoint(Pad); + CatchSwitchInst *CatchSwitch = + Builder.CreateCatchSwitch(ConstantTokenNone::get(Context), nullptr, 1); + CatchSwitch->addHandler(Body); + Builder.SetInsertPoint(Body); + CatchPadInst *CatchPad = Builder.CreateCatchPad(CatchSwitch, {}); + Builder.CreateCatchRet(CatchPad, Exit); + Builder.SetInsertPoint(Exit); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(Entry); + Solver.Solve(); + + EXPECT_TRUE(Solver.isBlockExecutable(Pad)); + EXPECT_TRUE(Solver.isBlockExecutable(Body)); + EXPECT_TRUE(Solver.isBlockExecutable(Exit)); +}