diff --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h --- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h +++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h @@ -48,6 +48,29 @@ Instruction *promoteCallWithIfThenElse(CallSite CS, Function *Callee, MDNode *BranchWeights = nullptr); +/// Try to promote (devirtualize) a virtual call on an Alloca. Return true on +/// success. +/// +/// Look for a pattern like: +/// +/// %o = alloca %class.Impl +/// %1 = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0 +/// store i32 (...)** bitcast (i8** getelementptr inbounds +/// ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, inrange i32 0, i64 2) +/// to i32 (...)**), i32 (...)*** %1 +/// %2 = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0 +/// %3 = bitcast %class.Interface* %2 to void (%class.Interface*)*** +/// %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %3 +/// %4 = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i +/// call void %4(%class.Interface* nonnull %2) +/// +/// @_ZTV4Impl = linkonce_odr dso_local unnamed_addr constant { [3 x i8*] } +/// { [3 x i8*] +/// [i8* null, i8* bitcast ({ i8*, i8*, i8* }* @_ZTI4Impl to i8*), +/// i8* bitcast (void (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] } +/// +bool tryPromoteCall(CallSite &CS); + } // end namespace llvm #endif // LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp --- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -12,6 +12,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CallPromotionUtils.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -458,4 +460,60 @@ return promoteCall(CallSite(NewInst), Callee); } +bool llvm::tryPromoteCall(CallSite &CS) { + assert(!CS.getCalledFunction()); + Module *M = CS.getCaller()->getParent(); + const DataLayout &DL = M->getDataLayout(); + Value *Callee = CS.getCalledValue(); + + LoadInst *VTableEntryLoad = dyn_cast(Callee); + if (!VTableEntryLoad) + return false; // Not a vtable entry load. + Value *VTableEntryPtr = VTableEntryLoad->getPointerOperand(); + APInt VTableOffset(DL.getTypeSizeInBits(VTableEntryPtr->getType()), 0); + Value *VTableBasePtr = VTableEntryPtr->stripAndAccumulateConstantOffsets( + DL, VTableOffset, /* AllowNonInbounds */ true); + LoadInst *VTablePtrLoad = dyn_cast(VTableBasePtr); + if (!VTablePtrLoad) + return false; // Not a vtable load. + Value *Object = VTablePtrLoad->getPointerOperand(); + APInt ObjectOffset(DL.getTypeSizeInBits(Object->getType()), 0); + Value *ObjectBase = Object->stripAndAccumulateConstantOffsets( + DL, ObjectOffset, /* AllowNonInbounds */ true); + if (!(isa(ObjectBase) && ObjectOffset == 0)) + // Not an Alloca or the offset isn't zero. + return false; + + // Look for the vtable pointer store into the object by the ctor. + BasicBlock::iterator BBI(VTablePtrLoad); + Value *VTablePtr = FindAvailableLoadedValue( + VTablePtrLoad, VTablePtrLoad->getParent(), BBI, 0, nullptr, nullptr); + if (!VTablePtr) + return false; // No vtable found. + APInt VTableOffsetGVBase(DL.getTypeSizeInBits(VTablePtr->getType()), 0); + Value *VTableGVBase = VTablePtr->stripAndAccumulateConstantOffsets( + DL, VTableOffsetGVBase, /* AllowNonInbounds */ true); + GlobalVariable *GV = dyn_cast(VTableGVBase); + if (!(GV && GV->isConstant() && GV->hasDefinitiveInitializer())) + // Not in the form of a global constant variable with an initializer. + return false; + + Constant *VTableGVInitializer = GV->getInitializer(); + APInt VTableGVOffset = VTableOffsetGVBase + VTableOffset; + if (!(VTableGVOffset.getActiveBits() <= 64)) + return false; // Out of range. + Constant *Ptr = getPointerAtOffset(VTableGVInitializer, + VTableGVOffset.getZExtValue(), + *M); + if (!Ptr) + return false; // No constant (function) pointer found. + Function *DirectCallee = dyn_cast(Ptr->stripPointerCasts()); + if (!DirectCallee) + return false; // No function pointer found. + + // Success. + promoteCall(CS, DirectCallee); + return true; +} + #undef DEBUG_TYPE diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt --- a/llvm/unittests/Transforms/Utils/CMakeLists.txt +++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -9,6 +9,7 @@ add_llvm_unittest(UtilsTests ASanStackFrameLayoutTest.cpp BasicBlockUtilsTest.cpp + CallPromotionUtilsTest.cpp CloningTest.cpp CodeExtractorTest.cpp CodeMoverUtilsTest.cpp diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp @@ -0,0 +1,332 @@ +//===- CallPromotionUtilsTest.cpp - CallPromotionUtils unit tests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/CallPromotionUtils.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static std::unique_ptr parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("UtilsTests", errs()); + return Mod; +} + +TEST(CallPromotionUtilsTest, TryPromoteCall) { + LLVMContext C; + std::unique_ptr M = parseIR(C, + R"IR( +%class.Impl = type <{ %class.Interface, i32, [4 x i8] }> +%class.Interface = type { i32 (...)** } + +@_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* bitcast (void (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] } + +define void @f() { +entry: + %o = alloca %class.Impl + %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0 + store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %base + %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1 + store i32 3, i32* %f + %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0 + %c = bitcast %class.Interface* %base.i to void (%class.Interface*)*** + %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c + %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i + call void %fp(%class.Interface* nonnull %base.i) + ret void +} + +declare void @_ZN4Impl3RunEv(%class.Impl* %this) +)IR"); + + auto *GV = M->getNamedValue("f"); + ASSERT_TRUE(GV); + auto *F = dyn_cast(GV); + ASSERT_TRUE(F); + Instruction *Inst = &F->front().front(); + auto *AI = dyn_cast(Inst); + ASSERT_TRUE(AI); + Inst = &*++F->front().rbegin(); + auto *CI = dyn_cast(Inst); + ASSERT_TRUE(CI); + CallSite CS(CI); + ASSERT_FALSE(CS.getCalledFunction()); + bool IsPromoted = tryPromoteCall(CS); + EXPECT_TRUE(IsPromoted); + GV = M->getNamedValue("_ZN4Impl3RunEv"); + ASSERT_TRUE(GV); + auto *F1 = dyn_cast(GV); + EXPECT_EQ(F1, CS.getCalledFunction()); +} + +TEST(CallPromotionUtilsTest, TryPromoteCall_NoFPLoad) { + LLVMContext C; + std::unique_ptr M = parseIR(C, + R"IR( +%class.Impl = type <{ %class.Interface, i32, [4 x i8] }> +%class.Interface = type { i32 (...)** } + +define void @f(void (%class.Interface*)* %fp, %class.Interface* nonnull %base.i) { +entry: + call void %fp(%class.Interface* nonnull %base.i) + ret void +} +)IR"); + + auto *GV = M->getNamedValue("f"); + ASSERT_TRUE(GV); + auto *F = dyn_cast(GV); + ASSERT_TRUE(F); + Instruction *Inst = &F->front().front(); + auto *CI = dyn_cast(Inst); + ASSERT_TRUE(CI); + CallSite CS(CI); + ASSERT_FALSE(CS.getCalledFunction()); + bool IsPromoted = tryPromoteCall(CS); + EXPECT_FALSE(IsPromoted); +} + +TEST(CallPromotionUtilsTest, TryPromoteCall_NoVTablePtrLoad) { + LLVMContext C; + std::unique_ptr M = parseIR(C, + R"IR( +%class.Impl = type <{ %class.Interface, i32, [4 x i8] }> +%class.Interface = type { i32 (...)** } + +define void @f(void (%class.Interface*)** %vtable.i, %class.Interface* nonnull %base.i) { +entry: + %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i + call void %fp(%class.Interface* nonnull %base.i) + ret void +} +)IR"); + + auto *GV = M->getNamedValue("f"); + ASSERT_TRUE(GV); + auto *F = dyn_cast(GV); + ASSERT_TRUE(F); + Instruction *Inst = &*++F->front().rbegin(); + auto *CI = dyn_cast(Inst); + ASSERT_TRUE(CI); + CallSite CS(CI); + ASSERT_FALSE(CS.getCalledFunction()); + bool IsPromoted = tryPromoteCall(CS); + EXPECT_FALSE(IsPromoted); +} + +TEST(CallPromotionUtilsTest, TryPromoteCall_NoVTableInitFound) { + LLVMContext C; + std::unique_ptr M = parseIR(C, + R"IR( +%class.Impl = type <{ %class.Interface, i32, [4 x i8] }> +%class.Interface = type { i32 (...)** } + +define void @f() { +entry: + %o = alloca %class.Impl + %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1 + store i32 3, i32* %f + %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0 + %c = bitcast %class.Interface* %base.i to void (%class.Interface*)*** + %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c + %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i + call void %fp(%class.Interface* nonnull %base.i) + ret void +} + +declare void @_ZN4Impl3RunEv(%class.Impl* %this) +)IR"); + + auto *GV = M->getNamedValue("f"); + ASSERT_TRUE(GV); + auto *F = dyn_cast(GV); + ASSERT_TRUE(F); + Instruction *Inst = &*++F->front().rbegin(); + auto *CI = dyn_cast(Inst); + ASSERT_TRUE(CI); + CallSite CS(CI); + ASSERT_FALSE(CS.getCalledFunction()); + bool IsPromoted = tryPromoteCall(CS); + EXPECT_FALSE(IsPromoted); +} + +TEST(CallPromotionUtilsTest, TryPromoteCall_EmptyVTable) { + LLVMContext C; + std::unique_ptr M = parseIR(C, + R"IR( +%class.Impl = type <{ %class.Interface, i32, [4 x i8] }> +%class.Interface = type { i32 (...)** } + +@_ZTV4Impl = external global { [3 x i8*] } + +define void @f() { +entry: + %o = alloca %class.Impl + %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0 + store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %base + %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1 + store i32 3, i32* %f + %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0 + %c = bitcast %class.Interface* %base.i to void (%class.Interface*)*** + %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c + %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i + call void %fp(%class.Interface* nonnull %base.i) + ret void +} + +declare void @_ZN4Impl3RunEv(%class.Impl* %this) +)IR"); + + auto *GV = M->getNamedValue("f"); + ASSERT_TRUE(GV); + auto *F = dyn_cast(GV); + ASSERT_TRUE(F); + Instruction *Inst = &F->front().front(); + auto *AI = dyn_cast(Inst); + ASSERT_TRUE(AI); + Inst = &*++F->front().rbegin(); + auto *CI = dyn_cast(Inst); + ASSERT_TRUE(CI); + CallSite CS(CI); + ASSERT_FALSE(CS.getCalledFunction()); + bool IsPromoted = tryPromoteCall(CS); + EXPECT_FALSE(IsPromoted); +} + +TEST(CallPromotionUtilsTest, TryPromoteCall_NullFP) { + LLVMContext C; + std::unique_ptr M = parseIR(C, + R"IR( +%class.Impl = type <{ %class.Interface, i32, [4 x i8] }> +%class.Interface = type { i32 (...)** } + +@_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* null] } + +define void @f() { +entry: + %o = alloca %class.Impl + %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0 + store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %base + %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1 + store i32 3, i32* %f + %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0 + %c = bitcast %class.Interface* %base.i to void (%class.Interface*)*** + %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c + %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i + call void %fp(%class.Interface* nonnull %base.i) + ret void +} + +declare void @_ZN4Impl3RunEv(%class.Impl* %this) +)IR"); + + auto *GV = M->getNamedValue("f"); + ASSERT_TRUE(GV); + auto *F = dyn_cast(GV); + ASSERT_TRUE(F); + Instruction *Inst = &F->front().front(); + auto *AI = dyn_cast(Inst); + ASSERT_TRUE(AI); + Inst = &*++F->front().rbegin(); + auto *CI = dyn_cast(Inst); + ASSERT_TRUE(CI); + CallSite CS(CI); + ASSERT_FALSE(CS.getCalledFunction()); + bool IsPromoted = tryPromoteCall(CS); + EXPECT_FALSE(IsPromoted); +} + +// Based on clang/test/CodeGenCXX/member-function-pointer-calls.cpp +TEST(CallPromotionUtilsTest, TryPromoteCall_MemberFunctionCalls) { + LLVMContext C; + std::unique_ptr M = parseIR(C, + R"IR( +%struct.A = type { i32 (...)** } + +@_ZTV1A = linkonce_odr unnamed_addr constant { [4 x i8*] } { [4 x i8*] [i8* null, i8* null, i8* bitcast (i32 (%struct.A*)* @_ZN1A3vf1Ev to i8*), i8* bitcast (i32 (%struct.A*)* @_ZN1A3vf2Ev to i8*)] }, align 8 + +define i32 @_Z2g1v() { +entry: + %a = alloca %struct.A, align 8 + %0 = bitcast %struct.A* %a to i8* + %1 = getelementptr %struct.A, %struct.A* %a, i64 0, i32 0 + store i32 (...)** bitcast (i8** getelementptr inbounds ({ [4 x i8*] }, { [4 x i8*] }* @_ZTV1A, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %1, align 8 + %2 = bitcast %struct.A* %a to i8* + %3 = bitcast i8* %2 to i8** + %vtable.i = load i8*, i8** %3, align 8 + %4 = bitcast i8* %vtable.i to i32 (%struct.A*)** + %memptr.virtualfn.i = load i32 (%struct.A*)*, i32 (%struct.A*)** %4, align 8 + %call.i = call i32 %memptr.virtualfn.i(%struct.A* %a) + ret i32 %call.i +} + +define i32 @_Z2g2v() { +entry: + %a = alloca %struct.A, align 8 + %0 = bitcast %struct.A* %a to i8* + %1 = getelementptr %struct.A, %struct.A* %a, i64 0, i32 0 + store i32 (...)** bitcast (i8** getelementptr inbounds ({ [4 x i8*] }, { [4 x i8*] }* @_ZTV1A, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %1, align 8 + %2 = bitcast %struct.A* %a to i8* + %3 = bitcast i8* %2 to i8** + %vtable.i = load i8*, i8** %3, align 8 + %4 = getelementptr i8, i8* %vtable.i, i64 8 + %5 = bitcast i8* %4 to i32 (%struct.A*)** + %memptr.virtualfn.i = load i32 (%struct.A*)*, i32 (%struct.A*)** %5, align 8 + %call.i = call i32 %memptr.virtualfn.i(%struct.A* %a) + ret i32 %call.i +} + +declare i32 @_ZN1A3vf1Ev(%struct.A* %this) +declare i32 @_ZN1A3vf2Ev(%struct.A* %this) +)IR"); + + auto *GV = M->getNamedValue("_Z2g1v"); + ASSERT_TRUE(GV); + auto *F = dyn_cast(GV); + ASSERT_TRUE(F); + Instruction *Inst = &F->front().front(); + auto *AI = dyn_cast(Inst); + ASSERT_TRUE(AI); + Inst = &*++F->front().rbegin(); + auto *CI = dyn_cast(Inst); + ASSERT_TRUE(CI); + CallSite CS1(CI); + ASSERT_FALSE(CS1.getCalledFunction()); + bool IsPromoted1 = tryPromoteCall(CS1); + EXPECT_TRUE(IsPromoted1); + GV = M->getNamedValue("_ZN1A3vf1Ev"); + ASSERT_TRUE(GV); + F = dyn_cast(GV); + EXPECT_EQ(F, CS1.getCalledFunction()); + + GV = M->getNamedValue("_Z2g2v"); + ASSERT_TRUE(GV); + F = dyn_cast(GV); + ASSERT_TRUE(F); + Inst = &F->front().front(); + AI = dyn_cast(Inst); + ASSERT_TRUE(AI); + Inst = &*++F->front().rbegin(); + CI = dyn_cast(Inst); + ASSERT_TRUE(CI); + CallSite CS2(CI); + ASSERT_FALSE(CS2.getCalledFunction()); + bool IsPromoted2 = tryPromoteCall(CS2); + EXPECT_TRUE(IsPromoted2); + GV = M->getNamedValue("_ZN1A3vf2Ev"); + ASSERT_TRUE(GV); + F = dyn_cast(GV); + EXPECT_EQ(F, CS2.getCalledFunction()); +}