Index: lib/CodeGen/CGCoroutine.cpp =================================================================== --- lib/CodeGen/CGCoroutine.cpp +++ lib/CodeGen/CGCoroutine.cpp @@ -523,7 +523,59 @@ Builder.CreateCondBr(CoroAlloc, AllocBB, InitBB); EmitBlock(AllocBB); - auto *AllocateCall = EmitScalarExpr(S.getAllocate()); + // Emit the call to the coroutine frame allocation function. + auto *AllocateCall = cast(EmitScalarExpr(S.getAllocate())); + + // The backend coroutine split transform will move stores and loads for the + // coroutine function's arguments down past the first suspend point. + // + // However, in the case that the coroutine frame is, as per + // [dcl.fct.def.coroutine]/7, being allocated with an allocation function + // matching the coroutine function's arguments, we need to ensure that the + // allocation function is passed arguments that have values stored in them. + // + // Here, we generate instructions to store the coroutine function's arguments + // separately, and then pass them into the allocation function. First, we + // search for the coroutine function argument allocas that correspond to the + // arguments passed into the allocation function. + for (unsigned OpIdx = 0, OpEnd = AllocateCall->getNumArgOperands(); + OpIdx != OpEnd; ++OpIdx) { + if (auto *AllocateOp = + dyn_cast(AllocateCall->getArgOperand(OpIdx))) { + for (auto &AllocateOpOp : AllocateOp->operands()) { + if (auto *Alloca = dyn_cast(AllocateOpOp)) { + // We've found the alloca instruction. Now we search for the store + // instruction that stores the coroutine function argument into that + // alloca's address. + for (auto I = Alloca->user_begin(), E = Alloca->user_end(); I != E; + ++I) { + if (auto *Store = dyn_cast(*I)) { + // Now we generate an alloca, store, and a load, to replace the + // allocation function call instruction operands. + llvm::BasicBlock::iterator InsertPt = Builder.GetInsertPoint(); + Builder.SetInsertPoint(AllocateCall); + + llvm::AllocaInst *NewAlloca = Builder.CreateAlloca( + Alloca->getAllocatedType(), Alloca->getArraySize(), + "coro.allocate." + Alloca->getName()); + NewAlloca->setAlignment(Alloca->getAlignment()); + + Address NewAllocaAddress = { + NewAlloca, + CharUnits::fromQuantity(NewAlloca->getAlignment())}; + Builder.CreateStore(Store->getOperand(0), NewAllocaAddress); + + llvm::LoadInst *NewLoad = Builder.CreateLoad(NewAllocaAddress); + AllocateCall->setOperand(OpIdx, NewLoad); + + Builder.SetInsertPoint(AllocBB, InsertPt); + } + } + } + } + } + } + auto *AllocOrInvokeContBB = Builder.GetInsertBlock(); // Handle allocation failure if 'ReturnStmtOnAllocFailure' was provided. Index: lib/Sema/SemaCoroutine.cpp =================================================================== --- lib/Sema/SemaCoroutine.cpp +++ lib/Sema/SemaCoroutine.cpp @@ -1050,7 +1050,12 @@ const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr; - // FIXME: Add support for stateful allocators. + // [dcl.fct.def.coroutine]/7 + // Lookup allocation functions using a parameter list composed of the + // requested size of the coroutine state being allocated, followed by + // the coroutine function's arguments. If a matching allocation function + // exists, use it. Otherwise, use an allocation function that just takes + // the requested size. FunctionDecl *OperatorNew = nullptr; FunctionDecl *OperatorDelete = nullptr; @@ -1058,10 +1063,41 @@ bool PassAlignment = false; SmallVector PlacementArgs; + // [dcl.fct.def.coroutine]/7 + // "The allocation function’s name is looked up in the scope of P. + // [...] If the lookup finds an allocation function in the scope of P, + // overload resolution is performed on a function call created by assembling + // an argument list." + for (auto *PD : FD.parameters()) { + if (PD->getType()->isDependentType()) + continue; + + // Build a reference to the parameter. + auto PDLoc = PD->getLocation(); + ExprResult PDRefExpr = + S.BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(), + ExprValueKind::VK_LValue, PDLoc); + if (PDRefExpr.isInvalid()) + return false; + + PlacementArgs.push_back(PDRefExpr.get()); + } S.FindAllocationFunctions(Loc, SourceRange(), /*UseGlobal*/ false, PromiseType, /*isArray*/ false, PassAlignment, PlacementArgs, - OperatorNew, UnusedResult); + OperatorNew, UnusedResult, /*Diagnose*/ false); + + // [dcl.fct.def.coroutine]/7 + // "If no matching function is found, overload resolution is performed again + // on a function call created by passing just the amount of space required as + // an argument of type std::size_t." + if (!OperatorNew && !PlacementArgs.empty()) { + PlacementArgs.clear(); + S.FindAllocationFunctions(Loc, SourceRange(), + /*UseGlobal*/ false, PromiseType, + /*isArray*/ false, PassAlignment, + PlacementArgs, OperatorNew, UnusedResult); + } bool IsGlobalOverload = OperatorNew && !isa(OperatorNew->getDeclContext()); @@ -1080,7 +1116,8 @@ OperatorNew, UnusedResult); } - assert(OperatorNew && "expected definition of operator new to be found"); + if (!OperatorNew) + return false; if (RequiresNoThrowAlloc) { const auto *FT = OperatorNew->getType()->getAs(); Index: test/CodeGenCoroutines/coro-alloc.cpp =================================================================== --- test/CodeGenCoroutines/coro-alloc.cpp +++ test/CodeGenCoroutines/coro-alloc.cpp @@ -106,6 +106,34 @@ co_return; } +struct promise_matching_placement_new_tag {}; + +template<> +struct std::experimental::coroutine_traits { + struct promise_type { + void *operator new(unsigned long, promise_matching_placement_new_tag, + int, float, double); + void get_return_object() {} + suspend_always initial_suspend() { return {}; } + suspend_always final_suspend() { return {}; } + void return_void() {} + }; +}; + +// CHECK-LABEL: f1a( +extern "C" void f1a(promise_matching_placement_new_tag, int x, float y , double z) { + // CHECK: %[[ID:.+]] = call token @llvm.coro.id(i32 16 + // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64() + // CHECK: store i32 %x, i32* %coro.allocate.x.addr, align 4 + // CHECK: %[[INT:.+]] = load i32, i32* %coro.allocate.x.addr, align 4 + // CHECK: store float %y, float* %coro.allocate.y.addr, align 4 + // CHECK: %[[FLOAT:.+]] = load float, float* %coro.allocate.y.addr, align 4 + // CHECK: store double %z, double* %coro.allocate.z.addr, align 8 + // CHECK: %[[DOUBLE:.+]] = load double, double* %coro.allocate.z.addr, align 8 + // CHECK: call i8* @_ZNSt12experimental16coroutine_traitsIJv34promise_matching_placement_new_tagifdEE12promise_typenwEmS1_ifd(i64 %[[SIZE]], i32 %[[INT]], float %[[FLOAT]], double %[[DOUBLE]]) + co_return; +} + struct promise_delete_tag {}; template<>