Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -776,6 +776,9 @@ /// the scalarization cost of a load/store. bool supportsEfficientVectorElementLoadStore() const; + /// If the target supports tail calls. + bool supportTailCall() const; + /// Don't restrict interleaved unrolling to small loops. bool enableAggressiveInterleaving(bool LoopHasReductions) const; @@ -1621,6 +1624,7 @@ getOperandsScalarizationOverhead(ArrayRef Args, ArrayRef Tys) = 0; virtual bool supportsEfficientVectorElementLoadStore() = 0; + virtual bool supportTailCall() = 0; virtual bool enableAggressiveInterleaving(bool LoopHasReductions) = 0; virtual MemCmpExpansionOptions enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const = 0; @@ -2083,6 +2087,8 @@ return Impl.supportsEfficientVectorElementLoadStore(); } + bool supportTailCall() override { return Impl.supportTailCall(); } + bool enableAggressiveInterleaving(bool LoopHasReductions) override { return Impl.enableAggressiveInterleaving(LoopHasReductions); } Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -340,6 +340,8 @@ bool supportsEfficientVectorElementLoadStore() const { return false; } + bool supportTailCall() const { return true; } + bool enableAggressiveInterleaving(bool LoopHasReductions) const { return false; } Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -522,6 +522,10 @@ return TTIImpl->supportsEfficientVectorElementLoadStore(); } +bool TargetTransformInfo::supportTailCall() const { + return TTIImpl->supportTailCall(); +} + bool TargetTransformInfo::enableAggressiveInterleaving( bool LoopHasReductions) const { return TTIImpl->enableAggressiveInterleaving(LoopHasReductions); Index: llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h =================================================================== --- llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h +++ llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h @@ -74,6 +74,8 @@ bool areInlineCompatible(const Function *Caller, const Function *Callee) const; + + bool supportTailCall() const; }; } // end namespace llvm Index: llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp +++ llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp @@ -139,3 +139,7 @@ // becomes "fall through" to default value of 2. UP.BEInsns = 2; } + +bool WebAssemblyTTIImpl::supportTailCall() const { + return getST()->hasTailCall(); +} Index: llvm/lib/Transforms/Coroutines/CoroSplit.cpp =================================================================== --- llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -31,6 +31,7 @@ #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/BinaryFormat/Dwarf.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" @@ -1571,7 +1572,8 @@ } static void splitSwitchCoroutine(Function &F, coro::Shape &Shape, - SmallVectorImpl &Clones) { + SmallVectorImpl &Clones, + TargetTransformInfo &TTI) { assert(Shape.ABI == coro::ABI::Switch); createResumeEntryBlock(F, Shape); @@ -1586,7 +1588,13 @@ postSplitCleanup(*DestroyClone); postSplitCleanup(*CleanupClone); - addMustTailToCoroResumes(*ResumeClone); + // Adding musttail call to support symmetric transfer. + // Skip targets which don't support tail call. + // + // FIXME: Could we support symmetric transfer effectively without musttail + // call? + if (TTI.supportTailCall()) + addMustTailToCoroResumes(*ResumeClone); // Store addresses resume/destroy/cleanup functions in the coroutine frame. updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); @@ -1891,6 +1899,7 @@ static coro::Shape splitCoroutine(Function &F, SmallVectorImpl &Clones, + TargetTransformInfo &TTI, bool OptimizeFrame) { PrettyStackTraceFunction prettyStackTrace(F); @@ -1913,7 +1922,7 @@ } else { switch (Shape.ABI) { case coro::ABI::Switch: - splitSwitchCoroutine(F, Shape, Clones); + splitSwitchCoroutine(F, Shape, Clones, TTI); break; case coro::ABI::Async: splitAsyncCoroutine(F, Shape, Clones); @@ -2090,7 +2099,8 @@ F.setSplittedCoroutine(); SmallVector Clones; - const coro::Shape Shape = splitCoroutine(F, Clones, OptimizeFrame); + const coro::Shape Shape = splitCoroutine( + F, Clones, FAM.getResult(F), OptimizeFrame); updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM); if (!Shape.CoroSuspends.empty()) { Index: llvm/test/Transforms/Coroutines/coro-split-musttail8.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/Coroutines/coro-split-musttail8.ll @@ -0,0 +1,57 @@ +; Tests that we wouldn't convert coro.resume to a musttail call if the target is +; Wasm32. +; RUN: opt < %s -passes='cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s + +target triple = "wasm32-unknown-unknown" + +define void @f() #0 { +entry: + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %alloc = call i8* @malloc(i64 16) #3 + %vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc) + + %save = call token @llvm.coro.save(i8* null) + %addr1 = call i8* @llvm.coro.subfn.addr(i8* null, i8 0) + %pv1 = bitcast i8* %addr1 to void (i8*)* + call fastcc void %pv1(i8* null) + + %suspend = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %suspend, label %exit [ + i8 0, label %await.ready + i8 1, label %exit + ] +await.ready: + %save2 = call token @llvm.coro.save(i8* null) + %addr2 = call i8* @llvm.coro.subfn.addr(i8* null, i8 0) + %pv2 = bitcast i8* %addr2 to void (i8*)* + call fastcc void %pv2(i8* null) + call void @print() + %suspend2 = call i8 @llvm.coro.suspend(token %save2, i1 false) + switch i8 %suspend2, label %exit [ + i8 0, label %exit + i8 1, label %exit + ] +exit: + call i1 @llvm.coro.end(i8* null, i1 false) + ret void +} + +; CHECK-NOT: musttail + +declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) #1 +declare i1 @llvm.coro.alloc(token) #2 +declare i64 @llvm.coro.size.i64() #3 +declare i8* @llvm.coro.begin(token, i8* writeonly) #2 +declare token @llvm.coro.save(i8*) #2 +declare i8* @llvm.coro.frame() #3 +declare i8 @llvm.coro.suspend(token, i1) #2 +declare i8* @llvm.coro.free(token, i8* nocapture readonly) #1 +declare i1 @llvm.coro.end(i8*, i1) #2 +declare i8* @llvm.coro.subfn.addr(i8* nocapture readonly, i8) #1 +declare i8* @malloc(i64) +declare void @print() + +attributes #0 = { presplitcoroutine } +attributes #1 = { argmemonly nounwind readonly } +attributes #2 = { nounwind } +attributes #3 = { nounwind readnone } Index: llvm/test/Transforms/Coroutines/coro-split-musttail9.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/Coroutines/coro-split-musttail9.ll @@ -0,0 +1,57 @@ +; Tests that we wouldn't convert coro.resume to a musttail call if the target is +; Wasm64. +; RUN: opt < %s -passes='cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s + +target triple = "wasm64-unknown-unknown" + +define void @f() #0 { +entry: + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %alloc = call i8* @malloc(i64 16) #3 + %vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc) + + %save = call token @llvm.coro.save(i8* null) + %addr1 = call i8* @llvm.coro.subfn.addr(i8* null, i8 0) + %pv1 = bitcast i8* %addr1 to void (i8*)* + call fastcc void %pv1(i8* null) + + %suspend = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %suspend, label %exit [ + i8 0, label %await.ready + i8 1, label %exit + ] +await.ready: + %save2 = call token @llvm.coro.save(i8* null) + %addr2 = call i8* @llvm.coro.subfn.addr(i8* null, i8 0) + %pv2 = bitcast i8* %addr2 to void (i8*)* + call fastcc void %pv2(i8* null) + call void @print() + %suspend2 = call i8 @llvm.coro.suspend(token %save2, i1 false) + switch i8 %suspend2, label %exit [ + i8 0, label %exit + i8 1, label %exit + ] +exit: + call i1 @llvm.coro.end(i8* null, i1 false) + ret void +} + +; CHECK-NOT: musttail + +declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) #1 +declare i1 @llvm.coro.alloc(token) #2 +declare i64 @llvm.coro.size.i64() #3 +declare i8* @llvm.coro.begin(token, i8* writeonly) #2 +declare token @llvm.coro.save(i8*) #2 +declare i8* @llvm.coro.frame() #3 +declare i8 @llvm.coro.suspend(token, i1) #2 +declare i8* @llvm.coro.free(token, i8* nocapture readonly) #1 +declare i1 @llvm.coro.end(i8*, i1) #2 +declare i8* @llvm.coro.subfn.addr(i8* nocapture readonly, i8) #1 +declare i8* @malloc(i64) +declare void @print() + +attributes #0 = { presplitcoroutine } +attributes #1 = { argmemonly nounwind readonly } +attributes #2 = { nounwind } +attributes #3 = { nounwind readnone }