Index: llvm/trunk/lib/Transforms/Coroutines/CoroFrame.cpp =================================================================== --- llvm/trunk/lib/Transforms/Coroutines/CoroFrame.cpp +++ llvm/trunk/lib/Transforms/Coroutines/CoroFrame.cpp @@ -171,19 +171,22 @@ for (auto *CE : Shape.CoroEnds) getBlockData(CE->getParent()).End = true; - // Mark all suspend blocks and indicate that kill everything they consume. - // Note, that crossing coro.save is used to indicate suspend, as any code + // Mark all suspend blocks and indicate that they kill everything they + // consume. Note, that crossing coro.save also requires a spill, as any code // between coro.save and coro.suspend may resume the coroutine and all of the // state needs to be saved by that time. - for (CoroSuspendInst *CSI : Shape.CoroSuspends) { - CoroSaveInst *const CoroSave = CSI->getCoroSave(); - BasicBlock *const CoroSaveBB = CoroSave->getParent(); - auto &B = getBlockData(CoroSaveBB); + auto markSuspendBlock = [&](IntrinsicInst* BarrierInst) { + BasicBlock *SuspendBlock = BarrierInst->getParent(); + auto &B = getBlockData(SuspendBlock); B.Suspend = true; B.Kills |= B.Consumes; + }; + for (CoroSuspendInst *CSI : Shape.CoroSuspends) { + markSuspendBlock(CSI); + markSuspendBlock(CSI->getCoroSave()); } - // Iterate propagating consumes and kills until they stop changing + // Iterate propagating consumes and kills until they stop changing. int Iteration = 0; (void)Iteration; @@ -533,6 +536,13 @@ isa(&V) || isa(&V) || isa(&V); } +// Check for structural coroutine intrinsics that should not be spilled into +// the coroutine frame. +static bool isCoroutineStructureIntrinsic(Instruction &I) { + return isa(&I) || isa(&I) || + isa(&I) || isa(&I); +} + // For every use of the value that is across suspend point, recreate that value // after a suspend point. static void rewriteMaterializableInstructions(IRBuilder<> &IRB, @@ -647,10 +657,13 @@ Shape.CoroBegin->getId()->clearPromise(); } - // Make sure that all coro.saves and the fallthrough coro.end are in their - // own block to simplify the logic of building up SuspendCrossing data. - for (CoroSuspendInst *CSI : Shape.CoroSuspends) + // Make sure that all coro.save, coro.suspend and the fallthrough coro.end + // intrinsics are in their own blocks to simplify the logic of building up + // SuspendCrossing data. + for (CoroSuspendInst *CSI : Shape.CoroSuspends) { splitAround(CSI->getCoroSave(), "CoroSave"); + splitAround(CSI, "CoroSuspend"); + } // Put fallthrough CoroEnd into its own block. Note: Shape::buildFrom places // the fallthrough coro.end as the first element of CoroEnds array. @@ -686,18 +699,9 @@ Spills.emplace_back(&A, U); for (Instruction &I : instructions(F)) { - // token returned by CoroSave is an artifact of how we build save/suspend - // pairs and should not be part of the Coroutine Frame - if (isa(&I)) - continue; - // CoroBeginInst returns a handle to a coroutine which is passed as a sole - // parameter to .resume and .cleanup parts and should not go into coroutine - // frame. - if (isa(&I)) - continue; - // A token returned CoroIdInst is used to tie together structural intrinsics - // in a coroutine. It should not be saved to the coroutine frame. - if (isa(&I)) + // Values returned from coroutine structure intrinsics should not be part + // of the Coroutine Frame. + if (isCoroutineStructureIntrinsic(I)) continue; // The Coroutine Promise always included into coroutine frame, no need to // check for suspend crossing. Index: llvm/trunk/test/Transforms/Coroutines/coro-split-02.ll =================================================================== --- llvm/trunk/test/Transforms/Coroutines/coro-split-02.ll +++ llvm/trunk/test/Transforms/Coroutines/coro-split-02.ll @@ -0,0 +1,54 @@ +; Tests that coro-split can handle the case when a code after coro.suspend uses +; a value produces between coro.save and coro.suspend (%Result.i19) +; RUN: opt < %s -coro-split -S | FileCheck %s + +%"struct.std::coroutine_handle" = type { i8* } +%"struct.std::coroutine_handle.0" = type { %"struct.std::coroutine_handle" } +%"struct.lean_future::Awaiter" = type { i32, %"struct.std::coroutine_handle.0" } + +declare i8* @malloc(i64) +declare void @print(i32) + +define void @a() "coroutine.presplit"="1" { +entry: + %ref.tmp7 = alloca %"struct.lean_future::Awaiter", align 8 + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %alloc = call i8* @malloc(i64 16) #3 + %vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc) + + %save = call token @llvm.coro.save(i8* null) + %Result.i19 = getelementptr inbounds %"struct.lean_future::Awaiter", %"struct.lean_future::Awaiter"* %ref.tmp7, i64 0, i32 0 + %suspend = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %suspend, label %exit [ + i8 0, label %await.ready + i8 1, label %exit + ] +await.ready: + %val = load i32, i32* %Result.i19 + call void @print(i32 %val) + br label %exit +exit: + call void @llvm.coro.end(i8* null, i1 false) + ret void +} + +; CHECK-LABEL: @a.resume( +; CHECK: getelementptr inbounds %a.Frame +; CHECK-NEXT: getelementptr inbounds %"struct.lean_future::Awaiter" +; CHECK-NEXT: %val = load i32, i32* %Result +; CHECK-NEXT: call void @print(i32 %val) +; CHECK-NEXT: ret void + +declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) +declare i1 @llvm.coro.alloc(token) #3 +declare noalias nonnull i8* @"\01??2@YAPEAX_K@Z"(i64) local_unnamed_addr +declare i64 @llvm.coro.size.i64() #5 +declare i8* @llvm.coro.begin(token, i8* writeonly) #3 +declare void @"\01?puts@@YAXZZ"(...) +declare token @llvm.coro.save(i8*) #3 +declare i8* @llvm.coro.frame() #5 +declare i8 @llvm.coro.suspend(token, i1) #3 +declare void @"\01??3@YAXPEAX@Z"(i8*) local_unnamed_addr #10 +declare i8* @llvm.coro.free(token, i8* nocapture readonly) #2 +declare void @llvm.coro.end(i8*, i1) #3 +