diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -497,6 +497,12 @@ /// one we abort as the kernel is malformed. CallBase *KernelDeinitCB = nullptr; + /// Flag to indicate if the associated function is a kernel entry. + bool IsKernelEntry = false; + + /// State to track what kernel entries can reach the associated function. + BooleanStateWithPtrSetVector ReachingKernelEntries; + /// Abstract State interface ///{ @@ -537,6 +543,8 @@ return false; if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions) return false; + if (ReachingKernelEntries != RHS.ReachingKernelEntries) + return false; return true; } @@ -2729,6 +2737,10 @@ if (!OMPInfoCache.Kernels.count(Fn)) return; + // Add itself to the reaching kernel and set IsKernelEntry. + ReachingKernelEntries.insert(Fn); + IsKernelEntry = true; + OMPInformationCache::RuntimeFunctionInfo &InitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; OMPInformationCache::RuntimeFunctionInfo &DeinitRFI = @@ -3213,6 +3225,9 @@ CheckRWInst, *this, UsedAssumedInformationInCheckRWInst)) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + if (!IsKernelEntry) + updateReachingKernelEntries(A); + // Callback to check a call instruction. auto CheckCallInst = [&](Instruction &I) { auto &CB = cast(I); @@ -3220,6 +3235,19 @@ *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); if (CBAA.getState().isValidState()) getState() ^= CBAA.getState(); + + Function *Callee = CB.getCalledFunction(); + if (Callee) { + // We need to propagate information to the callee, but since the + // construction of AA always starts with kernel entries, we have to + // create AAKernelInfoFunction for all called functions. However, here + // the caller doesn't depend on the callee. + // TODO: We might want to change the dependence here later if we need + // information from callee to caller. + A.getOrCreateAAFor(IRPosition::function(*Callee), this, + DepClassTy::NONE); + } + return true; }; @@ -3231,6 +3259,35 @@ return StateBefore == getState() ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; } + +private: + /// Update info regarding reaching kernels. + void updateReachingKernelEntries(Attributor &A) { + auto PredCallSite = [&](AbstractCallSite ACS) { + Function *Caller = ACS.getInstruction()->getFunction(); + + assert(Caller && "Caller is nullptr"); + + auto &CAA = + A.getOrCreateAAFor(IRPosition::function(*Caller)); + if (CAA.isValidState()) { + ReachingKernelEntries ^= CAA.ReachingKernelEntries; + return true; + } + + // We lost track of the caller of the associated function, any kernel + // could reach now. + ReachingKernelEntries.indicatePessimisticFixpoint(); + + return true; + }; + + bool AllCallSitesKnown; + if (!A.checkForAllCallSites(PredCallSite, *this, + true /* RequireAllCallSites */, + AllCallSitesKnown)) + ReachingKernelEntries.indicatePessimisticFixpoint(); + } }; /// The call site kernel info abstract attribute, basically, what can we say @@ -3377,6 +3434,182 @@ } }; +struct AAFoldRuntimeCall + : public StateWrapper { + using Base = StateWrapper; + + AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {} + + /// Statistics are tracked as part of manifest for now. + void trackStatistics() const override {} + + /// Create an abstract attribute biew for the position \p IRP. + static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP, + Attributor &A); + + /// See AbstractAttribute::getName() + const std::string getName() const override { return "AAFoldRuntimeCall"; } + + /// See AbstractAttribute::getIdAddr() + const char *getIdAddr() const override { return &ID; } + + /// This function should return true if the type of the \p AA is + /// AAFoldRuntimeCall + static bool classof(const AbstractAttribute *AA) { + return (AA->getIdAddr() == &ID); + } + + static const char ID; +}; + +struct AAFoldRuntimeCallCallSite : AAFoldRuntimeCall { + AAFoldRuntimeCallCallSite(const IRPosition &IRP, Attributor &A) + : AAFoldRuntimeCall(IRP, A) {} + + /// See AbstractAttribute::getAsStr() + const std::string getAsStr() const override { + if (!isValidState()) + return ""; + + std::string Str("simplified value: "); + + if (!SimplifiedValue.hasValue()) + return Str + std::string("none"); + + if (!SimplifiedValue.getValue()) + return Str + std::string("nullptr"); + + if (ConstantInt *CI = dyn_cast(SimplifiedValue.getValue())) + return Str + std::to_string(CI->getSExtValue()); + + return Str + std::string("unknown"); + } + + void initialize(Attributor &A) override { + Function *Callee = getAssociatedFunction(); + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); + assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() && + "Expected a known OpenMP runtime function"); + + RFKind = It->getSecond(); + + CallBase &CB = cast(getAssociatedValue()); + A.registerSimplificationCallback( + IRPosition::callsite_returned(CB), + [&](const IRPosition &IRP, const AbstractAttribute *AA, + bool &UsedAssumedInformation) -> Optional { + if (!isAtFixpoint()) { + UsedAssumedInformation = true; + if (AA) + A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); + } + return SimplifiedValue; + }); + } + + ChangeStatus updateImpl(Attributor &A) override { + ChangeStatus Changed = ChangeStatus::UNCHANGED; + + switch (RFKind) { + case OMPRTL___kmpc_is_spmd_exec_mode: + Changed = Changed | foldIsSPMDExecMode(A); + break; + default: + llvm_unreachable("Unhandled OpenMP runtime function!"); + } + + return Changed; + } + + ChangeStatus manifest(Attributor &A) override { + ChangeStatus Changed = ChangeStatus::UNCHANGED; + + if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) { + Instruction &CB = *getCtxI(); + A.changeValueAfterManifest(CB, **SimplifiedValue); + A.deleteAfterManifest(CB); + Changed = ChangeStatus::CHANGED; + } + + return Changed; + } + + ChangeStatus indicatePessimisticFixpoint() override { + SimplifiedValue = nullptr; + return AAFoldRuntimeCall::indicatePessimisticFixpoint(); + } + +private: + /// Fold __kmpc_is_spmd_exec_mode into a constant if possible. + ChangeStatus foldIsSPMDExecMode(Attributor &A) { + BooleanState StateBefore = getState(); + + unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0; + unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0; + auto &CallerKernelInfoAA = A.getAAFor( + *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); + + if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState()) + return indicatePessimisticFixpoint(); + + for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { + auto &AA = A.getAAFor(*this, IRPosition::function(*K), + DepClassTy::REQUIRED); + + if (!AA.isValidState()) { + SimplifiedValue = nullptr; + return indicatePessimisticFixpoint(); + } + + if (AA.SPMDCompatibilityTracker.isAssumed()) { + if (AA.SPMDCompatibilityTracker.isAtFixpoint()) + ++KnownSPMDCount; + else + ++AssumedSPMDCount; + } else { + if (AA.SPMDCompatibilityTracker.isAtFixpoint()) + ++KnownNonSPMDCount; + else + ++AssumedNonSPMDCount; + } + } + + if (KnownSPMDCount && KnownNonSPMDCount) + return indicatePessimisticFixpoint(); + + if (AssumedSPMDCount && AssumedNonSPMDCount) + return indicatePessimisticFixpoint(); + + auto &Ctx = getAnchorValue().getContext(); + if (KnownSPMDCount || AssumedSPMDCount) { + assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 && + "Expected only SPMD kernels!"); + // All reaching kernels are in SPMD mode. Update all function calls to + // __kmpc_is_spmd_exec_mode to 1. + SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true); + } else { + assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 && + "Expected only non-SPMD kernels!"); + // All reaching kernels are in non-SPMD mode. Update all function + // calls to __kmpc_is_spmd_exec_mode to 0. + SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false); + } + + return getState() == StateBefore ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + + /// An optional value the associated value is assumed to fold to. That is, we + /// assume the associated value (which is a call) can be replaced by this + /// simplified value. + Optional SimplifiedValue; + + /// The runtime function kind of the callee of the associated call site. + RuntimeFunction RFKind; +}; + } // namespace void OpenMPOpt::registerAAs(bool IsModulePass) { @@ -3393,6 +3626,18 @@ IRPosition::function(*Kernel), /* QueryingAA */ nullptr, DepClassTy::NONE, /* ForceUpdate */ false, /* UpdateAfterInit */ false); + + auto &IsSPMDRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode]; + IsSPMDRFI.foreachUse(SCC, [&](Use &U, Function &) { + CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsSPMDRFI); + if (!CI) + return false; + A.getOrCreateAAFor( + IRPosition::callsite_function(*CI), /* QueryingAA */ nullptr, + DepClassTy::NONE, /* ForceUpdate */ false, + /* UpdateAfterInit */ false); + return false; + }); } // Create CallSite AA for all Getters. @@ -3436,6 +3681,7 @@ const char AAKernelInfo::ID = 0; const char AAExecutionDomain::ID = 0; const char AAHeapToShared::ID = 0; +const char AAFoldRuntimeCall::ID = 0; AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP, Attributor &A) { @@ -3527,6 +3773,26 @@ return *AA; } +AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP, + Attributor &A) { + AAFoldRuntimeCall *AA = nullptr; + switch (IRP.getPositionKind()) { + case IRPosition::IRP_INVALID: + case IRPosition::IRP_FLOAT: + case IRPosition::IRP_ARGUMENT: + case IRPosition::IRP_RETURNED: + case IRPosition::IRP_CALL_SITE_RETURNED: + case IRPosition::IRP_CALL_SITE_ARGUMENT: + case IRPosition::IRP_FUNCTION: + llvm_unreachable("KernelInfo can only be created for call site position!"); + case IRPosition::IRP_CALL_SITE: + AA = new (A.Allocator) AAFoldRuntimeCallCallSite(IRP, A); + break; + } + + return *AA; +} + PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { if (!containsOpenMP(M)) return PreservedAnalyses::all(); diff --git a/llvm/test/Transforms/OpenMP/custom_state_machines.ll b/llvm/test/Transforms/OpenMP/custom_state_machines.ll --- a/llvm/test/Transforms/OpenMP/custom_state_machines.ll +++ b/llvm/test/Transforms/OpenMP/custom_state_machines.ll @@ -1713,8 +1713,8 @@ ; CHECK: if.end: ; CHECK-NEXT: [[TMP1:%.*]] = load i32, i32* [[A_ADDR]], align 4 ; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 [[TMP1]], 1 -; CHECK-NEXT: call void @simple_state_machine_interprocedural_nested_recursive_after.internalized(i32 [[SUB]]) #[[ATTR8]] -; CHECK-NEXT: call void @simple_state_machine_interprocedural_nested_recursive_after_after.internalized() #[[ATTR8]] +; CHECK-NEXT: call void @simple_state_machine_interprocedural_nested_recursive_after.internalized(i32 [[SUB]]) #[[ATTR7]] +; CHECK-NEXT: call void @simple_state_machine_interprocedural_nested_recursive_after_after.internalized() #[[ATTR7]] ; CHECK-NEXT: br label [[RETURN]] ; CHECK: return: ; CHECK-NEXT: ret void diff --git a/llvm/test/Transforms/OpenMP/is_spmd_exec_mode_fold.ll b/llvm/test/Transforms/OpenMP/is_spmd_exec_mode_fold.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/OpenMP/is_spmd_exec_mode_fold.ll @@ -0,0 +1,180 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --check-globals +; RUN: opt -S -passes=openmp-opt < %s | FileCheck %s +target triple = "nvptx64" + +%struct.ident_t = type { i32, i32, i32, i32, i8* } + +@is_spmd_exec_mode = weak constant i8 0 +@will_be_spmd_exec_mode = weak constant i8 1 +@non_spmd_exec_mode = weak constant i8 1 +@will_not_be_spmd_exec_mode = weak constant i8 1 +@G = external global i8 +@llvm.compiler.used = appending global [4 x i8*] [i8* @is_spmd_exec_mode, i8* @will_be_spmd_exec_mode, i8* @non_spmd_exec_mode, i8* @will_not_be_spmd_exec_mode ], section "llvm.metadata" + +;. +; CHECK: @[[IS_SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0 +; CHECK: @[[WILL_BE_SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0 +; CHECK: @[[NON_SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 1 +; CHECK: @[[WILL_NOT_BE_SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 1 +; CHECK: @[[G:[a-zA-Z0-9_$"\\.-]+]] = external global i8 +; CHECK: @[[LLVM_COMPILER_USED:[a-zA-Z0-9_$"\\.-]+]] = appending global [4 x i8*] [i8* @is_spmd_exec_mode, i8* @will_be_spmd_exec_mode, i8* @non_spmd_exec_mode, i8* @will_not_be_spmd_exec_mode], section "llvm.metadata" +;. +define weak void @is_spmd() { +; CHECK-LABEL: define {{[^@]+}}@is_spmd() { +; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false) +; CHECK-NEXT: call void @is_spmd_helper1() +; CHECK-NEXT: call void @is_spmd_helper2() +; CHECK-NEXT: call void @is_mixed_helper() +; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false) +; CHECK-NEXT: ret void +; + %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false) + call void @is_spmd_helper1() + call void @is_spmd_helper2() + call void @is_mixed_helper() + call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false) + ret void +} + +define weak void @will_be_spmd() { +; CHECK-LABEL: define {{[^@]+}}@will_be_spmd() { +; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false) +; CHECK-NEXT: call void @is_spmd_helper2() +; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false) +; CHECK-NEXT: ret void +; + %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false) + call void @is_spmd_helper2() + call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false) + ret void +} + +define weak void @non_spmd() { +; CHECK-LABEL: define {{[^@]+}}@non_spmd() { +; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false) +; CHECK-NEXT: call void @is_generic_helper1() +; CHECK-NEXT: call void @is_generic_helper2() +; CHECK-NEXT: call void @is_mixed_helper() +; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false) +; CHECK-NEXT: ret void +; + %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false) + call void @is_generic_helper1() + call void @is_generic_helper2() + call void @is_mixed_helper() + call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false) + ret void +} + +define weak void @will_not_be_spmd() { +; CHECK-LABEL: define {{[^@]+}}@will_not_be_spmd() { +; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false) +; CHECK-NEXT: call void @is_generic_helper1() +; CHECK-NEXT: call void @is_generic_helper2() +; CHECK-NEXT: call void @is_mixed_helper() +; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false) +; CHECK-NEXT: ret void +; + %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false) + call void @is_generic_helper1() + call void @is_generic_helper2() + call void @is_mixed_helper() + call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false) + ret void +} + +define internal void @is_spmd_helper1() { +; CHECK-LABEL: define {{[^@]+}}@is_spmd_helper1() { +; CHECK-NEXT: store i8 1, i8* @G, align 1 +; CHECK-NEXT: ret void +; + %isSPMD = call i8 @__kmpc_is_spmd_exec_mode() + store i8 %isSPMD, i8* @G + ret void +} + +define internal void @is_spmd_helper2() { +; CHECK-LABEL: define {{[^@]+}}@is_spmd_helper2() { +; CHECK-NEXT: br label [[F:%.*]] +; CHECK: t: +; CHECK-NEXT: unreachable +; CHECK: f: +; CHECK-NEXT: ret void +; + %isSPMD = call i8 @__kmpc_is_spmd_exec_mode() + %c = icmp eq i8 %isSPMD, 0 + br i1 %c, label %t, label %f +t: + call void @spmd_compatible() + ret void +f: + ret void +} + +define internal void @is_generic_helper1() { +; CHECK-LABEL: define {{[^@]+}}@is_generic_helper1() { +; CHECK-NEXT: store i8 0, i8* @G, align 1 +; CHECK-NEXT: ret void +; + %isSPMD = call i8 @__kmpc_is_spmd_exec_mode() + store i8 %isSPMD, i8* @G + ret void +} + +define internal void @is_generic_helper2() { +; CHECK-LABEL: define {{[^@]+}}@is_generic_helper2() { +; CHECK-NEXT: br label [[T:%.*]] +; CHECK: t: +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: ret void +; CHECK: f: +; CHECK-NEXT: unreachable +; + %isSPMD = call i8 @__kmpc_is_spmd_exec_mode() + %c = icmp eq i8 %isSPMD, 0 + br i1 %c, label %t, label %f +t: + call void @foo() + ret void +f: + call void @bar() + ret void +} + +define internal void @is_mixed_helper() { +; CHECK-LABEL: define {{[^@]+}}@is_mixed_helper() { +; CHECK-NEXT: [[ISSPMD:%.*]] = call i8 @__kmpc_is_spmd_exec_mode() +; CHECK-NEXT: store i8 [[ISSPMD]], i8* @G, align 1 +; CHECK-NEXT: ret void +; + %isSPMD = call i8 @__kmpc_is_spmd_exec_mode() + store i8 %isSPMD, i8* @G + ret void +} + +declare void @spmd_compatible() "llvm.assume"="ompx_spmd_amenable" +declare i8 @__kmpc_is_spmd_exec_mode() +declare i32 @__kmpc_target_init(%struct.ident_t*, i1 zeroext, i1 zeroext, i1 zeroext) #1 +declare void @__kmpc_target_deinit(%struct.ident_t* nocapture readnone, i1 zeroext, i1 zeroext) #1 +declare void @foo() +declare void @bar() + +!llvm.module.flags = !{!0, !1} +!nvvm.annotations = !{!2, !3, !4, !5} + +!0 = !{i32 7, !"openmp", i32 50} +!1 = !{i32 7, !"openmp-device", i32 50} +!2 = !{void ()* @is_spmd, !"kernel", i32 1} +!3 = !{void ()* @will_be_spmd, !"kernel", i32 1} +!4 = !{void ()* @non_spmd, !"kernel", i32 1} +!5 = !{void ()* @will_not_be_spmd, !"kernel", i32 1} +;. +; CHECK: attributes #[[ATTR0:[0-9]+]] = { "llvm.assume"="ompx_spmd_amenable" } +;. +; CHECK: [[META0:![0-9]+]] = !{i32 7, !"openmp", i32 50} +; CHECK: [[META1:![0-9]+]] = !{i32 7, !"openmp-device", i32 50} +; CHECK: [[META2:![0-9]+]] = !{void ()* @is_spmd, !"kernel", i32 1} +; CHECK: [[META3:![0-9]+]] = !{void ()* @will_be_spmd, !"kernel", i32 1} +; CHECK: [[META4:![0-9]+]] = !{void ()* @non_spmd, !"kernel", i32 1} +; CHECK: [[META5:![0-9]+]] = !{void ()* @will_not_be_spmd, !"kernel", i32 1} +;.