diff --git a/llvm/lib/Analysis/StackSafetyAnalysis.cpp b/llvm/lib/Analysis/StackSafetyAnalysis.cpp --- a/llvm/lib/Analysis/StackSafetyAnalysis.cpp +++ b/llvm/lib/Analysis/StackSafetyAnalysis.cpp @@ -14,12 +14,14 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ModuleSummaryAnalysis.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/StackLifetime.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/ModuleSummaryIndex.h" @@ -117,7 +119,7 @@ // Access range if the address (alloca or parameters). // It is allowed to be empty-set when there are no known accesses. ConstantRange Range; - std::map Accesses; + std::set UnsafeAccesses; // List of calls which pass address as an argument. // Value is offset range of address from base address (alloca or calling @@ -131,10 +133,9 @@ UseInfo(unsigned PointerSize) : Range{PointerSize, false} {} void updateRange(const ConstantRange &R) { Range = unionNoWrap(Range, R); } - void addRange(const Instruction *I, const ConstantRange &R) { - auto Ins = Accesses.emplace(I, R); - if (!Ins.second) - Ins.first->second = unionNoWrap(Ins.first->second, R); + void addRange(const Instruction *I, const ConstantRange &R, bool IsSafe) { + if (!IsSafe) + UnsafeAccesses.insert(I); updateRange(R); } }; @@ -230,7 +231,7 @@ struct StackSafetyGlobalInfo::InfoTy { GVToSSI Info; SmallPtrSet SafeAllocas; - std::map AccessIsUnsafe; + std::set UnsafeAccesses; }; namespace { @@ -253,6 +254,11 @@ void analyzeAllUses(Value *Ptr, UseInfo &AS, const StackLifetime &SL); + + bool isSafeAccess(const Use &U, AllocaInst *AI, const SCEV *AccessSize); + bool isSafeAccess(const Use &U, AllocaInst *AI, Value *V); + bool isSafeAccess(const Use &U, AllocaInst *AI, TypeSize AccessSize); + public: StackSafetyLocalAnalysis(Function &F, ScalarEvolution &SE) : F(F), DL(F.getParent()->getDataLayout()), SE(SE), @@ -333,6 +339,56 @@ return getAccessRange(U, Base, SizeRange); } +bool StackSafetyLocalAnalysis::isSafeAccess(const Use &U, AllocaInst *AI, + Value *V) { + return isSafeAccess(U, AI, SE.getSCEV(V)); +} + +bool StackSafetyLocalAnalysis::isSafeAccess(const Use &U, AllocaInst *AI, + TypeSize TS) { + if (TS.isScalable()) + return false; + auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize); + const SCEV *SV = SE.getConstant(CalculationTy, TS.getFixedSize()); + return isSafeAccess(U, AI, SV); +} + +bool StackSafetyLocalAnalysis::isSafeAccess(const Use &U, AllocaInst *AI, + const SCEV *AccessSize) { + + if (!AI) + return true; + if (isa(AccessSize)) + return false; + + const auto *I = cast(U.getUser()); + + auto ToCharPtr = [&](const SCEV *V) { + auto *PtrTy = IntegerType::getInt8PtrTy(SE.getContext()); + return SE.getTruncateOrZeroExtend(V, PtrTy); + }; + + const SCEV *AddrExp = ToCharPtr(SE.getSCEV(U.get())); + const SCEV *BaseExp = ToCharPtr(SE.getSCEV(AI)); + const SCEV *Diff = SE.getMinusSCEV(AddrExp, BaseExp); + if (isa(Diff)) + return false; + + auto Size = getStaticAllocaSizeRange(*AI); + + auto *CalculationTy = IntegerType::getIntNTy(SE.getContext(), PointerSize); + auto ToDiffTy = [&](const SCEV *V) { + return SE.getTruncateOrZeroExtend(V, CalculationTy); + }; + const SCEV *Min = ToDiffTy(SE.getConstant(Size.getLower())); + const SCEV *Max = SE.getMinusSCEV(ToDiffTy(SE.getConstant(Size.getUpper())), + ToDiffTy(AccessSize)); + return SE.evaluatePredicateAt(ICmpInst::Predicate::ICMP_SGE, Diff, Min, I) + .getValueOr(false) && + SE.evaluatePredicateAt(ICmpInst::Predicate::ICMP_SLE, Diff, Max, I) + .getValueOr(false); +} + /// The function analyzes all local uses of Ptr (alloca or argument) and /// calculates local access range and all function calls where it was used. void StackSafetyLocalAnalysis::analyzeAllUses(Value *Ptr, @@ -341,7 +397,7 @@ SmallPtrSet Visited; SmallVector WorkList; WorkList.push_back(Ptr); - const AllocaInst *AI = dyn_cast(Ptr); + AllocaInst *AI = dyn_cast(Ptr); // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc. while (!WorkList.empty()) { @@ -356,11 +412,13 @@ switch (I->getOpcode()) { case Instruction::Load: { if (AI && !SL.isAliveAfter(AI, I)) { - US.addRange(I, UnknownRange); + US.addRange(I, UnknownRange, /*IsSafe=*/false); break; } - US.addRange(I, - getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType()))); + auto TypeSize = DL.getTypeStoreSize(I->getType()); + auto AccessRange = getAccessRange(UI, Ptr, TypeSize); + bool Safe = isSafeAccess(UI, AI, TypeSize); + US.addRange(I, AccessRange, Safe); break; } @@ -370,16 +428,17 @@ case Instruction::Store: { if (V == I->getOperand(0)) { // Stored the pointer - conservatively assume it may be unsafe. - US.addRange(I, UnknownRange); + US.addRange(I, UnknownRange, /*IsSafe=*/false); break; } if (AI && !SL.isAliveAfter(AI, I)) { - US.addRange(I, UnknownRange); + US.addRange(I, UnknownRange, /*IsSafe=*/false); break; } - US.addRange( - I, getAccessRange( - UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType()))); + auto TypeSize = DL.getTypeStoreSize(I->getOperand(0)->getType()); + auto AccessRange = getAccessRange(UI, Ptr, TypeSize); + bool Safe = isSafeAccess(UI, AI, TypeSize); + US.addRange(I, AccessRange, Safe); break; } @@ -387,7 +446,7 @@ // Information leak. // FIXME: Process parameters correctly. This is a leak only if we return // alloca. - US.addRange(I, UnknownRange); + US.addRange(I, UnknownRange, /*IsSafe=*/false); break; case Instruction::Call: @@ -396,12 +455,20 @@ break; if (AI && !SL.isAliveAfter(AI, I)) { - US.addRange(I, UnknownRange); + US.addRange(I, UnknownRange, /*IsSafe=*/false); break; } - if (const MemIntrinsic *MI = dyn_cast(I)) { - US.addRange(I, getMemIntrinsicAccessRange(MI, UI, Ptr)); + auto AccessRange = getMemIntrinsicAccessRange(MI, UI, Ptr); + bool Safe = false; + if (const auto *MTI = dyn_cast(MI)) { + if (MTI->getRawSource() != UI && MTI->getRawDest() != UI) + Safe = true; + } else if (MI->getRawDest() != UI) { + Safe = true; + } + Safe = Safe || isSafeAccess(UI, AI, MI->getLength()); + US.addRange(I, AccessRange, Safe); break; } @@ -412,15 +479,16 @@ } if (!CB.isArgOperand(&UI)) { - US.addRange(I, UnknownRange); + US.addRange(I, UnknownRange, /*IsSafe=*/false); break; } unsigned ArgNo = CB.getArgOperandNo(&UI); if (CB.isByValArgument(ArgNo)) { - US.addRange(I, getAccessRange( - UI, Ptr, - DL.getTypeStoreSize(CB.getParamByValType(ArgNo)))); + auto TypeSize = DL.getTypeStoreSize(CB.getParamByValType(ArgNo)); + auto AccessRange = getAccessRange(UI, Ptr, TypeSize); + bool Safe = isSafeAccess(UI, AI, TypeSize); + US.addRange(I, AccessRange, Safe); break; } @@ -430,7 +498,7 @@ const GlobalValue *Callee = dyn_cast(CB.getCalledOperand()->stripPointerCasts()); if (!Callee) { - US.addRange(I, UnknownRange); + US.addRange(I, UnknownRange, /*IsSafe=*/false); break; } @@ -827,8 +895,8 @@ Info->SafeAllocas.insert(AI); ++NumAllocaStackSafe; } - for (const auto &A : KV.second.Accesses) - Info->AccessIsUnsafe[A.first] |= !AIRange.contains(A.second); + Info->UnsafeAccesses.insert(KV.second.UnsafeAccesses.begin(), + KV.second.UnsafeAccesses.end()); } } @@ -903,11 +971,7 @@ bool StackSafetyGlobalInfo::stackAccessIsSafe(const Instruction &I) const { const auto &Info = getInfo(); - auto It = Info.AccessIsUnsafe.find(&I); - if (It == Info.AccessIsUnsafe.end()) { - return true; - } - return !It->second; + return Info.UnsafeAccesses.find(&I) == Info.UnsafeAccesses.end(); } void StackSafetyGlobalInfo::print(raw_ostream &O) const { diff --git a/llvm/test/Analysis/StackSafetyAnalysis/local.ll b/llvm/test/Analysis/StackSafetyAnalysis/local.ll --- a/llvm/test/Analysis/StackSafetyAnalysis/local.ll +++ b/llvm/test/Analysis/StackSafetyAnalysis/local.ll @@ -44,6 +44,53 @@ ret void } +define void @StoreInBoundsCond(i64 %i) { +; CHECK-LABEL: @StoreInBoundsCond dso_preemptable{{$}} +; CHECK-NEXT: args uses: +; CHECK-NEXT: allocas uses: +; CHECK-NEXT: x[4]: full-set{{$}} +; GLOBAL-NEXT: safe accesses: +; GLOBAL-NEXT: store i8 0, i8* %x2, align 1 +; CHECK-EMPTY: +entry: + %x = alloca i32, align 4 + %x1 = bitcast i32* %x to i8* + %c1 = icmp sge i64 %i, 0 + %c2 = icmp slt i64 %i, 4 + br i1 %c1, label %c1.true, label %false + +c1.true: + br i1 %c2, label %c2.true, label %false + +c2.true: + %x2 = getelementptr i8, i8* %x1, i64 %i + store i8 0, i8* %x2, align 1 + br label %false + +false: + ret void +} + +define void @StoreInBoundsMinMax(i64 %i) { +; CHECK-LABEL: @StoreInBoundsMinMax dso_preemptable{{$}} +; CHECK-NEXT: args uses: +; CHECK-NEXT: allocas uses: +; CHECK-NEXT: x[4]: [0,4){{$}} +; GLOBAL-NEXT: safe accesses: +; GLOBAL-NEXT: store i8 0, i8* %x2, align 1 +; CHECK-EMPTY: +entry: + %x = alloca i32, align 4 + %x1 = bitcast i32* %x to i8* + %c1 = icmp sge i64 %i, 0 + %i1 = select i1 %c1, i64 %i, i64 0 + %c2 = icmp slt i64 %i1, 3 + %i2 = select i1 %c2, i64 %i1, i64 3 + %x2 = getelementptr i8, i8* %x1, i64 %i2 + store i8 0, i8* %x2, align 1 + ret void +} + define void @StoreInBounds2() { ; CHECK-LABEL: @StoreInBounds2 dso_preemptable{{$}} ; CHECK-NEXT: args uses: @@ -157,6 +204,54 @@ ret void } +define void @StoreOutOfBoundsCond(i64 %i) { +; CHECK-LABEL: @StoreOutOfBoundsCond dso_preemptable{{$}} +; CHECK-NEXT: args uses: +; CHECK-NEXT: allocas uses: +; CHECK-NEXT: x[4]: full-set{{$}} +; GLOBAL-NEXT: safe accesses: +; CHECK-EMPTY: +entry: + %x = alloca i32, align 4 + %x1 = bitcast i32* %x to i8* + %c1 = icmp sge i64 %i, 0 + %c2 = icmp slt i64 %i, 5 + br i1 %c1, label %c1.true, label %false + +c1.true: + br i1 %c2, label %c2.true, label %false + +c2.true: + %x2 = getelementptr i8, i8* %x1, i64 %i + store i8 0, i8* %x2, align 1 + br label %false + +false: + ret void +} + +define void @StoreOutOfBoundsCond2(i64 %i) { +; CHECK-LABEL: @StoreOutOfBoundsCond2 dso_preemptable{{$}} +; CHECK-NEXT: args uses: +; CHECK-NEXT: allocas uses: +; CHECK-NEXT: x[4]: full-set{{$}} +; GLOBAL-NEXT: safe accesses: +; CHECK-EMPTY: +entry: + %x = alloca i32, align 4 + %x1 = bitcast i32* %x to i8* + %c2 = icmp slt i64 %i, 5 + br i1 %c2, label %c2.true, label %false + +c2.true: + %x2 = getelementptr i8, i8* %x1, i64 %i + store i8 0, i8* %x2, align 1 + br label %false + +false: + ret void +} + define void @StoreOutOfBounds2() { ; CHECK-LABEL: @StoreOutOfBounds2 dso_preemptable{{$}} ; CHECK-NEXT: args uses: diff --git a/llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll b/llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll --- a/llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll +++ b/llvm/test/Analysis/StackSafetyAnalysis/memintrin.ll @@ -233,3 +233,42 @@ call void @llvm.memmove.p0i8.p0i8.i32(i8* %x1, i8* %x2, i32 9, i1 false) ret void } + +define void @MemsetInBoundsCast() { +; CHECK-LABEL: MemsetInBoundsCast dso_preemptable{{$}} +; CHECK-NEXT: args uses: +; CHECK-NEXT: allocas uses: +; CHECK-NEXT: x[4]: [0,4){{$}} +; CHECK-NEXT: y[1]: empty-set{{$}} +; GLOBAL-NEXT: safe accesses: +; GLOBAL-NEXT: call void @llvm.memset.p0i8.i32(i8* %x1, i8 %yint, i32 4, i1 false) +; CHECK-EMPTY: +entry: + %x = alloca i32, align 4 + %y = alloca i8, align 1 + %x1 = bitcast i32* %x to i8* + %yint = ptrtoint i8* %y to i8 + call void @llvm.memset.p0i8.i32(i8* %x1, i8 %yint, i32 4, i1 false) + ret void +} + +define void @MemcpyInBoundsCast2(i8 %zint8) { +; CHECK-LABEL: MemcpyInBoundsCast2 dso_preemptable{{$}} +; CHECK-NEXT: args uses: +; CHECK-NEXT: allocas uses: +; CHECK-NEXT: x[256]: [0,255){{$}} +; CHECK-NEXT: y[256]: [0,255){{$}} +; CHECK-NEXT: z[1]: empty-set{{$}} +; GLOBAL-NEXT: safe accesses: +; GLOBAL-NEXT: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %x1, i8* %y1, i32 %zint32, i1 false) +; CHECK-EMPTY: +entry: + %x = alloca [256 x i8], align 4 + %y = alloca [256 x i8], align 4 + %z = alloca i8, align 1 + %x1 = bitcast [256 x i8]* %x to i8* + %y1 = bitcast [256 x i8]* %y to i8* + %zint32 = zext i8 %zint8 to i32 + call void @llvm.memcpy.p0i8.p0i8.i32(i8* %x1, i8* %y1, i32 %zint32, i1 false) + ret void +}