diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -108,7 +108,6 @@ size_t const DefIndex = Mapping.blockToIndex(DefBB); size_t const UseIndex = Mapping.blockToIndex(UseBB); - assert(Block[UseIndex].Consumes[DefIndex] && "use must consume def"); bool const Result = Block[UseIndex].Kills[DefIndex]; LLVM_DEBUG(dbgs() << UseBB->getName() << " => " << DefBB->getName() << " answer is " << Result << "\n"); @@ -1396,6 +1395,24 @@ Spills.clear(); } + // Collect lifetime.start info for each alloca. + using LifetimeStart = SmallPtrSet; + llvm::DenseMap> LifetimeMap; + for (Instruction &I : instructions(F)) { + auto *II = dyn_cast(&I); + if (!II || II->getIntrinsicID() != Intrinsic::lifetime_start) + continue; + + if (auto *OpInst = dyn_cast(I.getOperand(1))) + if (auto *AI = dyn_cast(OpInst->getOperand(0))) { + + if (LifetimeMap.find(AI) == LifetimeMap.end()) + LifetimeMap[AI] = std::make_unique(); + + LifetimeMap[AI]->insert(OpInst); + } + } + // Collect the spills for arguments and other not-materializable values. for (Argument &A : F.args()) for (User *U : A.users()) @@ -1441,14 +1458,27 @@ continue; } - for (User *U : I.users()) - if (Checker.isDefinitionAcrossSuspend(I, U)) { + auto Iter = LifetimeMap.find(&I); + for (User *U : I.users()) { + bool NeedSpill = false; + + // Check against lifetime.start if the instruction has the info. + if (Iter != LifetimeMap.end()) + for (auto *S : *Iter->second) { + if ((NeedSpill = Checker.isDefinitionAcrossSuspend(*S, U))) + break; + } + else + NeedSpill = Checker.isDefinitionAcrossSuspend(I, U); + + if (NeedSpill) { // We cannot spill a token. if (I.getType()->isTokenTy()) report_fatal_error( "token definition is separated from the use by a suspend point"); Spills.emplace_back(&I, U); } + } } LLVM_DEBUG(dump("Spills", Spills)); Shape.FrameTy = buildFrameType(F, Shape, Spills); diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -567,8 +567,9 @@ // branching to the original beginning of the coroutine. Make this // the entry block of the cloned function. auto *Entry = cast(VMap[Shape.AllocaSpillBlock]); + auto *OldEntry = &NewF->getEntryBlock(); Entry->setName("entry" + Suffix); - Entry->moveBefore(&NewF->getEntryBlock()); + Entry->moveBefore(OldEntry); Entry->getTerminator()->eraseFromParent(); // Clear all predecessors of the new entry block. There should be @@ -581,8 +582,14 @@ Builder.CreateUnreachable(); BranchToEntry->eraseFromParent(); - // TODO: move any allocas into Entry that weren't moved into the frame. - // (Currently we move all allocas into the frame.) + // Move any allocas into Entry that weren't moved into the frame. + for (auto IT = OldEntry->begin(), End = OldEntry->end(); IT != End;) { + Instruction &I = *IT++; + if (!isa(&I) || I.getNumUses() == 0) + continue; + + I.moveBefore(*Entry, Entry->getFirstInsertionPt()); + } // Branch from the entry to the appropriate place. Builder.SetInsertPoint(Entry); diff --git a/llvm/test/Transforms/Coroutines/coro-split-02.ll b/llvm/test/Transforms/Coroutines/coro-split-02.ll --- a/llvm/test/Transforms/Coroutines/coro-split-02.ll +++ b/llvm/test/Transforms/Coroutines/coro-split-02.ll @@ -14,6 +14,7 @@ define void @a() "coroutine.presplit"="1" { entry: %ref.tmp7 = alloca %"struct.lean_future::Awaiter", align 8 + %testval = alloca i32 %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) %alloc = call i8* @malloc(i64 16) #3 %vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc) @@ -28,6 +29,9 @@ await.ready: %StrayCoroSave = call token @llvm.coro.save(i8* null) %val = load i32, i32* %Result.i19 + %cast = bitcast i32* %testval to i8* + call void @llvm.lifetime.start.p0i8(i64 4, i8* %cast) + call void @llvm.lifetime.end.p0i8(i64 4, i8* %cast) call void @print(i32 %val) br label %exit exit: @@ -36,10 +40,14 @@ } ; CHECK-LABEL: @a.resume( +; CHECK: %testval = alloca i32 ; CHECK: getelementptr inbounds %a.Frame ; CHECK-NEXT: getelementptr inbounds %"struct.lean_future::Awaiter" ; CHECK-NOT: call token @llvm.coro.save(i8* null) ; CHECK-NEXT: %val = load i32, i32* %Result +; CHECK-NEXT: %cast = bitcast i32* %testval to i8* +; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 4, i8* %cast) +; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 4, i8* %cast) ; CHECK-NEXT: call void @print(i32 %val) ; CHECK-NEXT: ret void @@ -55,4 +63,6 @@ declare void @"\01??3@YAXPEAX@Z"(i8*) local_unnamed_addr #10 declare i8* @llvm.coro.free(token, i8* nocapture readonly) #2 declare i1 @llvm.coro.end(i8*, i1) #3 +declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #4 +declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #4