diff --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h --- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h +++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h @@ -48,6 +48,29 @@ Instruction *promoteCallWithIfThenElse(CallSite CS, Function *Callee, MDNode *BranchWeights = nullptr); +/// Try to promote (devirtualize) a virtual call on an Alloca. Return true on +/// success. +/// +/// Look for a pattern like: +/// +/// %o = alloca %class.Impl +/// %1 = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0 +/// store i32 (...)** bitcast (i8** getelementptr inbounds +/// ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, inrange i32 0, i64 2) +/// to i32 (...)**), i32 (...)*** %1 +/// %2 = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0 +/// %3 = bitcast %class.Interface* %2 to void (%class.Interface*)*** +/// %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %3 +/// %4 = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i +/// call void %4(%class.Interface* nonnull %2) +/// +/// @_ZTV4Impl = linkonce_odr dso_local unnamed_addr constant { [3 x i8*] } +/// { [3 x i8*] +/// [i8* null, i8* bitcast ({ i8*, i8*, i8* }* @_ZTI4Impl to i8*), +/// i8* bitcast (void (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] } +/// +bool tryPromoteCall(CallSite &CS); + } // end namespace llvm #endif // LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp --- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -12,6 +12,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CallPromotionUtils.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -458,4 +460,60 @@ return promoteCall(CallSite(NewInst), Callee); } +bool llvm::tryPromoteCall(CallSite &CS) { + assert(!CS.getCalledFunction()); + Module *M = CS.getCaller()->getParent(); + const DataLayout &DL = M->getDataLayout(); + Value *Callee = CS.getCalledValue(); + + LoadInst *VTableEntryLoad = dyn_cast(Callee); + if (!VTableEntryLoad) + return false; // Not a vtable entry load. + Value *VTableEntryPtr = VTableEntryLoad->getPointerOperand(); + APInt VTableOffset(DL.getTypeSizeInBits(VTableEntryPtr->getType()), 0); + Value *VTableBasePtr = VTableEntryPtr->stripAndAccumulateConstantOffsets( + DL, VTableOffset, /* AllowNonInbounds */ false); + LoadInst *VTableLoad = dyn_cast(VTableBasePtr); + if (!VTableLoad) + return false; // Not a vtable load. + Value *Object = VTableLoad->getPointerOperand(); + APInt ObjectOffset(DL.getTypeSizeInBits(Object->getType()), 0); + Value *ObjectBase = Object->stripAndAccumulateConstantOffsets( + DL, ObjectOffset, /* AllowNonInbounds */ false); + if (!(isa(ObjectBase) && ObjectOffset == 0)) + // Not an Alloca or the offset isn't zero. + return false; + + // Look for the vtable pointer store into the object by the ctor. + BasicBlock::iterator BBI(VTableLoad); + Value *VTable = FindAvailableLoadedValue( + VTableLoad, VTableLoad->getParent(), BBI, 0, nullptr, nullptr); + if (!VTable) + return false; // No vtable found. + APInt VTableOffsetGVBase(DL.getTypeSizeInBits(VTable->getType()), 0); + Value *VTableGVBase = VTable->stripAndAccumulateConstantOffsets( + DL, VTableOffsetGVBase, /* AllowNonInbounds */ false); + GlobalVariable *GV = dyn_cast(VTableGVBase); + if (!(GV && GV->isConstant() && GV->hasDefinitiveInitializer())) + // Not in the form of a global constant variable with an initializer. + return false; + + Constant *VTableGVInitializer = GV->getInitializer(); + APInt VTableGVOffset = VTableOffsetGVBase + VTableOffset; + if (!(VTableGVOffset.getActiveBits() <= 64)) + return false; // Out of range. + Constant *Ptr = getPointerAtOffset(VTableGVInitializer, + VTableGVOffset.getZExtValue(), + *M); + if (!Ptr) + return false; // No constant (function) pointer found. + Function *DirectCallee = dyn_cast(Ptr->stripPointerCasts()); + if (!DirectCallee) + return false; // No function pointer found. + + // Success. + promoteCall(CS, DirectCallee); + return true; +} + #undef DEBUG_TYPE diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt --- a/llvm/unittests/Transforms/Utils/CMakeLists.txt +++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -9,6 +9,7 @@ add_llvm_unittest(UtilsTests ASanStackFrameLayoutTest.cpp BasicBlockUtilsTest.cpp + CallPromotionUtilsTest.cpp CloningTest.cpp CodeExtractorTest.cpp CodeMoverUtilsTest.cpp diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp @@ -0,0 +1,72 @@ +//===- CallPromotionUtilsTest.cpp - CallPromotionUtils unit tests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/CallPromotionUtils.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static std::unique_ptr parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("UtilsTests", errs()); + return Mod; +} + +TEST(CallPromotionUtilsTest, TryPromoteCall) { + LLVMContext C; + std::unique_ptr M = parseIR(C, + R"IR( +%class.Impl = type <{ %class.Interface, i32, [4 x i8] }> +%class.Interface = type { i32 (...)** } + +@_ZTV4Impl = constant { [3 x i8*] } { [3 x i8*] [i8* null, i8* null, i8* bitcast (void (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] } + +define void @f() { +entry: + %o = alloca %class.Impl + %base = getelementptr %class.Impl, %class.Impl* %o, i64 0, i32 0, i32 0 + store i32 (...)** bitcast (i8** getelementptr inbounds ({ [3 x i8*] }, { [3 x i8*] }* @_ZTV4Impl, i64 0, inrange i32 0, i64 2) to i32 (...)**), i32 (...)*** %base + %f = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 1 + store i32 3, i32* %f + %base.i = getelementptr inbounds %class.Impl, %class.Impl* %o, i64 0, i32 0 + %c = bitcast %class.Interface* %base.i to void (%class.Interface*)*** + %vtable.i = load void (%class.Interface*)**, void (%class.Interface*)*** %c + %fp = load void (%class.Interface*)*, void (%class.Interface*)** %vtable.i + call void %fp(%class.Interface* nonnull %base.i) + ret void +} + +declare void @_ZN4Impl3RunEv(%class.Impl* %this) +)IR"); + + auto *GV = M->getNamedValue("f"); + ASSERT_TRUE(GV); + auto *F = dyn_cast(GV); + ASSERT_TRUE(F); + Instruction *Inst = &F->front().front(); + auto *AI = dyn_cast(Inst); + ASSERT_TRUE(AI); + Inst = &*++F->front().rbegin(); + auto *CI = dyn_cast(Inst); + ASSERT_TRUE(CI); + CallSite CS(CI); + ASSERT_FALSE(CS.getCalledFunction()); + bool IsPromoted = tryPromoteCall(CS); + EXPECT_TRUE(IsPromoted); + GV = M->getNamedValue("_ZN4Impl3RunEv"); + ASSERT_TRUE(GV); + auto *F1 = dyn_cast(GV); + EXPECT_EQ(F1, CS.getCalledFunction()); +}