diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -6831,67 +6831,30 @@ const Expr *getMapExpr() const { return MapExpr; } }; - /// Class that associates information with a base pointer to be passed to the - /// runtime library. - class BasePointerInfo { - /// The base pointer. - llvm::Value *Ptr = nullptr; - /// The base declaration that refers to this device pointer, or null if - /// there is none. - const ValueDecl *DevPtrDecl = nullptr; - - public: - BasePointerInfo(llvm::Value *Ptr, const ValueDecl *DevPtrDecl = nullptr) - : Ptr(Ptr), DevPtrDecl(DevPtrDecl) {} - llvm::Value *operator*() const { return Ptr; } - const ValueDecl *getDevicePtrDecl() const { return DevPtrDecl; } - void setDevicePtrDecl(const ValueDecl *D) { DevPtrDecl = D; } - }; - + using MapBaseValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy; + using MapValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy; + using MapFlagsArrayTy = llvm::OpenMPIRBuilder::MapFlagsArrayTy; + using MapDimArrayTy = llvm::OpenMPIRBuilder::MapDimArrayTy; + using MapNonContiguousArrayTy = + llvm::OpenMPIRBuilder::MapNonContiguousArrayTy; using MapExprsArrayTy = SmallVector; - using MapBaseValuesArrayTy = SmallVector; - using MapValuesArrayTy = SmallVector; - using MapFlagsArrayTy = SmallVector; - using MapMappersArrayTy = SmallVector; - using MapDimArrayTy = SmallVector; - using MapNonContiguousArrayTy = SmallVector; + using MapValueDeclsArrayTy = SmallVector; /// This structure contains combined information generated for mappable /// clauses, including base pointers, pointers, sizes, map types, user-defined /// mappers, and non-contiguous information. - struct MapCombinedInfoTy { - struct StructNonContiguousInfo { - bool IsNonContiguous = false; - MapDimArrayTy Dims; - MapNonContiguousArrayTy Offsets; - MapNonContiguousArrayTy Counts; - MapNonContiguousArrayTy Strides; - }; + struct MapCombinedInfoTy : llvm::OpenMPIRBuilder::MapInfosTy { MapExprsArrayTy Exprs; - MapBaseValuesArrayTy BasePointers; - MapValuesArrayTy Pointers; - MapValuesArrayTy Sizes; - MapFlagsArrayTy Types; - MapMappersArrayTy Mappers; - StructNonContiguousInfo NonContigInfo; + MapValueDeclsArrayTy Mappers; + MapValueDeclsArrayTy DevicePtrDecls; /// Append arrays in \a CurInfo. void append(MapCombinedInfoTy &CurInfo) { Exprs.append(CurInfo.Exprs.begin(), CurInfo.Exprs.end()); - BasePointers.append(CurInfo.BasePointers.begin(), - CurInfo.BasePointers.end()); - Pointers.append(CurInfo.Pointers.begin(), CurInfo.Pointers.end()); - Sizes.append(CurInfo.Sizes.begin(), CurInfo.Sizes.end()); - Types.append(CurInfo.Types.begin(), CurInfo.Types.end()); + DevicePtrDecls.append(CurInfo.DevicePtrDecls.begin(), + CurInfo.DevicePtrDecls.end()); Mappers.append(CurInfo.Mappers.begin(), CurInfo.Mappers.end()); - NonContigInfo.Dims.append(CurInfo.NonContigInfo.Dims.begin(), - CurInfo.NonContigInfo.Dims.end()); - NonContigInfo.Offsets.append(CurInfo.NonContigInfo.Offsets.begin(), - CurInfo.NonContigInfo.Offsets.end()); - NonContigInfo.Counts.append(CurInfo.NonContigInfo.Counts.begin(), - CurInfo.NonContigInfo.Counts.end()); - NonContigInfo.Strides.append(CurInfo.NonContigInfo.Strides.begin(), - CurInfo.NonContigInfo.Strides.end()); + llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo); } }; @@ -7638,6 +7601,7 @@ assert(Size && "Failed to determine structure size"); CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr); CombinedInfo.BasePointers.push_back(BP.getPointer()); + CombinedInfo.DevicePtrDecls.push_back(nullptr); CombinedInfo.Pointers.push_back(LB.getPointer()); CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast( Size, CGF.Int64Ty, /*isSigned=*/true)); @@ -7649,6 +7613,7 @@ } CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr); CombinedInfo.BasePointers.push_back(BP.getPointer()); + CombinedInfo.DevicePtrDecls.push_back(nullptr); CombinedInfo.Pointers.push_back(LB.getPointer()); Size = CGF.Builder.CreatePtrDiff( CGF.Int8Ty, CGF.Builder.CreateConstGEP(HB, 1).getPointer(), @@ -7666,6 +7631,7 @@ (Next == CE && MapType != OMPC_MAP_unknown)) { CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr); CombinedInfo.BasePointers.push_back(BP.getPointer()); + CombinedInfo.DevicePtrDecls.push_back(nullptr); CombinedInfo.Pointers.push_back(LB.getPointer()); CombinedInfo.Sizes.push_back( CGF.Builder.CreateIntCast(Size, CGF.Int64Ty, /*isSigned=*/true)); @@ -8168,7 +8134,8 @@ [&UseDeviceDataCombinedInfo](const ValueDecl *VD, llvm::Value *Ptr, CodeGenFunction &CGF) { UseDeviceDataCombinedInfo.Exprs.push_back(VD); - UseDeviceDataCombinedInfo.BasePointers.emplace_back(Ptr, VD); + UseDeviceDataCombinedInfo.BasePointers.emplace_back(Ptr); + UseDeviceDataCombinedInfo.DevicePtrDecls.emplace_back(VD); UseDeviceDataCombinedInfo.Pointers.push_back(Ptr); UseDeviceDataCombinedInfo.Sizes.push_back( llvm::Constant::getNullValue(CGF.Int64Ty)); @@ -8337,8 +8304,7 @@ assert(RelevantVD && "No relevant declaration related with device pointer??"); - CurInfo.BasePointers[CurrentBasePointersIdx].setDevicePtrDecl( - RelevantVD); + CurInfo.DevicePtrDecls[CurrentBasePointersIdx] = RelevantVD; CurInfo.Types[CurrentBasePointersIdx] |= OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; } @@ -8377,7 +8343,8 @@ OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF); } CurInfo.Exprs.push_back(L.VD); - CurInfo.BasePointers.emplace_back(BasePtr, L.VD); + CurInfo.BasePointers.emplace_back(BasePtr); + CurInfo.DevicePtrDecls.emplace_back(L.VD); CurInfo.Pointers.push_back(Ptr); CurInfo.Sizes.push_back( llvm::Constant::getNullValue(this->CGF.Int64Ty)); @@ -8472,6 +8439,7 @@ CombinedInfo.Exprs.push_back(VD); // Base is the base of the struct CombinedInfo.BasePointers.push_back(PartialStruct.Base.getPointer()); + CombinedInfo.DevicePtrDecls.push_back(nullptr); // Pointer is the address of the lowest element llvm::Value *LB = LBAddr.getPointer(); const CXXMethodDecl *MD = @@ -8593,6 +8561,7 @@ VDLVal.getPointer(CGF)); CombinedInfo.Exprs.push_back(VD); CombinedInfo.BasePointers.push_back(ThisLVal.getPointer(CGF)); + CombinedInfo.DevicePtrDecls.push_back(nullptr); CombinedInfo.Pointers.push_back(ThisLValVal.getPointer(CGF)); CombinedInfo.Sizes.push_back( CGF.Builder.CreateIntCast(CGF.getTypeSize(CGF.getContext().VoidPtrTy), @@ -8619,6 +8588,7 @@ VDLVal.getPointer(CGF)); CombinedInfo.Exprs.push_back(VD); CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF)); + CombinedInfo.DevicePtrDecls.push_back(nullptr); CombinedInfo.Pointers.push_back(VarLValVal.getPointer(CGF)); CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast( CGF.getTypeSize( @@ -8630,6 +8600,7 @@ VDLVal.getPointer(CGF)); CombinedInfo.Exprs.push_back(VD); CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF)); + CombinedInfo.DevicePtrDecls.push_back(nullptr); CombinedInfo.Pointers.push_back(VarRVal.getScalarVal()); CombinedInfo.Sizes.push_back(llvm::ConstantInt::get(CGF.Int64Ty, 0)); } @@ -8654,7 +8625,7 @@ OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF | OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)) continue; - llvm::Value *BasePtr = LambdaPointers.lookup(*BasePointers[I]); + llvm::Value *BasePtr = LambdaPointers.lookup(BasePointers[I]); assert(BasePtr && "Unable to find base lambda address."); int TgtIdx = -1; for (unsigned J = I; J > 0; --J) { @@ -8696,7 +8667,8 @@ // pass its value. if (VD && (DevPointersMap.count(VD) || HasDevAddrsMap.count(VD))) { CombinedInfo.Exprs.push_back(VD); - CombinedInfo.BasePointers.emplace_back(Arg, VD); + CombinedInfo.BasePointers.emplace_back(Arg); + CombinedInfo.DevicePtrDecls.emplace_back(VD); CombinedInfo.Pointers.push_back(Arg); CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast( CGF.getTypeSize(CGF.getContext().VoidPtrTy), CGF.Int64Ty, @@ -8938,6 +8910,7 @@ if (CI.capturesThis()) { CombinedInfo.Exprs.push_back(nullptr); CombinedInfo.BasePointers.push_back(CV); + CombinedInfo.DevicePtrDecls.push_back(nullptr); CombinedInfo.Pointers.push_back(CV); const auto *PtrTy = cast(RI.getType().getTypePtr()); CombinedInfo.Sizes.push_back( @@ -8950,6 +8923,7 @@ const VarDecl *VD = CI.getCapturedVar(); CombinedInfo.Exprs.push_back(VD->getCanonicalDecl()); CombinedInfo.BasePointers.push_back(CV); + CombinedInfo.DevicePtrDecls.push_back(nullptr); CombinedInfo.Pointers.push_back(CV); if (!RI.getType()->isAnyPointerType()) { // We have to signal to the runtime captures passed by value that are @@ -8981,6 +8955,7 @@ auto I = FirstPrivateDecls.find(VD); CombinedInfo.Exprs.push_back(VD->getCanonicalDecl()); CombinedInfo.BasePointers.push_back(CV); + CombinedInfo.DevicePtrDecls.push_back(nullptr); if (I != FirstPrivateDecls.end() && ElementType->isAnyPointerType()) { Address PtrAddr = CGF.EmitLoadOfReference(CGF.MakeAddrLValue( CV, ElementType, CGF.getContext().getDeclAlign(VD), @@ -9266,7 +9241,7 @@ } for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) { - llvm::Value *BPVal = *CombinedInfo.BasePointers[I]; + llvm::Value *BPVal = CombinedInfo.BasePointers[I]; llvm::Value *BP = CGF.Builder.CreateConstInBoundsGEP2_32( llvm::ArrayType::get(CGM.VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.BasePointersArray, 0, I); @@ -9277,8 +9252,7 @@ CGF.Builder.CreateStore(BPVal, BPAddr); if (Info.requiresDevicePointerInfo()) - if (const ValueDecl *DevVD = - CombinedInfo.BasePointers[I].getDevicePtrDecl()) + if (const ValueDecl *DevVD = CombinedInfo.DevicePtrDecls[I]) Info.CaptureDeviceAddrMap.try_emplace(DevVD, BPAddr); llvm::Value *PVal = CombinedInfo.Pointers[I]; @@ -9592,7 +9566,7 @@ // Fill up the runtime mapper handle for all components. for (unsigned I = 0; I < Info.BasePointers.size(); ++I) { llvm::Value *CurBaseArg = MapperCGF.Builder.CreateBitCast( - *Info.BasePointers[I], CGM.getTypes().ConvertTypeForMem(C.VoidPtrTy)); + Info.BasePointers[I], CGM.getTypes().ConvertTypeForMem(C.VoidPtrTy)); llvm::Value *CurBeginArg = MapperCGF.Builder.CreateBitCast( Info.Pointers[I], CGM.getTypes().ConvertTypeForMem(C.VoidPtrTy)); llvm::Value *CurSizeArg = Info.Sizes[I]; @@ -10028,6 +10002,7 @@ if (CI->capturesVariableArrayType()) { CurInfo.Exprs.push_back(nullptr); CurInfo.BasePointers.push_back(*CV); + CurInfo.DevicePtrDecls.push_back(nullptr); CurInfo.Pointers.push_back(*CV); CurInfo.Sizes.push_back(CGF.Builder.CreateIntCast( CGF.getTypeSize(RI->getType()), CGF.Int64Ty, /*isSigned=*/true)); diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1445,6 +1445,49 @@ bool separateBeginEndCalls() { return SeparateBeginEndCalls; } }; + using MapValuesArrayTy = SmallVector; + using MapFlagsArrayTy = SmallVector; + using MapNamesArrayTy = SmallVector; + using MapDimArrayTy = SmallVector; + using MapNonContiguousArrayTy = SmallVector; + + /// This structure contains combined information generated for mappable + /// clauses, including base pointers, pointers, sizes, map types, user-defined + /// mappers, and non-contiguous information. + struct MapInfosTy { + struct StructNonContiguousInfo { + bool IsNonContiguous = false; + MapDimArrayTy Dims; + MapNonContiguousArrayTy Offsets; + MapNonContiguousArrayTy Counts; + MapNonContiguousArrayTy Strides; + }; + MapValuesArrayTy BasePointers; + MapValuesArrayTy Pointers; + MapValuesArrayTy Sizes; + MapFlagsArrayTy Types; + MapNamesArrayTy Names; + StructNonContiguousInfo NonContigInfo; + + /// Append arrays in \a CurInfo. + void append(MapInfosTy &CurInfo) { + BasePointers.append(CurInfo.BasePointers.begin(), + CurInfo.BasePointers.end()); + Pointers.append(CurInfo.Pointers.begin(), CurInfo.Pointers.end()); + Sizes.append(CurInfo.Sizes.begin(), CurInfo.Sizes.end()); + Types.append(CurInfo.Types.begin(), CurInfo.Types.end()); + Names.append(CurInfo.Names.begin(), CurInfo.Names.end()); + NonContigInfo.Dims.append(CurInfo.NonContigInfo.Dims.begin(), + CurInfo.NonContigInfo.Dims.end()); + NonContigInfo.Offsets.append(CurInfo.NonContigInfo.Offsets.begin(), + CurInfo.NonContigInfo.Offsets.end()); + NonContigInfo.Counts.append(CurInfo.NonContigInfo.Counts.begin(), + CurInfo.NonContigInfo.Counts.end()); + NonContigInfo.Strides.append(CurInfo.NonContigInfo.Strides.begin(), + CurInfo.NonContigInfo.Strides.end()); + } + }; + /// Emit the arguments to be passed to the runtime library based on the /// arrays of base pointers, pointers, sizes, map types, and mappers. If /// ForEndCall, emit map types to be passed for the end of the region instead