Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -117,6 +117,7 @@ void initializeEarlyCSELegacyPassPass(PassRegistry &); void initializeEarlyIfConverterPass(PassRegistry&); void initializeEdgeBundlesPass(PassRegistry&); +void initializeComprehensiveStaticInstrumentationPass(PassRegistry&); void initializeEfficiencySanitizerPass(PassRegistry&); void initializeEliminateAvailableExternallyLegacyPassPass(PassRegistry &); void initializeExpandISelPseudosPass(PassRegistry&); Index: include/llvm/Transforms/Instrumentation.h =================================================================== --- include/llvm/Transforms/Instrumentation.h +++ include/llvm/Transforms/Instrumentation.h @@ -163,6 +163,9 @@ } #endif +// Insert ComprehensiveStaticInstrumentation instrumentation +ModulePass *createComprehensiveStaticInstrumentationPass(); + // BoundsChecking - This pass instruments the code to perform run-time bounds // checking on loads, stores, and other memory intrinsics. FunctionPass *createBoundsCheckingPass(); Index: include/llvm/Transforms/Utils/ModuleUtils.h =================================================================== --- include/llvm/Transforms/Utils/ModuleUtils.h +++ include/llvm/Transforms/Utils/ModuleUtils.h @@ -40,6 +40,13 @@ void appendToGlobalDtors(Module &M, Function *F, int Priority, Constant *Data = nullptr); +// Validate the result of Module::getOrInsertFunction called for an +// interface function of ComprehensiveStaticInstrumentation. If the +// instrumented module defines a function with the same name, their +// prototypes must match, otherwise getOrInsertFunction returns a +// bitcast. +Function *checkCsiInterfaceFunction(Constant *FuncOrBitcast); + // Validate the result of Module::getOrInsertFunction called for an interface // function of given sanitizer. If the instrumented module defines a function // with the same name, their prototypes must match, otherwise Index: lib/Transforms/Instrumentation/CMakeLists.txt =================================================================== --- lib/Transforms/Instrumentation/CMakeLists.txt +++ lib/Transforms/Instrumentation/CMakeLists.txt @@ -1,6 +1,7 @@ add_llvm_library(LLVMInstrumentation AddressSanitizer.cpp BoundsChecking.cpp + ComprehensiveStaticInstrumentation.cpp DataFlowSanitizer.cpp GCOVProfiling.cpp MemorySanitizer.cpp Index: lib/Transforms/Instrumentation/ComprehensiveStaticInstrumentation.cpp =================================================================== --- /dev/null +++ lib/Transforms/Instrumentation/ComprehensiveStaticInstrumentation.cpp @@ -0,0 +1,821 @@ +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +using namespace llvm; + +namespace { +const char *const CsiRtUnitInitName = "__csirt_unit_init"; +const char *const CsiRtUnitCtorName = "csirt.unit_ctor"; +const char *const CsiFunctionBaseIdName = "__csi_unit_func_base_id"; +const char *const CsiFunctionExitBaseIdName = "__csi_unit_func_exit_base_id"; +const char *const CsiBasicBlockBaseIdName = "__csi_unit_bb_base_id"; +const char *const CsiCallsiteBaseIdName = "__csi_unit_callsite_base_id"; +const char *const CsiLoadBaseIdName = "__csi_unit_load_base_id"; +const char *const CsiStoreBaseIdName = "__csi_unit_store_base_id"; +const char *const CsiUnitFedTableName = "__csi_unit_fed_table"; +const char *const CsiFuncIdVariablePrefix = "__csi_func_id_"; +const char *const CsiUnitFedTableArrayName = "__csi_unit_fed_tables"; +const char *const CsiInitCallsiteToFunctionName = + "__csi_init_callsite_to_function"; + +const int64_t CsiCallsiteUnknownTargetId = -1; +// See llvm/tools/clang/lib/CodeGen/CodeGenModule.h: +const int CsiUnitCtorPriority = 65535; + +/// Return the first DILocation in the given basic block, or nullptr +/// if none exists. +DILocation *getFirstDebugLoc(BasicBlock &BB) { + for (Instruction &Inst : BB) { + if (DILocation *Loc = Inst.getDebugLoc()) { + return Loc; + } + } + return nullptr; +} + +/// Set DebugLoc on the call instruction to a CSI hook, based on the +/// debug information of the instrumented instruction. +void setInstrumentationDebugLoc(Instruction *Instrumented, Instruction *Call) { + DISubprogram *Subprog = Instrumented->getFunction()->getSubprogram(); + if (Subprog) { + if (Instrumented->getDebugLoc()) { + Call->setDebugLoc(Instrumented->getDebugLoc()); + } else { + LLVMContext &C = Instrumented->getFunction()->getParent()->getContext(); + Call->setDebugLoc(DILocation::get(C, 0, 0, Subprog)); + } + } +} + +/// Set DebugLoc on the call instruction to a CSI hook, based on the +/// debug information of the instrumented instruction. +void setInstrumentationDebugLoc(BasicBlock &Instrumented, Instruction *Call) { + DISubprogram *Subprog = Instrumented.getParent()->getSubprogram(); + if (Subprog) { + LLVMContext &C = Instrumented.getParent()->getParent()->getContext(); + Call->setDebugLoc(DILocation::get(C, 0, 0, Subprog)); + } +} + +/// Set DebugLoc on the call instruction to a CSI hook, based on the +/// debug information of the instrumented instruction. +void setInstrumentationDebugLoc(Function &Instrumented, Instruction *Call) { + DISubprogram *Subprog = Instrumented.getSubprogram(); + if (Subprog) { + LLVMContext &C = Instrumented.getParent()->getContext(); + Call->setDebugLoc(DILocation::get(C, 0, 0, Subprog)); + } +} + +/// Maintains a mapping from CSI ID to front-end data for that ID. +/// +/// The front-end data currently is the source location that a given +/// CSI ID corresponds to. +class FrontEndDataTable { +public: + FrontEndDataTable() : BaseId(nullptr), IdCounter(0) {} + FrontEndDataTable(Module &M, StringRef BaseIdName); + + /// The number of entries in this FED table + uint64_t size() const { return LocalIdToSourceLocationMap.size(); } + + /// The GlobalVariable holding the base ID for this FED table. + GlobalVariable *baseId() const { return BaseId; } + + /// Add the given Function to this FED table. + /// \returns The local ID of the Function. + uint64_t add(Function &F); + + /// Add the given BasicBlock to this FED table. + /// \returns The local ID of the BasicBlock. + uint64_t add(BasicBlock &BB); + + /// Add the given Instruction to this FED table. + /// \returns The local ID of the Instruction. + uint64_t add(Instruction &I); + + /// Get the local ID of the given Value. + uint64_t getId(Value *V); + + /// Converts a local to global ID conversion. + /// + /// This is done by using the given IRBuilder to insert a load to + /// the base ID global variable followed by an add of the base value + /// and the local ID. + /// + /// \returns A Value holding the global ID corresponding to the + /// given local ID. + Value *localToGlobalId(uint64_t LocalId, IRBuilder<> IRB) const; + + /// Get the Type for a pointer to a FED table entry. + /// + /// A FED table entry is just a source location. + static PointerType *getPointerType(LLVMContext &C); + + /// Insert this FED table into the given Module. + /// + /// The FED table is constructed as a ConstantArray indexed by local + /// IDs. The runtime is responsible for performing the mapping that + /// allows the table to be indexed by global ID. + Constant *insertIntoModule(Module &M) const; + +private: + struct SourceLocation { + int32_t Line; + StringRef File; + }; + + /// The GlobalVariable holding the base ID for this FED table. + GlobalVariable *BaseId; + /// Counter of local IDs used so far. + uint64_t IdCounter; + /// Map of local ID to SourceLocation. + std::map LocalIdToSourceLocationMap; + /// Map of Value to Local ID. + std::map ValueToLocalIdMap; + + /// Create a struct type to match the "struct SourceLocation" type. + static StructType *getSourceLocStructType(LLVMContext &C); + + /// Append the debug information to the table, assigning it the next + /// available ID. + /// + /// \returns The local ID of the appended information. + /// @{ + uint64_t add(DILocation *Loc); + uint64_t add(DISubprogram *Subprog); + /// @} + + /// Append the line and file information to the table, assigning it + /// the next available ID. + /// + /// \returns The new local ID of the DILocation. + uint64_t add(int32_t Line, StringRef File); +}; + +/// The Comprehensive Static Instrumentation pass. +/// Inserts calls to user-defined hooks at predefined points in the IR. +struct ComprehensiveStaticInstrumentation : public ModulePass { + static char ID; + + ComprehensiveStaticInstrumentation() : ModulePass(ID) {} + const char *getPassName() const override; + bool runOnModule(Module &M) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; + +private: + /// Initialize llvm::Functions for the CSI hooks. + /// @{ + void initializeLoadStoreHooks(Module &M); + void initializeFuncHooks(Module &M); + void initializeBasicBlockHooks(Module &M); + void initializeCallsiteHooks(Module &M); + /// @} + + /// Initialize the front-end data table structures. + void initializeFEDTables(Module &M); + + /// Generate a function that stores global function IDs into a set + /// of externally-visible global variables. + void generateInitCallsiteToFunction(Module &M); + + /// Get the number of bytes accessed via the given address. + int getNumBytesAccessed(Value *Addr, const DataLayout &DL); + + /// Compute CSI properties on the given ordered list of loads and stores. + void computeLoadAndStoreProperties( + SmallVectorImpl> + &LoadAndStoreProperties, + SmallVectorImpl &BBLoadsAndStores); + + /// Insert calls to the instrumentation hooks. + /// @{ + void addLoadStoreInstrumentation(Instruction *I, Function *BeforeFn, + Function *AfterFn, Value *CsiId, + Type *AddrType, Value *Addr, int NumBytes, + uint64_t Prop); + void instrumentLoadOrStore(Instruction *I, uint64_t Prop, + const DataLayout &DL); + void instrumentMemIntrinsic(Instruction *I); + void instrumentCallsite(Instruction *I); + void instrumentBasicBlock(BasicBlock &BB); + void instrumentFunction(Function &F); + /// @} + + /// Return true if the given function should not be instrumented. + bool shouldNotInstrumentFunction(Function &F); + + /// Initialize the CSI pass. + void initializeCsi(Module &M); + /// Finalize the CSI pass. + void finalizeCsi(Module &M); + + FrontEndDataTable FunctionFED, FunctionExitFED, BasicBlockFED, CallsiteFED, + LoadFED, StoreFED; + + Function *CsiBeforeCallsite, *CsiAfterCallsite; + Function *CsiFuncEntry, *CsiFuncExit; + Function *CsiBBEntry, *CsiBBExit; + Function *CsiBeforeRead, *CsiAfterRead; + Function *CsiBeforeWrite, *CsiAfterWrite; + + CallGraph *CG; + Function *MemmoveFn, *MemcpyFn, *MemsetFn; + Function *InitCallsiteToFunction; + Type *IntptrTy; + std::map FuncOffsetMap; +}; // struct ComprehensiveStaticInstrumentation +} // anonymous namespace + +char ComprehensiveStaticInstrumentation::ID = 0; + +INITIALIZE_PASS(ComprehensiveStaticInstrumentation, "csi", + "ComprehensiveStaticInstrumentation pass", false, false) + +const char *ComprehensiveStaticInstrumentation::getPassName() const { + return "ComprehensiveStaticInstrumentation"; +} + +ModulePass *llvm::createComprehensiveStaticInstrumentationPass() { + return new ComprehensiveStaticInstrumentation(); +} + +FrontEndDataTable::FrontEndDataTable(Module &M, StringRef BaseIdName) { + LLVMContext &C = M.getContext(); + IntegerType *Int64Ty = IntegerType::get(C, 64); + IdCounter = 0; + BaseId = new GlobalVariable(M, Int64Ty, false, GlobalValue::InternalLinkage, + ConstantInt::get(Int64Ty, 0), BaseIdName); + assert(BaseId); +} + +uint64_t FrontEndDataTable::add(Function &F) { + uint64_t Id = add(F.getSubprogram()); + ValueToLocalIdMap[&F] = Id; + return Id; +} + +uint64_t FrontEndDataTable::add(BasicBlock &BB) { + uint64_t Id = add(getFirstDebugLoc(BB)); + ValueToLocalIdMap[&BB] = Id; + return Id; +} + +uint64_t FrontEndDataTable::add(Instruction &I) { + uint64_t Id = add(I.getDebugLoc()); + ValueToLocalIdMap[&I] = Id; + return Id; +} + +uint64_t FrontEndDataTable::getId(Value *V) { + assert(ValueToLocalIdMap.find(V) != ValueToLocalIdMap.end() && + "Value not in ID map."); + return ValueToLocalIdMap[V]; +} + +Value *FrontEndDataTable::localToGlobalId(uint64_t LocalId, + IRBuilder<> IRB) const { + assert(BaseId); + Value *Base = IRB.CreateLoad(BaseId); + Value *Offset = IRB.getInt64(LocalId); + return IRB.CreateAdd(Base, Offset); +} + +PointerType *FrontEndDataTable::getPointerType(LLVMContext &C) { + return PointerType::get(getSourceLocStructType(C), 0); +} + +StructType *FrontEndDataTable::getSourceLocStructType(LLVMContext &C) { + return StructType::get(IntegerType::get(C, 32), + PointerType::get(IntegerType::get(C, 8), 0), nullptr); +} + +uint64_t FrontEndDataTable::add(DILocation *Loc) { + if (Loc) { + return add((int32_t)Loc->getLine(), Loc->getFilename()); + } else { + return add(-1, ""); + } +} + +uint64_t FrontEndDataTable::add(DISubprogram *Subprog) { + if (Subprog) { + return add((int32_t)Subprog->getLine(), Subprog->getFilename()); + } else { + return add(-1, ""); + } +} + +uint64_t FrontEndDataTable::add(int32_t Line, StringRef File) { + uint64_t Id = IdCounter++; + assert(LocalIdToSourceLocationMap.find(Id) == + LocalIdToSourceLocationMap.end() && + "Id already exists in FED table."); + LocalIdToSourceLocationMap[Id] = {Line, File}; + return Id; +} + +Constant *FrontEndDataTable::insertIntoModule(Module &M) const { + LLVMContext &C = M.getContext(); + StructType *FedType = getSourceLocStructType(C); + IntegerType *Int32Ty = IntegerType::get(C, 32); + Constant *Zero = ConstantInt::get(Int32Ty, 0); + Value *GepArgs[] = {Zero, Zero}; + SmallVector FEDEntries; + + for (const auto it : LocalIdToSourceLocationMap) { + const SourceLocation &E = it.second; + Value *Line = ConstantInt::get(Int32Ty, E.Line); + Constant *FileStrConstant = ConstantDataArray::getString(C, E.File); + GlobalVariable *GV = M.getGlobalVariable("__csi_unit_filename", true); + if (GV == NULL) { + GV = new GlobalVariable(M, FileStrConstant->getType(), + true, GlobalValue::PrivateLinkage, + FileStrConstant, "__csi_unit_filename", nullptr, GlobalVariable::NotThreadLocal, 0); + GV->setUnnamedAddr(true); + } + assert(GV); + Constant *File = + ConstantExpr::getGetElementPtr(GV->getValueType(), GV, GepArgs); + + FEDEntries.push_back(ConstantStruct::get(FedType, Line, File, nullptr)); + } + + ArrayType *FedArrayType = ArrayType::get(FedType, FEDEntries.size()); + Constant *Table = ConstantArray::get(FedArrayType, FEDEntries); + GlobalVariable *GV = + new GlobalVariable(M, FedArrayType, false, GlobalValue::InternalLinkage, + Table, CsiUnitFedTableName); + return ConstantExpr::getGetElementPtr(GV->getValueType(), GV, GepArgs); +} + +void ComprehensiveStaticInstrumentation::initializeFuncHooks(Module &M) { + IRBuilder<> IRB(M.getContext()); + CsiFuncEntry = checkCsiInterfaceFunction(M.getOrInsertFunction( + "__csi_func_entry", IRB.getVoidTy(), IRB.getInt64Ty(), nullptr)); + CsiFuncExit = checkCsiInterfaceFunction( + M.getOrInsertFunction("__csi_func_exit", IRB.getVoidTy(), + IRB.getInt64Ty(), IRB.getInt64Ty(), nullptr)); +} + +void ComprehensiveStaticInstrumentation::initializeBasicBlockHooks(Module &M) { + IRBuilder<> IRB(M.getContext()); + CsiBBEntry = checkCsiInterfaceFunction(M.getOrInsertFunction( + "__csi_bb_entry", IRB.getVoidTy(), IRB.getInt64Ty(), nullptr)); + CsiBBExit = checkCsiInterfaceFunction(M.getOrInsertFunction( + "__csi_bb_exit", IRB.getVoidTy(), IRB.getInt64Ty(), nullptr)); +} + +void ComprehensiveStaticInstrumentation::initializeCallsiteHooks(Module &M) { + IRBuilder<> IRB(M.getContext()); + CsiBeforeCallsite = checkCsiInterfaceFunction( + M.getOrInsertFunction("__csi_before_call", IRB.getVoidTy(), + IRB.getInt64Ty(), IRB.getInt64Ty(), nullptr)); + CsiAfterCallsite = checkCsiInterfaceFunction( + M.getOrInsertFunction("__csi_after_call", IRB.getVoidTy(), + IRB.getInt64Ty(), IRB.getInt64Ty(), nullptr)); +} + +void ComprehensiveStaticInstrumentation::initializeLoadStoreHooks(Module &M) { + + IRBuilder<> IRB(M.getContext()); + Type *RetType = IRB.getVoidTy(); + Type *AddrType = IRB.getInt8PtrTy(); + Type *NumBytesType = IRB.getInt32Ty(); + + CsiBeforeRead = checkCsiInterfaceFunction( + M.getOrInsertFunction("__csi_before_load", RetType, IRB.getInt64Ty(), + AddrType, NumBytesType, IRB.getInt64Ty(), nullptr)); + CsiAfterRead = checkCsiInterfaceFunction( + M.getOrInsertFunction("__csi_after_load", RetType, IRB.getInt64Ty(), + AddrType, NumBytesType, IRB.getInt64Ty(), nullptr)); + + CsiBeforeWrite = checkCsiInterfaceFunction( + M.getOrInsertFunction("__csi_before_store", RetType, IRB.getInt64Ty(), + AddrType, NumBytesType, IRB.getInt64Ty(), nullptr)); + CsiAfterWrite = checkCsiInterfaceFunction( + M.getOrInsertFunction("__csi_after_store", RetType, IRB.getInt64Ty(), + AddrType, NumBytesType, IRB.getInt64Ty(), nullptr)); + + MemmoveFn = checkCsiInterfaceFunction( + M.getOrInsertFunction("memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IntptrTy, nullptr)); + MemcpyFn = checkCsiInterfaceFunction( + M.getOrInsertFunction("memcpy", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IntptrTy, nullptr)); + MemsetFn = checkCsiInterfaceFunction( + M.getOrInsertFunction("memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + IRB.getInt32Ty(), IntptrTy, nullptr)); +} + +int ComprehensiveStaticInstrumentation::getNumBytesAccessed( + Value *Addr, const DataLayout &DL) { + Type *OrigPtrTy = Addr->getType(); + Type *OrigTy = cast(OrigPtrTy)->getElementType(); + assert(OrigTy->isSized()); + uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy); + if (TypeSize != 8 && TypeSize != 16 && TypeSize != 32 && TypeSize != 64 && + TypeSize != 128) { + return -1; + } + return TypeSize / 8; +} + +void ComprehensiveStaticInstrumentation::addLoadStoreInstrumentation( + Instruction *I, Function *BeforeFn, Function *AfterFn, Value *CsiId, + Type *AddrType, Value *Addr, int NumBytes, uint64_t Prop) { + IRBuilder<> IRB(I); + Instruction *Call = IRB.CreateCall(BeforeFn, {CsiId, IRB.CreatePointerCast(Addr, AddrType), + IRB.getInt32(NumBytes), IRB.getInt64(Prop)}); + setInstrumentationDebugLoc(I, Call); + + BasicBlock::iterator Iter(I); + Iter++; + IRB.SetInsertPoint(&*Iter); + + Call = IRB.CreateCall(AfterFn, {CsiId, IRB.CreatePointerCast(Addr, AddrType), + IRB.getInt32(NumBytes), IRB.getInt64(Prop)}); + setInstrumentationDebugLoc(I, Call); +} + +void ComprehensiveStaticInstrumentation::instrumentLoadOrStore( + Instruction *I, uint64_t Prop, const DataLayout &DL) { + IRBuilder<> IRB(I); + bool IsWrite = isa(I); + Value *Addr = IsWrite ? cast(I)->getPointerOperand() + : cast(I)->getPointerOperand(); + int NumBytes = getNumBytesAccessed(Addr, DL); + Type *AddrType = IRB.getInt8PtrTy(); + + if (NumBytes == -1) + return; // size that we don't recognize + + if (IsWrite) { + uint64_t LocalId = StoreFED.add(*I); + Value *CsiId = StoreFED.localToGlobalId(LocalId, IRB); + addLoadStoreInstrumentation(I, CsiBeforeWrite, CsiAfterWrite, CsiId, + AddrType, Addr, NumBytes, Prop); + } else { // is read + uint64_t LocalId = LoadFED.add(*I); + Value *CsiId = LoadFED.localToGlobalId(LocalId, IRB); + addLoadStoreInstrumentation(I, CsiBeforeRead, CsiAfterRead, CsiId, AddrType, + Addr, NumBytes, Prop); + } +} + +// If a memset intrinsic gets inlined by the code gen, we will miss races on it. +// So, we either need to ensure the intrinsic is not inlined, or instrument it. +// We do not instrument memset/memmove/memcpy intrinsics (too complicated), +// instead we simply replace them with regular function calls, which are then +// intercepted by the run-time. +// Since our pass runs after everyone else, the calls should not be +// replaced back with intrinsics. If that becomes wrong at some point, +// we will need to call e.g. __csi_memset to avoid the intrinsics. +void ComprehensiveStaticInstrumentation::instrumentMemIntrinsic( + Instruction *I) { + IRBuilder<> IRB(I); + if (MemSetInst *M = dyn_cast(I)) { + Instruction *Call = IRB.CreateCall( + MemsetFn, + {IRB.CreatePointerCast(M->getArgOperand(0), IRB.getInt8PtrTy()), + IRB.CreateIntCast(M->getArgOperand(1), IRB.getInt32Ty(), false), + IRB.CreateIntCast(M->getArgOperand(2), IntptrTy, false)}); + setInstrumentationDebugLoc(I, Call); + I->eraseFromParent(); + } else if (MemTransferInst *M = dyn_cast(I)) { + Instruction *Call = IRB.CreateCall( + isa(M) ? MemcpyFn : MemmoveFn, + {IRB.CreatePointerCast(M->getArgOperand(0), IRB.getInt8PtrTy()), + IRB.CreatePointerCast(M->getArgOperand(1), IRB.getInt8PtrTy()), + IRB.CreateIntCast(M->getArgOperand(2), IntptrTy, false)}); + setInstrumentationDebugLoc(I, Call); + I->eraseFromParent(); + } +} + +void ComprehensiveStaticInstrumentation::instrumentBasicBlock(BasicBlock &BB) { + IRBuilder<> IRB(&*BB.getFirstInsertionPt()); + uint64_t LocalId = BasicBlockFED.add(BB); + Value *CsiId = BasicBlockFED.localToGlobalId(LocalId, IRB); + + Instruction *Call = IRB.CreateCall(CsiBBEntry, {CsiId}); + setInstrumentationDebugLoc(BB, Call); + + TerminatorInst *TI = BB.getTerminator(); + IRB.SetInsertPoint(TI); + Call = IRB.CreateCall(CsiBBExit, {CsiId}); + setInstrumentationDebugLoc(BB, Call); +} + +void ComprehensiveStaticInstrumentation::instrumentCallsite(Instruction *I) { + IRBuilder<> IRB(I); + CallSite CS(I); + Instruction *Inst = CS.getInstruction(); + Module *M = Inst->getParent()->getParent()->getParent(); + Function *Called = CS.getCalledFunction(); + + if (Called && Called->getName().startswith("llvm.dbg")) { + return; + } + + uint64_t LocalId = CallsiteFED.add(*Inst); + Value *CallsiteId = CallsiteFED.localToGlobalId(LocalId, IRB); + + std::string GVName = CsiFuncIdVariablePrefix + Called->getName().str(); + GlobalVariable *FuncIdGV = + dyn_cast(M->getOrInsertGlobal(GVName, IRB.getInt64Ty())); + assert(FuncIdGV); + FuncIdGV->setConstant(false); + FuncIdGV->setLinkage(GlobalValue::WeakAnyLinkage); + FuncIdGV->setInitializer(IRB.getInt64(CsiCallsiteUnknownTargetId)); + + Value *FuncId = IRB.CreateLoad(FuncIdGV); + Instruction *Call = IRB.CreateCall(CsiBeforeCallsite, {CallsiteId, FuncId}); + setInstrumentationDebugLoc(I, Call); + + BasicBlock::iterator Iter(I); + Iter++; + IRB.SetInsertPoint(&*Iter); + Call = IRB.CreateCall(CsiAfterCallsite, {CallsiteId, FuncId}); + setInstrumentationDebugLoc(I, Call); +} + +void ComprehensiveStaticInstrumentation::initializeFEDTables(Module &M) { + FunctionFED = FrontEndDataTable(M, CsiFunctionBaseIdName); + FunctionExitFED = FrontEndDataTable(M, CsiFunctionExitBaseIdName); + BasicBlockFED = FrontEndDataTable(M, CsiBasicBlockBaseIdName); + CallsiteFED = FrontEndDataTable(M, CsiCallsiteBaseIdName); + LoadFED = FrontEndDataTable(M, CsiLoadBaseIdName); + StoreFED = FrontEndDataTable(M, CsiStoreBaseIdName); +} + +void ComprehensiveStaticInstrumentation::generateInitCallsiteToFunction( + Module &M) { + LLVMContext &C = M.getContext(); + BasicBlock *EntryBB = BasicBlock::Create(C, "", InitCallsiteToFunction); + IRBuilder<> IRB(ReturnInst::Create(C, EntryBB)); + + GlobalVariable *Base = FunctionFED.baseId(); + LoadInst *LI = IRB.CreateLoad(Base); + // Traverse the map of function name -> function local id. Generate + // a store of each function's global ID to the corresponding weak + // global variable. + for (const auto &it : FuncOffsetMap) { + std::string GVName = CsiFuncIdVariablePrefix + it.first; + GlobalVariable *GV = nullptr; + if ((GV = M.getGlobalVariable(GVName)) == nullptr) { + GV = new GlobalVariable(M, IRB.getInt64Ty(), false, + GlobalValue::WeakAnyLinkage, + IRB.getInt64(CsiCallsiteUnknownTargetId), GVName); + } + assert(GV); + IRB.CreateStore(IRB.CreateAdd(LI, IRB.getInt64(it.second)), GV); + } +} + +void ComprehensiveStaticInstrumentation::initializeCsi(Module &M) { + initializeFEDTables(M); + initializeFuncHooks(M); + initializeLoadStoreHooks(M); + initializeBasicBlockHooks(M); + initializeCallsiteHooks(M); + + FunctionType *FnType = + FunctionType::get(Type::getVoidTy(M.getContext()), {}, false); + InitCallsiteToFunction = checkCsiInterfaceFunction( + M.getOrInsertFunction(CsiInitCallsiteToFunctionName, FnType)); + assert(InitCallsiteToFunction); + InitCallsiteToFunction->setLinkage(GlobalValue::InternalLinkage); + + CG = &getAnalysis().getCallGraph(); + IntptrTy = M.getDataLayout().getIntPtrType(M.getContext()); +} + +// Create a struct type to match the unit_fed_entry_t type in csirt.c. +StructType *getUnitFedTableType(LLVMContext &C, PointerType *EntryPointerType) { + return StructType::get(IntegerType::get(C, 64), + PointerType::get(IntegerType::get(C, 64), 0), + EntryPointerType, nullptr); +} + +Constant *fedTableToUnitFedTable(Module &M, StructType *UnitFedTableType, + FrontEndDataTable &FedTable) { + Constant *NumEntries = + ConstantInt::get(IntegerType::get(M.getContext(), 64), FedTable.size()); + Constant *InsertedTable = FedTable.insertIntoModule(M); + return ConstantStruct::get(UnitFedTableType, NumEntries, FedTable.baseId(), + InsertedTable, nullptr); +} + +void ComprehensiveStaticInstrumentation::finalizeCsi(Module &M) { + LLVMContext &C = M.getContext(); + + // Add CSI global constructor, which calls unit init. + Function *Ctor = + Function::Create(FunctionType::get(Type::getVoidTy(C), false), + GlobalValue::InternalLinkage, CsiRtUnitCtorName, &M); + BasicBlock *CtorBB = BasicBlock::Create(C, "", Ctor); + IRBuilder<> IRB(ReturnInst::Create(C, CtorBB)); + + StructType *UnitFedTableType = + getUnitFedTableType(C, FrontEndDataTable::getPointerType(C)); + + // Lookup __csirt_unit_init + SmallVector InitArgTypes({IRB.getInt8PtrTy(), + PointerType::get(UnitFedTableType, 0), + InitCallsiteToFunction->getType()}); + FunctionType *InitFunctionTy = + FunctionType::get(IRB.getVoidTy(), InitArgTypes, false); + Function *InitFunction = checkCsiInterfaceFunction( + M.getOrInsertFunction(CsiRtUnitInitName, InitFunctionTy)); + assert(InitFunction); + + // Insert __csi_func_id_ weak symbols for all defined functions + // and generate the runtime code that stores to all of them. + generateInitCallsiteToFunction(M); + + SmallVector UnitFedTables({ + fedTableToUnitFedTable(M, UnitFedTableType, BasicBlockFED), + fedTableToUnitFedTable(M, UnitFedTableType, FunctionFED), + fedTableToUnitFedTable(M, UnitFedTableType, FunctionExitFED), + fedTableToUnitFedTable(M, UnitFedTableType, CallsiteFED), + fedTableToUnitFedTable(M, UnitFedTableType, LoadFED), + fedTableToUnitFedTable(M, UnitFedTableType, StoreFED), + }); + + ArrayType *UnitFedTableArrayType = + ArrayType::get(UnitFedTableType, UnitFedTables.size()); + Constant *Table = ConstantArray::get(UnitFedTableArrayType, UnitFedTables); + GlobalVariable *GV = new GlobalVariable(M, UnitFedTableArrayType, false, + GlobalValue::InternalLinkage, Table, + CsiUnitFedTableArrayName); + + Constant *Zero = ConstantInt::get(IRB.getInt32Ty(), 0); + Value *GepArgs[] = {Zero, Zero}; + + // Insert call to __csirt_unit_init + CallInst *Call = IRB.CreateCall( + InitFunction, + {IRB.CreateGlobalStringPtr(M.getName()), + ConstantExpr::getGetElementPtr(GV->getValueType(), GV, GepArgs), + InitCallsiteToFunction}); + + // Add the constructor to the global list + appendToGlobalCtors(M, Ctor, CsiUnitCtorPriority); + + CallGraphNode *CNCtor = CG->getOrInsertFunction(Ctor); + CallGraphNode *CNFunc = CG->getOrInsertFunction(InitFunction); + CNCtor->addCalledFunction(Call, CNFunc); +} + +void ComprehensiveStaticInstrumentation::getAnalysisUsage( + AnalysisUsage &AU) const { + AU.addRequired(); +} + +bool ComprehensiveStaticInstrumentation::shouldNotInstrumentFunction( + Function &F) { + Module &M = *F.getParent(); + // Never instrument the CSI ctor. + if (F.hasName() && F.getName() == CsiRtUnitCtorName) { + return true; + } + // Don't instrument functions that will run before or + // simultaneously with CSI ctors. + GlobalVariable *GV = M.getGlobalVariable("llvm.global_ctors"); + if (GV == nullptr) + return false; + ConstantArray *CA = cast(GV->getInitializer()); + for (Use &OP : CA->operands()) { + if (isa(OP)) + continue; + ConstantStruct *CS = cast(OP); + + if (Function *CF = dyn_cast(CS->getOperand(1))) { + uint64_t Priority = + dyn_cast(CS->getOperand(0))->getLimitedValue(); + if (Priority <= CsiUnitCtorPriority && CF->getName() == F.getName()) { + // Do not instrument F. + return true; + } + } + } + // false means do instrument it. + return false; +} + +void ComprehensiveStaticInstrumentation::computeLoadAndStoreProperties( + SmallVectorImpl> &LoadAndStoreProperties, + SmallVectorImpl &BBLoadsAndStores) { + SmallSet WriteTargets; + + for (SmallVectorImpl::reverse_iterator + It = BBLoadsAndStores.rbegin(), + E = BBLoadsAndStores.rend(); + It != E; ++It) { + Instruction *I = *It; + if (StoreInst *Store = dyn_cast(I)) { + WriteTargets.insert(Store->getPointerOperand()); + LoadAndStoreProperties.push_back(std::make_pair(I, 0)); + } else { + LoadInst *Load = cast(I); + Value *Addr = Load->getPointerOperand(); + bool HasBeenSeen = WriteTargets.count(Addr) > 0; + uint64_t Prop = HasBeenSeen ? 1 : 0; + LoadAndStoreProperties.push_back(std::make_pair(I, Prop)); + } + } + BBLoadsAndStores.clear(); +} + +bool ComprehensiveStaticInstrumentation::runOnModule(Module &M) { + initializeCsi(M); + + for (Function &F : M) { + instrumentFunction(F); + } + + finalizeCsi(M); + return true; // We always insert the unit constructor. +} + +void ComprehensiveStaticInstrumentation::instrumentFunction(Function &F) { + // This is required to prevent instrumenting the call to + // __csi_module_init from within the module constructor. + if (F.empty() || shouldNotInstrumentFunction(F)) { + return; + } + + SmallVector, 8> LoadAndStoreProperties; + SmallVector ReturnInstructions; + SmallVector MemIntrinsics; + SmallVector Callsites; + const DataLayout &DL = F.getParent()->getDataLayout(); + + // Traverse all instructions in a function and insert instrumentation + // on load & store + for (BasicBlock &BB : F) { + SmallVector BBLoadsAndStores; + for (Instruction &I : BB) { + if (isa(I) || isa(I)) { + BBLoadsAndStores.push_back(&I); + } else if (isa(I)) { + ReturnInstructions.push_back(&I); + } else if (isa(I) || isa(I)) { + Callsites.push_back(&I); + if (isa(I)) { + MemIntrinsics.push_back(&I); + } + } + } + computeLoadAndStoreProperties(LoadAndStoreProperties, BBLoadsAndStores); + } + + // Do this work in a separate loop after copying the iterators so that we + // aren't modifying the list as we're iterating. + for (std::pair p : LoadAndStoreProperties) { + instrumentLoadOrStore(p.first, p.second, DL); + } + + for (Instruction *I : MemIntrinsics) { + instrumentMemIntrinsic(I); + } + + for (Instruction *I : Callsites) { + instrumentCallsite(I); + } + + // Instrument basic blocks + // Note that we do this before function entry so that we put this at the + // beginning of the basic block, and then the function entry call goes before + // the call to basic block entry. + uint64_t LocalId = FunctionFED.add(F); + FuncOffsetMap[F.getName()] = LocalId; + for (BasicBlock &BB : F) { + instrumentBasicBlock(BB); + } + + // Instrument function entry/exit points. + IRBuilder<> IRB(&*F.getEntryBlock().getFirstInsertionPt()); + + Value *FuncId = FunctionFED.localToGlobalId(LocalId, IRB); + Instruction *Call = IRB.CreateCall(CsiFuncEntry, {FuncId}); + setInstrumentationDebugLoc(F, Call); + + for (Instruction *I : ReturnInstructions) { + IRBuilder<> IRBRet(I); + uint64_t ExitLocalId = FunctionExitFED.add(F); + Value *ExitCsiId = FunctionExitFED.localToGlobalId(ExitLocalId, IRBRet); + Call = IRBRet.CreateCall(CsiFuncExit, {ExitCsiId, FuncId}); + setInstrumentationDebugLoc(F, Call); + } +} Index: lib/Transforms/Instrumentation/Instrumentation.cpp =================================================================== --- lib/Transforms/Instrumentation/Instrumentation.cpp +++ lib/Transforms/Instrumentation/Instrumentation.cpp @@ -68,6 +68,7 @@ initializeThreadSanitizerPass(Registry); initializeSanitizerCoverageModulePass(Registry); initializeDataFlowSanitizerPass(Registry); + initializeComprehensiveStaticInstrumentationPass(Registry); initializeEfficiencySanitizerPass(Registry); } Index: lib/Transforms/Utils/ModuleUtils.cpp =================================================================== --- lib/Transforms/Utils/ModuleUtils.cpp +++ lib/Transforms/Utils/ModuleUtils.cpp @@ -89,6 +89,24 @@ appendToGlobalArray("llvm.global_dtors", M, F, Priority, Data); } +Function *llvm::checkCsiInterfaceFunction(Constant *FuncOrBitcast) { + if (Function *F = dyn_cast(FuncOrBitcast)) { + return F; + } + if (ConstantExpr *CE = dyn_cast(FuncOrBitcast)) { + if (CE->isCast() && CE->getOpcode() == Instruction::BitCast) { + if (Function *F = dyn_cast(CE->getOperand(0))) { + return F; + } + } + } + FuncOrBitcast->dump(); + std::string Err; + raw_string_ostream Stream(Err); + Stream << "ComprehensiveStaticInstrumentation interface function redefined: " << *FuncOrBitcast; + report_fatal_error(Err); +} + Function *llvm::checkSanitizerInterfaceFunction(Constant *FuncOrBitcast) { if (isa(FuncOrBitcast)) return cast(FuncOrBitcast); Index: test/Instrumentation/ComprehensiveStaticInstrumentation/basicblock_entry_exit.ll =================================================================== --- /dev/null +++ test/Instrumentation/ComprehensiveStaticInstrumentation/basicblock_entry_exit.ll @@ -0,0 +1,44 @@ +; Test CSI function entry/exit instrumentation. +; +; RUN: opt < %s -csi -S | FileCheck %s + +; CHECK: @__csi_unit_bb_base_id = internal global i64 0 + +define i32 @main() #0 { +entry: + %retval = alloca i32, align 4 + store i32 0, i32* %retval, align 4 + %call = call i32 @foo() + ret i32 %call +} + +define internal i32 @foo() #0 { +entry: + ret i32 1 +} + +; CHECK: define i32 @main() +; CHECK-NEXT: entry: +; CHECK: %2 = load i64, i64* @__csi_unit_bb_base_id +; CHECK-NEXT: %3 = add i64 %2, 0 +; CHECK-NEXT: call void @__csi_bb_entry(i64 %3) +; CHECK: %retval = alloca i32, align 4 +; CHECK: store i32 0, i32* %retval, align 4 +; CHECK: %call = call i32 @foo() +; CHECK: call void @__csi_bb_exit(i64 %3) +; CHECK: ret i32 %call + +; CHECK: define internal i32 @foo() +; CHECK-NEXT: entry: +; CHECK: %2 = load i64, i64* @__csi_unit_bb_base_id +; CHECK: %3 = add i64 %2, 1 +; CHECK-NEXT: call void @__csi_bb_entry(i64 %3) +; CHECK-NEXT: call void @__csi_bb_exit(i64 %3) +; CHECK: ret i32 1 + + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +; Top-level: + +; CHECK: declare void @__csi_bb_entry(i64) +; CHECK: declare void @__csi_bb_exit(i64) Index: test/Instrumentation/ComprehensiveStaticInstrumentation/func_entry_exit.ll =================================================================== --- /dev/null +++ test/Instrumentation/ComprehensiveStaticInstrumentation/func_entry_exit.ll @@ -0,0 +1,49 @@ +; Test CSI function entry/exit instrumentation. +; +; RUN: opt < %s -csi -S | FileCheck %s + +; CHECK: @__csi_unit_func_base_id = internal global i64 0 +; CHECK: @__csi_unit_func_exit_base_id = internal global i64 0 + +define i32 @main() #0 { +entry: + %retval = alloca i32, align 4 + store i32 0, i32* %retval, align 4 + %call = call i32 @foo() + ret i32 %call +} + +define internal i32 @foo() #0 { +entry: + ret i32 1 +} + +; CHECK: define i32 @main() +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = load i64, i64* @__csi_unit_func_base_id +; CHECK-NEXT: %1 = add i64 %0, 0 +; CHECK-NEXT: call void @__csi_func_entry(i64 %1) +; CHECK: %retval = alloca i32, align 4 +; CHECK: store i32 0, i32* %retval, align 4 +; CHECK: %call = call i32 @foo() +; CHECK: %11 = load i64, i64* @__csi_unit_func_exit_base_id +; CHECK-NEXT: %12 = add i64 %11, 0 +; CHECK-NEXT: call void @__csi_func_exit(i64 %12, i64 %1) +; CHECK-NEXT: ret i32 %call + +; CHECK: define internal i32 @foo() +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = load i64, i64* @__csi_unit_func_base_id +; CHECK-NEXT: %1 = add i64 %0, 1 +; CHECK-NEXT: call void @__csi_func_entry(i64 %1) +; CHECK: %4 = load i64, i64* @__csi_unit_func_exit_base_id +; CHECK-NEXT: %5 = add i64 %4, 1 +; CHECK-NEXT: call void @__csi_func_exit(i64 %5, i64 %1) +; CHECK-NEXT: ret i32 1 + + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +; Top-level: + +; CHECK: declare void @__csi_func_entry(i64) +; CHECK: declare void @__csi_func_exit(i64, i64) Index: test/Instrumentation/ComprehensiveStaticInstrumentation/interface.ll =================================================================== --- /dev/null +++ test/Instrumentation/ComprehensiveStaticInstrumentation/interface.ll @@ -0,0 +1,44 @@ +; Test CSI interface declarations. +; +; RUN: opt < %s -csi -S | FileCheck %s + +define i32 @main() #0 { +entry: + %retval = alloca i32, align 4 + store i32 0, i32* %retval, align 4 + %call = call i32 @foo() + ret i32 %call +} + +define internal i32 @foo() #0 { +entry: + ret i32 1 +} + +; CHECK: @__csi_unit_func_base_id = internal global i64 0 +; CHECK: @__csi_unit_func_exit_base_id = internal global i64 0 +; CHECK: @__csi_unit_bb_base_id = internal global i64 0 +; CHECK: @__csi_unit_callsite_base_id = internal global i64 0 +; CHECK: @__csi_unit_load_base_id = internal global i64 0 +; CHECK: @__csi_unit_store_base_id = internal global i64 0 +; CHECK: @__csi_func_id_foo = weak global i64 -1 +; CHECK: @__csi_func_id_main = weak global i64 -1 +; CHECK: @__csi_unit_filename = private unnamed_addr constant [1 x i8] zeroinitializer +; CHECK: @0 = private unnamed_addr constant [8 x i8] c"\00" +; CHECK: @llvm.global_ctors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 65535, void ()* @csirt.unit_ctor, i8* null }] + +; CHECK: declare void @__csi_func_entry(i64) +; CHECK: declare void @__csi_func_exit(i64, i64) +; CHECK: declare void @__csi_before_load(i64, i8*, i32, i64) +; CHECK: declare void @__csi_after_load(i64, i8*, i32, i64) +; CHECK: declare void @__csi_before_store(i64, i8*, i32, i64) +; CHECK: declare void @__csi_after_store(i64, i8*, i32, i64) +; CHECK: declare void @__csi_bb_entry(i64) +; CHECK: declare void @__csi_bb_exit(i64) +; CHECK: declare void @__csi_before_call(i64, i64) +; CHECK: declare void @__csi_after_call(i64, i64) +; CHECK: define internal void @__csi_init_callsite_to_function() +; CHECK: define internal void @csirt.unit_ctor() +; CHECK-NEXT: call void @__csirt_unit_init(i8* getelementptr inbounds ([8 x i8], [8 x i8]* @0, i32 0, i32 0), { i64, i64*, { i32, i8* }* }* getelementptr inbounds ([6 x { i64, i64*, { i32, i8* }* }], [6 x { i64, i64*, { i32, i8* }* }]* @__csi_unit_fed_tables, i32 0, i32 0), void ()* @__csi_init_callsite_to_function) +; CHECK-NEXT: ret void +; CHECK: declare void @__csirt_unit_init(i8*, { i64, i64*, { i32, i8* }* }*, void ()*)