diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -75,7 +75,7 @@ namespace { -/// A little helper class for building +/// A little helper class for building class CoroCloner { public: enum class Kind { @@ -563,7 +563,7 @@ // In the original function, the AllocaSpillBlock is a block immediately // following the allocation of the frame object which defines GEPs for // all the allocas that have been moved into the frame, and it ends by - // branching to the original beginning of the coroutine. Make this + // branching to the original beginning of the coroutine. Make this // the entry block of the cloned function. auto *Entry = cast(VMap[Shape.AllocaSpillBlock]); auto *OldEntry = &NewF->getEntryBlock(); @@ -1239,6 +1239,106 @@ S.resize(N); } +/// For every local variable that has lifetime intrinsics markers, we sink +/// their lifetime.start marker to the places where the variable is being +/// used for the first time. Doing so minimizes the lifetime of each variable, +/// hence minimizing the amount of data we end up putting on the frame. +static void sinkLifetimeStartMarkers(Function &F) { + DominatorTree Dom(F); + for (Instruction &I : instructions(F)) { + // We look for this particular pattern: + // %tmpX = alloca %.., align ... + // %0 = bitcast %...* %tmpX to i8* + // call void @llvm.lifetime.start.p0i8(i64 ..., i8* nonnull %0) #2 + if (!isa(&I)) + continue; + // Though strange, in theory there can be multiple BitCast instructions. + SmallPtrSet CastInsts; + // There can be multiple lifetime start markers for the same variable. + SmallPtrSet LifetimeStartInsts; + // SinkBarriers stores all instructions that use this local variable. + // When sinking the lifetime start intrinsics, we can never sink past + // these barriers. + SmallPtrSet SinkBarriers; + bool Valid = true; + auto AddSinkBarrier = [&](Instruction *I) { + // When adding a new barrier to SinkBarriers, we maintain the case + // that no instruction in SinkBarriers dominates another instruction. + SmallPtrSet ToRemove; + bool ShouldAdd = true; + for (auto *S : SinkBarriers) { + if (I == S || Dom.dominates(S, I)) { + ShouldAdd = false; + break; + } else if (Dom.dominates(I, S)) { + ToRemove.insert(S); + } + } + if (ShouldAdd) { + SinkBarriers.insert(I); + for (auto *R : ToRemove) { + SinkBarriers.erase(R); + } + } + }; + for (User *U : I.users()) { + if (!isa(U)) + continue; + CastInsts.insert(cast(U)); + for (User *CU : U->users()) { + // If we see any user of CastInst that's not lifetime start/end + // intrinsics, give up because it's too complex. + if (auto *CUI = dyn_cast(CU)) { + if (CUI->getIntrinsicID() == Intrinsic::lifetime_start) + LifetimeStartInsts.insert(CUI); + else if (CUI->getIntrinsicID() == Intrinsic::lifetime_end) + AddSinkBarrier(CUI); + else + Valid = false; + } else { + Valid = false; + } + } + } + if (!Valid || LifetimeStartInsts.empty()) + continue; + + for (User *U : I.users()) { + if (isa(U)) + continue; + // Every user of the variable is also a sink barrier. + AddSinkBarrier(cast(U)); + } + + // For each sink barrier, we insert a lifetime start marker right + // before it. + for (auto *S : SinkBarriers) { + if (auto *IS = dyn_cast(S)) { + if (IS->getIntrinsicID() == Intrinsic::lifetime_end) { + // If we have a lifetime end marker in SinkBarriers, meaning it's + // not dominated by any other users, we can safely delete it. + IS->eraseFromParent(); + continue; + } + } + // We find an existing lifetime.start marker that domintes the barrier, + // clone it and insert it right before the barrier. We cannot clone an + // arbitrary lifetime.start marker because we want to make sure the + // BitCast instruction referred in the marker also dominates the barrier. + for (const auto *LifetimeStart : LifetimeStartInsts) { + if (Dom.dominates(LifetimeStart, S)) { + LifetimeStart->clone()->insertBefore(S); + break; + } + } + } + // All the old lifetime.start markers are no longer necessary. + for (auto *S : LifetimeStartInsts) { + S->eraseFromParent(); + } + } +} + static void splitSwitchCoroutine(Function &F, coro::Shape &Shape, SmallVectorImpl &Clones) { assert(Shape.ABI == coro::ABI::Switch); @@ -1428,6 +1528,7 @@ return Shape; simplifySuspendPoints(Shape); + sinkLifetimeStartMarkers(F); buildCoroutineFrame(F, Shape); replaceFrameSize(Shape); diff --git a/llvm/test/Transforms/Coroutines/coro-split-02.ll b/llvm/test/Transforms/Coroutines/coro-split-02.ll --- a/llvm/test/Transforms/Coroutines/coro-split-02.ll +++ b/llvm/test/Transforms/Coroutines/coro-split-02.ll @@ -31,6 +31,8 @@ %val = load i32, i32* %Result.i19 %cast = bitcast i32* %testval to i8* call void @llvm.lifetime.start.p0i8(i64 4, i8* %cast) + %test = load i32, i32* %testval + call void @print(i32 %test) call void @llvm.lifetime.end.p0i8(i64 4, i8* %cast) call void @print(i32 %val) br label %exit @@ -40,14 +42,15 @@ } ; CHECK-LABEL: @a.resume( -; CHECK: %testval = alloca i32 ; CHECK: getelementptr inbounds %a.Frame ; CHECK-NEXT: getelementptr inbounds %"struct.lean_future::Awaiter" ; CHECK-NOT: call token @llvm.coro.save(i8* null) ; CHECK-NEXT: %val = load i32, i32* %Result ; CHECK-NEXT: %cast = bitcast i32* %testval to i8* ; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 4, i8* %cast) -; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 4, i8* %cast) +; CHECK-NEXT: %test = load i32, i32* %testval +; CHECK-NEXT: call void @print(i32 %test) +; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 4, i8* %cast) ; CHECK-NEXT: call void @print(i32 %val) ; CHECK-NEXT: ret void diff --git a/llvm/test/Transforms/Coroutines/coro-split-02.ll b/llvm/test/Transforms/Coroutines/coro-split-sink-lifetime.ll copy from llvm/test/Transforms/Coroutines/coro-split-02.ll copy to llvm/test/Transforms/Coroutines/coro-split-sink-lifetime.ll --- a/llvm/test/Transforms/Coroutines/coro-split-02.ll +++ b/llvm/test/Transforms/Coroutines/coro-split-sink-lifetime.ll @@ -1,6 +1,5 @@ -; Tests that coro-split can handle the case when a code after coro.suspend uses -; a value produces between coro.save and coro.suspend (%Result.i19) -; and checks whether stray coro.saves are properly removed +; Tests that coro-split will optimize the lifetime.start maker of each local variable, +; sink them to the places closest to the actual use. ; RUN: opt < %s -coro-split -S | FileCheck %s ; RUN: opt < %s -passes=coro-split -S | FileCheck %s @@ -15,6 +14,9 @@ entry: %ref.tmp7 = alloca %"struct.lean_future::Awaiter", align 8 %testval = alloca i32 + %cast = bitcast i32* %testval to i8* + ; lifetime of %testval starts here, but not used until await.ready. + call void @llvm.lifetime.start.p0i8(i64 4, i8* %cast) %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) %alloc = call i8* @malloc(i64 16) #3 %vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc) @@ -29,8 +31,8 @@ await.ready: %StrayCoroSave = call token @llvm.coro.save(i8* null) %val = load i32, i32* %Result.i19 - %cast = bitcast i32* %testval to i8* - call void @llvm.lifetime.start.p0i8(i64 4, i8* %cast) + %test = load i32, i32* %testval + call void @print(i32 %test) call void @llvm.lifetime.end.p0i8(i64 4, i8* %cast) call void @print(i32 %val) br label %exit @@ -40,14 +42,15 @@ } ; CHECK-LABEL: @a.resume( -; CHECK: %testval = alloca i32 -; CHECK: getelementptr inbounds %a.Frame +; CHECK: %testval = alloca i32, align 4 +; CHECK-NEXT: getelementptr inbounds %a.Frame ; CHECK-NEXT: getelementptr inbounds %"struct.lean_future::Awaiter" -; CHECK-NOT: call token @llvm.coro.save(i8* null) +; CHECK-NEXT: %cast1 = bitcast i32* %testval to i8* ; CHECK-NEXT: %val = load i32, i32* %Result -; CHECK-NEXT: %cast = bitcast i32* %testval to i8* -; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 4, i8* %cast) -; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 4, i8* %cast) +; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 4, i8* %cast1) +; CHECK-NEXT: %test = load i32, i32* %testval +; CHECK-NEXT: call void @print(i32 %test) +; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 4, i8* %cast1) ; CHECK-NEXT: call void @print(i32 %val) ; CHECK-NEXT: ret void