diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -226,6 +226,9 @@ omp::IdentFlag Flags = omp::IdentFlag(0), unsigned Reserve2Flags = 0); + // Get the type corresponding to __kmpc_impl_lanemask_t from the deviceRTL + Type *getLanemaskType(); + /// Generate control flow and cleanup for cancellation. /// /// \param CancelFlag Flag indicating if the cancellation is performed. @@ -427,6 +430,7 @@ /// values. /// ///{ + Type *LanemaskTy = nullptr; #define OMP_TYPE(VarName, InitValue) Type *VarName = nullptr; #define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \ ArrayType *VarName##Ty = nullptr; \ diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -553,8 +553,9 @@ Int16, VoidPtrPtr) __OMP_RTL(__kmpc_restore_team_static_memory, false, Void, Int16, Int16) __OMP_RTL(__kmpc_barrier_simple_spmd, false, Void, IdentPtr, Int32) -__OMP_RTL(__kmpc_warp_active_thread_mask, false, Int32, ) -__OMP_RTL(__kmpc_syncwarp, false, Void, Int32) + +__OMP_RTL(__kmpc_warp_active_thread_mask, false, LanemaskTy,) +__OMP_RTL(__kmpc_syncwarp, false, Void, LanemaskTy) __OMP_RTL(__last, false, Void, ) diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/Triple.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/IRBuilder.h" @@ -62,6 +63,7 @@ OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) { FunctionType *FnTy = nullptr; Function *Fn = nullptr; + Type *LanemaskTy = getLanemaskType(); // Try to find the declation in the module first. switch (FnID) { @@ -217,6 +219,14 @@ return Ident; } +Type *OpenMPIRBuilder::getLanemaskType() { + LLVMContext &Ctx = M.getContext(); + Triple triple(M.getTargetTriple()); + + // This test is adequate until deviceRTL has finer grained lane widths + return triple.isAMDGCN() ? Type::getInt64Ty(Ctx) : Type::getInt32Ty(Ctx); +} + Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr) { Constant *&SrcLocStr = SrcLocStrMap[LocStr]; if (!SrcLocStr) { @@ -1177,6 +1187,8 @@ void OpenMPIRBuilder::initializeTypes(Module &M) { LLVMContext &Ctx = M.getContext(); StructType *T; + LanemaskTy = getLanemaskType(); + #define OMP_TYPE(VarName, InitValue) VarName = InitValue; #define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \ VarName##Ty = ArrayType::get(ElemTy, ArraySize); \ diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -331,6 +331,7 @@ /// in OpenMPKinds.def. void initializeRuntimeFunctions() { Module &M = *((*ModuleSlice.begin())->getParent()); + Type *LanemaskTy = OMPBuilder.getLanemaskType(); // Helper macros for handling __VA_ARGS__ in OMP_RTL #define OMP_TYPE(VarName, ...) \