Index: llvm/trunk/include/llvm/InitializePasses.h =================================================================== --- llvm/trunk/include/llvm/InitializePasses.h +++ llvm/trunk/include/llvm/InitializePasses.h @@ -319,7 +319,7 @@ void initializeSCEVAAWrapperPassPass(PassRegistry&); void initializeSLPVectorizerPass(PassRegistry&); void initializeSROALegacyPassPass(PassRegistry&); -void initializeSafeStackPass(PassRegistry&); +void initializeSafeStackLegacyPassPass(PassRegistry&); void initializeSampleProfileLoaderLegacyPassPass(PassRegistry&); void initializeSanitizerCoverageModulePass(PassRegistry&); void initializeScalarEvolutionWrapperPassPass(PassRegistry&); Index: llvm/trunk/lib/CodeGen/CodeGen.cpp =================================================================== --- llvm/trunk/lib/CodeGen/CodeGen.cpp +++ llvm/trunk/lib/CodeGen/CodeGen.cpp @@ -79,7 +79,7 @@ initializeRAGreedyPass(Registry); initializeRegisterCoalescerPass(Registry); initializeRenameIndependentSubregsPass(Registry); - initializeSafeStackPass(Registry); + initializeSafeStackLegacyPassPass(Registry); initializeShrinkWrapPass(Registry); initializeSlotIndexesPass(Registry); initializeStackColoringPass(Registry); Index: llvm/trunk/lib/CodeGen/SafeStack.cpp =================================================================== --- llvm/trunk/lib/CodeGen/SafeStack.cpp +++ llvm/trunk/lib/CodeGen/SafeStack.cpp @@ -19,6 +19,7 @@ #include "SafeStackLayout.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Triple.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -92,11 +93,11 @@ /// determined statically), and the unsafe stack, which contains all /// local variables that are accessed in ways that we can't prove to /// be safe. -class SafeStack : public FunctionPass { - const TargetMachine *TM; - const TargetLoweringBase *TL; - const DataLayout *DL; - ScalarEvolution *SE; +class SafeStack { + Function &F; + const TargetLoweringBase &TL; + const DataLayout &DL; + ScalarEvolution &SE; Type *StackPtrTy; Type *IntPtrTy; @@ -171,33 +172,21 @@ uint64_t AllocaSize); public: - static char ID; // Pass identification, replacement for typeid. - SafeStack(const TargetMachine *TM) - : FunctionPass(ID), TM(TM), TL(nullptr), DL(nullptr) { - initializeSafeStackPass(*PassRegistry::getPassRegistry()); - } - SafeStack() : SafeStack(nullptr) {} - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired(); - } - - bool doInitialization(Module &M) override { - DL = &M.getDataLayout(); - - StackPtrTy = Type::getInt8PtrTy(M.getContext()); - IntPtrTy = DL->getIntPtrType(M.getContext()); - Int32Ty = Type::getInt32Ty(M.getContext()); - Int8Ty = Type::getInt8Ty(M.getContext()); - - return false; - } - - bool runOnFunction(Function &F) override; -}; // class SafeStack + SafeStack(Function &F, const TargetLoweringBase &TL, const DataLayout &DL, + ScalarEvolution &SE) + : F(F), TL(TL), DL(DL), SE(SE), + StackPtrTy(Type::getInt8PtrTy(F.getContext())), + IntPtrTy(DL.getIntPtrType(F.getContext())), + Int32Ty(Type::getInt32Ty(F.getContext())), + Int8Ty(Type::getInt8Ty(F.getContext())) {} + + // Run the transformation on the associated function. + // Returns whether the function was changed. + bool run(); +}; uint64_t SafeStack::getStaticAllocaAllocationSize(const AllocaInst* AI) { - uint64_t Size = DL->getTypeAllocSize(AI->getAllocatedType()); + uint64_t Size = DL.getTypeAllocSize(AI->getAllocatedType()); if (AI->isArrayAllocation()) { auto C = dyn_cast(AI->getArraySize()); if (!C) @@ -209,11 +198,11 @@ bool SafeStack::IsAccessSafe(Value *Addr, uint64_t AccessSize, const Value *AllocaPtr, uint64_t AllocaSize) { - AllocaOffsetRewriter Rewriter(*SE, AllocaPtr); - const SCEV *Expr = Rewriter.visit(SE->getSCEV(Addr)); + AllocaOffsetRewriter Rewriter(SE, AllocaPtr); + const SCEV *Expr = Rewriter.visit(SE.getSCEV(Addr)); - uint64_t BitWidth = SE->getTypeSizeInBits(Expr->getType()); - ConstantRange AccessStartRange = SE->getUnsignedRange(Expr); + uint64_t BitWidth = SE.getTypeSizeInBits(Expr->getType()); + ConstantRange AccessStartRange = SE.getUnsignedRange(Expr); ConstantRange SizeRange = ConstantRange(APInt(BitWidth, 0), APInt(BitWidth, AccessSize)); ConstantRange AccessRange = AccessStartRange.add(SizeRange); @@ -226,8 +215,8 @@ << *AllocaPtr << "\n" << " Access " << *Addr << "\n" << " SCEV " << *Expr - << " U: " << SE->getUnsignedRange(Expr) - << ", S: " << SE->getSignedRange(Expr) << "\n" + << " U: " << SE.getUnsignedRange(Expr) + << ", S: " << SE.getSignedRange(Expr) << "\n" << " Range " << AccessRange << "\n" << " AllocaRange " << AllocaRange << "\n" << " " << (Safe ? "safe" : "unsafe") << "\n"); @@ -266,7 +255,7 @@ switch (I->getOpcode()) { case Instruction::Load: { - if (!IsAccessSafe(UI, DL->getTypeStoreSize(I->getType()), AllocaPtr, + if (!IsAccessSafe(UI, DL.getTypeStoreSize(I->getType()), AllocaPtr, AllocaSize)) return false; break; @@ -282,7 +271,7 @@ return false; } - if (!IsAccessSafe(UI, DL->getTypeStoreSize(I->getOperand(0)->getType()), + if (!IsAccessSafe(UI, DL.getTypeStoreSize(I->getOperand(0)->getType()), AllocaPtr, AllocaSize)) return false; break; @@ -343,7 +332,7 @@ } Value *SafeStack::getStackGuard(IRBuilder<> &IRB, Function &F) { - Value *StackGuardVar = TL->getIRStackGuard(IRB); + Value *StackGuardVar = TL.getIRStackGuard(IRB); if (!StackGuardVar) StackGuardVar = F.getParent()->getOrInsertGlobal("__stack_chk_guard", StackPtrTy); @@ -390,7 +379,7 @@ if (!Arg.hasByValAttr()) continue; uint64_t Size = - DL->getTypeStoreSize(Arg.getType()->getPointerElementType()); + DL.getTypeStoreSize(Arg.getType()->getPointerElementType()); if (IsSafeStackAlloca(&Arg, Size)) continue; @@ -476,19 +465,19 @@ if (StackGuardSlot) { Type *Ty = StackGuardSlot->getAllocatedType(); unsigned Align = - std::max(DL->getPrefTypeAlignment(Ty), StackGuardSlot->getAlignment()); + std::max(DL.getPrefTypeAlignment(Ty), StackGuardSlot->getAlignment()); SSL.addObject(StackGuardSlot, getStaticAllocaAllocationSize(StackGuardSlot), Align, SSC.getFullLiveRange()); } for (Argument *Arg : ByValArguments) { Type *Ty = Arg->getType()->getPointerElementType(); - uint64_t Size = DL->getTypeStoreSize(Ty); + uint64_t Size = DL.getTypeStoreSize(Ty); if (Size == 0) Size = 1; // Don't create zero-sized stack objects. // Ensure the object is properly aligned. - unsigned Align = std::max((unsigned)DL->getPrefTypeAlignment(Ty), + unsigned Align = std::max((unsigned)DL.getPrefTypeAlignment(Ty), Arg->getParamAlignment()); SSL.addObject(Arg, Size, Align, SSC.getFullLiveRange()); } @@ -501,7 +490,7 @@ // Ensure the object is properly aligned. unsigned Align = - std::max((unsigned)DL->getPrefTypeAlignment(Ty), AI->getAlignment()); + std::max((unsigned)DL.getPrefTypeAlignment(Ty), AI->getAlignment()); SSL.addObject(AI, Size, Align, SSC.getLiveRange(AI)); } @@ -539,7 +528,7 @@ unsigned Offset = SSL.getObjectOffset(Arg); Type *Ty = Arg->getType()->getPointerElementType(); - uint64_t Size = DL->getTypeStoreSize(Ty); + uint64_t Size = DL.getTypeStoreSize(Ty); if (Size == 0) Size = 1; // Don't create zero-sized stack objects. @@ -630,7 +619,7 @@ ArraySize = IRB.CreateIntCast(ArraySize, IntPtrTy, false); Type *Ty = AI->getAllocatedType(); - uint64_t TySize = DL->getTypeAllocSize(Ty); + uint64_t TySize = DL.getTypeAllocSize(Ty); Value *Size = IRB.CreateMul(ArraySize, ConstantInt::get(IntPtrTy, TySize)); Value *SP = IRB.CreatePtrToInt(IRB.CreateLoad(UnsafeStackPtr), IntPtrTy); @@ -638,7 +627,7 @@ // Align the SP value to satisfy the AllocaInst, type and stack alignments. unsigned Align = std::max( - std::max((unsigned)DL->getPrefTypeAlignment(Ty), AI->getAlignment()), + std::max((unsigned)DL.getPrefTypeAlignment(Ty), AI->getAlignment()), (unsigned)StackAlignment); assert(isPowerOf2_32(Align)); @@ -685,25 +674,10 @@ } } -bool SafeStack::runOnFunction(Function &F) { - DEBUG(dbgs() << "[SafeStack] Function: " << F.getName() << "\n"); - - if (!F.hasFnAttribute(Attribute::SafeStack)) { - DEBUG(dbgs() << "[SafeStack] safestack is not requested" - " for this function\n"); - return false; - } - - if (F.isDeclaration()) { - DEBUG(dbgs() << "[SafeStack] function definition" - " is not available\n"); - return false; - } - - if (!TM) - report_fatal_error("Target machine is required"); - TL = TM->getSubtargetImpl(F)->getTargetLowering(); - SE = &getAnalysis().getSE(); +bool SafeStack::run() { + assert(F.hasFnAttribute(Attribute::SafeStack) && + "Can't run SafeStack on a function without the attribute"); + assert(!F.isDeclaration() && "Can't run SafeStack on a function declaration"); ++NumFunctions; @@ -736,7 +710,7 @@ ++NumUnsafeStackRestorePointsFunctions; IRBuilder<> IRB(&F.front(), F.begin()->getFirstInsertionPt()); - UnsafeStackPtr = TL->getSafeStackPointerLocation(IRB); + UnsafeStackPtr = TL.getSafeStackPointerLocation(IRB); // Load the current stack pointer (we'll also use it as a base pointer). // FIXME: use a dedicated register for it ? @@ -788,14 +762,57 @@ return true; } +class SafeStackLegacyPass : public FunctionPass { + const TargetMachine *TM; + +public: + static char ID; // Pass identification, replacement for typeid.. + SafeStackLegacyPass(const TargetMachine *TM) : FunctionPass(ID), TM(TM) { + initializeSafeStackLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + SafeStackLegacyPass() : SafeStackLegacyPass(nullptr) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + } + + bool runOnFunction(Function &F) override { + DEBUG(dbgs() << "[SafeStack] Function: " << F.getName() << "\n"); + + if (!F.hasFnAttribute(Attribute::SafeStack)) { + DEBUG(dbgs() << "[SafeStack] safestack is not requested" + " for this function\n"); + return false; + } + + if (F.isDeclaration()) { + DEBUG(dbgs() << "[SafeStack] function definition" + " is not available\n"); + return false; + } + + if (!TM) + report_fatal_error("Target machine is required"); + auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); + if (!TL) + report_fatal_error("TargetLowering instance is required"); + + auto *DL = &F.getParent()->getDataLayout(); + auto &SE = getAnalysis().getSE(); + + return SafeStack(F, *TL, *DL, SE).run(); + } +}; + } // anonymous namespace -char SafeStack::ID = 0; -INITIALIZE_TM_PASS_BEGIN(SafeStack, "safe-stack", +char SafeStackLegacyPass::ID = 0; +INITIALIZE_TM_PASS_BEGIN(SafeStackLegacyPass, "safe-stack", "Safe Stack instrumentation pass", false, false) -INITIALIZE_TM_PASS_END(SafeStack, "safe-stack", +INITIALIZE_TM_PASS_END(SafeStackLegacyPass, "safe-stack", "Safe Stack instrumentation pass", false, false) FunctionPass *llvm::createSafeStackPass(const llvm::TargetMachine *TM) { - return new SafeStack(TM); + return new SafeStackLegacyPass(TM); } Index: llvm/trunk/tools/opt/opt.cpp =================================================================== --- llvm/trunk/tools/opt/opt.cpp +++ llvm/trunk/tools/opt/opt.cpp @@ -390,7 +390,7 @@ initializeRewriteSymbolsLegacyPassPass(Registry); initializeWinEHPreparePass(Registry); initializeDwarfEHPreparePass(Registry); - initializeSafeStackPass(Registry); + initializeSafeStackLegacyPassPass(Registry); initializeSjLjEHPreparePass(Registry); initializePreISelIntrinsicLoweringLegacyPassPass(Registry); initializeGlobalMergePass(Registry);