diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h --- a/clang/lib/CodeGen/CGOpenMPRuntime.h +++ b/clang/lib/CodeGen/CGOpenMPRuntime.h @@ -1439,9 +1439,9 @@ bool SeparateBeginEndCalls) : llvm::OpenMPIRBuilder::TargetDataInfo(RequiresDevicePointerInfo, SeparateBeginEndCalls) {} - /// Map between the a declaration of a capture and the corresponding base - /// pointer address where the runtime returns the device pointers. - llvm::DenseMap CaptureDeviceAddrMap; + /// Map between the a declaration of a capture and the corresponding new + /// alloca address where the runtime returns the device pointers. + llvm::DenseMap CaptureDeviceAddrMap; }; /// Emit the target data mapping code associated with \a D. diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -6821,6 +6821,7 @@ const Expr *getMapExpr() const { return MapExpr; } }; + using DeviceInfoTy = llvm::OpenMPIRBuilder::DeviceInfoTy; using MapBaseValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy; using MapValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy; using MapFlagsArrayTy = llvm::OpenMPIRBuilder::MapFlagsArrayTy; @@ -7592,6 +7593,7 @@ CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr); CombinedInfo.BasePointers.push_back(BP.getPointer()); CombinedInfo.DevicePtrDecls.push_back(nullptr); + CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None); CombinedInfo.Pointers.push_back(LB.getPointer()); CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast( Size, CGF.Int64Ty, /*isSigned=*/true)); @@ -7604,6 +7606,7 @@ CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr); CombinedInfo.BasePointers.push_back(BP.getPointer()); CombinedInfo.DevicePtrDecls.push_back(nullptr); + CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None); CombinedInfo.Pointers.push_back(LB.getPointer()); Size = CGF.Builder.CreatePtrDiff( CGF.Int8Ty, CGF.Builder.CreateConstGEP(HB, 1).getPointer(), @@ -7622,6 +7625,7 @@ CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr); CombinedInfo.BasePointers.push_back(BP.getPointer()); CombinedInfo.DevicePtrDecls.push_back(nullptr); + CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None); CombinedInfo.Pointers.push_back(LB.getPointer()); CombinedInfo.Sizes.push_back( CGF.Builder.CreateIntCast(Size, CGF.Int64Ty, /*isSigned=*/true)); @@ -8122,10 +8126,12 @@ auto &&UseDeviceDataCombinedInfoGen = [&UseDeviceDataCombinedInfo](const ValueDecl *VD, llvm::Value *Ptr, - CodeGenFunction &CGF) { + CodeGenFunction &CGF, bool IsDevAddr) { UseDeviceDataCombinedInfo.Exprs.push_back(VD); UseDeviceDataCombinedInfo.BasePointers.emplace_back(Ptr); UseDeviceDataCombinedInfo.DevicePtrDecls.emplace_back(VD); + UseDeviceDataCombinedInfo.DevicePointers.emplace_back( + IsDevAddr ? DeviceInfoTy::Address : DeviceInfoTy::Pointer); UseDeviceDataCombinedInfo.Pointers.push_back(Ptr); UseDeviceDataCombinedInfo.Sizes.push_back( llvm::Constant::getNullValue(CGF.Int64Ty)); @@ -8165,7 +8171,7 @@ } else { Ptr = CGF.EmitLoadOfScalar(CGF.EmitLValue(IE), IE->getExprLoc()); } - UseDeviceDataCombinedInfoGen(VD, Ptr, CGF); + UseDeviceDataCombinedInfoGen(VD, Ptr, CGF, IsDevAddr); } }; @@ -8192,6 +8198,7 @@ // item. if (CI != Data.end()) { if (IsDevAddr) { + CI->ForDeviceAddr = IsDevAddr; CI->ReturnDevicePointer = true; Found = true; break; @@ -8204,6 +8211,7 @@ PrevCI == CI->Components.rend() || isa(PrevCI->getAssociatedExpression()) || !VarD || VarD->hasLocalStorage()) { + CI->ForDeviceAddr = IsDevAddr; CI->ReturnDevicePointer = true; Found = true; break; @@ -8295,6 +8303,8 @@ "No relevant declaration related with device pointer??"); CurInfo.DevicePtrDecls[CurrentBasePointersIdx] = RelevantVD; + CurInfo.DevicePointers[CurrentBasePointersIdx] = + L.ForDeviceAddr ? DeviceInfoTy::Address : DeviceInfoTy::Pointer; CurInfo.Types[CurrentBasePointersIdx] |= OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; } @@ -8335,6 +8345,8 @@ CurInfo.Exprs.push_back(L.VD); CurInfo.BasePointers.emplace_back(BasePtr); CurInfo.DevicePtrDecls.emplace_back(L.VD); + CurInfo.DevicePointers.emplace_back( + L.ForDeviceAddr ? DeviceInfoTy::Address : DeviceInfoTy::Pointer); CurInfo.Pointers.push_back(Ptr); CurInfo.Sizes.push_back( llvm::Constant::getNullValue(this->CGF.Int64Ty)); @@ -8430,6 +8442,7 @@ // Base is the base of the struct CombinedInfo.BasePointers.push_back(PartialStruct.Base.getPointer()); CombinedInfo.DevicePtrDecls.push_back(nullptr); + CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None); // Pointer is the address of the lowest element llvm::Value *LB = LBAddr.getPointer(); const CXXMethodDecl *MD = @@ -8552,6 +8565,7 @@ CombinedInfo.Exprs.push_back(VD); CombinedInfo.BasePointers.push_back(ThisLVal.getPointer(CGF)); CombinedInfo.DevicePtrDecls.push_back(nullptr); + CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None); CombinedInfo.Pointers.push_back(ThisLValVal.getPointer(CGF)); CombinedInfo.Sizes.push_back( CGF.Builder.CreateIntCast(CGF.getTypeSize(CGF.getContext().VoidPtrTy), @@ -8579,6 +8593,7 @@ CombinedInfo.Exprs.push_back(VD); CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF)); CombinedInfo.DevicePtrDecls.push_back(nullptr); + CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None); CombinedInfo.Pointers.push_back(VarLValVal.getPointer(CGF)); CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast( CGF.getTypeSize( @@ -8591,6 +8606,7 @@ CombinedInfo.Exprs.push_back(VD); CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF)); CombinedInfo.DevicePtrDecls.push_back(nullptr); + CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None); CombinedInfo.Pointers.push_back(VarRVal.getScalarVal()); CombinedInfo.Sizes.push_back(llvm::ConstantInt::get(CGF.Int64Ty, 0)); } @@ -8659,6 +8675,7 @@ CombinedInfo.Exprs.push_back(VD); CombinedInfo.BasePointers.emplace_back(Arg); CombinedInfo.DevicePtrDecls.emplace_back(VD); + CombinedInfo.DevicePointers.emplace_back(DeviceInfoTy::Pointer); CombinedInfo.Pointers.push_back(Arg); CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast( CGF.getTypeSize(CGF.getContext().VoidPtrTy), CGF.Int64Ty, @@ -8899,6 +8916,7 @@ CombinedInfo.Exprs.push_back(nullptr); CombinedInfo.BasePointers.push_back(CV); CombinedInfo.DevicePtrDecls.push_back(nullptr); + CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None); CombinedInfo.Pointers.push_back(CV); const auto *PtrTy = cast(RI.getType().getTypePtr()); CombinedInfo.Sizes.push_back( @@ -8912,6 +8930,7 @@ CombinedInfo.Exprs.push_back(VD->getCanonicalDecl()); CombinedInfo.BasePointers.push_back(CV); CombinedInfo.DevicePtrDecls.push_back(nullptr); + CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None); CombinedInfo.Pointers.push_back(CV); if (!RI.getType()->isAnyPointerType()) { // We have to signal to the runtime captures passed by value that are @@ -8944,6 +8963,7 @@ CombinedInfo.Exprs.push_back(VD->getCanonicalDecl()); CombinedInfo.BasePointers.push_back(CV); CombinedInfo.DevicePtrDecls.push_back(nullptr); + CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None); if (I != FirstPrivateDecls.end() && ElementType->isAnyPointerType()) { Address PtrAddr = CGF.EmitLoadOfReference(CGF.MakeAddrLValue( CV, ElementType, CGF.getContext().getDeclAlign(VD), @@ -9025,7 +9045,6 @@ CGOpenMPRuntime::TargetDataInfo &Info, llvm::OpenMPIRBuilder &OMPBuilder, bool IsNonContiguous = false) { CodeGenModule &CGM = CGF.CGM; - ASTContext &Ctx = CGF.getContext(); // Reset the array information. Info.clearArrayInfo(); @@ -9047,11 +9066,9 @@ FillInfoMap); } - auto DeviceAddrCB = [&](unsigned int I, llvm::Value *BP, llvm::Value *BPVal) { + auto DeviceAddrCB = [&](unsigned int I, llvm::Value *NewDecl) { if (const ValueDecl *DevVD = CombinedInfo.DevicePtrDecls[I]) { - Address BPAddr(BP, BPVal->getType(), - Ctx.getTypeAlignInChars(Ctx.VoidPtrTy)); - Info.CaptureDeviceAddrMap.try_emplace(DevVD, BPAddr); + Info.CaptureDeviceAddrMap.try_emplace(DevVD, NewDecl); } }; @@ -9775,6 +9792,8 @@ CurInfo.Exprs.push_back(nullptr); CurInfo.BasePointers.push_back(*CV); CurInfo.DevicePtrDecls.push_back(nullptr); + CurInfo.DevicePointers.push_back( + MappableExprsHandler::DeviceInfoTy::None); CurInfo.Pointers.push_back(*CV); CurInfo.Sizes.push_back(CGF.Builder.CreateIntCast( CGF.getTypeSize(RI->getType()), CGF.Int64Ty, /*isSigned=*/true)); @@ -10454,12 +10473,9 @@ CGF.Builder.GetInsertPoint()); }; - auto DeviceAddrCB = [&](unsigned int I, llvm::Value *BP, llvm::Value *BPVal) { + auto DeviceAddrCB = [&](unsigned int I, llvm::Value *NewDecl) { if (const ValueDecl *DevVD = CombinedInfo.DevicePtrDecls[I]) { - ASTContext &Ctx = CGF.getContext(); - Address BPAddr(BP, BPVal->getType(), - Ctx.getTypeAlignInChars(Ctx.VoidPtrTy)); - Info.CaptureDeviceAddrMap.try_emplace(DevVD, BPAddr); + Info.CaptureDeviceAddrMap.try_emplace(DevVD, NewDecl); } }; diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -7158,14 +7158,13 @@ void CodeGenFunction::EmitOMPUseDevicePtrClause( const OMPUseDevicePtrClause &C, OMPPrivateScope &PrivateScope, - const llvm::DenseMap &CaptureDeviceAddrMap) { - auto OrigVarIt = C.varlist_begin(); - auto InitIt = C.inits().begin(); - for (const Expr *PvtVarIt : C.private_copies()) { - const auto *OrigVD = - cast(cast(*OrigVarIt)->getDecl()); - const auto *InitVD = cast(cast(*InitIt)->getDecl()); - const auto *PvtVD = cast(cast(PvtVarIt)->getDecl()); + const llvm::DenseMap + CaptureDeviceAddrMap) { + llvm::SmallDenseSet, 4> Processed; + for (const Expr *OrigVarIt : C.varlists()) { + const auto *OrigVD = cast(cast(OrigVarIt)->getDecl()); + if (!Processed.insert(OrigVD).second) + continue; // In order to identify the right initializer we need to match the // declaration used by the mapping logic. In some cases we may get @@ -7186,32 +7185,16 @@ if (InitAddrIt == CaptureDeviceAddrMap.end()) continue; - // Initialize the temporary initialization variable with the address - // we get from the runtime library. We have to cast the source address - // because it is always a void *. References are materialized in the - // privatization scope, so the initialization here disregards the fact - // the original variable is a reference. llvm::Type *Ty = ConvertTypeForMem(OrigVD->getType().getNonReferenceType()); - Address InitAddr = Builder.CreateElementBitCast(InitAddrIt->second, Ty); - setAddrOfLocalVar(InitVD, InitAddr); - - // Emit private declaration, it will be initialized by the value we - // declaration we just added to the local declarations map. - EmitDecl(*PvtVD); - - // The initialization variables reached its purpose in the emission - // of the previous declaration, so we don't need it anymore. - LocalDeclMap.erase(InitVD); // Return the address of the private variable. - bool IsRegistered = - PrivateScope.addPrivate(OrigVD, GetAddrOfLocalVar(PvtVD)); + bool IsRegistered = PrivateScope.addPrivate( + OrigVD, + Address(InitAddrIt->second, Ty, + getContext().getTypeAlignInChars(getContext().VoidPtrTy))); assert(IsRegistered && "firstprivate var already registered as private"); // Silence the warning about unused variable. (void)IsRegistered; - - ++OrigVarIt; - ++InitIt; } } @@ -7226,7 +7209,8 @@ void CodeGenFunction::EmitOMPUseDeviceAddrClause( const OMPUseDeviceAddrClause &C, OMPPrivateScope &PrivateScope, - const llvm::DenseMap &CaptureDeviceAddrMap) { + const llvm::DenseMap + CaptureDeviceAddrMap) { llvm::SmallDenseSet, 4> Processed; for (const Expr *Ref : C.varlists()) { const VarDecl *OrigVD = getBaseDecl(Ref); @@ -7251,7 +7235,11 @@ if (InitAddrIt == CaptureDeviceAddrMap.end()) continue; - Address PrivAddr = InitAddrIt->getSecond(); + llvm::Type *Ty = ConvertTypeForMem(OrigVD->getType().getNonReferenceType()); + + Address PrivAddr = + Address(InitAddrIt->second, Ty, + getContext().getTypeAlignInChars(getContext().VoidPtrTy)); // For declrefs and variable length array need to load the pointer for // correct mapping, since the pointer to the data was passed to the runtime. if (isa(Ref->IgnoreParenImpCasts()) || diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -3404,10 +3404,12 @@ OMPPrivateScope &PrivateScope); void EmitOMPUseDevicePtrClause( const OMPUseDevicePtrClause &C, OMPPrivateScope &PrivateScope, - const llvm::DenseMap &CaptureDeviceAddrMap); + const llvm::DenseMap + CaptureDeviceAddrMap); void EmitOMPUseDeviceAddrClause( const OMPUseDeviceAddrClause &C, OMPPrivateScope &PrivateScope, - const llvm::DenseMap &CaptureDeviceAddrMap); + const llvm::DenseMap + CaptureDeviceAddrMap); /// Emit code for copyin clause in \a D directive. The next code is /// generated at the start of outlined functions for directives: /// \code diff --git a/clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp b/clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp --- a/clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp +++ b/clang/test/OpenMP/target_data_use_device_ptr_codegen.cpp @@ -411,11 +411,11 @@ // CK2: [[BP2:%.+]] = getelementptr inbounds [3 x ptr], ptr %{{.+}}, i32 0, i32 2 // CK2: store ptr [[RVAL2:%.+]], ptr [[BP2]], // CK2: call void @__tgt_target_data_begin{{.+}}[[MTYPE03]] + // CK2: [[VAL1:%.+]] = load ptr, ptr [[BP1]], + // CK2: store ptr [[VAL1]], ptr [[PVT1:%.+]], // CK2: [[VAL2:%.+]] = load ptr, ptr [[BP2]], // CK2: store ptr [[VAL2]], ptr [[PVT2:%.+]], // CK2: store ptr [[PVT2]], ptr [[_PVT2:%.+]], - // CK2: [[VAL1:%.+]] = load ptr, ptr [[BP1]], - // CK2: store ptr [[VAL1]], ptr [[PVT1:%.+]], // CK2: store ptr [[PVT1]], ptr [[_PVT1:%.+]], // CK2: [[TT2:%.+]] = load ptr, ptr [[_PVT2]], // CK2: [[_TT2:%.+]] = load ptr, ptr [[TT2]], diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1568,6 +1568,9 @@ public: TargetDataRTArgs RTArgs; + SmallMapVector, 4> + DevicePtrInfoMap; + /// Indicate whether any user-defined mapper exists. bool HasMapper = false; /// The total number of pointers passed to the runtime library. @@ -1594,7 +1597,9 @@ bool separateBeginEndCalls() { return SeparateBeginEndCalls; } }; + enum class DeviceInfoTy { None, Pointer, Address }; using MapValuesArrayTy = SmallVector; + using MapDeviceInfoArrayTy = SmallVector; using MapFlagsArrayTy = SmallVector; using MapNamesArrayTy = SmallVector; using MapDimArrayTy = SmallVector; @@ -1613,6 +1618,7 @@ }; MapValuesArrayTy BasePointers; MapValuesArrayTy Pointers; + MapDeviceInfoArrayTy DevicePointers; MapValuesArrayTy Sizes; MapFlagsArrayTy Types; MapNamesArrayTy Names; @@ -1623,6 +1629,8 @@ BasePointers.append(CurInfo.BasePointers.begin(), CurInfo.BasePointers.end()); Pointers.append(CurInfo.Pointers.begin(), CurInfo.Pointers.end()); + DevicePointers.append(CurInfo.DevicePointers.begin(), + CurInfo.DevicePointers.end()); Sizes.append(CurInfo.Sizes.begin(), CurInfo.Sizes.end()); Types.append(CurInfo.Types.begin(), CurInfo.Types.end()); Names.append(CurInfo.Names.begin(), CurInfo.Names.end()); @@ -1659,7 +1667,7 @@ void emitOffloadingArrays( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo, TargetDataInfo &Info, bool IsNonContiguous = false, - function_ref DeviceAddrCB = nullptr, + function_ref DeviceAddrCB = nullptr, function_ref CustomMapperCB = nullptr); /// Creates offloading entry for the provided entry ID \a ID, address \a @@ -2046,7 +2054,7 @@ function_ref BodyGenCB = nullptr, - function_ref DeviceAddrCB = nullptr, + function_ref DeviceAddrCB = nullptr, function_ref CustomMapperCB = nullptr); using TargetBodyGenCallbackTy = function_ref BodyGenCB, - function_ref DeviceAddrCB, + function_ref DeviceAddrCB, function_ref CustomMapperCB) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -4127,6 +4127,14 @@ Builder.CreateCall(BeginMapperFunc, OffloadingArgs); + for (auto DeviceMap : Info.DevicePtrInfoMap) { + if (isa(DeviceMap.second.second)) { + auto *LI = + Builder.CreateLoad(Builder.getPtrTy(), DeviceMap.second.first); + Builder.CreateStore(LI, DeviceMap.second.second); + } + } + // If device pointer privatization is required, emit the body of the // region here. It will have to be duplicated: with and without // privatization. @@ -4538,7 +4546,7 @@ void OpenMPIRBuilder::emitOffloadingArrays( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo, TargetDataInfo &Info, bool IsNonContiguous, - function_ref DeviceAddrCB, + function_ref DeviceAddrCB, function_ref CustomMapperCB) { // Reset the array information. @@ -4673,9 +4681,21 @@ BPVal, BP, M.getDataLayout().getPrefTypeAlign(Builder.getInt8PtrTy())); if (Info.requiresDevicePointerInfo()) { - assert(DeviceAddrCB && - "DeviceAddrCB missing for DevicePtr code generation"); - DeviceAddrCB(I, BP, BPVal); + if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) { + CodeGenIP = Builder.saveIP(); + Builder.restoreIP(AllocaIP); + Info.DevicePtrInfoMap[BPVal] = { + BP, Builder.CreateAlloca(Builder.getPtrTy())}; + Builder.restoreIP(CodeGenIP); + assert(DeviceAddrCB && + "DeviceAddrCB missing for DevicePtr code generation"); + DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second); + } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) { + Info.DevicePtrInfoMap[BPVal] = {BP, BP}; + assert(DeviceAddrCB && + "DeviceAddrCB missing for DevicePtr code generation"); + DeviceAddrCB(I, BP); + } } Value *PVal = CombinedInfo.Pointers[I];