diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h --- a/clang/include/clang/AST/ASTContext.h +++ b/clang/include/clang/AST/ASTContext.h @@ -538,6 +538,9 @@ /// need them (like static local vars). llvm::MapVector MangleNumbers; llvm::MapVector StaticLocalNumbers; + /// Mapping the associated device lambda mangling number if present. + mutable llvm::DenseMap + DeviceLambdaManglingNumbers; /// Mapping that stores parameterIndex values for ParmVarDecls when /// that value exceeds the bitfield size of ParmVarDeclBits.ParameterIndex. diff --git a/clang/include/clang/AST/DeclCXX.h b/clang/include/clang/AST/DeclCXX.h --- a/clang/include/clang/AST/DeclCXX.h +++ b/clang/include/clang/AST/DeclCXX.h @@ -1735,6 +1735,12 @@ getLambdaData().HasKnownInternalLinkage = HasKnownInternalLinkage; } + /// Set the device side mangling number. + void setDeviceLambdaManglingNumber(unsigned Num) const; + + /// Retrieve the device side mangling number. + unsigned getDeviceLambdaManglingNumber() const; + /// Returns the inheritance model used for this record. MSInheritanceModel getMSInheritanceModel() const; diff --git a/clang/include/clang/AST/Mangle.h b/clang/include/clang/AST/Mangle.h --- a/clang/include/clang/AST/Mangle.h +++ b/clang/include/clang/AST/Mangle.h @@ -107,6 +107,9 @@ virtual bool shouldMangleCXXName(const NamedDecl *D) = 0; virtual bool shouldMangleStringLiteral(const StringLiteral *SL) = 0; + virtual bool isDeviceMangleContext() const { return false; } + virtual void setDeviceMangleContext(bool) {} + // FIXME: consider replacing raw_ostream & with something like SmallString &. void mangleName(GlobalDecl GD, raw_ostream &); virtual void mangleCXXName(GlobalDecl GD, raw_ostream &) = 0; diff --git a/clang/include/clang/AST/MangleNumberingContext.h b/clang/include/clang/AST/MangleNumberingContext.h --- a/clang/include/clang/AST/MangleNumberingContext.h +++ b/clang/include/clang/AST/MangleNumberingContext.h @@ -52,6 +52,11 @@ /// this context. virtual unsigned getManglingNumber(const TagDecl *TD, unsigned MSLocalManglingNumber) = 0; + + /// Retrieve the mangling number of a new lambda expression with the + /// given call operator within the device context. No device number is + /// assigned if there's no device numbering context is associated. + virtual unsigned getDeviceManglingNumber(const CXXMethodDecl *) { return 0; } }; } // end namespace clang diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -6558,7 +6558,7 @@ /// Number lambda for linkage purposes if necessary. void handleLambdaNumbering( CXXRecordDecl *Class, CXXMethodDecl *Method, - Optional> Mangling = None); + Optional> Mangling = None); /// Endow the lambda scope info with the relevant properties. void buildLambdaScope(sema::LambdaScopeInfo *LSI, diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp --- a/clang/lib/AST/ASTImporter.cpp +++ b/clang/lib/AST/ASTImporter.cpp @@ -2848,6 +2848,8 @@ return CDeclOrErr.takeError(); D2CXX->setLambdaMangling(DCXX->getLambdaManglingNumber(), *CDeclOrErr, DCXX->hasKnownLambdaInternalLinkage()); + D2CXX->setDeviceLambdaManglingNumber( + DCXX->getDeviceLambdaManglingNumber()); } else if (DCXX->isInjectedClassName()) { // We have to be careful to do a similar dance to the one in // Sema::ActOnStartCXXMemberDeclarations diff --git a/clang/lib/AST/CXXABI.h b/clang/lib/AST/CXXABI.h --- a/clang/lib/AST/CXXABI.h +++ b/clang/lib/AST/CXXABI.h @@ -22,8 +22,9 @@ class CXXConstructorDecl; class DeclaratorDecl; class Expr; -class MemberPointerType; +class MangleContext; class MangleNumberingContext; +class MemberPointerType; /// Implements C++ ABI-specific semantic analysis functions. class CXXABI { @@ -75,6 +76,8 @@ /// Creates an instance of a C++ ABI class. CXXABI *CreateItaniumCXXABI(ASTContext &Ctx); CXXABI *CreateMicrosoftCXXABI(ASTContext &Ctx); +std::unique_ptr +createItaniumNumberingContext(MangleContext *); } #endif diff --git a/clang/lib/AST/DeclCXX.cpp b/clang/lib/AST/DeclCXX.cpp --- a/clang/lib/AST/DeclCXX.cpp +++ b/clang/lib/AST/DeclCXX.cpp @@ -1593,6 +1593,20 @@ return getLambdaData().ContextDecl.get(Source); } +void CXXRecordDecl::setDeviceLambdaManglingNumber(unsigned Num) const { + assert(isLambda() && "Not a lambda closure type!"); + if (Num) + getASTContext().DeviceLambdaManglingNumbers[this] = Num; +} + +unsigned CXXRecordDecl::getDeviceLambdaManglingNumber() const { + assert(isLambda() && "Not a lambda closure type!"); + auto I = getASTContext().DeviceLambdaManglingNumbers.find(this); + if (I != getASTContext().DeviceLambdaManglingNumbers.end()) + return I->second; + return 0; +} + static CanQualType GetConversionType(ASTContext &Context, NamedDecl *Conv) { QualType T = cast(Conv->getUnderlyingDecl()->getAsFunction()) diff --git a/clang/lib/AST/ItaniumCXXABI.cpp b/clang/lib/AST/ItaniumCXXABI.cpp --- a/clang/lib/AST/ItaniumCXXABI.cpp +++ b/clang/lib/AST/ItaniumCXXABI.cpp @@ -258,3 +258,9 @@ CXXABI *clang::CreateItaniumCXXABI(ASTContext &Ctx) { return new ItaniumCXXABI(Ctx); } + +std::unique_ptr +clang::createItaniumNumberingContext(MangleContext *Mangler) { + return std::make_unique( + cast(Mangler)); +} diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -125,6 +125,8 @@ llvm::DenseMap Discriminator; llvm::DenseMap Uniquifier; + bool IsDevCtx = false; + public: explicit ItaniumMangleContextImpl(ASTContext &Context, DiagnosticsEngine &Diags) @@ -137,6 +139,10 @@ bool shouldMangleStringLiteral(const StringLiteral *) override { return false; } + + bool isDeviceMangleContext() const override { return IsDevCtx; } + void setDeviceMangleContext(bool IsDev) override { IsDevCtx = IsDev; } + void mangleCXXName(GlobalDecl GD, raw_ostream &) override; void mangleThunk(const CXXMethodDecl *MD, const ThunkInfo &Thunk, raw_ostream &) override; @@ -1876,7 +1882,15 @@ // (in lexical order) with that same and context. // // The AST keeps track of the number for us. - unsigned Number = Lambda->getLambdaManglingNumber(); + // + // In CUDA/HIP, to ensure the consistent lamba numbering between the device- + // and host-side compilations, an extra device mangle context may be created + // if the host-side CXX ABI has different numbering for lambda. In such case, + // if the mangle context is that device-side one, use the device-side lambda + // mangling number for this lambda. + unsigned Number = Context.isDeviceMangleContext() + ? Lambda->getDeviceLambdaManglingNumber() + : Lambda->getLambdaManglingNumber(); assert(Number > 0 && "Lambda should be mangled as an unnamed class"); if (Number > 1) mangleNumber(Number - 2); diff --git a/clang/lib/AST/MicrosoftCXXABI.cpp b/clang/lib/AST/MicrosoftCXXABI.cpp --- a/clang/lib/AST/MicrosoftCXXABI.cpp +++ b/clang/lib/AST/MicrosoftCXXABI.cpp @@ -16,6 +16,7 @@ #include "clang/AST/Attr.h" #include "clang/AST/CXXInheritance.h" #include "clang/AST/DeclCXX.h" +#include "clang/AST/Mangle.h" #include "clang/AST/MangleNumberingContext.h" #include "clang/AST/RecordLayout.h" #include "clang/AST/Type.h" @@ -64,6 +65,19 @@ } }; +class MSHIPNumberingContext : public MicrosoftNumberingContext { + std::unique_ptr DeviceCtx; + +public: + MSHIPNumberingContext(MangleContext *Mangler) { + DeviceCtx = createItaniumNumberingContext(Mangler); + } + + unsigned getDeviceManglingNumber(const CXXMethodDecl *CallOperator) override { + return DeviceCtx->getManglingNumber(CallOperator); + } +}; + class MicrosoftCXXABI : public CXXABI { ASTContext &Context; llvm::SmallDenseMap RecordToCopyCtor; @@ -73,8 +87,19 @@ llvm::SmallDenseMap UnnamedTagDeclToTypedefNameDecl; + // MangleContext for device numbering context, which is based on Itanium C++ + // ABI. + std::unique_ptr Mangler; + public: - MicrosoftCXXABI(ASTContext &Ctx) : Context(Ctx) { } + MicrosoftCXXABI(ASTContext &Ctx) : Context(Ctx) { + if (Context.getLangOpts().CUDA) { + assert(Context.getTargetInfo().getCXXABI().isMicrosoft() && + Context.getAuxTargetInfo()->getCXXABI().isItaniumFamily() && + "Unexpected combination of C++ ABIs."); + Mangler.reset(Context.createMangleContext(Context.getAuxTargetInfo())); + } + } MemberPointerInfo getMemberPointerInfo(const MemberPointerType *MPT) const override; @@ -133,6 +158,8 @@ std::unique_ptr createMangleNumberingContext() const override { + if (Context.getLangOpts().CUDA) + return std::make_unique(Mangler.get()); return std::make_unique(); } }; @@ -266,4 +293,3 @@ CXXABI *clang::CreateMicrosoftCXXABI(ASTContext &Ctx) { return new MicrosoftCXXABI(Ctx); } - diff --git a/clang/lib/CodeGen/CGCUDANV.cpp b/clang/lib/CodeGen/CGCUDANV.cpp --- a/clang/lib/CodeGen/CGCUDANV.cpp +++ b/clang/lib/CodeGen/CGCUDANV.cpp @@ -190,6 +190,12 @@ CharPtrTy = llvm::PointerType::getUnqual(Types.ConvertType(Ctx.CharTy)); VoidPtrTy = cast(Types.ConvertType(Ctx.VoidPtrTy)); VoidPtrPtrTy = VoidPtrTy->getPointerTo(); + // If the host and device have different C++ ABIs, mark it as the device + // mangle context so that the mangling needs to retrieve the additonal device + // lambda mangling number instead of the regular host one. + DeviceMC->setDeviceMangleContext( + CGM.getContext().getTargetInfo().getCXXABI().isMicrosoft() && + CGM.getContext().getAuxTargetInfo()->getCXXABI().isItaniumFamily()); } llvm::FunctionCallee CGNVCUDARuntime::getSetupArgumentFn() const { diff --git a/clang/lib/Sema/SemaLambda.cpp b/clang/lib/Sema/SemaLambda.cpp --- a/clang/lib/Sema/SemaLambda.cpp +++ b/clang/lib/Sema/SemaLambda.cpp @@ -429,15 +429,16 @@ void Sema::handleLambdaNumbering( CXXRecordDecl *Class, CXXMethodDecl *Method, - Optional> Mangling) { + Optional> Mangling) { if (Mangling) { - unsigned ManglingNumber; bool HasKnownInternalLinkage; + unsigned ManglingNumber, DeviceManglingNumber; Decl *ManglingContextDecl; - std::tie(ManglingNumber, HasKnownInternalLinkage, ManglingContextDecl) = - Mangling.getValue(); + std::tie(HasKnownInternalLinkage, ManglingNumber, DeviceManglingNumber, + ManglingContextDecl) = Mangling.getValue(); Class->setLambdaMangling(ManglingNumber, ManglingContextDecl, HasKnownInternalLinkage); + Class->setDeviceLambdaManglingNumber(DeviceManglingNumber); return; } @@ -473,6 +474,7 @@ unsigned ManglingNumber = MCtx->getManglingNumber(Method); Class->setLambdaMangling(ManglingNumber, ManglingContextDecl, HasKnownInternalLinkage); + Class->setDeviceLambdaManglingNumber(MCtx->getDeviceManglingNumber(Method)); } } diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -12504,10 +12504,11 @@ E->getCaptureDefault()); getDerived().transformedLocalDecl(OldClass, {Class}); - Optional> Mangling; + Optional> Mangling; if (getDerived().ReplacingOriginal()) - Mangling = std::make_tuple(OldClass->getLambdaManglingNumber(), - OldClass->hasKnownLambdaInternalLinkage(), + Mangling = std::make_tuple(OldClass->hasKnownLambdaInternalLinkage(), + OldClass->getLambdaManglingNumber(), + OldClass->getDeviceLambdaManglingNumber(), OldClass->getLambdaContextDecl()); // Build the call operator. diff --git a/clang/lib/Serialization/ASTReaderDecl.cpp b/clang/lib/Serialization/ASTReaderDecl.cpp --- a/clang/lib/Serialization/ASTReaderDecl.cpp +++ b/clang/lib/Serialization/ASTReaderDecl.cpp @@ -1748,6 +1748,7 @@ Lambda.NumExplicitCaptures = Record.readInt(); Lambda.HasKnownInternalLinkage = Record.readInt(); Lambda.ManglingNumber = Record.readInt(); + D->setDeviceLambdaManglingNumber(Record.readInt()); Lambda.ContextDecl = readDeclID(); Lambda.Captures = (Capture *)Reader.getContext().Allocate( sizeof(Capture) * Lambda.NumCaptures); diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -5663,6 +5663,7 @@ Record->push_back(Lambda.NumExplicitCaptures); Record->push_back(Lambda.HasKnownInternalLinkage); Record->push_back(Lambda.ManglingNumber); + Record->push_back(D->getDeviceLambdaManglingNumber()); AddDeclRef(D->getLambdaContextDecl()); AddTypeSourceInfo(Lambda.MethodTyInfo); for (unsigned I = 0, N = Lambda.NumCaptures; I != N; ++I) { diff --git a/clang/test/CodeGenCUDA/ms-linker-options.cu b/clang/test/CodeGenCUDA/ms-linker-options.cu --- a/clang/test/CodeGenCUDA/ms-linker-options.cu +++ b/clang/test/CodeGenCUDA/ms-linker-options.cu @@ -2,12 +2,12 @@ // RUN: -fno-autolink -triple amdgcn-amd-amdhsa \ // RUN: | FileCheck -check-prefix=DEV %s // RUN: %clang_cc1 -emit-llvm -o - -fms-extensions -x hip %s -triple \ -// RUN: x86_64-pc-windows-msvc | FileCheck -check-prefix=HOST %s +// RUN: x86_64-pc-windows-msvc -aux-triple amdgcn | FileCheck -check-prefix=HOST %s // RUN: %clang_cc1 -emit-llvm -o - -fcuda-is-device -fms-extensions %s \ // RUN: -fno-autolink -triple amdgcn-amd-amdhsa \ // RUN: | FileCheck -check-prefix=DEV %s // RUN: %clang_cc1 -emit-llvm -o - -fms-extensions %s -triple \ -// RUN: x86_64-pc-windows-msvc | FileCheck -check-prefix=HOST %s +// RUN: x86_64-pc-windows-msvc -aux-triple amdgcn | FileCheck -check-prefix=HOST %s // DEV-NOT: llvm.linker.options // DEV-NOT: llvm.dependent-libraries diff --git a/clang/test/CodeGenCUDA/unnamed-types.cu b/clang/test/CodeGenCUDA/unnamed-types.cu --- a/clang/test/CodeGenCUDA/unnamed-types.cu +++ b/clang/test/CodeGenCUDA/unnamed-types.cu @@ -1,12 +1,17 @@ // RUN: %clang_cc1 -std=c++11 -x hip -triple x86_64-linux-gnu -aux-triple amdgcn-amd-amdhsa -emit-llvm %s -o - | FileCheck %s --check-prefix=HOST +// RUN: %clang_cc1 -std=c++11 -x hip -triple x86_64-pc-windows-msvc -aux-triple amdgcn-amd-amdhsa -emit-llvm %s -o - | FileCheck %s --check-prefix=MSVC // RUN: %clang_cc1 -std=c++11 -x hip -triple amdgcn-amd-amdhsa -fcuda-is-device -emit-llvm %s -o - | FileCheck %s --check-prefix=DEVICE #include "Inputs/cuda.h" // HOST: @0 = private unnamed_addr constant [43 x i8] c"_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_\00", align 1 +// HOST: @1 = private unnamed_addr constant [60 x i8] c"_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_\00", align 1 +// Check that, on MSVC, the same device kernel mangling name is generated. +// MSVC: @0 = private unnamed_addr constant [43 x i8] c"_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_\00", align 1 +// MSVC: @1 = private unnamed_addr constant [60 x i8] c"_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_\00", align 1 __device__ float d0(float x) { - return [](float x) { return x + 2.f; }(x); + return [](float x) { return x + 1.f; }(x); } __device__ float d1(float x) { @@ -14,11 +19,21 @@ } // DEVICE: amdgpu_kernel void @_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_( +// DEVICE: define internal float @_ZZZ2f1PfENKUlS_E_clES_ENKUlfE_clEf( template __global__ void k0(float *p, F f) { p[0] = f(p[0]) + d0(p[1]) + d1(p[2]); } +// DEVICE: amdgpu_kernel void @_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_( +// DEVICE: define internal float @_ZZ2f1PfENKUlfE_clEf( +// DEVICE: define internal float @_ZZ2f1PfENKUlffE_clEff( +// DEVICE: define internal float @_ZZ2f1PfENKUlfE0_clEf( +template +__global__ void k1(float *p, F0 f0, F1 f1, F2 f2) { + p[0] = f0(p[0]) + f1(p[1], p[2]) + f2(p[3]); +} + void f0(float *p) { [](float *p) { *p = 1.f; @@ -29,11 +44,17 @@ // linkages are still required to keep the original `internal` linkage. // HOST: define internal void @_ZZ2f1PfENKUlS_E_clES_( -// DEVICE: define internal float @_ZZZ2f1PfENKUlS_E_clES_ENKUlfE_clEf( void f1(float *p) { [](float *p) { - k0<<<1,1>>>(p, [] __device__ (float x) { return x + 1.f; }); + k0<<<1,1>>>(p, [] __device__ (float x) { return x + 3.f; }); }(p); + k1<<<1,1>>>(p, + [] __device__ (float x) { return x + 4.f; }, + [] __device__ (float x, float y) { return x * y; }, + [] __device__ (float x) { return x + 5.f; }); } // HOST: @__hip_register_globals // HOST: __hipRegisterFunction{{.*}}@_Z17__device_stub__k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_{{.*}}@0 +// HOST: __hipRegisterFunction{{.*}}@_Z17__device_stub__k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_{{.*}}@1 +// MSVC: __hipRegisterFunction{{.*}}@"??$k0@V@?0???R1?0??f1@@YAXPEAM@Z@QEBA@0@Z@@@YAXPEAMV@?0???R0?0??f1@@YAX0@Z@QEBA@0@Z@@Z{{.*}}@0 +// MSVC: __hipRegisterFunction{{.*}}@"??$k1@V@?0??f1@@YAXPEAM@Z@V@?0??2@YAX0@Z@V@?0??2@YAX0@Z@@@YAXPEAMV@?0??f1@@YAX0@Z@V@?0??1@YAX0@Z@V@?0??1@YAX0@Z@@Z{{.*}}@1