Index: clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp =================================================================== --- clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp +++ clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp @@ -1799,9 +1799,8 @@ llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty}; auto *FnTy = llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false); - RTLFn = CGM.CreateRuntimeFunction(FnTy, /*Name*/ "__kmpc_barrier"); - cast(RTLFn.getCallee()) - ->addFnAttr(llvm::Attribute::Convergent); + RTLFn = + CGM.CreateConvergentRuntimeFunction(FnTy, /*Name*/ "__kmpc_barrier"); break; } case OMPRTL__kmpc_barrier_simple_spmd: { @@ -1810,10 +1809,8 @@ llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty}; auto *FnTy = llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false); - RTLFn = - CGM.CreateRuntimeFunction(FnTy, /*Name*/ "__kmpc_barrier_simple_spmd"); - cast(RTLFn.getCallee()) - ->addFnAttr(llvm::Attribute::Convergent); + RTLFn = CGM.CreateConvergentRuntimeFunction( + FnTy, /*Name*/ "__kmpc_barrier_simple_spmd"); break; } case OMPRTL_NVPTX__kmpc_warp_active_thread_mask: { Index: clang/lib/CodeGen/CodeGenModule.h =================================================================== --- clang/lib/CodeGen/CodeGenModule.h +++ clang/lib/CodeGen/CodeGenModule.h @@ -1031,7 +1031,14 @@ llvm::FunctionCallee CreateRuntimeFunction(llvm::FunctionType *Ty, StringRef Name, llvm::AttributeList ExtraAttrs = llvm::AttributeList(), - bool Local = false); + bool Local = false, bool AssumeConvergent = false); + + llvm::FunctionCallee CreateConvergentRuntimeFunction( + llvm::FunctionType *Ty, StringRef Name, + llvm::AttributeList ExtraAttrs = llvm::AttributeList(), + bool Local = false) { + return CreateRuntimeFunction(Ty, Name, ExtraAttrs, Local, true); + } /// Create a new runtime global variable with the specified type and name. llvm::Constant *CreateRuntimeVariable(llvm::Type *Ty, Index: clang/lib/CodeGen/CodeGenModule.cpp =================================================================== --- clang/lib/CodeGen/CodeGenModule.cpp +++ clang/lib/CodeGen/CodeGenModule.cpp @@ -3332,8 +3332,14 @@ /// type and name. llvm::FunctionCallee CodeGenModule::CreateRuntimeFunction(llvm::FunctionType *FTy, StringRef Name, - llvm::AttributeList ExtraAttrs, - bool Local) { + llvm::AttributeList ExtraAttrs, bool Local, + bool AssumeConvergent) { + if (AssumeConvergent) { + ExtraAttrs = + ExtraAttrs.addAttribute(VMContext, llvm::AttributeList::FunctionIndex, + llvm::Attribute::Convergent); + } + llvm::Constant *C = GetOrCreateLLVMFunction(Name, FTy, GlobalDecl(), /*ForVTable=*/false, /*DontDefer=*/false, /*IsThunk=*/false,