diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -535,6 +535,46 @@ continue; } + // Find bitcasts in the outlined region that have lifetime marker users + // outside that region. Replace the lifetime marker use with an + // outside region bitcast to avoid unnecessary alloca/reload instructions + // and extra lifetime markers. + SmallVector LifetimeBitcastUsers; + for (User *U : AI->users()) { + if (!definedInRegion(Blocks, U)) + continue; + + if (U->stripInBoundsConstantOffsets() != AI) + continue; + + Instruction *Bitcast = cast(U); + for (User *BU : Bitcast->users()) { + IntrinsicInst *IntrInst = dyn_cast(BU); + if (!IntrInst) + continue; + + if (!IntrInst->isLifetimeStartOrEnd()) + continue; + + if (definedInRegion(Blocks, IntrInst)) + continue; + + LLVM_DEBUG(dbgs() << "Replace use of extracted region bitcast" + << *Bitcast << " in out-of-region lifetime marker " + << *IntrInst << "\n"); + LifetimeBitcastUsers.push_back(IntrInst); + } + } + + for (Instruction *I : LifetimeBitcastUsers) { + Module *M = AIFunc->getParent(); + LLVMContext &Ctx = M->getContext(); + auto *Int8PtrTy = Type::getInt8PtrTy(Ctx); + CastInst *CastI = + CastInst::CreatePointerCast(AI, Int8PtrTy, "lt.cast", I); + I->replaceUsesOfWith(I->getOperand(1), CastI); + } + // Follow any bitcasts. SmallVector Bitcasts; SmallVector BitcastLifetimeInfo; diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp --- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -282,4 +282,53 @@ EXPECT_FALSE(verifyFunction(*Func)); EXPECT_FALSE(CE.verifyAssumptionCache(*Func, *Outlined, &AC)); } + +TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"ir( + target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" + target triple = "x86_64-unknown-linux-gnu" + + declare void @use(i32*) + declare void @llvm.lifetime.start.p0i8(i64, i8*) + declare void @llvm.lifetime.end.p0i8(i64, i8*) + + define void @foo() { + entry: + %0 = alloca i32 + br label %extract + + extract: + %1 = bitcast i32* %0 to i8* + call void @llvm.lifetime.start.p0i8(i64 4, i8* %1) + call void @use(i32* %0) + br label %exit + + exit: + call void @use(i32* %0) + call void @llvm.lifetime.end.p0i8(i64 4, i8* %1) + ret void + } + )ir", + Err, Ctx)); + + Function *Func = M->getFunction("foo"); + SmallVector Blocks{getBlockByName(Func, "extract")}; + + CodeExtractor CE(Blocks); + EXPECT_TRUE(CE.isEligible()); + + CodeExtractorAnalysisCache CEAC(*Func); + SetVector Inputs, Outputs, SinkingCands, HoistingCands; + BasicBlock *CommonExit = nullptr; + CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); + CE.findInputsOutputs(Inputs, Outputs, SinkingCands); + EXPECT_EQ(Outputs.size(), 0U); + + Function *Outlined = CE.extractCodeRegion(CEAC); + EXPECT_TRUE(Outlined); + EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); +} } // end anonymous namespace