Index: lib/Transforms/IPO/WholeProgramDevirt.cpp =================================================================== --- lib/Transforms/IPO/WholeProgramDevirt.cpp +++ lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -31,11 +31,13 @@ #include "llvm/Transforms/IPO.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" +#include "llvm/Analysis/ConstantFolding.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" @@ -173,6 +175,7 @@ struct VTableSlot { Metadata *BitSetID; uint64_t ByteOffset; + bool IsRelative; }; } @@ -182,19 +185,22 @@ template <> struct DenseMapInfo { static VTableSlot getEmptyKey() { return {DenseMapInfo::getEmptyKey(), - DenseMapInfo::getEmptyKey()}; + DenseMapInfo::getEmptyKey(), + false}; } static VTableSlot getTombstoneKey() { return {DenseMapInfo::getTombstoneKey(), - DenseMapInfo::getTombstoneKey()}; + DenseMapInfo::getTombstoneKey(), + true}; } static unsigned getHashValue(const VTableSlot &I) { return DenseMapInfo::getHashValue(I.BitSetID) ^ - DenseMapInfo::getHashValue(I.ByteOffset); + DenseMapInfo::getHashValue(I.ByteOffset) + I.IsRelative; } static bool isEqual(const VTableSlot &LHS, const VTableSlot &RHS) { - return LHS.BitSetID == RHS.BitSetID && LHS.ByteOffset == RHS.ByteOffset; + return LHS.BitSetID == RHS.BitSetID && LHS.ByteOffset == RHS.ByteOffset && + LHS.IsRelative == RHS.IsRelative; } }; @@ -233,13 +239,13 @@ void findLoadCallsAtConstantOffset(Metadata *BitSet, Value *Ptr, uint64_t Offset, Value *VTable); void findCallsAtConstantOffset(Metadata *BitSet, Value *Ptr, uint64_t Offset, - Value *VTable); + bool IsRelative, Value *VTable); void buildBitSets(std::vector &Bits, DenseMap> &BitSets); bool tryFindVirtualCallTargets(std::vector &TargetsForSlot, const std::set &BitSetInfos, - uint64_t ByteOffset); + uint64_t ByteOffset, bool IsRelative); bool trySingleImplDevirt(ArrayRef TargetsForSlot, MutableArrayRef CallSites); bool tryEvaluateFunctionsWithArgs( @@ -279,15 +285,16 @@ // Search for virtual calls that call FPtr and add them to CallSlots. void DevirtModule::findCallsAtConstantOffset(Metadata *BitSet, Value *FPtr, - uint64_t Offset, Value *VTable) { + uint64_t Offset, bool IsRelative, + Value *VTable) { for (const Use &U : FPtr->uses()) { Value *User = U.getUser(); if (isa(User)) { - findCallsAtConstantOffset(BitSet, User, Offset, VTable); + findCallsAtConstantOffset(BitSet, User, Offset, IsRelative, VTable); } else if (auto CI = dyn_cast(User)) { - CallSlots[{BitSet, Offset}].push_back({VTable, CI}); + CallSlots[{BitSet, Offset, IsRelative}].push_back({VTable, CI}); } else if (auto II = dyn_cast(User)) { - CallSlots[{BitSet, Offset}].push_back({VTable, II}); + CallSlots[{BitSet, Offset, IsRelative}].push_back({VTable, II}); } } } @@ -301,7 +308,15 @@ if (isa(User)) { findLoadCallsAtConstantOffset(BitSet, User, Offset, VTable); } else if (isa(User)) { - findCallsAtConstantOffset(BitSet, User, Offset, VTable); + findCallsAtConstantOffset(BitSet, User, Offset, false, VTable); + } else if (auto II = dyn_cast(User)) { + if (II->getIntrinsicID() == Intrinsic::load_relative && + isa(II->getArgOperand(1))) { + uint64_t LoadOffset = + cast(II->getArgOperand(1))->getZExtValue(); + findCallsAtConstantOffset(BitSet, User, Offset + LoadOffset, true, + VTable); + } } else if (auto GEP = dyn_cast(User)) { // Take into account the GEP offset. if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) { @@ -353,29 +368,80 @@ } } +namespace { + +Constant *getElementAtOffset(Constant *Init, uint64_t Offset, const DataLayout &DL) { + if (auto *CA = dyn_cast(Init)) { + ArrayType *InitTy = CA->getType(); + + uint64_t ElemSize = DL.getTypeAllocSize(InitTy->getElementType()); + if (Offset % ElemSize != 0) + return nullptr; + + unsigned Op = Offset / ElemSize; + if (Op >= Init->getNumOperands()) + return nullptr; + + return CA->getOperand(Op); + } else if (auto *CS = dyn_cast(Init)) { + StructType *InitTy = CS->getType(); + + const StructLayout *SL = DL.getStructLayout(InitTy); + unsigned Elem = SL->getElementContainingOffset(Offset); + if (SL->getElementOffset(Elem) != Offset) + return nullptr; + + return CS->getOperand(Elem); + } else { + return nullptr; + } +} + +} + bool DevirtModule::tryFindVirtualCallTargets( std::vector &TargetsForSlot, - const std::set &BitSetInfos, uint64_t ByteOffset) { + const std::set &BitSetInfos, uint64_t ByteOffset, + bool IsRelative) { for (const BitSetInfo &BS : BitSetInfos) { if (!BS.Bits->GV->isConstant()) return false; - auto Init = dyn_cast(BS.Bits->GV->getInitializer()); - if (!Init) + Constant *Elem = + getElementAtOffset(BS.Bits->GV->getInitializer(), + BS.Offset + ByteOffset, M.getDataLayout()); + if (!Elem) return false; - ArrayType *VTableTy = Init->getType(); - uint64_t ElemSize = - M.getDataLayout().getTypeAllocSize(VTableTy->getElementType()); - uint64_t GlobalSlotOffset = BS.Offset + ByteOffset; - if (GlobalSlotOffset % ElemSize != 0) - return false; + if (IsRelative) { + auto *CE = dyn_cast(Elem); + if (!CE) + return false; + if (CE->getOpcode() == Instruction::Trunc) { + CE = dyn_cast(CE->getOperand(0)); + if (!CE) + return false; + } - unsigned Op = GlobalSlotOffset / ElemSize; - if (Op >= Init->getNumOperands()) - return false; + if (CE->getOpcode() != Instruction::Sub) + return false; + + auto *RHS = dyn_cast(CE->getOperand(1)); + GlobalValue *RHSGV; + APInt RHSOffset; + if (!RHS || RHS->getOpcode() != Instruction::PtrToInt || + !IsConstantOffsetFromGlobal(RHS->getOperand(0), RHSGV, RHSOffset, + M.getDataLayout()) || + RHSGV != BS.Bits->GV || RHSOffset != BS.Offset) + return false; + + CE = dyn_cast(CE->getOperand(0)); + if (!CE || CE->getOpcode() != Instruction::PtrToInt) + return false; + Elem = CE->getOperand(0); + } - auto Fn = dyn_cast(Init->getOperand(Op)->stripPointerCasts()); + auto *Fn = dyn_cast(Elem->stripPointerCasts()); if (!Fn) return false; @@ -709,7 +775,7 @@ // implementation at offset S.first.ByteOffset, and add to TargetsForSlot. std::vector TargetsForSlot; if (!tryFindVirtualCallTargets(TargetsForSlot, BitSets[S.first.BitSetID], - S.first.ByteOffset)) + S.first.ByteOffset, S.first.IsRelative)) continue; if (trySingleImplDevirt(TargetsForSlot, S.second)) Index: test/Transforms/WholeProgramDevirt/bad-read-from-vtable-struct.ll =================================================================== --- /dev/null +++ test/Transforms/WholeProgramDevirt/bad-read-from-vtable-struct.ll @@ -0,0 +1,64 @@ +; RUN: opt -S -wholeprogramdevirt %s | FileCheck %s + +target datalayout = "e-p:64:64" +target triple = "x86_64-unknown-linux-gnu" + +@vt = global { i8*, i8* } { i8* zeroinitializer, i8* bitcast (void (i8*)* @vf to i8*) } + +define void @vf(i8* %this) { + ret void +} + +; CHECK: define void @unaligned +define void @unaligned(i8* %obj) { + %vtableptr = bitcast i8* %obj to [1 x i8*]** + %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr + %vtablei8 = bitcast [1 x i8*]* %vtable to i8* + %p = call i1 @llvm.bitset.test(i8* %vtablei8, metadata !"bitset") + call void @llvm.assume(i1 %p) + %fptrptr = getelementptr i8, i8* %vtablei8, i32 1 + %fptrptr_casted = bitcast i8* %fptrptr to i8** + %fptr = load i8*, i8** %fptrptr_casted + %fptr_casted = bitcast i8* %fptr to void (i8*)* + ; CHECK: call void % + call void %fptr_casted(i8* %obj) + ret void +} + +; CHECK: define void @outofbounds +define void @outofbounds(i8* %obj) { + %vtableptr = bitcast i8* %obj to [1 x i8*]** + %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr + %vtablei8 = bitcast [1 x i8*]* %vtable to i8* + %p = call i1 @llvm.bitset.test(i8* %vtablei8, metadata !"bitset") + call void @llvm.assume(i1 %p) + %fptrptr = getelementptr i8, i8* %vtablei8, i32 16 + %fptrptr_casted = bitcast i8* %fptrptr to i8** + %fptr = load i8*, i8** %fptrptr_casted + %fptr_casted = bitcast i8* %fptr to void (i8*)* + ; CHECK: call void % + call void %fptr_casted(i8* %obj) + ret void +} + +; CHECK: define void @nonfunction +define void @nonfunction(i8* %obj) { + %vtableptr = bitcast i8* %obj to [1 x i8*]** + %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr + %vtablei8 = bitcast [1 x i8*]* %vtable to i8* + %p = call i1 @llvm.bitset.test(i8* %vtablei8, metadata !"bitset") + call void @llvm.assume(i1 %p) + %fptrptr = getelementptr i8, i8* %vtablei8, i32 0 + %fptrptr_casted = bitcast i8* %fptrptr to i8** + %fptr = load i8*, i8** %fptrptr_casted + %fptr_casted = bitcast i8* %fptr to void (i8*)* + ; CHECK: call void % + call void %fptr_casted(i8* %obj) + ret void +} + +declare i1 @llvm.bitset.test(i8*, metadata) +declare void @llvm.assume(i1) + +!0 = !{!"bitset", { i8*, i8* }* @vt, i32 0} +!llvm.bitsets = !{!0} Index: test/Transforms/WholeProgramDevirt/bad-relative-vtable.ll =================================================================== --- /dev/null +++ test/Transforms/WholeProgramDevirt/bad-relative-vtable.ll @@ -0,0 +1,49 @@ +; RUN: opt -S -wholeprogramdevirt %s | FileCheck %s + +target datalayout = "e-p:64:64" +target triple = "x86_64-unknown-linux-gnu" + +@g = external constant i32 +@vt1 = constant { i32, i32, i32 } { i32 0, + i32 trunc (i64 sub (i64 ptrtoint (void (i8*)* @vf to i64), i64 ptrtoint (i32* getelementptr ({ i32, i32, i32 }, { i32, i32, i32 }* @vt1, i32 0, i32 1) to i64)) to i32), + i32 trunc (i64 sub (i64 ptrtoint (void (i8*)* @vf to i64), i64 ptrtoint (i32* @g to i64)) to i32) +} + +define void @vf(i8* %this) { + ret void +} + +; CHECK: define void @call1 +define void @call1(i8* %obj) { + %vtableptr = bitcast i8* %obj to [1 x i8*]** + %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr + %vtablei8 = bitcast [1 x i8*]* %vtable to i8* + %p = call i1 @llvm.bitset.test(i8* %vtablei8, metadata !"bitset") + call void @llvm.assume(i1 %p) + %fptr = call i8* @llvm.load.relative.i32(i8* %vtablei8, i32 4) + %fptr_casted = bitcast i8* %fptr to void (i8*)* + ; CHECK: call void % + call void %fptr_casted(i8* %obj) + ret void +} + +; CHECK: define void @call2 +define void @call2(i8* %obj) { + %vtableptr = bitcast i8* %obj to [1 x i8*]** + %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr + %vtablei8 = bitcast [1 x i8*]* %vtable to i8* + %p = call i1 @llvm.bitset.test(i8* %vtablei8, metadata !"bitset") + call void @llvm.assume(i1 %p) + %fptr = call i8* @llvm.load.relative.i32(i8* %vtablei8, i32 8) + %fptr_casted = bitcast i8* %fptr to void (i8*)* + ; CHECK: call void % + call void %fptr_casted(i8* %obj) + ret void +} + +declare i1 @llvm.bitset.test(i8*, metadata) +declare i8* @llvm.load.relative.i32(i8*, i32) +declare void @llvm.assume(i1) + +!0 = !{!"bitset", { i32, i32, i32 }* @vt1, i32 0} +!llvm.bitsets = !{!0} Index: test/Transforms/WholeProgramDevirt/devirt-single-impl-relative-32.ll =================================================================== --- /dev/null +++ test/Transforms/WholeProgramDevirt/devirt-single-impl-relative-32.ll @@ -0,0 +1,33 @@ +; RUN: opt -S -wholeprogramdevirt %s | FileCheck %s + +target datalayout = "e-p:32:32" +target triple = "i386-unknown-linux-gnu" + +@vt1 = constant { i32, i32 } { i32 0, i32 sub (i32 ptrtoint (void (i8*)* @vf to i32), i32 ptrtoint ({ i32, i32 }* @vt1 to i32)) } +@vt2 = constant { i32, i32 } { i32 0, i32 sub (i32 ptrtoint (void (i8*)* @vf to i32), i32 ptrtoint ({ i32, i32 }* @vt2 to i32)) } + +define void @vf(i8* %this) { + ret void +} + +; CHECK: define void @call +define void @call(i8* %obj) { + %vtableptr = bitcast i8* %obj to [1 x i8*]** + %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr + %vtablei8 = bitcast [1 x i8*]* %vtable to i8* + %p = call i1 @llvm.bitset.test(i8* %vtablei8, metadata !"bitset") + call void @llvm.assume(i1 %p) + %fptr = call i8* @llvm.load.relative.i32(i8* %vtablei8, i32 4) + %fptr_casted = bitcast i8* %fptr to void (i8*)* + ; CHECK: call void @vf( + call void %fptr_casted(i8* %obj) + ret void +} + +declare i1 @llvm.bitset.test(i8*, metadata) +declare i8* @llvm.load.relative.i32(i8*, i32) +declare void @llvm.assume(i1) + +!0 = !{!"bitset", { i32, i32 }* @vt1, i32 0} +!1 = !{!"bitset", { i32, i32 }* @vt2, i32 0} +!llvm.bitsets = !{!0, !1} Index: test/Transforms/WholeProgramDevirt/devirt-single-impl-relative.ll =================================================================== --- /dev/null +++ test/Transforms/WholeProgramDevirt/devirt-single-impl-relative.ll @@ -0,0 +1,33 @@ +; RUN: opt -S -wholeprogramdevirt %s | FileCheck %s + +target datalayout = "e-p:64:64" +target triple = "x86_64-unknown-linux-gnu" + +@vt1 = constant { i32, i32 } { i32 0, i32 trunc (i64 sub (i64 ptrtoint (void (i8*)* @vf to i64), i64 ptrtoint ({ i32, i32 }* @vt1 to i64)) to i32) } +@vt2 = constant { i32, i32 } { i32 0, i32 trunc (i64 sub (i64 ptrtoint (void (i8*)* @vf to i64), i64 ptrtoint ({ i32, i32 }* @vt2 to i64)) to i32) } + +define void @vf(i8* %this) { + ret void +} + +; CHECK: define void @call +define void @call(i8* %obj) { + %vtableptr = bitcast i8* %obj to [1 x i8*]** + %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr + %vtablei8 = bitcast [1 x i8*]* %vtable to i8* + %p = call i1 @llvm.bitset.test(i8* %vtablei8, metadata !"bitset") + call void @llvm.assume(i1 %p) + %fptr = call i8* @llvm.load.relative.i32(i8* %vtablei8, i32 4) + %fptr_casted = bitcast i8* %fptr to void (i8*)* + ; CHECK: call void @vf( + call void %fptr_casted(i8* %obj) + ret void +} + +declare i1 @llvm.bitset.test(i8*, metadata) +declare i8* @llvm.load.relative.i32(i8*, i32) +declare void @llvm.assume(i1) + +!0 = !{!"bitset", { i32, i32 }* @vt1, i32 0} +!1 = !{!"bitset", { i32, i32 }* @vt2, i32 0} +!llvm.bitsets = !{!0, !1} Index: test/Transforms/WholeProgramDevirt/non-relative.ll =================================================================== --- /dev/null +++ test/Transforms/WholeProgramDevirt/non-relative.ll @@ -0,0 +1,33 @@ +; RUN: opt -S -wholeprogramdevirt %s | FileCheck %s + +target datalayout = "e-p:64:64" +target triple = "x86_64-unknown-linux-gnu" + +@vt1 = constant { i32, i32 } { i32 0, i32 trunc (i64 sub (i64 ptrtoint (void (i8*)* @vf to i64), i64 ptrtoint ({ i32, i32 }* @vt1 to i64)) to i32) } +@vt2 = constant { i32, i32 } { i32 0, i32 trunc (i64 sub (i64 ptrtoint (void (i8*)* @vf to i64), i64 ptrtoint ({ i32, i32 }* @vt2 to i64)) to i32) } + +define void @vf(i8* %this) { + ret void +} + +; CHECK: define void @call +define void @call(i8* %obj) { + %vtableptr = bitcast i8* %obj to { i32, i8* }** + %vtable = load { i32, i8* }*, { i32, i8* }** %vtableptr + %vtablei8 = bitcast { i32, i8* }* %vtable to i8* + %p = call i1 @llvm.bitset.test(i8* %vtablei8, metadata !"bitset") + call void @llvm.assume(i1 %p) + %fptrptr = getelementptr { i32, i8* }, { i32, i8* }* %vtable, i32 0, i32 1 + %fptr = load i8*, i8** %fptrptr + %fptr_casted = bitcast i8* %fptr to void (i8*)* + ; CHECK: call void % + call void %fptr_casted(i8* %obj) + ret void +} + +declare i1 @llvm.bitset.test(i8*, metadata) +declare void @llvm.assume(i1) + +!0 = !{!"bitset", { i32, i32 }* @vt1, i32 0} +!1 = !{!"bitset", { i32, i32 }* @vt2, i32 0} +!llvm.bitsets = !{!0, !1}