diff --git a/clang/lib/CodeGen/CGGPUBuiltin.cpp b/clang/lib/CodeGen/CGGPUBuiltin.cpp --- a/clang/lib/CodeGen/CGGPUBuiltin.cpp +++ b/clang/lib/CodeGen/CGGPUBuiltin.cpp @@ -21,13 +21,14 @@ using namespace clang; using namespace CodeGen; -static llvm::Function *GetVprintfDeclaration(llvm::Module &M) { +namespace { +llvm::Function *GetVprintfDeclaration(llvm::Module &M) { llvm::Type *ArgTypes[] = {llvm::Type::getInt8PtrTy(M.getContext()), llvm::Type::getInt8PtrTy(M.getContext())}; llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get( llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false); - if (auto* F = M.getFunction("vprintf")) { + if (auto *F = M.getFunction("vprintf")) { // Our CUDA system header declares vprintf with the right signature, so // nobody else should have been able to declare vprintf with a bogus // signature. @@ -66,39 +67,24 @@ // // Note that by the time this function runs, E's args have already undergone the // standard C vararg promotion (short -> int, float -> double, etc.). -RValue -CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { - assert(getTarget().getTriple().isNVPTX()); - assert(E->getBuiltinCallee() == Builtin::BIprintf); - assert(E->getNumArgs() >= 1); // printf always has at least one arg. - const llvm::DataLayout &DL = CGM.getDataLayout(); - llvm::LLVMContext &Ctx = CGM.getLLVMContext(); +std::pair +packArgsIntoNVPTXFormatBuffer(CodeGenFunction *CGF, const CallArgList &Args) { - CallArgList Args; - EmitCallArgs(Args, - E->getDirectCallee()->getType()->getAs(), - E->arguments(), E->getDirectCallee(), - /* ParamsToSkip = */ 0); - - // We don't know how to emit non-scalar varargs. - if (llvm::any_of(llvm::drop_begin(Args), [&](const CallArg &A) { - return !A.getRValue(*this).isScalar(); - })) { - CGM.ErrorUnsupported(E, "non-scalar arg to printf"); - return RValue::get(llvm::ConstantInt::get(IntTy, 0)); - } + const llvm::DataLayout &DL = CGF->CGM.getDataLayout(); + llvm::LLVMContext &Ctx = CGF->CGM.getLLVMContext(); + CGBuilderTy &Builder = CGF->Builder; // Construct and fill the args buffer that we'll pass to vprintf. llvm::Value *BufferPtr; if (Args.size() <= 1) { - // If there are no args, pass a null pointer to vprintf. + // If there are no args, pass a null pointer and size 0 BufferPtr = llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(Ctx)); + return {BufferPtr, llvm::TypeSize::Fixed(0)}; } else { llvm::SmallVector ArgTypes; for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) - ArgTypes.push_back(Args[I].getRValue(*this).getScalarVal()->getType()); + ArgTypes.push_back(Args[I].getRValue(*CGF).getScalarVal()->getType()); // Using llvm::StructType is correct only because printf doesn't accept // aggregates. If we had to handle aggregates here, we'd have to manually @@ -106,18 +92,43 @@ // that the alignment of the llvm type was the same as the alignment of the // clang type. llvm::Type *AllocaTy = llvm::StructType::create(ArgTypes, "printf_args"); - llvm::Value *Alloca = CreateTempAlloca(AllocaTy); + llvm::Value *Alloca = CGF->CreateTempAlloca(AllocaTy); for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) { llvm::Value *P = Builder.CreateStructGEP(AllocaTy, Alloca, I - 1); - llvm::Value *Arg = Args[I].getRValue(*this).getScalarVal(); + llvm::Value *Arg = Args[I].getRValue(*CGF).getScalarVal(); Builder.CreateAlignedStore(Arg, P, DL.getPrefTypeAlign(Arg->getType())); } BufferPtr = Builder.CreatePointerCast(Alloca, llvm::Type::getInt8PtrTy(Ctx)); + return {BufferPtr, DL.getTypeAllocSize(AllocaTy)}; + } +} +} // namespace + +RValue +CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E, + ReturnValueSlot ReturnValue) { + assert(getTarget().getTriple().isNVPTX()); + assert(E->getBuiltinCallee() == Builtin::BIprintf); + assert(E->getNumArgs() >= 1); // printf always has at least one arg. + + CallArgList Args; + EmitCallArgs(Args, + E->getDirectCallee()->getType()->getAs(), + E->arguments(), E->getDirectCallee(), + /* ParamsToSkip = */ 0); + + // We don't know how to emit non-scalar varargs. + if (llvm::any_of(llvm::drop_begin(Args), [&](const CallArg &A) { + return !A.getRValue(*this).isScalar(); + })) { + CGM.ErrorUnsupported(E, "non-scalar arg to printf"); + return RValue::get(llvm::ConstantInt::get(IntTy, 0)); } + llvm::Value *BufferPtr = packArgsIntoNVPTXFormatBuffer(this, Args).first; // Invoke vprintf and return. - llvm::Function* VprintfFunc = GetVprintfDeclaration(CGM.getModule()); + llvm::Function *VprintfFunc = GetVprintfDeclaration(CGM.getModule()); return RValue::get(Builder.CreateCall( VprintfFunc, {Args[0].getRValue(*this).getScalarVal(), BufferPtr})); }