Index: include/clang/Basic/Builtins.def =================================================================== --- include/clang/Basic/Builtins.def +++ include/clang/Basic/Builtins.def @@ -1458,6 +1458,9 @@ BUILTIN(__builtin_coro_end, "bv*Ib", "n") BUILTIN(__builtin_coro_suspend, "cIb", "n") BUILTIN(__builtin_coro_param, "bv*v*", "n") + +BUILTIN(__builtin_coro_frame_max_size, "z", "nc") + // OpenCL v2.0 s6.13.16, s9.17.3.5 - Pipe functions. // We need the generic prototype, since the packet type could be anything. LANGBUILTIN(read_pipe, "i.", "tn", OCLC20_LANG) Index: lib/CodeGen/CGBuiltin.cpp =================================================================== --- lib/CodeGen/CGBuiltin.cpp +++ lib/CodeGen/CGBuiltin.cpp @@ -3377,6 +3377,9 @@ case Builtin::BI__builtin_coro_param: return EmitCoroutineIntrinsic(E, Intrinsic::coro_param); + case Builtin::BI__builtin_coro_frame_max_size: + return EmitCoroutineFrameMaxSize(E); + // OpenCL v2.0 s6.13.16.2, Built-in pipe read and write functions case Builtin::BIread_pipe: case Builtin::BIwrite_pipe: { Index: lib/CodeGen/CGCoroutine.cpp =================================================================== --- lib/CodeGen/CGCoroutine.cpp +++ lib/CodeGen/CGCoroutine.cpp @@ -14,8 +14,10 @@ #include "CGCleanup.h" #include "CodeGenFunction.h" #include "llvm/ADT/ScopeExit.h" +#include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/StmtCXX.h" #include "clang/AST/StmtVisitor.h" +#include "clang/Basic/TargetInfo.h" using namespace clang; using namespace CodeGen; @@ -758,3 +760,172 @@ } return RValue::get(Call); } + +namespace { +class FrameSizeBuilder { + const ASTContext &Context; + uint64_t FrameSize = 0; + +public: + FrameSizeBuilder(ASTContext &Context) : Context(Context) {} + + uint64_t getFrameSize() const { return FrameSize; } + + void addType(const QualType Ty) { + if ((FrameSize & (Context.getTypeAlign(Ty) - 1)) != 0) + FrameSize = llvm::alignTo(FrameSize, + Context.getTypeAlignInChars(Ty).getQuantity()); + FrameSize += Context.getTypeSizeInChars(Ty).getQuantity(); + } + + void addType(TargetInfo::IntType Ty) { + const TargetInfo &TI = Context.getTargetInfo(); + if ((FrameSize & (TI.getTypeAlign(Ty) - 1)) != 0) + FrameSize = llvm::alignTo( + FrameSize, + Context.toCharUnitsFromBits(TI.getTypeAlign(Ty)).getQuantity()); + FrameSize += TI.getTypeWidth(Ty); + } +}; + +class PotentialSpillsVisitor + : public RecursiveASTVisitor { + FrameSizeBuilder &Builder; + const QualType HandleTy; + +public: + PotentialSpillsVisitor(FrameSizeBuilder &Builder, QualType HandleTy) + : Builder(Builder), HandleTy(HandleTy) {} + + bool shouldVisitImplicitCode() const { return true; } + + // Add up the size of any variables explicitly declared within the + // coroutine, as well as the implicit __promise and __coro_gro variables. + bool VisitVarDecl(VarDecl *VD) { + Builder.addType(VD->getType()); + return true; + } + + // Add up the size of any temporaries materialized within the coroutine, as + // well as the implicit temporaries materialized when: + // 1. __promise.initial_suspend and .final_suspend are called to construct + // awaitables + // 2. Coroutine handles are materialized via .from_address and + // passed in as arguments to __builtin_coro_frame. Coroutine handles also + // get bitcast to void* in LLVM IR and the bitcast spills, so the bitcast + // also needs to be accounted for in the coroutine frame size + // 3. __promise.get_return_object is called to construct return objects + bool VisitMaterializeTemporaryExpr(MaterializeTemporaryExpr *E) { + QualType ETy = E->GetTemporaryExpr()->getType(); + Builder.addType(ETy); + + // LLVM spills the bitcast for this type, causing an additional increase + // in coroutine frame size. + if (ETy == HandleTy) + Builder.addType(ETy); + + return true; + } +}; +}; + +static RValue emitFrameMaxSizeError(const CodeGenFunction &CGF, + const CallExpr *E, StringRef Err) { + CGF.CGM.Error(E->getExprLoc(), "__builtin_coro_frame_max_size " + Err.str()); + return RValue::get( + llvm::Constant::getIntegerValue(CGF.Int64Ty, llvm::APInt(64, 0))); +} + +RValue CodeGenFunction::EmitCoroutineFrameMaxSize(const CallExpr *E) { + auto *FD = dyn_cast_or_null(CurFuncDecl); + if (!FD) + return emitFrameMaxSizeError(*this, E, + "must be used from within a function"); + + const TemplateArgumentList *Args = FD->getTemplateSpecializationArgs(); + if (Args->size() != 1) + return emitFrameMaxSizeError( + *this, E, + "must be called from a function with a single template argument"); + + const TemplateArgument &Arg = Args->get(0); + if (Arg.getKind() != TemplateArgument::ArgKind::Type) + return emitFrameMaxSizeError( + *this, E, + "must be called from a function with a single template type argument"); + + const RecordType *RTy = cast(Arg.getAsType().getTypePtr()); + CXXRecordDecl *RD = cast(RTy->getDecl()); + if (!RD->isLambda()) + return emitFrameMaxSizeError(*this, E, + "must be called from a function whose single " + "template type argument is a lambda type"); + + CXXMethodDecl *MD = RD->getLambdaCallOperator(); + if (!MD->doesThisDeclarationHaveABody()) + return emitFrameMaxSizeError( + *this, E, + "must be called from a function whose single template type argument is " + "a lambda with a body"); + + auto *CoroBody = dyn_cast_or_null(MD->getBody()); + if (!CoroBody) + return emitFrameMaxSizeError( + *this, E, + "must be called from a function whose single template type argument is " + "a coroutine lambda with a body"); + + // Build the maximum potential coroutine frame size, adding alignment padding + // as necessary. + ASTContext &Ctx = getContext(); + FrameSizeBuilder Builder(Ctx); + + // Add space for the two llvm.coro.subfn pointers LLVM adds to the coroutine + // frame. + Builder.addType(Ctx.VoidPtrTy); + Builder.addType(Ctx.VoidPtrTy); + // The coroutine resume index added by LLVM to the coroutine frame will use + // the smallest viable integer type, which in some cases might be i1 (1 byte) + // vs. i64 (8 bytes). So the maximum size here might be 7 bytes larger than + // actual. + Builder.addType(Ctx.getTargetInfo().getInt64Type()); + // Add space for the promise that LLVM adds to the coroutine frame. + auto *PromiseDecl = CoroBody->getPromiseDeclStmt(); + QualType PromiseType; + if (auto *PromiseDeclStmt = dyn_cast(PromiseDecl)) + if (auto *PromiseVarDecl = + dyn_cast(PromiseDeclStmt->getSingleDecl())) + PromiseType = PromiseVarDecl->getType(); + Builder.addType(PromiseType); + + // Next, we add space for any and all variables or temporaries that may spill + // across suspend point boundaries, and thus may be moved onto the coroutine + // frame. In reality LLVM may elide many of these moves, but this function + // calculates the upper bound for the frame size, so we must be pessimistic + // and assume all variables end up on the frame. + + // First, the implicit 'this' variable on the lambda. + Builder.addType(MD->getThisType(Ctx)); + + // Next, each of the parameters to the lambda. + for (auto *PD : FD->parameters()) + Builder.addType(PD->getType()); + + // Finally, any temporaries that may be materialized -- especially coroutine + // handles, for which LLVM introduces a bitcast and so must be counted twice. + QualType HandleTy; + if (auto *InitSuspend = + dyn_cast(CoroBody->getInitSuspendStmt())) + if (auto *Coawait = dyn_cast(InitSuspend->getSubExpr())) + if (auto *AwaitSuspend = + dyn_cast(Coawait->getSuspendExpr())) + if (auto *HandleCtor = + dyn_cast(AwaitSuspend->getArg(0))) + if (auto *Cast = dyn_cast(HandleCtor->getArg(0))) + HandleTy = Cast->getSubExpr()->getType(); + PotentialSpillsVisitor V(Builder, HandleTy); + V.TraverseStmt(CoroBody); + + return RValue::get(llvm::Constant::getIntegerValue( + Int64Ty, llvm::APInt(64, Builder.getFrameSize()))); +} Index: lib/CodeGen/CodeGenFunction.h =================================================================== --- lib/CodeGen/CodeGenFunction.h +++ lib/CodeGen/CodeGenFunction.h @@ -2880,6 +2880,7 @@ bool ignoreResult = false); LValue EmitCoyieldLValue(const CoyieldExpr *E); RValue EmitCoroutineIntrinsic(const CallExpr *E, unsigned int IID); + RValue EmitCoroutineFrameMaxSize(const CallExpr *E); void EnterCXXTryStmt(const CXXTryStmt &S, bool IsFnTryBlock = false); void ExitCXXTryStmt(const CXXTryStmt &S, bool IsFnTryBlock = false);