diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -3462,6 +3462,9 @@ case OMPRTL___kmpc_is_spmd_exec_mode: Changed |= foldIsSPMDExecMode(A); break; + case OMPRTL___kmpc_is_generic_main_thread_id: + Changed |= foldIsGenericMainThread(A); + break; default: llvm_unreachable("Unhandled OpenMP runtime function!"); } @@ -3476,6 +3479,10 @@ Instruction &CB = *getCtxI(); A.changeValueAfterManifest(CB, **SimplifiedValue); A.deleteAfterManifest(CB); + + LLVM_DEBUG(dbgs() << TAG << "Folding runtime call: " << CB << " with " + << **SimplifiedValue << "\n"); + Changed = ChangeStatus::CHANGED; } @@ -3552,6 +3559,30 @@ : ChangeStatus::CHANGED; } + /// Fold __kmpc_is_generic_main_thread_id into a constant if possible. + ChangeStatus foldIsGenericMainThread(Attributor &A) { + Optional SimplifiedValueBefore = SimplifiedValue; + + CallBase &CB = cast(getAssociatedValue()); + Function *F = CB.getFunction(); + const auto &ExecutionDomainAA = A.getAAFor( + *this, IRPosition::function(*F), DepClassTy::REQUIRED); + + if (!ExecutionDomainAA.isValidState()) { + SimplifiedValue = nullptr; + return indicatePessimisticFixpoint(); + } + + auto &Ctx = getAnchorValue().getContext(); + if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB)) + SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true); + else + SimplifiedValue = nullptr; + + return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + /// An optional value the associated value is assumed to fold to. That is, we /// assume the associated value (which is a call) can be replaced by this /// simplified value. @@ -3578,6 +3609,19 @@ DepClassTy::NONE, /* ForceUpdate */ false, /* UpdateAfterInit */ false); + auto &IsMainRFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_is_generic_main_thread_id]; + IsMainRFI.foreachUse(SCC, [&](Use &U, Function &F) { + CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsMainRFI); + if (!CI) + return false; + A.getOrCreateAAFor( + IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr, + DepClassTy::NONE, /* ForceUpdate */ false, + /* UpdateAfterInit */ false); + return false; + }); + auto &IsSPMDRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode]; IsSPMDRFI.foreachUse(SCC, [&](Use &U, Function &) { CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsSPMDRFI); diff --git a/llvm/test/Transforms/OpenMP/fold_generic_main_thread.ll b/llvm/test/Transforms/OpenMP/fold_generic_main_thread.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/OpenMP/fold_generic_main_thread.ll @@ -0,0 +1,84 @@ +; RUN: opt -S -passes='openmp-opt' < %s | FileCheck %s +; ModuleID = 'single_threaded_exeuction.c' + +%struct.ident_t = type { i32, i32, i32, i32, i8* } + +@0 = private unnamed_addr constant [1 x i8] c"\00", align 1 +@1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([1 x i8], [1 x i8]* @0, i32 0, i32 0) }, align 8 + +define void @kernel() { + %call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 false, i1 false) + %cmp = icmp eq i32 %call, -1 + br i1 %cmp, label %if.then, label %if.else +if.then: + call void @foo() + br label %if.end +if.else: + call void @bar() + br label %if.end +if.end: + call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 true) + ret void +} + +; CHECK-NOT: [[CALL:%.*]] = call signext i8 @__kmpc_is_generic_main_thread_id(i32 %tid) +; Function Attrs: noinline +define internal void @foo() { +entry: + %tid = call i32 @__kmpc_get_hardware_thread_id(); + %ismain = call signext i8 @__kmpc_is_generic_main_thread_id(i32 %tid) + %pred = icmp eq i8 %ismain, 1 + br i1 %pred, label %if.then, label %if.end + +if.then: + call void @baz() + br label %if.end + +if.end: + ret void +} + +; CHECK: [[CALL:%.*]] = call signext i8 @__kmpc_is_generic_main_thread_id(i32 %tid) +; Function Attrs: noinline +define internal void @bar() { +entry: + %tid = call i32 @__kmpc_get_hardware_thread_id(); + %ismain = call signext i8 @__kmpc_is_generic_main_thread_id(i32 %tid) + %pred = icmp eq i8 %ismain, 1 + br i1 %pred, label %if.then, label %if.end + +if.then: + call void @baz() + br label %if.end + +if.end: + ret void +} + + +; Function Attrs: noinline +define internal void @baz() { +entry: + ret void +} + +declare i8 @__kmpc_is_generic_main_thread_id(i32) + +declare i32 @__kmpc_get_hardware_thread_id() + +declare i32 @__kmpc_target_init(%struct.ident_t*, i1, i1, i1) + +declare void @__kmpc_target_deinit(%struct.ident_t*, i1, i1) + +!llvm.dbg.cu = !{!0} +!llvm.module.flags = !{!3, !4, !5, !6} +!nvvm.annotations = !{!7} + +!0 = distinct !DICompileUnit(language: DW_LANG_C99, file: !1, producer: "clang version 13.0.0", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug, enums: !2, splitDebugInlining: false, nameTableKind: None) +!1 = !DIFile(filename: "fold_generic_main_thread.c", directory: "/tmp/fold_generic_main_thread.c") +!2 = !{} +!3 = !{i32 2, !"Debug Info Version", i32 3} +!4 = !{i32 1, !"wchar_size", i32 4} +!5 = !{i32 7, !"openmp", i32 50} +!6 = !{i32 7, !"openmp-device", i32 50} +!7 = !{void ()* @kernel, !"kernel", i32 1}