diff --git a/llvm/include/llvm/IR/Assumptions.h b/llvm/include/llvm/IR/Assumptions.h --- a/llvm/include/llvm/IR/Assumptions.h +++ b/llvm/include/llvm/IR/Assumptions.h @@ -21,6 +21,7 @@ namespace llvm { class Function; +class CallBase; /// The key we use for assumption attributes. constexpr StringRef AssumptionAttrKey = "llvm.assume"; @@ -45,6 +46,9 @@ /// Return true if \p F has the assumption \p AssumptionStr attached. bool hasAssumption(Function &F, const KnownAssumptionString &AssumptionStr); +/// Return true if \p CB has the assumption \p AssumptionStr attached. +bool hasAssumption(CallBase &CB, const KnownAssumptionString &AssumptionStr); + } // namespace llvm #endif diff --git a/llvm/lib/IR/Assumptions.cpp b/llvm/lib/IR/Assumptions.cpp --- a/llvm/lib/IR/Assumptions.cpp +++ b/llvm/lib/IR/Assumptions.cpp @@ -11,6 +11,7 @@ #include "llvm/IR/Assumptions.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" using namespace llvm; @@ -29,6 +30,21 @@ }); } +bool llvm::hasAssumption(CallBase &CB, + const KnownAssumptionString &AssumptionStr) { + const Attribute &A = CB.getFnAttr(AssumptionAttrKey); + if (!A.isValid()) + return false; + assert(A.isStringAttribute() && "Expected a string attribute!"); + + SmallVector Strings; + A.getValueAsString().split(Strings, ","); + + return llvm::any_of(Strings, [=](StringRef Assumption) { + return Assumption == AssumptionStr; + }); +} + StringSet<> llvm::KnownAssumptionStrings({ "omp_no_openmp", // OpenMP 5.1 "omp_no_openmp_routines", // OpenMP 5.1 diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -3705,13 +3705,25 @@ Function *Callee = getAssociatedFunction(); // Helper to lookup an assumption string. - auto HasAssumption = [](Function *Fn, StringRef AssumptionStr) { - return Fn && hasAssumption(*Fn, AssumptionStr); + auto HasAssumption = [](Value *V, StringRef AssumptionStr) { + if (V) { + if (Function *Fn = dyn_cast(V)) + return hasAssumption(*Fn, AssumptionStr); + if (CallBase *CB = dyn_cast(V)) + return hasAssumption(*CB, AssumptionStr); + } + return false; }; // Check for SPMD-mode assumptions. - if (HasAssumption(Callee, "ompx_spmd_amenable")) + if (HasAssumption(&CB, "ompx_spmd_amenable")) { SPMDCompatibilityTracker.indicateOptimisticFixpoint(); + indicateOptimisticFixpoint(); + } + + if (HasAssumption(Callee, "ompx_spmd_amenable")) { + SPMDCompatibilityTracker.indicateOptimisticFixpoint(); + } // First weed out calls we do not care about, that is readonly/readnone // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a