diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -1170,6 +1170,14 @@ virtual Value *getReplacementValue(InternalControlVar ICV, const Instruction *I, Attributor &A) = 0; + /// Check if any value was tracked. + virtual bool hasTrackedValue(InternalControlVar &ICV) const { return false; } + + /// Check if any call potentially changes the ICV. + virtual bool hasUnknownCall(InternalControlVar &ICV, Attributor &A) const { + return false; + } + /// See AbstractAttribute::getName() const std::string getName() const override { return "AAICVTracker"; } @@ -1216,7 +1224,7 @@ auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) { CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI); Instruction *UserI = cast(U.getUser()); - Value *ReplVal = getReplacementValue(ICV, UserI, A); + Value *ReplVal = ICVReplacementValuesMap[ICV][UserI]; if (!ReplVal || !CI) return false; @@ -1237,6 +1245,23 @@ InternalControlVar::ICV___last> ICVValuesMap; + // Map of ICV to their values at specific program point. + EnumeratedArray, InternalControlVar, + InternalControlVar::ICV___last> + ICVReplacementValuesMap; + + bool hasTrackedValue(InternalControlVar &ICV) const override { + return !ICVValuesMap[ICV].empty(); + } + + bool hasUnknownCall(InternalControlVar &ICV, Attributor &A) const override { + for (BasicBlock &BB : *getAnchorScope()) + for (Instruction &I : BB) + if (callChangesICV(A, &I, ICV)) + return true; + return false; + } + // Currently only nthreads is being tracked. // this array will only grow with time. InternalControlVar TrackableICVs[1] = {ICV_nthreads}; @@ -1250,6 +1275,7 @@ for (InternalControlVar ICV : TrackableICVs) { auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; + auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter]; auto TrackValues = [&](Use &U, Function &) { CallInst *CI = OpenMPOpt::getCallIfRegularCall(U); @@ -1264,20 +1290,72 @@ return false; }; + // Map replacement values to Getters and their uses. + auto MapReplacementValues = [&](Use &U, Function &) { + CallInst *CI = OpenMPOpt::getCallIfRegularCall(U); + if (!CI) + return false; + + Value *ReplVal = getReplacementValue(ICV, CI, A); + if (ICVReplacementValuesMap[ICV] + .insert(std::make_pair(CI, ReplVal)) + .second) + HasChanged = ChangeStatus::CHANGED; + + assert((!ICVReplacementValuesMap[ICV].lookup(CI) || + ICVReplacementValuesMap[ICV].lookup(CI) == ReplVal) && + "Getter should be either not mapped or mapped to ReplVal"); + + return false; + }; + SetterRFI.foreachUse(TrackValues, F); + GetterRFI.foreachUse(MapReplacementValues, F); } return HasChanged; } + /// Check if \p I is a call and whether it changes the ICV. + bool callChangesICV(Attributor &A, const Instruction *I, + InternalControlVar &ICV) const { + const auto *CB = dyn_cast(I); + if (!CB) + return false; + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter]; + auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter]; + Function *CalledFunction = CB->getCalledFunction(); + + if (CalledFunction == GetterRFI.Declaration) + return false; + if (CalledFunction == SetterRFI.Declaration) + return true; + + // Since we don't know, assume it changes the ICV. + if (CalledFunction->isDeclaration()) + return true; + + const auto &ICVTrackingAA = + A.getAAFor(*this, IRPosition::function(*CalledFunction)); + + if (ICVTrackingAA.isAssumedTracked()) + return !ICVTrackingAA.hasTrackedValue(ICV) && + ICVTrackingAA.hasUnknownCall(ICV, A); + + return true; + } + /// Return the value with which \p I can be replaced for specific \p ICV. Value *getReplacementValue(InternalControlVar ICV, const Instruction *I, Attributor &A) override { + if (ICVReplacementValuesMap[ICV].count(I)) + return ICVReplacementValuesMap[ICV].lookup(I); + const BasicBlock *CurrBB = I->getParent(); auto &ValuesSet = ICVValuesMap[ICV]; - auto &OMPInfoCache = static_cast(A.getInfoCache()); - auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter]; for (const auto &ICVVal : ValuesSet) { if (CurrBB == ICVVal.Inst->getParent()) { @@ -1286,19 +1364,43 @@ // both instructions are in the same BB and at \p I we know the ICV // value. - while (I != ICVVal.Inst) { + const Instruction *CurrInst = I; + while (CurrInst != ICVVal.Inst) { // we don't yet know if a call might update an ICV. // TODO: check callsite AA for value. - if (const auto *CB = dyn_cast(I)) - if (CB->getCalledFunction() != GetterRFI.Declaration) - return nullptr; + if (callChangesICV(A, CurrInst, ICV)) + return nullptr; - I = I->getPrevNode(); + CurrInst = CurrInst->getPrevNode(); } // No call in between, return the value. return ICVVal.TrackedValue; } + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + auto &Explorer = OMPInfoCache.getMustBeExecutedContextExplorer(); + for (const BasicBlock *Pred : predecessors(CurrBB)) { + if (ICVVal.Inst->getParent() == Pred) { + if (!Explorer.findInContextOf(ICVVal.Inst, I)) + return nullptr; + Instruction *CurrInst = ICVVal.Inst->getNextNode(); + while (CurrInst != Pred->getTerminator()) { + // If any of the calls after the tracked value, might change the + // ICV, we don't know the value. + if (callChangesICV(A, CurrInst, ICV)) + return nullptr; + CurrInst = CurrInst->getNextNode(); + } + return ICVVal.TrackedValue; + } + + // if any of the calls in a block that doesn't have tracked value might + // change the ICV, we don't know the value. + for (const Instruction &Inst : *Pred) + if (callChangesICV(A, &Inst, ICV)) + return nullptr; + } } // No value was tracked. @@ -1362,7 +1464,9 @@ OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, /*CGSCC*/ Functions, OMPInModule.getKernels()); - Attributor A(Functions, InfoCache, CGUpdater); + SetVector ModuleSlice(InfoCache.ModuleSlice.begin(), + InfoCache.ModuleSlice.end()); + Attributor A(ModuleSlice, InfoCache, CGUpdater); // TODO: Compute the module slice we are allowed to look at. OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); @@ -1428,7 +1532,9 @@ *(Functions.back()->getParent()), AG, Allocator, /*CGSCC*/ Functions, OMPInModule.getKernels()); - Attributor A(Functions, InfoCache, CGUpdater); + SetVector ModuleSlice(InfoCache.ModuleSlice.begin(), + InfoCache.ModuleSlice.end()); + Attributor A(ModuleSlice, InfoCache, CGUpdater); // TODO: Compute the module slice we are allowed to look at. OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); diff --git a/llvm/test/Transforms/OpenMP/icv_tracking.ll b/llvm/test/Transforms/OpenMP/icv_tracking.ll --- a/llvm/test/Transforms/OpenMP/icv_tracking.ll +++ b/llvm/test/Transforms/OpenMP/icv_tracking.ll @@ -7,6 +7,29 @@ @.str = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 @0 = private unnamed_addr global %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([23 x i8], [23 x i8]* @.str, i32 0, i32 0) }, align 8 +; doesn't modify any ICVs. +define i32 @icv_free_use(i32 %0) { +; CHECK-LABEL: define {{[^@]+}}@icv_free_use +; CHECK-SAME: (i32 [[TMP0:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = add nsw i32 [[TMP0]], 1 +; CHECK-NEXT: ret i32 [[TMP2]] +; + %2 = add nsw i32 %0, 1 + ret i32 %2 +} + +define i32 @bad_use(i32 %0) { +; CHECK-LABEL: define {{[^@]+}}@bad_use +; CHECK-SAME: (i32 [[TMP0:%.*]]) +; CHECK-NEXT: tail call void @use(i32 [[TMP0]]) +; CHECK-NEXT: [[TMP2:%.*]] = add nsw i32 [[TMP0]], 1 +; CHECK-NEXT: ret i32 [[TMP2]] +; + tail call void @use(i32 %0) + %2 = add nsw i32 %0, 1 + ret i32 %2 +} + define dso_local i32 @foo(i32 %0, i32 %1) { ; CHECK-LABEL: define {{[^@]+}}@foo ; CHECK-SAME: (i32 [[TMP0:%.*]], i32 [[TMP1:%.*]]) @@ -105,5 +128,154 @@ ret void } +define dso_local i32 @bar1(i32 %0, i32 %1) { +; CHECK-LABEL: define {{[^@]+}}@bar1 +; CHECK-SAME: (i32 [[TMP0:%.*]], i32 [[TMP1:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = icmp sgt i32 [[TMP0]], [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = select i1 [[TMP3]], i32 [[TMP0]], i32 [[TMP1]] +; CHECK-NEXT: tail call void @omp_set_num_threads(i32 [[TMP4]]) +; CHECK-NEXT: tail call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* nonnull @0, i32 0, void (i32*, i32*, ...)* bitcast (void (i32*, i32*)* @.omp_outlined..2 to void (i32*, i32*, ...)*)) +; CHECK-NEXT: [[TMP5:%.*]] = tail call i32 @omp_get_max_threads() +; CHECK-NEXT: tail call void @use(i32 [[TMP5]]) +; CHECK-NEXT: ret i32 0 +; + %3 = icmp sgt i32 %0, %1 + %4 = select i1 %3, i32 %0, i32 %1 + tail call void @omp_set_num_threads(i32 %4) + %5 = tail call i32 @omp_get_max_threads() + tail call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* nonnull @0, i32 0, void (i32*, i32*, ...)* bitcast (void (i32*, i32*)* @.omp_outlined..2 to void (i32*, i32*, ...)*)) + %6 = tail call i32 @omp_get_max_threads() + tail call void @use(i32 %6) + ret i32 0 +} + +define internal void @.omp_outlined..2(i32* %0, i32* %1) { +; CHECK-LABEL: define {{[^@]+}}@.omp_outlined..2 +; CHECK-SAME: (i32* [[TMP0:%.*]], i32* [[TMP1:%.*]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @omp_get_max_threads() +; CHECK-NEXT: [[TMP4:%.*]] = tail call i32 @icv_free_use(i32 [[TMP3]]) +; CHECK-NEXT: tail call void @omp_set_num_threads(i32 10) +; CHECK-NEXT: [[TMP5:%.*]] = tail call i32 @icv_free_use(i32 10) +; CHECK-NEXT: [[TMP6:%.*]] = tail call i32 @icv_free_use(i32 10) +; CHECK-NEXT: ret void +; + %3 = tail call i32 @omp_get_max_threads() + %4 = tail call i32 @icv_free_use(i32 %3) + tail call void @omp_set_num_threads(i32 10) + %5 = tail call i32 @omp_get_max_threads() + %6 = tail call i32 @icv_free_use(i32 %5) + %7 = tail call i32 @omp_get_max_threads() + %8 = tail call i32 @icv_free_use(i32 %7) + ret void +} +define void @test(i1 %0) { +; CHECK-LABEL: define {{[^@]+}}@test +; CHECK-SAME: (i1 [[TMP0:%.*]]) +; CHECK-NEXT: call void @omp_set_num_threads(i32 2) +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i1 [[TMP0]], false +; CHECK-NEXT: br i1 [[TMP2]], label [[TMP4:%.*]], label [[TMP3:%.*]] +; CHECK: 3: +; CHECK-NEXT: call void @use(i32 10) +; CHECK-NEXT: br label [[TMP4]] +; CHECK: 4: +; CHECK-NEXT: [[TMP5:%.*]] = call i32 @omp_get_max_threads() +; CHECK-NEXT: call void @use(i32 [[TMP5]]) +; CHECK-NEXT: ret void +; + call void @omp_set_num_threads(i32 2) + %2 = icmp eq i1 %0, 0 + br i1 %2, label %4, label %3 + +3: ; preds = %1 + call void @use(i32 10) + br label %4 + +4: ; preds = %3, %1 + %5 = call i32 @omp_get_max_threads() + call void @use(i32 %5) + ret void +} + +define void @test1(i1 %0) { +; CHECK-LABEL: define {{[^@]+}}@test1 +; CHECK-SAME: (i1 [[TMP0:%.*]]) +; CHECK-NEXT: call void @omp_set_num_threads(i32 2) +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i1 [[TMP0]], false +; CHECK-NEXT: br i1 [[TMP2]], label [[TMP5:%.*]], label [[TMP3:%.*]] +; CHECK: 3: +; CHECK-NEXT: [[TMP4:%.*]] = call i32 @icv_free_use(i32 10) +; CHECK-NEXT: br label [[TMP5]] +; CHECK: 5: +; CHECK-NEXT: call void @use(i32 2) +; CHECK-NEXT: ret void +; + call void @omp_set_num_threads(i32 2) + %2 = icmp eq i1 %0, 0 + br i1 %2, label %5, label %3 + +3: ; preds = %1 + %4 = call i32 @icv_free_use(i32 10) + br label %5 + +5: ; preds = %3, %1 + %6 = call i32 @omp_get_max_threads() + call void @use(i32 %6) + ret void +} + +define void @bad_use_test(i1 %0) { +; CHECK-LABEL: define {{[^@]+}}@bad_use_test +; CHECK-SAME: (i1 [[TMP0:%.*]]) +; CHECK-NEXT: call void @omp_set_num_threads(i32 2) +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i1 [[TMP0]], false +; CHECK-NEXT: br i1 [[TMP2]], label [[TMP5:%.*]], label [[TMP3:%.*]] +; CHECK: 3: +; CHECK-NEXT: [[TMP4:%.*]] = call i32 @bad_use(i32 10) +; CHECK-NEXT: br label [[TMP5]] +; CHECK: 5: +; CHECK-NEXT: [[TMP6:%.*]] = call i32 @omp_get_max_threads() +; CHECK-NEXT: call void @use(i32 [[TMP6]]) +; CHECK-NEXT: ret void +; + call void @omp_set_num_threads(i32 2) + %2 = icmp eq i1 %0, 0 + br i1 %2, label %5, label %3 + +3: ; preds = %1 + %4 = call i32 @bad_use(i32 10) + br label %5 + +5: ; preds = %3, %1 + %6 = call i32 @omp_get_max_threads() + call void @use(i32 %6) + ret void +} + +define void @test2(i1 %0) { +; CHECK-LABEL: define {{[^@]+}}@test2 +; CHECK-SAME: (i1 [[TMP0:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i1 [[TMP0]], false +; CHECK-NEXT: br i1 [[TMP2]], label [[TMP4:%.*]], label [[TMP3:%.*]] +; CHECK: 3: +; CHECK-NEXT: call void @omp_set_num_threads(i32 4) +; CHECK-NEXT: br label [[TMP4]] +; CHECK: 4: +; CHECK-NEXT: [[TMP5:%.*]] = call i32 @omp_get_max_threads() +; CHECK-NEXT: call void @use(i32 [[TMP5]]) +; CHECK-NEXT: ret void +; + %2 = icmp eq i1 %0, 0 + br i1 %2, label %4, label %3 + +3: ; preds = %1 + call void @omp_set_num_threads(i32 4) + br label %4 + +4: ; preds = %3, %1 + %5 = call i32 @omp_get_max_threads() + call void @use(i32 %5) + ret void +} + !0 = !{!1} !1 = !{i64 2, i64 -1, i64 -1, i1 true}