diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -934,7 +934,8 @@ BasicBlock *UnconditionalSucc = nullptr; Instruction *I = InitialInst; - while (I->isTerminator()) { + while (I->isTerminator() || + (isa(I) && I->getNextNode()->isTerminator())) { if (isa(I)) { if (I != InitialInst) { // If InitialInst is an unconditional branch, @@ -958,6 +959,29 @@ I = BB->getFirstNonPHIOrDbgOrLifetime(); continue; } + } else if (auto *CondCmp = dyn_cast(I)) { + auto *BR = dyn_cast(I->getNextNode()); + if (BR && BR->isConditional() && CondCmp == BR->getCondition()) { + // If the case number of suspended switch instruction is reduced to + // 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator. + // And the comparsion looks like : %cond = icmp eq i8 %V, constant. + ConstantInt *CondConst = dyn_cast(CondCmp->getOperand(1)); + if (CondConst && CondCmp->getPredicate() == CmpInst::ICMP_EQ) { + Value *V = CondCmp->getOperand(0); + auto it = ResolvedValues.find(V); + if (it != ResolvedValues.end()) + V = it->second; + + if (ConstantInt *Cond0 = dyn_cast(V)) { + BasicBlock *BB = Cond0->equalsInt(CondConst->getZExtValue()) + ? BR->getSuccessor(0) + : BR->getSuccessor(1); + scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); + I = BB->getFirstNonPHIOrDbgOrLifetime(); + continue; + } + } + } } else if (auto *SI = dyn_cast(I)) { Value *V = SI->getCondition(); auto it = ResolvedValues.find(V); diff --git a/llvm/test/Transforms/Coroutines/coro-split-musttail3.ll b/llvm/test/Transforms/Coroutines/coro-split-musttail3.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Coroutines/coro-split-musttail3.ll @@ -0,0 +1,101 @@ +; Tests that coro-split will convert coro.resume followed by a suspend to a +; musttail call. +; RUN: opt < %s -coro-split -S | FileCheck %s +; RUN: opt < %s -passes=coro-split -S | FileCheck %s + +define void @f() #0 { +entry: + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %alloc = call i8* @malloc(i64 16) #3 + %vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc) + + %save = call token @llvm.coro.save(i8* null) + %addr1 = call i8* @llvm.coro.subfn.addr(i8* null, i8 0) + %pv1 = bitcast i8* %addr1 to void (i8*)* + call fastcc void %pv1(i8* null) + + %suspend = call i8 @llvm.coro.suspend(token %save, i1 false) + %cmp = icmp eq i8 %suspend, 0 + br i1 %cmp, label %await.suspend, label %exit +await.suspend: + %save2 = call token @llvm.coro.save(i8* null) + %br0 = call i8 @switch_result() + switch i8 %br0, label %unreach [ + i8 0, label %await.resume3 + i8 1, label %await.resume1 + i8 2, label %await.resume2 + ] +await.resume1: + %hdl = call i8* @g() + %addr2 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0) + %pv2 = bitcast i8* %addr2 to void (i8*)* + call fastcc void %pv2(i8* %hdl) + br label %final.suspend +await.resume2: + %hdl2 = call i8* @h() + %addr3 = call i8* @llvm.coro.subfn.addr(i8* %hdl2, i8 0) + %pv3 = bitcast i8* %addr3 to void (i8*)* + call fastcc void %pv3(i8* %hdl2) + br label %final.suspend +await.resume3: + %addr4 = call i8* @llvm.coro.subfn.addr(i8* null, i8 0) + %pv4 = bitcast i8* %addr4 to void (i8*)* + call fastcc void %pv4(i8* null) + br label %final.suspend +final.suspend: + %suspend2 = call i8 @llvm.coro.suspend(token %save2, i1 false) + %cmp2 = icmp eq i8 %suspend2, 0 + br i1 %cmp2, label %pre.exit, label %exit +pre.exit: + br label %exit +exit: + call i1 @llvm.coro.end(i8* null, i1 false) + ret void +unreach: + unreachable +} + +; Verify that in the initial function resume is not marked with musttail. +; CHECK-LABEL: @f( +; CHECK: %[[addr1:.+]] = call i8* @llvm.coro.subfn.addr(i8* null, i8 0) +; CHECK-NEXT: %[[pv1:.+]] = bitcast i8* %[[addr1]] to void (i8*)* +; CHECK-NOT: musttail call fastcc void %[[pv1]](i8* null) + +; Verify that in the resume part resume call is marked with musttail. +; CHECK-LABEL: @f.resume( +; CHECK: %[[hdl:.+]] = call i8* @g() +; CHECK-NEXT: %[[addr2:.+]] = call i8* @llvm.coro.subfn.addr(i8* %[[hdl]], i8 0) +; CHECK-NEXT: %[[pv2:.+]] = bitcast i8* %[[addr2]] to void (i8*)* +; CHECK-NEXT: musttail call fastcc void %[[pv2]](i8* %[[hdl]]) +; CHECK-NEXT: ret void +; CHECK: %[[hdl2:.+]] = call i8* @h() +; CHECK-NEXT: %[[addr3:.+]] = call i8* @llvm.coro.subfn.addr(i8* %[[hdl2]], i8 0) +; CHECK-NEXT: %[[pv3:.+]] = bitcast i8* %[[addr3]] to void (i8*)* +; CHECK-NEXT: musttail call fastcc void %[[pv3]](i8* %[[hdl2]]) +; CHECK-NEXT: ret void +; CHECK: %[[addr4:.+]] = call i8* @llvm.coro.subfn.addr(i8* null, i8 0) +; CHECK-NEXT: %[[pv4:.+]] = bitcast i8* %[[addr4]] to void (i8*)* +; CHECK-NEXT: musttail call fastcc void %[[pv4]](i8* null) +; CHECK-NEXT: ret void + + + +declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) #1 +declare i1 @llvm.coro.alloc(token) #2 +declare i64 @llvm.coro.size.i64() #3 +declare i8* @llvm.coro.begin(token, i8* writeonly) #2 +declare token @llvm.coro.save(i8*) #2 +declare i8* @llvm.coro.frame() #3 +declare i8 @llvm.coro.suspend(token, i1) #2 +declare i8* @llvm.coro.free(token, i8* nocapture readonly) #1 +declare i1 @llvm.coro.end(i8*, i1) #2 +declare i8* @llvm.coro.subfn.addr(i8* nocapture readonly, i8) #1 +declare i8* @malloc(i64) +declare i8 @switch_result() +declare i8* @g() +declare i8* @h() + +attributes #0 = { "coroutine.presplit"="1" } +attributes #1 = { argmemonly nounwind readonly } +attributes #2 = { nounwind } +attributes #3 = { nounwind readnone }