diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -77,11 +77,14 @@ // // For every basic block 'i' it maintains a BlockData that consists of: // Consumes: a bit vector which contains a set of indices of blocks that can -// reach block 'i' +// reach block 'i'. A block can trivially reach itself. // Kills: a bit vector which contains a set of indices of blocks that can -// reach block 'i', but one of the path will cross a suspend point +// reach block 'i' but there is a path crossing a suspend point +// not repeating 'i' (path to 'i' without cycles containing 'i'). // Suspend: a boolean indicating whether block 'i' contains a suspend point. // End: a boolean indicating whether block 'i' contains a coro.end intrinsic. +// KillLoop: There is a path from 'i' to 'i' not otherwise repeating 'i' that +// crosses a suspend point. // namespace { struct SuspendCrossingInfo { @@ -92,6 +95,7 @@ BitVector Kills; bool Suspend = false; bool End = false; + bool KillLoop = false; }; SmallVector Block; @@ -109,16 +113,31 @@ SuspendCrossingInfo(Function &F, coro::Shape &Shape); - bool hasPathCrossingSuspendPoint(BasicBlock *DefBB, BasicBlock *UseBB) const { - size_t const DefIndex = Mapping.blockToIndex(DefBB); - size_t const UseIndex = Mapping.blockToIndex(UseBB); - - bool const Result = Block[UseIndex].Kills[DefIndex]; - LLVM_DEBUG(dbgs() << UseBB->getName() << " => " << DefBB->getName() + /// Returns true if there is a path from \p From to \p To crossing a suspend + /// point without crossing \p From a 2nd time. + bool hasPathCrossingSuspendPoint(BasicBlock *From, BasicBlock *To) const { + size_t const FromIndex = Mapping.blockToIndex(From); + size_t const ToIndex = Mapping.blockToIndex(To); + bool const Result = Block[ToIndex].Kills[FromIndex]; + LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName() << " answer is " << Result << "\n"); return Result; } + /// Returns true if there is a path from \p From to \p To crossing a suspend + /// point without crossing \p From a 2nd time. If \p From is the same as \p To + /// this will also check if there is a looping path crossing a suspend point. + bool hasPathOrLoopCrossingSuspendPoint(BasicBlock *From, + BasicBlock *To) const { + size_t const FromIndex = Mapping.blockToIndex(From); + size_t const ToIndex = Mapping.blockToIndex(To); + bool Result = Block[ToIndex].Kills[FromIndex] || + (From == To && Block[ToIndex].KillLoop); + LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName() + << " answer is " << Result << " (path or loop)\n"); + return Result; + } + bool isDefinitionAcrossSuspend(BasicBlock *DefBB, User *U) const { auto *I = cast(U); @@ -271,6 +290,7 @@ } else { // This is reached when S block it not Suspend nor coro.end and it // need to make sure that it is not in the kill set. + S.KillLoop |= S.Kills[SuccNo]; S.Kills.reset(SuccNo); } @@ -1440,6 +1460,19 @@ for (auto *S : LifetimeStarts) if (Checker.isDefinitionAcrossSuspend(*S, I)) return true; + // Addresses are guaranteed to be identical after every lifetime.start so + // we cannot use the local stack if the address escaped and there is a + // suspend point between lifetime markers. This should also cover the + // case of a single lifetime.start intrinsic in a loop with suspend point. + if (PI.isEscaped()) { + for (auto *A : LifetimeStarts) { + for (auto *B : LifetimeStarts) { + if (Checker.hasPathOrLoopCrossingSuspendPoint(A->getParent(), + B->getParent())) + return true; + } + } + } return false; } // FIXME: Ideally the isEscaped check should come at the beginning. diff --git a/llvm/test/Transforms/Coroutines/coro-alloca-loop-carried-address.ll b/llvm/test/Transforms/Coroutines/coro-alloca-loop-carried-address.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Coroutines/coro-alloca-loop-carried-address.ll @@ -0,0 +1,86 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes='cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s + +@escape_hatch0 = external global i64 +@escape_hatch1 = external global i64 + +define void @foo() presplitcoroutine { +; CHECK-LABEL: @foo( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[STACKVAR0:%.*]] = alloca i64, align 8 +; CHECK-NEXT: [[ID:%.*]] = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr @foo.resumers) +; CHECK-NEXT: [[ALLOC:%.*]] = call ptr @malloc(i64 40) +; CHECK-NEXT: [[VFRAME:%.*]] = call noalias nonnull ptr @llvm.coro.begin(token [[ID]], ptr [[ALLOC]]) +; CHECK-NEXT: store ptr @foo.resume, ptr [[VFRAME]], align 8 +; CHECK-NEXT: [[DESTROY_ADDR:%.*]] = getelementptr inbounds [[FOO_FRAME:%.*]], ptr [[VFRAME]], i32 0, i32 1 +; CHECK-NEXT: store ptr @foo.destroy, ptr [[DESTROY_ADDR]], align 8 +; CHECK-NEXT: [[STACKVAR0_RELOAD_ADDR:%.*]] = getelementptr inbounds [[FOO_FRAME]], ptr [[VFRAME]], i32 0, i32 2 +; CHECK-NEXT: [[STACKVAR1_RELOAD_ADDR:%.*]] = getelementptr inbounds [[FOO_FRAME]], ptr [[VFRAME]], i32 0, i32 3 +; CHECK-NEXT: [[STACKVAR0_INT:%.*]] = ptrtoint ptr [[STACKVAR0_RELOAD_ADDR]] to i64 +; CHECK-NEXT: store i64 [[STACKVAR0_INT]], ptr @escape_hatch0, align 4 +; CHECK-NEXT: [[STACKVAR1_INT:%.*]] = ptrtoint ptr [[STACKVAR1_RELOAD_ADDR]] to i64 +; CHECK-NEXT: store i64 [[STACKVAR1_INT]], ptr @escape_hatch1, align 4 +; CHECK-NEXT: br label [[LOOP:%.*]] +; CHECK: loop: +; CHECK-NEXT: store i64 1234, ptr [[STACKVAR0_RELOAD_ADDR]], align 4 +; CHECK-NEXT: call void @bar() +; CHECK-NEXT: [[INDEX_ADDR1:%.*]] = getelementptr inbounds [[FOO_FRAME]], ptr [[VFRAME]], i32 0, i32 4 +; CHECK-NEXT: store i1 false, ptr [[INDEX_ADDR1]], align 1 +; CHECK-NEXT: br i1 false, label [[LOOP]], label [[AFTERCOROEND:%.*]] +; CHECK: AfterCoroEnd: +; CHECK-NEXT: ret void +; +entry: + %stackvar0 = alloca i64 + %stackvar1 = alloca i64 + + ; address of %stackvar escapes and may be relied upon even after + ; suspending/resuming the coroutine regardless of the lifetime markers. + %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null) + %size = call i64 @llvm.coro.size.i64() + %alloc = call ptr @malloc(i64 %size) + %vFrame = call noalias nonnull ptr @llvm.coro.begin(token %id, ptr %alloc) + + ; %stackvar0 must be rewritten to reference the coroutine Frame! + %stackvar0_int = ptrtoint ptr %stackvar0 to i64 + store i64 %stackvar0_int, ptr @escape_hatch0 + ; %stackvar1 must be rewritten to reference the coroutine Frame! + %stackvar1_int = ptrtoint ptr %stackvar1 to i64 + store i64 %stackvar1_int, ptr @escape_hatch1 + + br label %loop + +loop: + call void @llvm.lifetime.start(i64 8, ptr %stackvar0) + + store i64 1234, ptr %stackvar0 + + ; Call could potentially change value in memory referenced by %stackvar0 / + ; %stackvar1 and rely on it staying the same across suspension. + call void @bar() + + call void @llvm.lifetime.end(i64 8, ptr %stackvar0) + + %save = call token @llvm.coro.save(ptr null) + %suspend = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %suspend, label %exit [ + i8 0, label %loop + i8 1, label %exit + ] + +exit: + call i1 @llvm.coro.end(ptr null, i1 false) + ret void +} + +declare void @bar() +declare ptr @malloc(i64) + +declare token @llvm.coro.id(i32, ptr readnone, ptr nocapture readonly, ptr) +declare i64 @llvm.coro.size.i64() +declare ptr @llvm.coro.begin(token, ptr writeonly) +declare token @llvm.coro.save(ptr) +declare i8 @llvm.coro.suspend(token, i1) +declare i1 @llvm.coro.end(ptr, i1) +declare void @llvm.lifetime.start(i64, ptr nocapture) +declare void @llvm.lifetime.end(i64, ptr nocapture)