Index: include/llvm/Analysis/SparsePropagation.h =================================================================== --- include/llvm/Analysis/SparsePropagation.h +++ include/llvm/Analysis/SparsePropagation.h @@ -23,16 +23,27 @@ namespace llvm { -template class SparseSolver; +/// A template for translating between LLVM Values and LatticeKeys. Clients must +/// provide a specialization of LatticeKeyInfo for their LatticeKey type. +template struct LatticeKeyInfo { + // static inline Value *getValueFromLatticeKey(LatticeKey Key); + // static inline LatticeKey getLatticeKeyFromValue(Value *V); +}; + +template > +class SparseSolver; /// AbstractLatticeFunction - This class is implemented by the dataflow instance /// to specify what the lattice values are and how they handle merges etc. This /// gives the client the power to compute lattice values from instructions, /// constants, etc. The current requirement is that lattice values must be -/// copyable. At the moment, nothing tries to avoid copying. - - -template class AbstractLatticeFunction { +/// copyable. At the moment, nothing tries to avoid copying. Additionally, +/// lattice keys must be able to be used as keys of a mapping data structure. +/// Internally, the generic solver currently uses a DenseMap to map lattice keys +/// to lattice values. If the lattice key is a non-standard type, a +/// specialization of DenseMapInfo must be provided. +template class AbstractLatticeFunction { private: LatticeVal UndefVal, OverdefinedVal, UntrackedVal; @@ -50,35 +61,21 @@ LatticeVal getOverdefinedVal() const { return OverdefinedVal; } LatticeVal getUntrackedVal() const { return UntrackedVal; } - /// IsUntrackedValue - If the specified Value is something that is obviously - /// uninteresting to the analysis (and would always return UntrackedVal), - /// this function can return true to avoid pointless work. - virtual bool IsUntrackedValue(Value *V) { return false; } + /// IsUntrackedValue - If the specified LatticeKey is obviously uninteresting + /// to the analysis (i.e., it would always return UntrackedVal), this + /// function can return true to avoid pointless work. + virtual bool IsUntrackedValue(LatticeKey Key) { return false; } - /// ComputeConstant - Given a constant value, compute and return a lattice - /// value corresponding to the specified constant. - virtual LatticeVal ComputeConstant(Constant *C) { - return getOverdefinedVal(); // always safe + /// ComputeLatticeVal - Compute and return a LatticeVal corresponding to the + /// given LatticeKey. + virtual LatticeVal ComputeLatticeVal(LatticeKey Key) { + return getOverdefinedVal(); } /// IsSpecialCasedPHI - Given a PHI node, determine whether this PHI node is /// one that the we want to handle through ComputeInstructionState. virtual bool IsSpecialCasedPHI(PHINode *PN) { return false; } - /// GetConstant - If the specified lattice value is representable as an LLVM - /// constant value, return it. Otherwise return null. The returned value - /// must be in the same LLVM type as Val. - virtual Constant *GetConstant(LatticeVal LV, Value *Val, - SparseSolver &SS) { - return nullptr; - } - - /// ComputeArgument - Given a formal argument value, compute and return a - /// lattice value corresponding to the specified argument. - virtual LatticeVal ComputeArgument(Argument *I) { - return getOverdefinedVal(); // always safe - } - /// MergeValues - Compute and return the merge of the two specified lattice /// values. Merging should only move one direction down the lattice to /// guarantee convergence (toward overdefined). @@ -86,27 +83,40 @@ return getOverdefinedVal(); // always safe, never useful. } - /// ComputeInstructionState - Given an instruction and a vector of its operand - /// values, compute the result value of the instruction. - virtual LatticeVal ComputeInstructionState(Instruction &I, - SparseSolver &SS) { - return getOverdefinedVal(); // always safe, never useful. + /// ComputeInstructionState - Compute the LatticeKeys that change as a result + /// of executing instruction \p I. Their associated LatticeVals are store in + /// \p ChangedValues. + virtual void + ComputeInstructionState(Instruction &I, + DenseMap &ChangedValues, + SparseSolver &SS) = 0; + + /// PrintLatticeVal - Render the given LatticeVal to the specified stream. + virtual void PrintLatticeVal(LatticeVal LV, raw_ostream &OS); + + /// PrintLatticeKey - Render the given LatticeKey to the specified stream. + virtual void PrintLatticeKey(LatticeKey Key, raw_ostream &OS); + + /// GetValueFromLatticeVal - If the given LatticeVal is representable as an + /// LLVM value, return it; otherwise, return nullptr. If a type is given, the + /// returned value must have the same type. This function is used by the + /// generic solver in attempting to resolve branch and switch conditions. + virtual Value *GetValueFromLatticeVal(LatticeVal LV, Type *Ty = nullptr) { + return nullptr; } - - /// PrintValue - Render the specified lattice value to the specified stream. - virtual void PrintValue(LatticeVal V, raw_ostream &OS); }; /// SparseSolver - This class is a general purpose solver for Sparse Conditional /// Propagation with a programmable lattice function. -template class SparseSolver { +template +class SparseSolver { /// LatticeFunc - This is the object that knows the lattice and how to /// compute transfer functions. - AbstractLatticeFunction *LatticeFunc; + AbstractLatticeFunction *LatticeFunc; - /// ValueState - Holds the lattice state associated with LLVM values. - DenseMap ValueState; + /// ValueState - Holds the LatticeVals associated with LatticeKeys. + DenseMap ValueState; /// BBExecutable - Holds the basic blocks that are executable. SmallPtrSet BBExecutable; @@ -124,28 +134,29 @@ std::set KnownFeasibleEdges; public: - explicit SparseSolver(AbstractLatticeFunction *Lattice) + explicit SparseSolver( + AbstractLatticeFunction *Lattice) : LatticeFunc(Lattice) {} SparseSolver(const SparseSolver &) = delete; SparseSolver &operator=(const SparseSolver &) = delete; /// Solve - Solve for constants and executable blocks. - void Solve(Function &F); + void Solve(); - void Print(Function &F, raw_ostream &OS) const; + void Print(raw_ostream &OS) const; /// getExistingValueState - Return the LatticeVal object corresponding to the /// given value from the ValueState map. If the value is not in the map, /// UntrackedVal is returned, unlike the getValueState method. - LatticeVal getExistingValueState(Value *V) const { - auto I = ValueState.find(V); + LatticeVal getExistingValueState(LatticeKey Key) const { + auto I = ValueState.find(Key); return I != ValueState.end() ? I->second : LatticeFunc->getUntrackedVal(); } /// getValueState - Return the LatticeVal object corresponding to the given /// value from the ValueState map. If the value is not in the map, its state /// is initialized. - LatticeVal getValueState(Value *V); + LatticeVal getValueState(LatticeKey Key); /// isEdgeFeasible - Return true if the control flow edge from the 'From' /// basic block to the 'To' basic block is currently feasible. If @@ -162,15 +173,16 @@ return BBExecutable.count(BB); } -private: - /// UpdateState - When the state for some instruction is potentially updated, - /// this function notices and adds I to the worklist if needed. - void UpdateState(Instruction &Inst, LatticeVal V); - /// MarkBlockExecutable - This method can be used by clients to mark all of /// the blocks that are known to be intrinsically live in the processed unit. void MarkBlockExecutable(BasicBlock *BB); +private: + /// UpdateState - When the state of some LatticeKey is potentially updated to + /// the given LatticeVal, this function notices and adds the LLVM value + /// corresponding the key to the work list, if needed. + void UpdateState(LatticeKey Key, LatticeVal LV); + /// markEdgeExecutable - Mark a basic block as executable, adding it to the BB /// work list if it is not already executable. void markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest); @@ -189,9 +201,9 @@ // AbstractLatticeFunction Implementation //===----------------------------------------------------------------------===// -template -void AbstractLatticeFunction::PrintValue(LatticeVal V, - raw_ostream &OS) { +template +void AbstractLatticeFunction::PrintLatticeVal( + LatticeVal V, raw_ostream &OS) { if (V == UndefVal) OS << "undefined"; else if (V == OverdefinedVal) @@ -202,57 +214,59 @@ OS << "unknown lattice value"; } +template +void AbstractLatticeFunction::PrintLatticeKey( + LatticeKey Key, raw_ostream &OS) { + OS << "unknown lattice key"; +} + //===----------------------------------------------------------------------===// // SparseSolver Implementation //===----------------------------------------------------------------------===// -template -LatticeVal SparseSolver::getValueState(Value *V) { - auto I = ValueState.find(V); +template +LatticeVal +SparseSolver::getValueState(LatticeKey Key) { + auto I = ValueState.find(Key); if (I != ValueState.end()) return I->second; // Common case, in the map - LatticeVal LV; - if (LatticeFunc->IsUntrackedValue(V)) + if (LatticeFunc->IsUntrackedValue(Key)) return LatticeFunc->getUntrackedVal(); - else if (Constant *C = dyn_cast(V)) - LV = LatticeFunc->ComputeConstant(C); - else if (Argument *A = dyn_cast(V)) - LV = LatticeFunc->ComputeArgument(A); - else if (!isa(V)) - // All other non-instructions are overdefined. - LV = LatticeFunc->getOverdefinedVal(); - else - // All instructions are underdefined by default. - LV = LatticeFunc->getUndefVal(); + LatticeVal LV = LatticeFunc->ComputeLatticeVal(Key); // If this value is untracked, don't add it to the map. if (LV == LatticeFunc->getUntrackedVal()) return LV; - return ValueState[V] = LV; + return ValueState[Key] = LV; } -template -void SparseSolver::UpdateState(Instruction &Inst, LatticeVal V) { - auto I = ValueState.find(&Inst); - if (I != ValueState.end() && I->second == V) +template +void SparseSolver::UpdateState(LatticeKey Key, + LatticeVal LV) { + auto I = ValueState.find(Key); + if (I != ValueState.end() && I->second == LV) return; // No change. - // An update. Visit uses of I. - ValueState[&Inst] = V; - ValueWorkList.push_back(&Inst); + // Update the state of the given LatticeKey and add its corresponding LLVM + // value to the work list. + ValueState[Key] = LV; + if (Value *V = KeyInfo::getValueFromLatticeKey(Key)) + ValueWorkList.push_back(V); } -template -void SparseSolver::MarkBlockExecutable(BasicBlock *BB) { +template +void SparseSolver::MarkBlockExecutable( + BasicBlock *BB) { + if (!BBExecutable.insert(BB).second) + return; DEBUG(dbgs() << "Marking Block Executable: " << BB->getName() << "\n"); - BBExecutable.insert(BB); // Basic block is executable! BBWorkList.push_back(BB); // Add the block to the work list! } -template -void SparseSolver::markEdgeExecutable(BasicBlock *Source, - BasicBlock *Dest) { +template +void SparseSolver::markEdgeExecutable( + BasicBlock *Source, BasicBlock *Dest) { if (!KnownFeasibleEdges.insert(Edge(Source, Dest)).second) return; // This edge is already known to be executable! @@ -270,8 +284,8 @@ } } -template -void SparseSolver::getFeasibleSuccessors( +template +void SparseSolver::getFeasibleSuccessors( TerminatorInst &TI, SmallVectorImpl &Succs, bool AggressiveUndef) { Succs.resize(TI.getNumSuccessors()); if (TI.getNumSuccessors() == 0) @@ -285,9 +299,11 @@ LatticeVal BCValue; if (AggressiveUndef) - BCValue = getValueState(BI->getCondition()); + BCValue = + getValueState(KeyInfo::getLatticeKeyFromValue(BI->getCondition())); else - BCValue = getExistingValueState(BI->getCondition()); + BCValue = getExistingValueState( + KeyInfo::getLatticeKeyFromValue(BI->getCondition())); if (BCValue == LatticeFunc->getOverdefinedVal() || BCValue == LatticeFunc->getUntrackedVal()) { @@ -300,7 +316,9 @@ if (BCValue == LatticeFunc->getUndefVal()) return; - Constant *C = LatticeFunc->GetConstant(BCValue, BI->getCondition(), *this); + Constant *C = + dyn_cast_or_null(LatticeFunc->GetValueFromLatticeVal( + BCValue, BI->getCondition()->getType())); if (!C || !isa(C)) { // Non-constant values can go either way. Succs[0] = Succs[1] = true; @@ -312,10 +330,8 @@ return; } - if (isa(TI)) { - // Invoke instructions successors are always executable. - // TODO: Could ask the lattice function if the value can throw. - Succs[0] = Succs[1] = true; + if (TI.isExceptional()) { + Succs.assign(Succs.size(), true); return; } @@ -327,9 +343,10 @@ SwitchInst &SI = cast(TI); LatticeVal SCValue; if (AggressiveUndef) - SCValue = getValueState(SI.getCondition()); + SCValue = getValueState(KeyInfo::getLatticeKeyFromValue(SI.getCondition())); else - SCValue = getExistingValueState(SI.getCondition()); + SCValue = getExistingValueState( + KeyInfo::getLatticeKeyFromValue(SI.getCondition())); if (SCValue == LatticeFunc->getOverdefinedVal() || SCValue == LatticeFunc->getUntrackedVal()) { @@ -342,7 +359,8 @@ if (SCValue == LatticeFunc->getUndefVal()) return; - Constant *C = LatticeFunc->GetConstant(SCValue, SI.getCondition(), *this); + Constant *C = dyn_cast_or_null(LatticeFunc->GetValueFromLatticeVal( + SCValue, SI.getCondition()->getType())); if (!C || !isa(C)) { // All destinations are executable! Succs.assign(TI.getNumSuccessors(), true); @@ -352,9 +370,9 @@ Succs[Case.getSuccessorIndex()] = true; } -template -bool SparseSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To, - bool AggressiveUndef) { +template +bool SparseSolver::isEdgeFeasible( + BasicBlock *From, BasicBlock *To, bool AggressiveUndef) { SmallVector SuccFeasible; TerminatorInst *TI = From->getTerminator(); getFeasibleSuccessors(*TI, SuccFeasible, AggressiveUndef); @@ -366,8 +384,9 @@ return false; } -template -void SparseSolver::visitTerminatorInst(TerminatorInst &TI) { +template +void SparseSolver::visitTerminatorInst( + TerminatorInst &TI) { SmallVector SuccFeasible; getFeasibleSuccessors(TI, SuccFeasible, true); @@ -379,19 +398,22 @@ markEdgeExecutable(BB, TI.getSuccessor(i)); } -template -void SparseSolver::visitPHINode(PHINode &PN) { +template +void SparseSolver::visitPHINode(PHINode &PN) { // The lattice function may store more information on a PHINode than could be // computed from its incoming values. For example, SSI form stores its sigma // functions as PHINodes with a single incoming value. if (LatticeFunc->IsSpecialCasedPHI(&PN)) { - LatticeVal IV = LatticeFunc->ComputeInstructionState(PN, *this); - if (IV != LatticeFunc->getUntrackedVal()) - UpdateState(PN, IV); + DenseMap ChangedValues; + LatticeFunc->ComputeInstructionState(PN, ChangedValues, *this); + for (auto &ChangedValue : ChangedValues) + if (ChangedValue.second != LatticeFunc->getUntrackedVal()) + UpdateState(ChangedValue.first, ChangedValue.second); return; } - LatticeVal PNIV = getValueState(&PN); + LatticeKey Key = KeyInfo::getLatticeKeyFromValue(&PN); + LatticeVal PNIV = getValueState(Key); LatticeVal Overdefined = LatticeFunc->getOverdefinedVal(); // If this value is already overdefined (common) just return. @@ -401,7 +423,7 @@ // Super-extra-high-degree PHI nodes are unlikely to ever be interesting, // and slow us down a lot. Just mark them overdefined. if (PN.getNumIncomingValues() > 64) { - UpdateState(PN, Overdefined); + UpdateState(Key, Overdefined); return; } @@ -414,7 +436,8 @@ continue; // Merge in this value. - LatticeVal OpVal = getValueState(PN.getIncomingValue(i)); + LatticeVal OpVal = + getValueState(KeyInfo::getLatticeKeyFromValue(PN.getIncomingValue(i))); if (OpVal != PNIV) PNIV = LatticeFunc->MergeValues(PNIV, OpVal); @@ -423,11 +446,11 @@ } // Update the PHI with the compute value, which is the merge of the inputs. - UpdateState(PN, PNIV); + UpdateState(Key, PNIV); } -template -void SparseSolver::visitInst(Instruction &I) { +template +void SparseSolver::visitInst(Instruction &I) { // PHIs are handled by the propagation logic, they are never passed into the // transfer functions. if (PHINode *PN = dyn_cast(&I)) @@ -435,17 +458,18 @@ // Otherwise, ask the transfer function what the result is. If this is // something that we care about, remember it. - LatticeVal IV = LatticeFunc->ComputeInstructionState(I, *this); - if (IV != LatticeFunc->getUntrackedVal()) - UpdateState(I, IV); + DenseMap ChangedValues; + LatticeFunc->ComputeInstructionState(I, ChangedValues, *this); + for (auto &ChangedValue : ChangedValues) + if (ChangedValue.second != LatticeFunc->getUntrackedVal()) + UpdateState(ChangedValue.first, ChangedValue.second); if (TerminatorInst *TI = dyn_cast(&I)) visitTerminatorInst(*TI); } -template void SparseSolver::Solve(Function &F) { - MarkBlockExecutable(&F.getEntryBlock()); - +template +void SparseSolver::Solve() { // Process the work lists until they are empty! while (!BBWorkList.empty() || !ValueWorkList.empty()) { // Process the value work list. @@ -478,22 +502,24 @@ } } -template -void SparseSolver::Print(Function &F, raw_ostream &OS) const { - OS << "\nFUNCTION: " << F.getName() << "\n"; - for (auto &BB : F) { - if (!BBExecutable.count(&BB)) - OS << "INFEASIBLE: "; - OS << "\t"; - if (BB.hasName()) - OS << BB.getName() << ":\n"; - else - OS << "; anon bb\n"; - for (auto &I : BB) { - LatticeFunc->PrintValue(getExistingValueState(&I), OS); - OS << I << "\n"; - } +template +void SparseSolver::Print( + raw_ostream &OS) const { + if (ValueState.empty()) + return; + LatticeKey Key; + LatticeVal LV; + + OS << "ValueState:\n"; + for (auto &Entry : ValueState) { + std::tie(Key, LV) = Entry; + if (LV == LatticeFunc->getUntrackedVal()) + continue; + OS << "\t"; + LatticeFunc->PrintLatticeVal(LV, OS); + OS << ": "; + LatticeFunc->PrintLatticeKey(Key, OS); OS << "\n"; } } Index: unittests/Analysis/CMakeLists.txt =================================================================== --- unittests/Analysis/CMakeLists.txt +++ unittests/Analysis/CMakeLists.txt @@ -22,6 +22,7 @@ OrderedBasicBlockTest.cpp ProfileSummaryInfoTest.cpp ScalarEvolutionTest.cpp + SparsePropagation.cpp TargetLibraryInfoTest.cpp TBAATest.cpp UnrollAnalyzer.cpp Index: unittests/Analysis/SparsePropagation.cpp =================================================================== --- /dev/null +++ unittests/Analysis/SparsePropagation.cpp @@ -0,0 +1,559 @@ +//===- SparsePropagation.cpp - Unit tests for the generic solver ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/SparsePropagation.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/IRBuilder.h" +#include "gtest/gtest.h" +using namespace llvm; + +namespace { +/// To enable interprocedural analysis, we assign LLVM values to the following +/// groups. The register group represents SSA registers, the return group +/// represents the return values of functions, and the memory group represents +/// in-memory values. An LLVM Value can technically be in more than one group. +/// It's necessary to distinguish these groups so we can, for example, track a +/// global variable separately from the value stored at its location. +enum class IPOGrouping { Register, Return, Memory }; + +/// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings. +/// The PointerIntPair header provides a DenseMapInfo specialization, so using +/// these as LatticeKeys is fine. +using TestLatticeKey = PointerIntPair; +} // namespace + +namespace llvm { +/// A specialization of LatticeKeyInfo for TestLatticeKeys. The generic solver +/// must translate between LatticeKeys and LLVM Values when adding Values to +/// its work list and inspecting the state of control-flow related values. +template <> struct LatticeKeyInfo { + static inline Value *getValueFromLatticeKey(TestLatticeKey Key) { + return Key.getPointer(); + } + static inline TestLatticeKey getLatticeKeyFromValue(Value *V) { + return TestLatticeKey(V, IPOGrouping::Register); + } +}; +} // namespace llvm + +namespace { +/// This class defines a simple test lattice value that could be used for +/// solving problems similar to constant propagation. The value is maintained +/// as a PointerIntPair. +class TestLatticeVal { +public: + /// The states of the lattices value. Only the ConstantVal state is + /// interesting; the rest are special states used by the generic solver. The + /// UntrackedVal state differs from the other three in that the generic + /// solver uses it to avoid doing unnecessary work. In particular, when a + /// value moves to the UntrackedVal state, it's users are not notified. + enum TestLatticeStateTy { + UndefinedVal, + ConstantVal, + OverdefinedVal, + UntrackedVal + }; + + TestLatticeVal() : LatticeVal(nullptr, UndefinedVal) {} + TestLatticeVal(Constant *C, TestLatticeStateTy State) + : LatticeVal(C, State) {} + + /// Return true if this lattice value is in the Constant state. This is used + /// for checking the solver results. + bool isConstant() const { return LatticeVal.getInt() == ConstantVal; } + + /// Return true if this lattice value is in the Overdefined state. This is + /// used for checking the solver results. + bool isOverdefined() const { return LatticeVal.getInt() == OverdefinedVal; } + + bool operator==(const TestLatticeVal &RHS) const { + return LatticeVal == RHS.LatticeVal; + } + + bool operator!=(const TestLatticeVal &RHS) const { + return LatticeVal != RHS.LatticeVal; + } + +private: + /// A simple lattice value type for problems similar to constant propagation. + /// It holds the constant value and the lattice state. + PointerIntPair LatticeVal; +}; + +/// This class defines a simple test lattice function that could be used for +/// solving problems similar to constant propagation. The test lattice differs +/// from a "real" lattice in a few ways. First, it initializes all return +/// values, values stored in global variables, and arguments in the undefined +/// state. This means that there are no limitations on what we can track +/// interprocedurally. For simplicity, all global values in the tests will be +/// given internal linkage, since this is not something this lattice function +/// tracks. Second, it only handles the few instructions necessary for the +/// tests. +class TestLatticeFunc + : public AbstractLatticeFunction { +public: + /// Construct a new test lattice function with special values for the + /// Undefined, Overdefined, and Untracked states. + TestLatticeFunc() + : AbstractLatticeFunction( + TestLatticeVal(nullptr, TestLatticeVal::UndefinedVal), + TestLatticeVal(nullptr, TestLatticeVal::OverdefinedVal), + TestLatticeVal(nullptr, TestLatticeVal::UntrackedVal)) {} + + /// Compute and return a TestLatticeVal for the given TestLatticeKey. For the + /// test analysis, a LatticeKey will begin in the undefined state, unless it + /// represents an LLVM Constant in the register grouping. + TestLatticeVal ComputeLatticeVal(TestLatticeKey Key) override { + if (Key.getInt() == IPOGrouping::Register) + if (auto *C = dyn_cast(Key.getPointer())) + return TestLatticeVal(C, TestLatticeVal::ConstantVal); + return getUndefVal(); + } + + /// Merge the two given lattice values. This merge should be equivalent to + /// what is done for constant propagation. That is, the resulting lattice + /// value is constant only if the two given lattice values are constant and + /// hold the same value. + TestLatticeVal MergeValues(TestLatticeVal X, TestLatticeVal Y) override { + if (X == getUntrackedVal() || Y == getUntrackedVal()) + return getUntrackedVal(); + if (X == getOverdefinedVal() || Y == getOverdefinedVal()) + return getOverdefinedVal(); + if (X == getUndefVal() && Y == getUndefVal()) + return getUndefVal(); + if (X == getUndefVal()) + return Y; + if (Y == getUndefVal()) + return X; + if (X == Y) + return X; + return getOverdefinedVal(); + } + + /// Compute the lattice values that change as a result of executing the given + /// instruction. We only handle the few instructions needed for the tests. + void ComputeInstructionState( + Instruction &I, DenseMap &ChangedValues, + SparseSolver &SS) override { + switch (I.getOpcode()) { + case Instruction::Call: + return visitCallSite(cast(&I), ChangedValues, SS); + case Instruction::Ret: + return visitReturn(*cast(&I), ChangedValues, SS); + case Instruction::Store: + return visitStore(*cast(&I), ChangedValues, SS); + default: + return visitInst(I, ChangedValues, SS); + } + } + +private: + /// Handle call sites. The state of a called function's argument is the merge + /// of the current formal argument state with the call site's corresponding + /// actual argument state. The call site state is the merge of the call site + /// state with the returned value state of the called function. + void visitCallSite(CallSite CS, + DenseMap &ChangedValues, + SparseSolver &SS) { + Function *F = CS.getCalledFunction(); + Instruction *I = CS.getInstruction(); + auto RegI = TestLatticeKey(I, IPOGrouping::Register); + if (!F) { + ChangedValues[RegI] = getOverdefinedVal(); + return; + } + SS.MarkBlockExecutable(&F->front()); + for (Argument &A : F->args()) { + auto RegFormal = TestLatticeKey(&A, IPOGrouping::Register); + auto RegActual = + TestLatticeKey(CS.getArgument(A.getArgNo()), IPOGrouping::Register); + ChangedValues[RegFormal] = + MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual)); + } + auto RetF = TestLatticeKey(F, IPOGrouping::Return); + ChangedValues[RegI] = + MergeValues(SS.getValueState(RegI), SS.getValueState(RetF)); + } + + /// Handle return instructions. The function's return state is the merge of + /// the returned value state and the function's current return state. + void visitReturn(ReturnInst &I, + DenseMap &ChangedValues, + SparseSolver &SS) { + Function *F = I.getParent()->getParent(); + if (F->getReturnType()->isVoidTy()) + return; + auto RegR = TestLatticeKey(I.getReturnValue(), IPOGrouping::Register); + auto RetF = TestLatticeKey(F, IPOGrouping::Return); + ChangedValues[RetF] = + MergeValues(SS.getValueState(RegR), SS.getValueState(RetF)); + } + + /// Handle store instructions. If the pointer operand of the store is a + /// global variable, we attempt to track the value. The global variable state + /// is the merge of the stored value state with the current global variable + /// state. + void visitStore(StoreInst &I, + DenseMap &ChangedValues, + SparseSolver &SS) { + auto *GV = dyn_cast(I.getPointerOperand()); + if (!GV) + return; + auto RegVal = TestLatticeKey(I.getValueOperand(), IPOGrouping::Register); + auto MemPtr = TestLatticeKey(GV, IPOGrouping::Memory); + ChangedValues[MemPtr] = + MergeValues(SS.getValueState(RegVal), SS.getValueState(MemPtr)); + } + + /// Handle all other instructions. All other instructions are marked + /// overdefined. + void visitInst(Instruction &I, + DenseMap &ChangedValues, + SparseSolver &SS) { + auto RegI = TestLatticeKey(&I, IPOGrouping::Register); + ChangedValues[RegI] = getOverdefinedVal(); + } +}; + +/// This class defines the common data used for all of the tests. The tests +/// should add code to the module and then run the solver. +class SparsePropagationTest : public testing::Test { +protected: + LLVMContext Context; + Module M; + IRBuilder<> Builder; + TestLatticeFunc Lattice; + SparseSolver Solver; + +public: + SparsePropagationTest() + : M("", Context), Builder(Context), Solver(&Lattice) {} +}; +} // namespace + +/// Test that we mark discovered functions executable. +/// +/// define internal void @f() { +/// call void @g() +/// ret void +/// } +/// +/// define internal void @g() { +/// call void @f() +/// ret void +/// } +/// +/// For this test, we initially mark "f" executable, and the solver discovers +/// "g" because of the call in "f". The mutually recursive call in "g" also +/// tests that we don't add a block to the basic block work list if it is +/// already executable. Doing so would put the solver into an infinite loop. +TEST_F(SparsePropagationTest, MarkBlockExecutable) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateCall(G); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateCall(F); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.Solve(); + + EXPECT_TRUE(Solver.isBlockExecutable(GEntry)); +} + +/// Test that we propagate information through global variables. +/// +/// @gv = internal global i64 +/// +/// define internal void @f() { +/// store i64 1, i64* @gv +/// ret void +/// } +/// +/// define internal void @g() { +/// store i64 1, i64* @gv +/// ret void +/// } +/// +/// For this test, we initially mark both "f" and "g" executable. The solver +/// computes the lattice state of the global variable as constant. The test +/// ensures that the solver uses the correct state map (i.e., the one for +/// in-memory values) for the store instruction. Thus, "gv" should be constant +/// in the global variable state map but untracked in the value state map. +TEST_F(SparsePropagationTest, GlobalVariableConstant) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + GlobalVariable *GV = + new GlobalVariable(M, Builder.getInt64Ty(), false, + GlobalValue::InternalLinkage, nullptr, "gv"); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateStore(Builder.getInt64(1), GV); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateStore(Builder.getInt64(1), GV); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.MarkBlockExecutable(GEntry); + Solver.Solve(); + + auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory); + EXPECT_TRUE(Solver.getExistingValueState(MemGV).isConstant()); +} + +/// Test that we propagate information through global variables. +/// +/// @gv = internal global i64 +/// +/// define internal void @f() { +/// store i64 0, i64* @gv +/// ret void +/// } +/// +/// define internal void @g() { +/// store i64 1, i64* @gv +/// ret void +/// } +/// +/// For this test, we initially mark both "f" and "g" executable. The solver +/// computes the lattice state of the global variable as overdefined. The test +/// ensures that the solver uses the correct state map (i.e., the one for +/// in-memory values) for the store instruction. Thus, "gv" should be +/// overdefined in the global variable state map but untracked in the value +/// state map. +TEST_F(SparsePropagationTest, GlobalVariableOverDefined) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + GlobalVariable *GV = + new GlobalVariable(M, Builder.getInt64Ty(), false, + GlobalValue::InternalLinkage, nullptr, "gv"); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateStore(Builder.getInt64(0), GV); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateStore(Builder.getInt64(1), GV); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.MarkBlockExecutable(GEntry); + Solver.Solve(); + + auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory); + EXPECT_TRUE(Solver.getExistingValueState(MemGV).isOverdefined()); +} + +/// Test that we propagate information through function returns. +/// +/// define internal i64 @f(i1* %cond) { +/// if: +/// %0 = load i1, i1* %cond +/// br i1 %0, label %then, label %else +/// +/// then: +/// ret i64 1 +/// +/// else: +/// ret i64 1 +/// } +/// +/// For this test, we initially mark "f" executable. The solver computes the +/// return value of the function as constant. The test ensures that the solver +/// uses the correct state map (i.e., the one for functions) for the return +/// instructions. Thus, "f" should be constant in the function state map, but +/// untracked in the value state map. +TEST_F(SparsePropagationTest, FunctionDefined) { + Function *F = + Function::Create(FunctionType::get(Builder.getInt64Ty(), + {Type::getInt1PtrTy(Context)}, false), + GlobalValue::InternalLinkage, "f", &M); + BasicBlock *If = BasicBlock::Create(Context, "if", F); + BasicBlock *Then = BasicBlock::Create(Context, "then", F); + BasicBlock *Else = BasicBlock::Create(Context, "else", F); + F->arg_begin()->setName("cond"); + Builder.SetInsertPoint(If); + LoadInst *Cond = Builder.CreateLoad(F->arg_begin()); + Builder.CreateCondBr(Cond, Then, Else); + Builder.SetInsertPoint(Then); + Builder.CreateRet(Builder.getInt64(1)); + Builder.SetInsertPoint(Else); + Builder.CreateRet(Builder.getInt64(1)); + + Solver.MarkBlockExecutable(If); + Solver.Solve(); + + auto RetF = TestLatticeKey(F, IPOGrouping::Return); + EXPECT_TRUE(Solver.getExistingValueState(RetF).isConstant()); +} + +/// Test that we propagate information through function returns. +/// +/// define internal i64 @f(i1* %cond) { +/// if: +/// %0 = load i1, i1* %cond +/// br i1 %0, label %then, label %else +/// +/// then: +/// ret i64 0 +/// +/// else: +/// ret i64 1 +/// } +/// +/// For this test, we initially mark "f" executable. The solver computes the +/// return value of the function as overdefined. The test ensures that the +/// solver uses the correct state map (i.e., the one for functions) for the +/// return instructions. Thus, "f" should be overdefined in the function state +/// map, but untracked in the value state map. +TEST_F(SparsePropagationTest, FunctionOverDefined) { + Function *F = + Function::Create(FunctionType::get(Builder.getInt64Ty(), + {Type::getInt1PtrTy(Context)}, false), + GlobalValue::InternalLinkage, "f", &M); + BasicBlock *If = BasicBlock::Create(Context, "if", F); + BasicBlock *Then = BasicBlock::Create(Context, "then", F); + BasicBlock *Else = BasicBlock::Create(Context, "else", F); + F->arg_begin()->setName("cond"); + Builder.SetInsertPoint(If); + LoadInst *Cond = Builder.CreateLoad(F->arg_begin()); + Builder.CreateCondBr(Cond, Then, Else); + Builder.SetInsertPoint(Then); + Builder.CreateRet(Builder.getInt64(0)); + Builder.SetInsertPoint(Else); + Builder.CreateRet(Builder.getInt64(1)); + + Solver.MarkBlockExecutable(If); + Solver.Solve(); + + auto RetF = TestLatticeKey(F, IPOGrouping::Return); + EXPECT_TRUE(Solver.getExistingValueState(RetF).isOverdefined()); +} + +/// Test that we propagate information through arguments. +/// +/// define internal void @f() { +/// call void @g(i64 0, i64 1) +/// call void @g(i64 1, i64 1) +/// ret void +/// } +/// +/// define internal void @g(i64 %a, i64 %b) { +/// ret void +/// } +/// +/// For this test, we initially mark "f" executable, and the solver discovers +/// "g" because of the calls in "f". The solver computes the state of argument +/// "a" as overdefined and the state of "b" as constant. The test ensures that +/// the solver uses the value state map for the arguments, instead of the +/// function or global variable maps. +/// +/// In addition, this test demonstrates that ComputeInstructionState can alter +/// the state of multiple lattice values, in addition to the one associated +/// with the instruction definition. Each call instruction in this test updates +/// the state of arguments "a" and "b". +TEST_F(SparsePropagationTest, ComputeInstructionState) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create( + FunctionType::get(Builder.getVoidTy(), + {Builder.getInt64Ty(), Builder.getInt64Ty()}, false), + GlobalValue::InternalLinkage, "g", &M); + Argument *A = G->arg_begin(); + Argument *B = std::next(G->arg_begin()); + A->setName("a"); + B->setName("b"); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateCall(G, {Builder.getInt64(0), Builder.getInt64(1)}); + Builder.CreateCall(G, {Builder.getInt64(1), Builder.getInt64(1)}); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.Solve(); + + auto RegA = TestLatticeKey(A, IPOGrouping::Register); + auto RegB = TestLatticeKey(B, IPOGrouping::Register); + EXPECT_TRUE(Solver.getExistingValueState(RegA).isOverdefined()); + EXPECT_TRUE(Solver.getExistingValueState(RegB).isConstant()); +} + +/// Test that we can handle exceptional terminator instructions. +/// +/// declare internal void @p() +/// +/// declare internal void @g() +/// +/// define internal void @f() personality i8* bitcast (void ()* @p to i8*) { +/// entry: +/// invoke void @g() +/// to label %exit unwind label %catch.pad +/// +/// catch.pad: +/// %0 = catchswitch within none [label %catch.body] unwind to caller +/// +/// catch.body: +/// %1 = catchpad within %0 [] +/// catchret from %1 to label %exit +/// +/// exit: +/// ret void +/// } +/// +/// For this test, we initially mark the entry block executable. The solver +/// then discovers the rest of the blocks in the function are executable. +TEST_F(SparsePropagationTest, ExceptionalTerminatorInsts) { + Function *P = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "p", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Constant *C = + ConstantExpr::getCast(Instruction::BitCast, P, Builder.getInt8PtrTy()); + F->setPersonalityFn(C); + BasicBlock *Entry = BasicBlock::Create(Context, "entry", F); + BasicBlock *Pad = BasicBlock::Create(Context, "catch.pad", F); + BasicBlock *Body = BasicBlock::Create(Context, "catch.body", F); + BasicBlock *Exit = BasicBlock::Create(Context, "exit", F); + Builder.SetInsertPoint(Entry); + Builder.CreateInvoke(G, Exit, Pad); + Builder.SetInsertPoint(Pad); + CatchSwitchInst *CatchSwitch = + Builder.CreateCatchSwitch(ConstantTokenNone::get(Context), nullptr, 1); + CatchSwitch->addHandler(Body); + Builder.SetInsertPoint(Body); + CatchPadInst *CatchPad = Builder.CreateCatchPad(CatchSwitch, {}); + Builder.CreateCatchRet(CatchPad, Exit); + Builder.SetInsertPoint(Exit); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(Entry); + Solver.Solve(); + + EXPECT_TRUE(Solver.isBlockExecutable(Pad)); + EXPECT_TRUE(Solver.isBlockExecutable(Body)); + EXPECT_TRUE(Solver.isBlockExecutable(Exit)); +}