Index: lib/Transforms/IPO/WholeProgramDevirt.cpp =================================================================== --- lib/Transforms/IPO/WholeProgramDevirt.cpp +++ lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -389,22 +389,33 @@ if (!TM.Bits->GV->isConstant()) return false; - auto Init = dyn_cast(TM.Bits->GV->getInitializer()); - if (!Init) - return false; - ArrayType *VTableTy = Init->getType(); + const DataLayout &DL = M.getDataLayout(); + const Constant *I = TM.Bits->GV->getInitializer(); + const uint64_t GlobalSlotOffset = TM.Offset + ByteOffset; + unsigned Op; - uint64_t ElemSize = - M.getDataLayout().getTypeAllocSize(VTableTy->getElementType()); - uint64_t GlobalSlotOffset = TM.Offset + ByteOffset; - if (GlobalSlotOffset % ElemSize != 0) - return false; + if (auto *C = dyn_cast(I)) { + const StructLayout *SL = DL.getStructLayout(C->getType()); + + if (GlobalSlotOffset >= SL->getSizeInBytes()) + return false; + + Op = SL->getElementContainingOffset(GlobalSlotOffset); + } else if (auto *C = dyn_cast(I)) { + ArrayType *VTableTy = C->getType(); + const uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType()); + + if (GlobalSlotOffset % ElemSize != 0) + return false; + + Op = GlobalSlotOffset / ElemSize; - unsigned Op = GlobalSlotOffset / ElemSize; - if (Op >= Init->getNumOperands()) + if (Op >= C->getNumOperands()) + return false; + } else return false; - auto Fn = dyn_cast(Init->getOperand(Op)->stripPointerCasts()); + auto Fn = dyn_cast(I->getOperand(Op)->stripPointerCasts()); if (!Fn) return false; Index: test/Transforms/WholeProgramDevirt/non-array-vtable.ll =================================================================== --- test/Transforms/WholeProgramDevirt/non-array-vtable.ll +++ test/Transforms/WholeProgramDevirt/non-array-vtable.ll @@ -3,18 +3,21 @@ target datalayout = "e-p:64:64" target triple = "x86_64-unknown-linux-gnu" -@vt = constant i8* bitcast (void (i8*)* @vf to i8*), !type !0 +%vtTy = type { void (i8*)* } + +@vt1 = constant i8* bitcast (void (i8*)* @vf to i8*), !type !0 +@vt2 = constant %vtTy { void (i8*)* @vf }, !type !1 define void @vf(i8* %this) { ret void } -; CHECK: define void @call -define void @call(i8* %obj) { +; CHECK: define void @call1 +define void @call1(i8* %obj) { %vtableptr = bitcast i8* %obj to [1 x i8*]** %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr %vtablei8 = bitcast [1 x i8*]* %vtable to i8* - %p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid") + %p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid1") call void @llvm.assume(i1 %p) %fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0 %fptr = load i8*, i8** %fptrptr @@ -24,7 +27,23 @@ ret void } +; CHECK: define void @call2 +define void @call2(i8* %obj) { + %vtableptr = bitcast i8* %obj to [1 x i8*]** + %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr + %vtablei8 = bitcast [1 x i8*]* %vtable to i8* + %p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid2") + call void @llvm.assume(i1 %p) + %fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 0 + %fptr = load i8*, i8** %fptrptr + %fptr_casted = bitcast i8* %fptr to void (i8*)* + ; CHECK: call void @vf( + call void %fptr_casted(i8* %obj) + ret void +} + declare i1 @llvm.type.test(i8*, metadata) declare void @llvm.assume(i1) -!0 = !{i32 0, !"typeid"} +!0 = !{i32 0, !"typeid1"} +!1 = !{i32 0, !"typeid2"}