Index: llvm/trunk/lib/Transforms/Coroutines/CoroSplit.cpp =================================================================== --- llvm/trunk/lib/Transforms/Coroutines/CoroSplit.cpp +++ llvm/trunk/lib/Transforms/Coroutines/CoroSplit.cpp @@ -27,6 +27,7 @@ #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" @@ -400,6 +401,91 @@ FPM.doFinalization(); } +// Assuming we arrived at the block NewBlock from Prev instruction, store +// PHI's incoming values in the ResolvedValues map. +static void +scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock, + DenseMap &ResolvedValues) { + auto *PrevBB = Prev->getParent(); + auto *I = &*NewBlock->begin(); + while (auto PN = dyn_cast(I)) { + auto V = PN->getIncomingValueForBlock(PrevBB); + // See if we already resolved it. + auto VI = ResolvedValues.find(V); + if (VI != ResolvedValues.end()) + V = VI->second; + // Remember the value. + ResolvedValues[PN] = V; + I = I->getNextNode(); + } +} + +// Replace a sequence of branches leading to a ret, with a clone of a ret +// instruction. Suspend instruction represented by a switch, track the PHI +// values and select the correct case successor when possible. +static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { + DenseMap ResolvedValues; + + Instruction *I = InitialInst; + while (isa(I)) { + if (isa(I)) { + if (I != InitialInst) + ReplaceInstWithInst(InitialInst, I->clone()); + return true; + } + if (auto *BR = dyn_cast(I)) { + if (BR->isUnconditional()) { + BasicBlock *BB = BR->getSuccessor(0); + scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); + I = BB->getFirstNonPHIOrDbgOrLifetime(); + continue; + } + } else if (auto *SI = dyn_cast(I)) { + Value *V = SI->getCondition(); + auto it = ResolvedValues.find(V); + if (it != ResolvedValues.end()) + V = it->second; + if (ConstantInt *Cond = dyn_cast(V)) { + BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor(); + scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); + I = BB->getFirstNonPHIOrDbgOrLifetime(); + continue; + } + } + return false; + } + return false; +} + +// Add musttail to any resume instructions that is immediately followed by a +// suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call +// for symmetrical coroutine control transfer (C++ Coroutines TS extension). +// This transformation is done only in the resume part of the coroutine that has +// identical signature and calling convention as the coro.resume call. +static void addMustTailToCoroResumes(Function &F) { + bool changed = false; + + // Collect potential resume instructions. + SmallVector Resumes; + for (auto &I : instructions(F)) + if (auto *Call = dyn_cast(&I)) + if (auto *CalledValue = Call->getCalledValue()) + // CoroEarly pass replaced coro resumes with indirect calls to an + // address return by CoroSubFnInst intrinsic. See if it is one of those. + if (isa(CalledValue->stripPointerCasts())) + Resumes.push_back(Call); + + // Set musttail on those that are followed by a ret instruction. + for (CallInst *Call : Resumes) + if (simplifyTerminatorLeadingToRet(Call->getNextNode())) { + Call->setTailCallKind(llvm::CallInst::TCK_MustTail); + changed = true; + } + + if (changed) + removeUnreachableBlocks(F); +} + // Coroutine has no suspend points. Remove heap allocation for the coroutine // frame if possible. static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) { @@ -608,6 +694,8 @@ postSplitCleanup(*DestroyClone); postSplitCleanup(*CleanupClone); + addMustTailToCoroResumes(*ResumeClone); + // Store addresses resume/destroy/cleanup functions in the coroutine frame. updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); Index: llvm/trunk/test/Transforms/Coroutines/coro-split-musttail.ll =================================================================== --- llvm/trunk/test/Transforms/Coroutines/coro-split-musttail.ll +++ llvm/trunk/test/Transforms/Coroutines/coro-split-musttail.ll @@ -0,0 +1,60 @@ +; Tests that coro-split will convert coro.resume followed by a suspend to a +; musttail call. +; RUN: opt < %s -coro-split -S | FileCheck %s + +define void @f() "coroutine.presplit"="1" { +entry: + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %alloc = call i8* @malloc(i64 16) #3 + %vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc) + + %save = call token @llvm.coro.save(i8* null) + %addr1 = call i8* @llvm.coro.subfn.addr(i8* null, i8 0) + %pv1 = bitcast i8* %addr1 to void (i8*)* + call fastcc void %pv1(i8* null) + + %suspend = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %suspend, label %exit [ + i8 0, label %await.ready + i8 1, label %exit + ] +await.ready: + %save2 = call token @llvm.coro.save(i8* null) + %addr2 = call i8* @llvm.coro.subfn.addr(i8* null, i8 0) + %pv2 = bitcast i8* %addr2 to void (i8*)* + call fastcc void %pv2(i8* null) + + %suspend2 = call i8 @llvm.coro.suspend(token %save2, i1 false) + switch i8 %suspend2, label %exit [ + i8 0, label %exit + i8 1, label %exit + ] +exit: + call i1 @llvm.coro.end(i8* null, i1 false) + ret void +} + +; Verify that in the initial function resume is not marked with musttail. +; CHECK-LABEL: @f( +; CHECK: %[[addr1:.+]] = call i8* @llvm.coro.subfn.addr(i8* null, i8 0) +; CHECK-NEXT: %[[pv1:.+]] = bitcast i8* %[[addr1]] to void (i8*)* +; CHECK-NOT: musttail call fastcc void %[[pv1]](i8* null) + +; Verify that in the resume part resume call is marked with musttail. +; CHECK-LABEL: @f.resume( +; CHECK: %[[addr2:.+]] = call i8* @llvm.coro.subfn.addr(i8* null, i8 0) +; CHECK-NEXT: %[[pv2:.+]] = bitcast i8* %[[addr2]] to void (i8*)* +; CHECK-NEXT: musttail call fastcc void %[[pv2]](i8* null) +; CHECK-NEXT: ret void + +declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) +declare i1 @llvm.coro.alloc(token) #3 +declare i64 @llvm.coro.size.i64() #5 +declare i8* @llvm.coro.begin(token, i8* writeonly) #3 +declare token @llvm.coro.save(i8*) #3 +declare i8* @llvm.coro.frame() #5 +declare i8 @llvm.coro.suspend(token, i1) #3 +declare i8* @llvm.coro.free(token, i8* nocapture readonly) #2 +declare i1 @llvm.coro.end(i8*, i1) #3 +declare i8* @llvm.coro.subfn.addr(i8* nocapture readonly, i8) #5 +declare i8* @malloc(i64)