diff --git a/clang/lib/CodeGen/TargetInfo.cpp b/clang/lib/CodeGen/TargetInfo.cpp --- a/clang/lib/CodeGen/TargetInfo.cpp +++ b/clang/lib/CodeGen/TargetInfo.cpp @@ -10228,12 +10228,23 @@ private: void setCCs(); }; + +class SPIRVABIInfo : public CommonSPIRABIInfo { +public: + SPIRVABIInfo(CodeGenTypes &CGT) : CommonSPIRABIInfo(CGT) {} + void computeInfo(CGFunctionInfo &FI) const override; + +private: + ABIArgInfo classifyKernelArgumentType(QualType Ty) const; +}; } // end anonymous namespace namespace { class CommonSPIRTargetCodeGenInfo : public TargetCodeGenInfo { public: CommonSPIRTargetCodeGenInfo(CodeGen::CodeGenTypes &CGT) : TargetCodeGenInfo(std::make_unique(CGT)) {} + CommonSPIRTargetCodeGenInfo(std::unique_ptr ABIInfo) + : TargetCodeGenInfo(std::move(ABIInfo)) {} LangAS getASTAllocaAddressSpace() const override { return getLangASFromTargetAS( @@ -10242,18 +10253,60 @@ unsigned getOpenCLKernelCallingConv() const override; }; - +class SPIRVTargetCodeGenInfo : public CommonSPIRTargetCodeGenInfo { +public: + SPIRVTargetCodeGenInfo(CodeGen::CodeGenTypes &CGT) + : CommonSPIRTargetCodeGenInfo(std::make_unique(CGT)) {} + void setCUDAKernelCallingConvention(const FunctionType *&FT) const override; +}; } // End anonymous namespace. + void CommonSPIRABIInfo::setCCs() { assert(getRuntimeCC() == llvm::CallingConv::C); RuntimeCC = llvm::CallingConv::SPIR_FUNC; } +ABIArgInfo SPIRVABIInfo::classifyKernelArgumentType(QualType Ty) const { + if (getContext().getLangOpts().HIP) { + // Coerce pointer arguments with default address space to CrossWorkGroup + // pointers for HIPSPV. When the language mode is HIP, the SPIRTargetInfo + // maps cuda_device to SPIR-V's CrossWorkGroup address space. + llvm::Type *LTy = CGT.ConvertType(Ty); + auto DefaultAS = getContext().getTargetAddressSpace(LangAS::Default); + auto GlobalAS = getContext().getTargetAddressSpace(LangAS::cuda_device); + if (LTy->isPointerTy() && LTy->getPointerAddressSpace() == DefaultAS) { + LTy = llvm::PointerType::get( + cast(LTy)->getElementType(), GlobalAS); + return ABIArgInfo::getDirect(LTy, 0, nullptr, false); + } + } + return classifyArgumentType(Ty); +} + +void SPIRVABIInfo::computeInfo(CGFunctionInfo &FI) const { + // The logic is same as in DefaultABIInfo with an exception on the kernel + // arguments handling. + llvm::CallingConv::ID CC = FI.getCallingConvention(); + + if (!getCXXABI().classifyReturnType(FI)) + FI.getReturnInfo() = classifyReturnType(FI.getReturnType()); + + for (auto &I : FI.arguments()) { + if (CC == llvm::CallingConv::SPIR_KERNEL) { + I.info = classifyKernelArgumentType(I.type); + } else { + I.info = classifyArgumentType(I.type); + } + } +} + namespace clang { namespace CodeGen { void computeSPIRKernelABIInfo(CodeGenModule &CGM, CGFunctionInfo &FI) { - DefaultABIInfo SPIRABI(CGM.getTypes()); - SPIRABI.computeInfo(FI); + if (CGM.getTarget().getTriple().isSPIRV()) + SPIRVABIInfo(CGM.getTypes()).computeInfo(FI); + else + CommonSPIRABIInfo(CGM.getTypes()).computeInfo(FI); } } } @@ -10262,6 +10315,16 @@ return llvm::CallingConv::SPIR_KERNEL; } +void SPIRVTargetCodeGenInfo::setCUDAKernelCallingConvention( + const FunctionType *&FT) const { + // Convert HIP kernels to SPIR-V kernels. + if (getABIInfo().getContext().getLangOpts().HIP) { + FT = getABIInfo().getContext().adjustFunctionType( + FT, FT->getExtInfo().withCallingConv(CC_OpenCLKernel)); + return; + } +} + static bool appendType(SmallStringEnc &Enc, QualType QType, const CodeGen::CodeGenModule &CGM, TypeStringCache &TSC); @@ -11327,9 +11390,10 @@ return SetCGInfo(new ARCTargetCodeGenInfo(Types)); case llvm::Triple::spir: case llvm::Triple::spir64: + return SetCGInfo(new CommonSPIRTargetCodeGenInfo(Types)); case llvm::Triple::spirv32: case llvm::Triple::spirv64: - return SetCGInfo(new CommonSPIRTargetCodeGenInfo(Types)); + return SetCGInfo(new SPIRVTargetCodeGenInfo(Types)); case llvm::Triple::ve: return SetCGInfo(new VETargetCodeGenInfo(Types)); } diff --git a/clang/test/CodeGenHIP/hipspv-kernel.cpp b/clang/test/CodeGenHIP/hipspv-kernel.cpp new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenHIP/hipspv-kernel.cpp @@ -0,0 +1,9 @@ +// RUN: %clang_cc1 -triple spirv64 -x hip -emit-llvm -fcuda-is-device \ +// RUN: -o - %s | FileCheck %s + +#define __global__ __attribute__((global)) + +// CHECK: define {{.*}}spir_kernel void @_Z3fooPff(float addrspace(1)* {{.*}}, float {{.*}}) +__global__ void foo(float *a, float b) { + *a = b; +}