Changeset View
Standalone View
llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Show First 20 Lines • Show All 1,827 Lines • ▼ Show 20 Lines | bool runAttributor(bool IsModulePass) { | |||||||||
ChangeStatus Changed = A.run(); | ChangeStatus Changed = A.run(); | |||||||||
LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size() | LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size() | |||||||||
<< " functions, result: " << Changed << ".\n"); | << " functions, result: " << Changed << ".\n"); | |||||||||
return Changed == ChangeStatus::CHANGED; | return Changed == ChangeStatus::CHANGED; | |||||||||
} | } | |||||||||
void registerFoldRuntimeCall(RuntimeFunction RF); | ||||||||||
jdoerfert: copy and paste. | ||||||||||
/// Populate the Attributor with abstract attribute opportunities in the | /// Populate the Attributor with abstract attribute opportunities in the | |||||||||
/// function. | /// function. | |||||||||
void registerAAs(bool IsModulePass); | void registerAAs(bool IsModulePass); | |||||||||
}; | }; | |||||||||
Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { | Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { | |||||||||
if (!OMPInfoCache.ModuleSlice.count(&F)) | if (!OMPInfoCache.ModuleSlice.count(&F)) | |||||||||
return nullptr; | return nullptr; | |||||||||
▲ Show 20 Lines • Show All 1,657 Lines • ▼ Show 20 Lines | void initialize(Attributor &A) override { | |||||||||
const unsigned int WrapperFunctionArgNo = 6; | const unsigned int WrapperFunctionArgNo = 6; | |||||||||
RuntimeFunction RF = It->getSecond(); | RuntimeFunction RF = It->getSecond(); | |||||||||
switch (RF) { | switch (RF) { | |||||||||
// All the functions we know are compatible with SPMD mode. | // All the functions we know are compatible with SPMD mode. | |||||||||
case OMPRTL___kmpc_is_spmd_exec_mode: | case OMPRTL___kmpc_is_spmd_exec_mode: | |||||||||
case OMPRTL___kmpc_for_static_fini: | case OMPRTL___kmpc_for_static_fini: | |||||||||
case OMPRTL___kmpc_global_thread_num: | case OMPRTL___kmpc_global_thread_num: | |||||||||
case OMPRTL___kmpc_get_hardware_num_threads_in_block: | ||||||||||
case OMPRTL___kmpc_get_hardware_num_blocks: | ||||||||||
case OMPRTL___kmpc_single: | case OMPRTL___kmpc_single: | |||||||||
case OMPRTL___kmpc_end_single: | case OMPRTL___kmpc_end_single: | |||||||||
case OMPRTL___kmpc_master: | case OMPRTL___kmpc_master: | |||||||||
case OMPRTL___kmpc_end_master: | case OMPRTL___kmpc_end_master: | |||||||||
case OMPRTL___kmpc_barrier: | case OMPRTL___kmpc_barrier: | |||||||||
break; | break; | |||||||||
case OMPRTL___kmpc_for_static_init_4: | case OMPRTL___kmpc_for_static_init_4: | |||||||||
case OMPRTL___kmpc_for_static_init_4u: | case OMPRTL___kmpc_for_static_init_4u: | |||||||||
▲ Show 20 Lines • Show All 188 Lines • ▼ Show 20 Lines | A.registerSimplificationCallback( | |||||||||
A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); | A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); | |||||||||
} | } | |||||||||
return SimplifiedValue; | return SimplifiedValue; | |||||||||
}); | }); | |||||||||
} | } | |||||||||
ChangeStatus updateImpl(Attributor &A) override { | ChangeStatus updateImpl(Attributor &A) override { | |||||||||
ChangeStatus Changed = ChangeStatus::UNCHANGED; | ChangeStatus Changed = ChangeStatus::UNCHANGED; | |||||||||
switch (RFKind) { | switch (RFKind) { | |||||||||
case OMPRTL___kmpc_is_spmd_exec_mode: | case OMPRTL___kmpc_is_spmd_exec_mode: | |||||||||
Changed |= foldIsSPMDExecMode(A); | Changed |= foldIsSPMDExecMode(A); | |||||||||
break; | break; | |||||||||
case OMPRTL___kmpc_is_generic_main_thread_id: | case OMPRTL___kmpc_is_generic_main_thread_id: | |||||||||
Changed |= foldIsGenericMainThread(A); | Changed |= foldIsGenericMainThread(A); | |||||||||
break; | break; | |||||||||
case OMPRTL___kmpc_parallel_level: | case OMPRTL___kmpc_parallel_level: | |||||||||
Changed |= foldParallelLevel(A); | Changed |= foldParallelLevel(A); | |||||||||
break; | break; | |||||||||
case OMPRTL___kmpc_get_hardware_num_threads_in_block: | ||||||||||
Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit"); | ||||||||||
break; | ||||||||||
case OMPRTL___kmpc_get_hardware_num_blocks: | ||||||||||
Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams"); | ||||||||||
break; | ||||||||||
default: | default: | |||||||||
llvm_unreachable("Unhandled OpenMP runtime function!"); | llvm_unreachable("Unhandled OpenMP runtime function!"); | |||||||||
} | } | |||||||||
return Changed; | return Changed; | |||||||||
} | } | |||||||||
ChangeStatus manifest(Attributor &A) override { | ChangeStatus manifest(Attributor &A) override { | |||||||||
▲ Show 20 Lines • Show All 99 Lines • ▼ Show 20 Lines | ChangeStatus foldIsGenericMainThread(Attributor &A) { | |||||||||
else | else | |||||||||
return indicatePessimisticFixpoint(); | return indicatePessimisticFixpoint(); | |||||||||
return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED | return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED | |||||||||
: ChangeStatus::CHANGED; | : ChangeStatus::CHANGED; | |||||||||
} | } | |||||||||
/// Fold __kmpc_parallel_level into a constant if possible. | /// Fold __kmpc_parallel_level into a constant if possible. | |||||||||
ChangeStatus foldParallelLevel(Attributor &A) { | ChangeStatus foldParallelLevel(Attributor &A) { | |||||||||
jdoerfert: | ||||||||||
Optional<Value *> SimplifiedValueBefore = SimplifiedValue; | Optional<Value *> SimplifiedValueBefore = SimplifiedValue; | |||||||||
auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( | auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( | |||||||||
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); | *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); | |||||||||
if (!CallerKernelInfoAA.ParallelLevels.isValidState()) | if (!CallerKernelInfoAA.ParallelLevels.isValidState()) | |||||||||
return indicatePessimisticFixpoint(); | return indicatePessimisticFixpoint(); | |||||||||
if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState()) | if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState()) | |||||||||
if (!CallerKernelInfoAA.ReachingKernelEntries.isValid()) return indicatePessimisticFixpoint(); also below jdoerfert: ```
if (!CallerKernelInfoAA.ReachingKernelEntries.isValid())
return… | ||||||||||
return indicatePessimisticFixpoint(); | return indicatePessimisticFixpoint(); | |||||||||
if (CallerKernelInfoAA.ReachingKernelEntries.empty()) { | if (CallerKernelInfoAA.ReachingKernelEntries.empty()) { | |||||||||
assert(!SimplifiedValue.hasValue() && | assert(!SimplifiedValue.hasValue() && | |||||||||
"SimplifiedValue should keep none at this point"); | "SimplifiedValue should keep none at this point"); | |||||||||
return ChangeStatus::UNCHANGED; | return ChangeStatus::UNCHANGED; | |||||||||
} | } | |||||||||
unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0; | unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0; | |||||||||
unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0; | unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0; | |||||||||
for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { | for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { | |||||||||
use early exit instead, return ... jdoerfert: use early exit instead,
if (!...)
return ...
| ||||||||||
auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K), | auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K), | |||||||||
DepClassTy::REQUIRED); | DepClassTy::REQUIRED); | |||||||||
You should set CurrentAttrValue = NextAttrValue at the end of the loop. jdoerfert: You should set CurrentAttrValue = NextAttrValue at the end of the loop. | ||||||||||
if (!AA.SPMDCompatibilityTracker.isValidState()) | if (!AA.SPMDCompatibilityTracker.isValidState()) | |||||||||
return indicatePessimisticFixpoint(); | return indicatePessimisticFixpoint(); | |||||||||
if (AA.SPMDCompatibilityTracker.isAssumed()) { | if (AA.SPMDCompatibilityTracker.isAssumed()) { | |||||||||
if (AA.SPMDCompatibilityTracker.isAtFixpoint()) | if (AA.SPMDCompatibilityTracker.isAtFixpoint()) | |||||||||
jdoerfert: | ||||||||||
++KnownSPMDCount; | ++KnownSPMDCount; | |||||||||
else | else | |||||||||
no reaching kernels is fine, keep it at none (which is the default) and just return UNCHANGED. jdoerfert: no reaching kernels is fine, keep it at none (which is the default) and just return UNCHANGED. | ||||||||||
++AssumedSPMDCount; | ++AssumedSPMDCount; | |||||||||
make this function take the string attribute such that we can have a single one and not two. jdoerfert: make this function take the string attribute such that we can have a single one and not two. | ||||||||||
} else { | } else { | |||||||||
if (AA.SPMDCompatibilityTracker.isAtFixpoint()) | if (AA.SPMDCompatibilityTracker.isAtFixpoint()) | |||||||||
++KnownNonSPMDCount; | ++KnownNonSPMDCount; | |||||||||
else | else | |||||||||
++AssumedNonSPMDCount; | ++AssumedNonSPMDCount; | |||||||||
} | } | |||||||||
} | } | |||||||||
Show All 9 Lines | if (AssumedSPMDCount || KnownSPMDCount) { | |||||||||
assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 && | assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 && | |||||||||
"Expected only SPMD kernels!"); | "Expected only SPMD kernels!"); | |||||||||
SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1); | SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1); | |||||||||
} else { | } else { | |||||||||
assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 && | assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 && | |||||||||
"Expected only non-SPMD kernels!"); | "Expected only non-SPMD kernels!"); | |||||||||
SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0); | SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0); | |||||||||
} | } | |||||||||
return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED | ||||||||||
: ChangeStatus::CHANGED; | ||||||||||
} | ||||||||||
ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) { | ||||||||||
// Specialize only if all the calls agree with the attribute constant value | ||||||||||
int32_t CurrentAttrValue = -1; | ||||||||||
Optional<Value *> SimplifiedValueBefore = SimplifiedValue; | ||||||||||
auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>( | ||||||||||
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); | ||||||||||
if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState()) | ||||||||||
return indicatePessimisticFixpoint(); | ||||||||||
// Iterate over the kernels that reach this function | ||||||||||
for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { | ||||||||||
int32_t NextAttrVal = -1; | ||||||||||
if (K->hasFnAttribute(Attr)) | ||||||||||
NextAttrVal = | ||||||||||
std::stoi(K->getFnAttribute(Attr).getValueAsString().str()); | ||||||||||
if (NextAttrVal == -1 || | ||||||||||
(CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal)) | ||||||||||
return indicatePessimisticFixpoint(); | ||||||||||
CurrentAttrValue = NextAttrVal; | ||||||||||
} | ||||||||||
if (CurrentAttrValue != -1) { | ||||||||||
auto &Ctx = getAnchorValue().getContext(); | ||||||||||
SimplifiedValue = | ||||||||||
ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue); | ||||||||||
} | ||||||||||
return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED | return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED | |||||||||
: ChangeStatus::CHANGED; | : ChangeStatus::CHANGED; | |||||||||
} | } | |||||||||
/// An optional value the associated value is assumed to fold to. That is, we | /// An optional value the associated value is assumed to fold to. That is, we | |||||||||
I haven't read through this part, but if we can only fold it to a constant sometimes, we shouldn't mark the calls that survive noinline, as that'll be expensive for the cases that this pass misses. JonChesterfield: I haven't read through this part, but if we can only fold it to a constant sometimes, we… | ||||||||||
We are about to remove noinline from known runtime functions such that we can keep them around until we get to OpenMP-Opt as calls. This will have the effect we want without any drawbacks. Thus, adding noinline in the runitme will be totally fine. jdoerfert: We are about to remove noinline from known runtime functions such that we can keep them around… | ||||||||||
/// assume the associated value (which is a call) can be replaced by this | /// assume the associated value (which is a call) can be replaced by this | |||||||||
/// simplified value. | /// simplified value. | |||||||||
Optional<Value *> SimplifiedValue; | Optional<Value *> SimplifiedValue; | |||||||||
/// The runtime function kind of the callee of the associated call site. | /// The runtime function kind of the callee of the associated call site. | |||||||||
RuntimeFunction RFKind; | RuntimeFunction RFKind; | |||||||||
}; | }; | |||||||||
} // namespace | } // namespace | |||||||||
/// Register folding callsite | ||||||||||
void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) { | ||||||||||
auto &RFI = OMPInfoCache.RFIs[RF]; | ||||||||||
RFI.foreachUse(SCC, [&](Use &U, Function &F) { | ||||||||||
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI); | ||||||||||
if (!CI) | ||||||||||
leftover. (and you can also just pipe an IRPosition into errs()) jdoerfert: leftover. (and you can also just pipe an IRPosition into errs()) | ||||||||||
return false; | ||||||||||
A.getOrCreateAAFor<AAFoldRuntimeCall>( | ||||||||||
No need to have else here because the code above already returns. tianshilei1992: No need to have `else` here because the code above already returns. | ||||||||||
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr, | ||||||||||
This part needs to be changed. Refer to the trunk for more details. Basically it should be IRPosition::callsite_returned(*CI). tianshilei1992: This part needs to be changed. Refer to the trunk for more details. Basically it should be… | ||||||||||
DepClassTy::NONE, /* ForceUpdate */ false, | ||||||||||
/* UpdateAfterInit */ false); | ||||||||||
no need for else as well tianshilei1992: no need for `else` as well | ||||||||||
return false; | ||||||||||
Directly indicate pessimistic state because we don't know clearly what the number is. tianshilei1992: Directly indicate pessimistic state because we don't know clearly what the number is. | ||||||||||
}); | ||||||||||
} | ||||||||||
void OpenMPOpt::registerAAs(bool IsModulePass) { | void OpenMPOpt::registerAAs(bool IsModulePass) { | |||||||||
if (SCC.empty()) | if (SCC.empty()) | |||||||||
return; | return; | |||||||||
if (IsModulePass) { | if (IsModulePass) { | |||||||||
// Ensure we create the AAKernelInfo AAs first and without triggering an | // Ensure we create the AAKernelInfo AAs first and without triggering an | |||||||||
// update. This will make sure we register all value simplification | // update. This will make sure we register all value simplification | |||||||||
This has to be updated as SimplifiedValue is not part of BooleanState. Refer to foldIsSPMDExecMode. tianshilei1992: This has to be updated as `SimplifiedValue` is not part of `BooleanState`. Refer to… | ||||||||||
// callbacks before any other AA has the chance to create an AAValueSimplify | // callbacks before any other AA has the chance to create an AAValueSimplify | |||||||||
// or similar. | // or similar. | |||||||||
for (Function *Kernel : OMPInfoCache.Kernels) | for (Function *Kernel : OMPInfoCache.Kernels) | |||||||||
A.getOrCreateAAFor<AAKernelInfo>( | A.getOrCreateAAFor<AAKernelInfo>( | |||||||||
IRPosition::function(*Kernel), /* QueryingAA */ nullptr, | IRPosition::function(*Kernel), /* QueryingAA */ nullptr, | |||||||||
DepClassTy::NONE, /* ForceUpdate */ false, | DepClassTy::NONE, /* ForceUpdate */ false, | |||||||||
Update accordingly tianshilei1992: Update accordingly | ||||||||||
/* UpdateAfterInit */ false); | /* UpdateAfterInit */ false); | |||||||||
auto &IsMainRFI = | ||||||||||
OMPInfoCache.RFIs[OMPRTL___kmpc_is_generic_main_thread_id]; | ||||||||||
IsMainRFI.foreachUse(SCC, [&](Use &U, Function &F) { | ||||||||||
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsMainRFI); | ||||||||||
if (!CI) | ||||||||||
return false; | ||||||||||
A.getOrCreateAAFor<AAFoldRuntimeCall>( | ||||||||||
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr, | ||||||||||
DepClassTy::NONE, /* ForceUpdate */ false, | ||||||||||
/* UpdateAfterInit */ false); | ||||||||||
return false; | ||||||||||
}); | ||||||||||
auto &IsSPMDRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode]; | registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id); | |||||||||
IsSPMDRFI.foreachUse(SCC, [&](Use &U, Function &) { | registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode); | |||||||||
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsSPMDRFI); | registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level); | |||||||||
if (!CI) | registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block); | |||||||||
return false; | registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks); | |||||||||
A.getOrCreateAAFor<AAFoldRuntimeCall>( | ||||||||||
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr, | ||||||||||
DepClassTy::NONE, /* ForceUpdate */ false, | ||||||||||
/* UpdateAfterInit */ false); | ||||||||||
return false; | ||||||||||
}); | ||||||||||
auto &ParallelLevelRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_level]; | ||||||||||
ParallelLevelRFI.foreachUse(SCC, [&](Use &U, Function &) { | ||||||||||
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &ParallelLevelRFI); | ||||||||||
if (!CI) | ||||||||||
return false; | ||||||||||
A.getOrCreateAAFor<AAFoldRuntimeCall>( | ||||||||||
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr, | ||||||||||
DepClassTy::NONE, /* ForceUpdate */ false, | ||||||||||
/* UpdateAfterInit */ false); | ||||||||||
return false; | ||||||||||
}); | ||||||||||
} | } | |||||||||
// Create CallSite AA for all Getters. | // Create CallSite AA for all Getters. | |||||||||
for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) { | for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) { | |||||||||
auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)]; | auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)]; | |||||||||
auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter]; | auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter]; | |||||||||
▲ Show 20 Lines • Show All 408 Lines • Show Last 20 Lines |
copy and paste.