Index: lib/Transforms/Coroutines/CoroSplit.cpp =================================================================== --- lib/Transforms/Coroutines/CoroSplit.cpp +++ lib/Transforms/Coroutines/CoroSplit.cpp @@ -538,43 +538,92 @@ CoroBegin->eraseFromParent(); } -// look for a very simple pattern -// coro.save -// no other calls -// resume or destroy call -// coro.suspend -// -// If there are other calls between coro.save and coro.suspend, they can -// potentially resume or destroy the coroutine, so it is unsafe to eliminate a -// suspend point. -static bool simplifySuspendPoint(CoroSuspendInst *Suspend, - CoroBeginInst *CoroBegin) { - auto *Save = Suspend->getCoroSave(); - auto *BB = Suspend->getParent(); - if (BB != Save->getParent()) - return false; +// SimplifySuspendPoint needs to check that there is no calls between +// coro_save and coro_suspend, since any of the calls may potentially resume +// the coroutine and if that is the case we cannot eliminate the suspend point. +static bool hasCallsInBlockBetween(Instruction* From, Instruction *To) { + for (Instruction *I = From; I != To; I = I->getNextNode()) { + // Assume that no intrinsic can resume the coroutine. + if (isa(I)) + continue; - CallSite SingleCallSite; + if (CallSite CS = CallSite(I)) + return true; + } + return false; +} - // Check that we have only one CallSite. - for (Instruction *I = Save->getNextNode(); I != Suspend; - I = I->getNextNode()) { - if (isa(I)) - continue; - if (isa(I)) - continue; - if (CallSite CS = CallSite(I)) { - if (SingleCallSite) - return false; - else - SingleCallSite = CS; - } +static bool hasCallsInBlocksBetween(BasicBlock *SaveBB, BasicBlock *ResDesBB) { + SmallPtrSet Set; + SmallVector Worklist; + + Set.insert(SaveBB); + Worklist.push_back(ResDesBB); + + // Accumulate all blocks between SaveBB and ResDesBB. Because CoroSaveIntr + // returns a token consumed by suspend instruction, all blocks in between + // will have to eventually hit SaveBB when going backwards from ResDesBB. + while (Worklist.empty()) { + auto *BB = Worklist.pop_back_val(); + Set.insert(BB); + for (auto *Pred: predecessors(BB)) + if (Set.count(Pred) == 0) + Worklist.push_back(Pred); } - auto *CallInstr = SingleCallSite.getInstruction(); - if (!CallInstr) + + // SaveBB and ResDesBB are checked separately in hasCallsBetween. + Set.erase(SaveBB); + Set.erase(ResDesBB); + + for (auto *BB: Set) + if (hasCallsInBlockBetween(BB->getFirstNonPHI(), nullptr)) + return true; + + return false; +} + +static bool hasCallsBetween(Instruction* Save, Instruction *ResumeOrDestroy) { + auto *SaveBB = Save->getParent(); + auto *ResumeOrDestroyBB = ResumeOrDestroy->getParent(); + + if (SaveBB == ResumeOrDestroyBB) + return hasCallsInBlockBetween(Save->getNextNode(), ResumeOrDestroy); + + // Any calls from Save to the end of the block? + if (hasCallsInBlockBetween(Save->getNextNode(), nullptr)) + return true; + + // Any calls from begging of the block up to ResumeOrDestroy? + if (hasCallsInBlockBetween(ResumeOrDestroyBB->getFirstNonPHI(), + ResumeOrDestroy)) + return true; + + // Any calls in all of the blocks between SaveBB and ResumeOrDestroyBB? + if (hasCallsInBlocksBetween(SaveBB, ResumeOrDestroyBB)) + return true; + + return false; +} + +// If a SuspendIntrin is preceded by Resume or Destroy, we can eliminate the +// suspend point and replace it with nornal control flow. +static bool simplifySuspendPoint(CoroSuspendInst *Suspend, + CoroBeginInst *CoroBegin) { + Instruction *Prev = Suspend->getPrevNode(); + if (!Prev) { + auto *Pred = Suspend->getParent()->getSinglePredecessor(); + if (!Pred) + return false; + Prev = Pred->getTerminator(); + } + + CallSite CS{Prev}; + if (!CS) return false; - auto *Callee = SingleCallSite.getCalledValue()->stripPointerCasts(); + auto *CallInstr = CS.getInstruction(); + + auto *Callee = CS.getCalledValue()->stripPointerCasts(); // See if the callsite is for resumption or destruction of the coroutine. auto *SubFn = dyn_cast(Callee); @@ -585,6 +634,13 @@ if (SubFn->getFrame() != CoroBegin) return false; + // See if the transformation is safe. Specifically, see if there are any + // calls in between Save and CallInstr. They can potenitally resume the + // coroutine rendering this optimization unsafe. + auto *Save = Suspend->getCoroSave(); + if (hasCallsBetween(Save, CallInstr)) + return false; + // Replace llvm.coro.suspend with the value that results in resumption over // the resume or cleanup path. Suspend->replaceAllUsesWith(SubFn->getRawIndex()); @@ -592,8 +648,20 @@ Save->eraseFromParent(); // No longer need a call to coro.resume or coro.destroy. + if (auto *Invoke = dyn_cast(CallInstr)) { + BranchInst::Create(Invoke->getNormalDest(), Invoke); + } + + // Grab the CalledValue from CS before erasing the CallInstr. + auto *CalledValue = CS.getCalledValue(); CallInstr->eraseFromParent(); + // If no more users remove it. Usually it is a bitcast of SubFn. + if (CalledValue != SubFn && CalledValue->user_empty()) + if (auto *I = dyn_cast(CalledValue)) + I->eraseFromParent(); + + // Now we are good to remove SubFn. if (SubFn->user_empty()) SubFn->eraseFromParent(); Index: test/Transforms/Coroutines/no-suspend.ll =================================================================== --- test/Transforms/Coroutines/no-suspend.ll +++ test/Transforms/Coroutines/no-suspend.ll @@ -37,15 +37,14 @@ } ; SimplifySuspendPoint will detect that coro.resume resumes itself and will -; replace suspend with a jump to %resume label turning it into no-suspend +; replace suspend with a jump to %resume label turning it into no-suspend ; coroutine. ; ; CHECK-LABEL: define void @simplify_resume( -; CHECK-NEXT: entry: -; CHECK-NEXT: call void @print(i32 0) +; CHECK: call void @print(i32 0) ; CHECK-NEXT: ret void ; -define void @simplify_resume() { +define void @simplify_resume(i8* %src, i8* %dst) { entry: %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) %need.dyn.alloc = call i1 @llvm.coro.alloc(token %id) @@ -60,6 +59,8 @@ br label %body body: %save = call token @llvm.coro.save(i8* %hdl) + ; memcpy intrinsics should be not prevent simlification. + call void @llvm.memcpy.p0i8.p0i8.i64(i8* %dst, i8* %src, i64 1, i1 false) call void @llvm.coro.resume(i8* %hdl) %0 = call i8 @llvm.coro.suspend(token %save, i1 false) switch i8 %0, label %suspend [i8 0, label %resume @@ -90,7 +91,7 @@ ; CHECK-NEXT: call void @print(i32 1) ; CHECK-NEXT: ret void ; -define void @simplify_destroy() { +define void @simplify_destroy() personality i32 0 { entry: %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) %need.dyn.alloc = call i1 @llvm.coro.alloc(token %id) @@ -105,7 +106,9 @@ br label %body body: %save = call token @llvm.coro.save(i8* %hdl) - call void @llvm.coro.destroy(i8* %hdl) + invoke void @llvm.coro.destroy(i8* %hdl) to label %real_susp unwind label %lpad + +real_susp: %0 = call i8 @llvm.coro.suspend(token %save, i1 false) switch i8 %0, label %suspend [i8 0, label %resume i8 1, label %pre.cleanup] @@ -124,17 +127,77 @@ suspend: call i1 @llvm.coro.end(i8* %hdl, i1 false) ret void +lpad: + %lpval = landingpad { i8*, i32 } + cleanup + + call void @print(i32 2) + resume { i8*, i32 } %lpval } +; SimplifySuspendPoint will detect that coro.resume resumes itself and will +; replace suspend with a jump to %resume label turning it into no-suspend +; coroutine. +; +; CHECK-LABEL: define void @simplify_resume_with_inlined_if( +; CHECK: call void @print(i32 0) +; CHECK-NEXT: ret void +; +define void @simplify_resume_with_inlined_if(i8* %src, i8* %dst, i1 %cond) { +entry: + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %need.dyn.alloc = call i1 @llvm.coro.alloc(token %id) + br i1 %need.dyn.alloc, label %dyn.alloc, label %coro.begin +dyn.alloc: + %size = call i32 @llvm.coro.size.i32() + %alloc = call i8* @malloc(i32 %size) + br label %coro.begin +coro.begin: + %phi = phi i8* [ null, %entry ], [ %alloc, %dyn.alloc ] + %hdl = call noalias i8* @llvm.coro.begin(token %id, i8* %phi) + br label %body +body: + %save = call token @llvm.coro.save(i8* %hdl) + br i1 %cond, label %if.then, label %if.else +if.then: + call void @llvm.memcpy.p0i8.p0i8.i64(i8* %dst, i8* %src, i64 1, i1 false) + br label %if.end +if.else: + call void @llvm.memcpy.p0i8.p0i8.i64(i8* %src, i8* %dst, i64 1, i1 false) + br label %if.end +if.end: + call void @llvm.coro.resume(i8* %hdl) + %0 = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %0, label %suspend [i8 0, label %resume + i8 1, label %pre.cleanup] +resume: + call void @print(i32 0) + br label %cleanup + +pre.cleanup: + call void @print(i32 1) + br label %cleanup + +cleanup: + %mem = call i8* @llvm.coro.free(token %id, i8* %hdl) + call void @free(i8* %mem) + br label %suspend +suspend: + call i1 @llvm.coro.end(i8* %hdl, i1 false) + ret void +} + + + ; SimplifySuspendPoint won't be able to simplify if it detects that there are ; other calls between coro.save and coro.suspend. They potentially can call ; resume or destroy, so we should not simplify this suspend point. ; -; CHECK-LABEL: define void @cannot_simplify( +; CHECK-LABEL: define void @cannot_simplify_other_calls( ; CHECK-NEXT: entry: ; CHECK-NEXT: call i8* @malloc -define void @cannot_simplify() { +define void @cannot_simplify_other_calls() { entry: %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) %need.dyn.alloc = call i1 @llvm.coro.alloc(token %id) @@ -171,6 +234,103 @@ ret void } +; SimplifySuspendPoint won't be able to simplify if it detects that there are +; other calls between coro.save and coro.suspend. They potentially can call +; resume or destroy, so we should not simplify this suspend point. +; +; CHECK-LABEL: define void @cannot_simplify_calls_in_terminator( +; CHECK-NEXT: entry: +; CHECK-NEXT: call i8* @malloc + +define void @cannot_simplify_calls_in_terminator() personality i32 0 { +entry: + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %need.dyn.alloc = call i1 @llvm.coro.alloc(token %id) + br i1 %need.dyn.alloc, label %dyn.alloc, label %coro.begin +dyn.alloc: + %size = call i32 @llvm.coro.size.i32() + %alloc = call i8* @malloc(i32 %size) + br label %coro.begin +coro.begin: + %phi = phi i8* [ null, %entry ], [ %alloc, %dyn.alloc ] + %hdl = call noalias i8* @llvm.coro.begin(token %id, i8* %phi) + br label %body +body: + %save = call token @llvm.coro.save(i8* %hdl) + invoke void @foo() to label %resume_cont unwind label %lpad +resume_cont: + call void @llvm.coro.destroy(i8* %hdl) + %0 = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %0, label %suspend [i8 0, label %resume + i8 1, label %pre.cleanup] +resume: + call void @print(i32 0) + br label %cleanup + +pre.cleanup: + call void @print(i32 1) + br label %cleanup + +cleanup: + %mem = call i8* @llvm.coro.free(token %id, i8* %hdl) + call void @free(i8* %mem) + br label %suspend +suspend: + call i1 @llvm.coro.end(i8* %hdl, i1 false) + ret void +lpad: + %lpval = landingpad { i8*, i32 } + cleanup + + call void @print(i32 2) + resume { i8*, i32 } %lpval +} + +; SimplifySuspendPoint won't be able to simplify if it detects that resume or +; destroy does not immediately preceed coro.suspend. +; +; CHECK-LABEL: define void @cannot_simplify_not_last_instr( +; CHECK-NEXT: entry: +; CHECK-NEXT: call i8* @malloc + +define void @cannot_simplify_not_last_instr(i8* %dst, i8* %src) { +entry: + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %need.dyn.alloc = call i1 @llvm.coro.alloc(token %id) + br i1 %need.dyn.alloc, label %dyn.alloc, label %coro.begin +dyn.alloc: + %size = call i32 @llvm.coro.size.i32() + %alloc = call i8* @malloc(i32 %size) + br label %coro.begin +coro.begin: + %phi = phi i8* [ null, %entry ], [ %alloc, %dyn.alloc ] + %hdl = call noalias i8* @llvm.coro.begin(token %id, i8* %phi) + br label %body +body: + %save = call token @llvm.coro.save(i8* %hdl) + call void @llvm.coro.destroy(i8* %hdl) + ; memcpy separates destory from suspend, therefore cannot simplify. + call void @llvm.memcpy.p0i8.p0i8.i64(i8* %dst, i8* %src, i64 1, i1 false) + %0 = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %0, label %suspend [i8 0, label %resume + i8 1, label %pre.cleanup] +resume: + call void @print(i32 0) + br label %cleanup + +pre.cleanup: + call void @print(i32 1) + br label %cleanup + +cleanup: + %mem = call i8* @llvm.coro.free(token %id, i8* %hdl) + call void @free(i8* %mem) + br label %suspend +suspend: + call i1 @llvm.coro.end(i8* %hdl, i1 false) + ret void +} + declare i8* @malloc(i32) declare void @free(i8*) declare void @print(i32) @@ -187,3 +347,5 @@ declare void @llvm.coro.resume(i8*) declare void @llvm.coro.destroy(i8*) + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1)