Index: lib/Transforms/Coroutines/CoroEarly.cpp =================================================================== --- lib/Transforms/Coroutines/CoroEarly.cpp +++ lib/Transforms/Coroutines/CoroEarly.cpp @@ -13,6 +13,7 @@ #include "CoroInternal.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" @@ -24,10 +25,18 @@ namespace { // Created on demand if CoroEarly pass has work to do. class Lowerer : public coro::LowererBase { + IRBuilder<> Builder; + PointerType *AnyResumeFnPtrTy; + void lowerResumeOrDestroy(CallSite CS, CoroSubFnInst::ResumeKind); + void lowerCoroPromise(CoroPromiseInst *Intrin); public: - Lowerer(Module &M) : LowererBase(M) {} + Lowerer(Module &M) + : LowererBase(M), Builder(Context), + AnyResumeFnPtrTy(FunctionType::get(Type::getVoidTy(Context), Int8Ptr, + /*isVarArg=*/false) + ->getPointerTo()) {} bool lowerEarlyIntrinsics(Function &F); }; } @@ -44,6 +53,34 @@ CS.setCallingConv(CallingConv::Fast); } +// Coroutine promise field is always at the fixed offset from the beginning of +// the coroutine frame. i8* coro.promise(i8*, i1 from) intrinsic adds an offset +// to a passed pointer to move from coroutine frame to coroutine promise and +// vice versa. Since we don't know exactly which coroutine frame it is, we build +// a coroutine frame mock up starting with two function pointers, followed by a +// properly aligned coroutine promise field. +// TODO: Handle the case when coroutine promise alloca has align override. +void Lowerer::lowerCoroPromise(CoroPromiseInst *Intrin) { + Value *Operand = Intrin->getArgOperand(0); + int64_t Alignement = Intrin->getAlignment(); + Type *Int8Ty = Type::getInt8Ty(Context); + + auto SampleStruct = + StructType::get(Context, {AnyResumeFnPtrTy, AnyResumeFnPtrTy, Int8Ty}); + const DataLayout &DL = TheModule.getDataLayout(); + int64_t Offset = alignTo( + DL.getStructLayout(SampleStruct)->getElementOffset(2), Alignement); + if (Intrin->isFromPromise()) + Offset = -Offset; + + Builder.SetInsertPoint(Intrin); + Value *Replacement = + Builder.CreateConstInBoundsGEP1_32(Int8Ty, Operand, Offset); + + Intrin->replaceAllUsesWith(Replacement); + Intrin->eraseFromParent(); +} + // Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate, // as CoroSplit assumes there is exactly one coro.begin. After CoroSplit, // NoDuplicate attribute will be removed from coro.begin otherwise, it will @@ -91,6 +128,9 @@ case Intrinsic::coro_destroy: lowerResumeOrDestroy(CS, CoroSubFnInst::DestroyIndex); break; + case Intrinsic::coro_promise: + lowerCoroPromise(cast(&I)); + break; } Changed = true; } Index: lib/Transforms/Coroutines/CoroFrame.cpp =================================================================== --- lib/Transforms/Coroutines/CoroFrame.cpp +++ lib/Transforms/Coroutines/CoroFrame.cpp @@ -311,8 +311,11 @@ // Figure out how wide should be an integer type storing the suspend index. unsigned IndexBits = std::max(1U, Log2_64_Ceil(Shape.CoroSuspends.size())); - - SmallVector Types{FnPtrTy, FnPtrTy, Type::getIntNTy(C, IndexBits)}; + Type *PromiseType = Shape.PromiseAlloca + ? Shape.PromiseAlloca->getType()->getElementType() + : Type::getInt1Ty(C); + SmallVector Types{FnPtrTy, FnPtrTy, PromiseType, + Type::getIntNTy(C, IndexBits)}; Value *CurrentDef = nullptr; // Create an entry for every spilled value. @@ -321,6 +324,9 @@ continue; CurrentDef = S.def(); + // PromiseAlloca was already added to Types array earlier. + if (CurrentDef == Shape.PromiseAlloca) + continue; Type *Ty = nullptr; if (auto *AI = dyn_cast(CurrentDef)) @@ -376,6 +382,9 @@ // we remember allocas and their indices to be handled once we processed // all the spills. SmallVector, 4> Allocas; + // Promise alloca (if present) has a fixed field number (Shape::PromiseField) + if (Shape.PromiseAlloca) + Allocas.emplace_back(Shape.PromiseAlloca, coro::Shape::PromiseField); // Create a load instruction to reload the spilled value from the coroutine // frame. @@ -400,7 +409,7 @@ ++Index; if (auto *AI = dyn_cast(CurrentValue)) { - // Spiled AllocaInst will be replaced with GEP from the coroutine frame + // Spilled AllocaInst will be replaced with GEP from the coroutine frame // there is no spill required. Allocas.emplace_back(AI, Index); if (!AI->isStaticAlloca()) @@ -444,7 +453,11 @@ for (auto &P : Allocas) { auto *G = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, P.second); - ReplaceInstWithInst(P.first, cast(G)); + // We are not using ReplaceInstWithInst(P.first, cast(G)) here, + // as we are changing location of the instruction. + G->takeName(P.first); + P.first->replaceAllUsesWith(G); + P.first->eraseFromParent(); } return FramePtr; } @@ -568,6 +581,10 @@ } void coro::buildCoroutineFrame(Function &F, Shape &Shape) { + Shape.PromiseAlloca = Shape.CoroBegin->getId()->getPromise(); + if (Shape.PromiseAlloca) { + Shape.CoroBegin->getId()->clearPromise(); + } // Make sure that all coro.saves and the fallthrough coro.end are in their // own block to simplify the logic of building up SuspendCrossing data. @@ -621,6 +638,10 @@ // in a coroutine. It should not be saved to the coroutine frame. if (isa(&I)) continue; + // The Coroutine Promise always included into coroutine frame, no need to + // check for suspend crossing. + if (Shape.PromiseAlloca == &I) + continue; for (User *U : I.users()) if (Checker.isDefinitionAcrossSuspend(I, U)) { Index: lib/Transforms/Coroutines/CoroInstr.h =================================================================== --- lib/Transforms/Coroutines/CoroInstr.h +++ lib/Transforms/Coroutines/CoroInstr.h @@ -80,6 +80,39 @@ enum { AlignArg, PromiseArg, CoroutineArg, InfoArg }; public: + IntrinsicInst *getCoroBegin() { + for (User *U : users()) + if (auto *II = dyn_cast(U)) + if (II->getIntrinsicID() == Intrinsic::coro_begin) + return II; + llvm_unreachable("no coro.begin associated with coro.id"); + } + + AllocaInst *getPromise() const { + Value *Arg = getArgOperand(PromiseArg); + return isa(Arg) + ? nullptr + : cast(Arg->stripPointerCasts()); + } + + void clearPromise() { + Value *Arg = getArgOperand(PromiseArg); + setArgOperand(PromiseArg, + ConstantPointerNull::get(Type::getInt8PtrTy(getContext()))); + if (isa(Arg)) + return; + assert((isa(Arg) || isa(Arg)) && + "unexpected instruction designating the promise"); + // TODO: Add a check that any remaining users of Inst are after coro.begin + // or add code to move the users after coro.begin. + auto Inst = cast(Arg); + if (Inst->use_empty()) { + Inst->eraseFromParent(); + return; + } else + Inst->moveBefore(getCoroBegin()->getNextNode()); + } + // Info argument of coro.id is // fresh out of the frontend: null ; // outlined : {Init, Return, Susp1, Susp2, ...} ; @@ -198,6 +231,27 @@ } }; +/// This represents the llvm.coro.done instruction. +class LLVM_LIBRARY_VISIBILITY CoroPromiseInst : public IntrinsicInst { + enum { FrameArg, AlignArg, FromArg }; + +public: + bool isFromPromise() const { + return cast(getArgOperand(FromArg))->isOneValue(); + } + int64_t getAlignment() const { + return cast(getArgOperand(AlignArg))->getSExtValue(); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_promise; + } + static inline bool classof(const Value *V) { + return isa(V) && classof(cast(V)); + } +}; + /// This represents the llvm.coro.suspend instruction. class LLVM_LIBRARY_VISIBILITY CoroSuspendInst : public IntrinsicInst { enum { SaveArg, FinalArg }; Index: lib/Transforms/Coroutines/CoroInternal.h =================================================================== --- lib/Transforms/Coroutines/CoroInternal.h +++ lib/Transforms/Coroutines/CoroInternal.h @@ -74,17 +74,19 @@ enum { ResumeField, DestroyField, + PromiseField, IndexField, LastKnownField = IndexField }; StructType *FrameTy; Instruction *FramePtr; - BasicBlock* AllocaSpillBlock; - SwitchInst* ResumeSwitch; + BasicBlock *AllocaSpillBlock; + SwitchInst *ResumeSwitch; + AllocaInst *PromiseAlloca; bool HasFinalSuspend; - IntegerType* getIndexType() const { + IntegerType *getIndexType() const { assert(FrameTy && "frame type not assigned"); return cast(FrameTy->getElementType(IndexField)); } @@ -97,7 +99,7 @@ void buildFrom(Function &F); }; -void buildCoroutineFrame(Function& F, Shape& Shape); +void buildCoroutineFrame(Function &F, Shape &Shape); } // End namespace coro. } // End namespace llvm Index: lib/Transforms/Coroutines/Coroutines.cpp =================================================================== --- lib/Transforms/Coroutines/Coroutines.cpp +++ lib/Transforms/Coroutines/Coroutines.cpp @@ -198,6 +198,7 @@ Shape.FramePtr = nullptr; Shape.AllocaSpillBlock = nullptr; Shape.ResumeSwitch = nullptr; + Shape.PromiseAlloca = nullptr; Shape.HasFinalSuspend = false; } Index: test/Transforms/Coroutines/ex4.ll =================================================================== --- /dev/null +++ test/Transforms/Coroutines/ex4.ll @@ -0,0 +1,71 @@ +; Fourth example from Doc/Coroutines.rst (coroutine promise) +; RUN: opt < %s -O2 -enable-coroutines -S | FileCheck %s + +define i8* @f(i32 %n) { +entry: + %promise = alloca i32 + %pv = bitcast i32* %promise to i8* + %id = call token @llvm.coro.id(i32 0, i8* %pv, i8* null, i8* null) + %need.dyn.alloc = call i1 @llvm.coro.alloc(token %id) + br i1 %need.dyn.alloc, label %dyn.alloc, label %coro.begin +dyn.alloc: + %size = call i32 @llvm.coro.size.i32() + %alloc = call i8* @malloc(i32 %size) + br label %coro.begin +coro.begin: + %phi = phi i8* [ null, %entry ], [ %alloc, %dyn.alloc ] + %hdl = call noalias i8* @llvm.coro.begin(token %id, i8* %phi) + br label %loop +loop: + %n.val = phi i32 [ %n, %coro.begin ], [ %inc, %loop ] + %inc = add nsw i32 %n.val, 1 + store i32 %n.val, i32* %promise + %0 = call i8 @llvm.coro.suspend(token none, i1 false) + switch i8 %0, label %suspend [i8 0, label %loop + i8 1, label %cleanup] +cleanup: + %mem = call i8* @llvm.coro.free(token %id, i8* %hdl) + call void @free(i8* %mem) + br label %suspend +suspend: + call void @llvm.coro.end(i8* %hdl, i1 false) + ret i8* %hdl +} + +; CHECK-LABEL: @main +define i32 @main() { +entry: + %hdl = call i8* @f(i32 4) + %promise.addr.raw = call i8* @llvm.coro.promise(i8* %hdl, i32 4, i1 false) + %promise.addr = bitcast i8* %promise.addr.raw to i32* + %val0 = load i32, i32* %promise.addr + call void @print(i32 %val0) + call void @llvm.coro.resume(i8* %hdl) + %val1 = load i32, i32* %promise.addr + call void @print(i32 %val1) + call void @llvm.coro.resume(i8* %hdl) + %val2 = load i32, i32* %promise.addr + call void @print(i32 %val2) + call void @llvm.coro.destroy(i8* %hdl) + ret i32 0 +; CHECK: call void @print(i32 4) +; CHECK-NEXT: call void @print(i32 5) +; CHECK-NEXT: call void @print(i32 6) +; CHECK: ret i32 0 +} + +declare i8* @llvm.coro.promise(i8*, i32, i1) +declare i8* @malloc(i32) +declare void @free(i8*) +declare void @print(i32) + +declare token @llvm.coro.id(i32, i8*, i8*, i8*) +declare i1 @llvm.coro.alloc(token) +declare i32 @llvm.coro.size.i32() +declare i8* @llvm.coro.begin(token, i8*) +declare i8 @llvm.coro.suspend(token, i1) +declare i8* @llvm.coro.free(token, i8*) +declare void @llvm.coro.end(i8*, i1) + +declare void @llvm.coro.resume(i8*) +declare void @llvm.coro.destroy(i8*)