diff --git a/llvm/include/llvm/IR/Assumptions.h b/llvm/include/llvm/IR/Assumptions.h --- a/llvm/include/llvm/IR/Assumptions.h +++ b/llvm/include/llvm/IR/Assumptions.h @@ -15,6 +15,7 @@ #ifndef LLVM_IR_ASSUMPTIONS_H #define LLVM_IR_ASSUMPTIONS_H +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" @@ -44,11 +45,25 @@ }; /// Return true if \p F has the assumption \p AssumptionStr attached. -bool hasAssumption(Function &F, const KnownAssumptionString &AssumptionStr); +bool hasAssumption(const Function &F, + const KnownAssumptionString &AssumptionStr); /// Return true if \p CB or the callee has the assumption \p AssumptionStr /// attached. -bool hasAssumption(CallBase &CB, const KnownAssumptionString &AssumptionStr); +bool hasAssumption(const CallBase &CB, + const KnownAssumptionString &AssumptionStr); + +/// Return the set of all assumptions for the function \p F. +DenseSet getAssumptions(const Function &F); + +/// Return the set of all assumptions for the call \p CB. +DenseSet getAssumptions(const CallBase &CB); + +/// Appends the set of assumptions \p Assumptions to \F. +void addAssumptions(Function &F, const DenseSet &Assumptions); + +/// Appends the set of assumptions \p Assumptions to \CB. +void addAssumptions(CallBase &CB, const DenseSet &Assumptions); } // namespace llvm diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -101,6 +101,7 @@ #include "llvm/ADT/GraphTraits.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Triple.h" #include "llvm/ADT/iterator.h" @@ -4609,6 +4610,96 @@ static const char ID; }; +/// An abstract attribute for getting assumption information. +struct AAAssumptionInfo : public StateWrapper { + using Base = StateWrapper; + + /// A wrapper around a set that has semantics for handling unions and + /// intersections with a "universal" set that contains all elements. + struct AssumptionSet { + /// Creates a universal set with no concrete elements. + AssumptionSet() : Universal(true) {} + + /// Creates a non-universal set with concrete values. + AssumptionSet(const DenseSet &Assumptions) + : Universal(false), Assumptions(Assumptions) {} + + const DenseSet &getAssumptions() const { return Assumptions; } + + bool isUniversal() const { return Universal; } + + bool empty() const { return Assumptions.empty() && !Universal; } + + /// Finds A := A ^ B where A or B could be the "Universal" set which + /// contains every possible attribute. Returns true if changes were made. + bool getIntersection(const AssumptionSet &NewAssumptions) { + bool IsUniversal = Universal; + unsigned Size = Assumptions.size(); + + // A := A ^ U = A + if (NewAssumptions.isUniversal()) + return false; + + // A := U ^ B = B + if (Universal) + Assumptions = NewAssumptions.getAssumptions(); + else + set_intersect(Assumptions, NewAssumptions.getAssumptions()); + + Universal &= NewAssumptions.isUniversal(); + return IsUniversal != Universal || Size != Assumptions.size(); + } + + /// Finds A := A u B where A or B could be the "Universal" set which + /// contains every possible attribute. + bool getUnion(const AssumptionSet &NewAssumptions) { + bool IsUniversal = Universal; + unsigned Size = Assumptions.size(); + + // A := A u U = U = U u B + if (!NewAssumptions.isUniversal() && !Universal) + set_union(Assumptions, NewAssumptions.getAssumptions()); + + Universal |= NewAssumptions.isUniversal(); + return IsUniversal != Universal || Size != Assumptions.size(); + } + + private: + /// Indicates if this set is "universal", containing every possible element. + bool Universal; + + /// The set of currently active assumptions. + DenseSet Assumptions; + }; + + AAAssumptionInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {} + + /// Returns true if the assumption set contains the assumption \p Assumption. + virtual bool hasAssumption(const StringRef Assumption) const = 0; + + /// Create an abstract attribute view for the position \p IRP. + static AAAssumptionInfo &createForPosition(const IRPosition &IRP, + Attributor &A); + + /// See AbstractAttribute::getName() + const std::string getName() const override { return "AAAssumptionInfo"; } + + /// See AbstractAttribute::getIdAddr() + const char *getIdAddr() const override { return &ID; } + + /// This function should return true if the type of the \p AA is + /// AAAssumptionInfo + static bool classof(const AbstractAttribute *AA) { + return (AA->getIdAddr() == &ID); + } + + /// Unique ID (due to the unique address) + static const char ID; + + /// Get the current set of active assumptions. + virtual const AssumptionSet getAssumptionSet() const = 0; +}; + raw_ostream &operator<<(raw_ostream &, const AAPointerInfo::Access &); /// Run options, used by the pass manager. diff --git a/llvm/lib/IR/Assumptions.cpp b/llvm/lib/IR/Assumptions.cpp --- a/llvm/lib/IR/Assumptions.cpp +++ b/llvm/lib/IR/Assumptions.cpp @@ -9,6 +9,8 @@ //===----------------------------------------------------------------------===// #include "llvm/IR/Assumptions.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -29,15 +31,29 @@ return Assumption == AssumptionStr; }); } + +DenseSet getAssumptions(const Attribute &A) { + if (!A.isValid()) + return DenseSet(); + assert(A.isStringAttribute() && "Expected a string attribute!"); + + DenseSet Assumptions; + SmallVector Strings; + A.getValueAsString().split(Strings, ","); + + for (StringRef Str : Strings) + Assumptions.insert(Str); + return Assumptions; +} } // namespace -bool llvm::hasAssumption(Function &F, +bool llvm::hasAssumption(const Function &F, const KnownAssumptionString &AssumptionStr) { const Attribute &A = F.getFnAttribute(AssumptionAttrKey); return ::hasAssumption(A, AssumptionStr); } -bool llvm::hasAssumption(CallBase &CB, +bool llvm::hasAssumption(const CallBase &CB, const KnownAssumptionString &AssumptionStr) { if (Function *F = CB.getCalledFunction()) if (hasAssumption(*F, AssumptionStr)) @@ -47,6 +63,47 @@ return ::hasAssumption(A, AssumptionStr); } +DenseSet llvm::getAssumptions(const Function &F) { + const Attribute &A = F.getFnAttribute(AssumptionAttrKey); + return ::getAssumptions(A); +} + +DenseSet llvm::getAssumptions(const CallBase &CB) { + const Attribute &A = CB.getFnAttr(AssumptionAttrKey); + return ::getAssumptions(A); +} + +void llvm::addAssumptions(Function &F, const DenseSet &Assumptions) { + if (Assumptions.empty()) + return; + + DenseSet CurAssumptions = getAssumptions(F); + + if (!set_union(CurAssumptions, Assumptions)) + return; + + LLVMContext &Ctx = F.getContext(); + F.addFnAttr(llvm::Attribute::get( + Ctx, llvm::AssumptionAttrKey, + llvm::join(CurAssumptions.begin(), CurAssumptions.end(), ","))); +} + +void llvm::addAssumptions(CallBase &CB, + const DenseSet &Assumptions) { + if (Assumptions.empty()) + return; + + DenseSet CurAssumptions = getAssumptions(CB); + + if (!set_union(CurAssumptions, Assumptions)) + return; + + LLVMContext &Ctx = CB.getContext(); + CB.addFnAttr(llvm::Attribute::get( + Ctx, llvm::AssumptionAttrKey, + llvm::join(CurAssumptions.begin(), CurAssumptions.end(), ","))); +} + StringSet<> llvm::KnownAssumptionStrings({ "omp_no_openmp", // OpenMP 5.1 "omp_no_openmp_routines", // OpenMP 5.1 diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -2494,6 +2494,9 @@ // Every function can be "readnone/argmemonly/inaccessiblememonly/...". getOrCreateAAFor(FPos); + // Every function can track active assumptions. + getOrCreateAAFor(FPos); + // Every function might be applicable for Heap-To-Stack conversion. if (EnableHeapToStack) getOrCreateAAFor(FPos); diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -28,6 +29,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Assumptions.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" @@ -146,6 +148,7 @@ PIPE_OPERATOR(AACallEdges) PIPE_OPERATOR(AAFunctionReachability) PIPE_OPERATOR(AAPointerInfo) +PIPE_OPERATOR(AAAssumptionInfo) #undef PIPE_OPERATOR @@ -9658,6 +9661,7 @@ } void trackStatistics() const override {} + private: bool canReachUnknownCallee() const override { return WholeFunction.CanReachUnknownCallee; @@ -9671,6 +9675,132 @@ DenseMap CBQueries; }; +/// ---------------------- Assumption Propagation ------------------------------ +struct AAAssumptionInfoImpl : public AAAssumptionInfo { + AAAssumptionInfoImpl(const IRPosition &IRP, Attributor &A) + : AAAssumptionInfo(IRP, A), Assumptions() {} + + const AssumptionSet getAssumptionSet() const override { return Assumptions; } + + bool hasAssumption(const StringRef Assumption) const override { + return Assumptions.getAssumptions().contains(Assumption) && isValidState(); + } + + /// See AbstractAttribute::getAsStr() + const std::string getAsStr() const override { + if (Assumptions.isUniversal()) + return "[Universal]"; + + std::string AssumptionStr = + llvm::join(Assumptions.getAssumptions().begin(), + Assumptions.getAssumptions().end(), ","); + return "[" + AssumptionStr + "]"; + } + + /// Set of all currently valid assumptions, initially a universal set. + AssumptionSet Assumptions; +}; + +/// Propogates assumption information from a parent function to all of its +/// direct sucessors. This is conceptually equivalent to computing the dominator +/// set of the call graph except we merge attributes instead of nodes. +struct AAAssumptionInfoFunction final : AAAssumptionInfoImpl { + AAAssumptionInfoFunction(const IRPosition &IRP, Attributor &A) + : AAAssumptionInfoImpl(IRP, A), + CurrentAssumptions(getAssumptions(*IRP.getAssociatedFunction())) {} + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + Function *AssociatedFunction = getIRPosition().getAssociatedFunction(); + + addAssumptions(*AssociatedFunction, Assumptions.getAssumptions()); + + return !Assumptions.empty() ? ChangeStatus::CHANGED + : ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + bool Changed = false; + + auto CallSitePred = [&](AbstractCallSite ACS) { + const auto &AssumptionAA = A.getAAFor( + *this, IRPosition::function(*ACS.getInstruction()->getFunction()), + DepClassTy::REQUIRED); + Changed |= Assumptions.getIntersection(AssumptionAA.getAssumptionSet()); + return !Assumptions.empty() || !CurrentAssumptions.empty(); + }; + + bool AllCallSitesKnown; + // Get the intersection of all assumptions held by this node's predecessors. + if (!A.checkForAllCallSites(CallSitePred, *this, true, AllCallSitesKnown)) { + if (AllCallSitesKnown) + return indicatePessimisticFixpoint(); + + // If we don't know all the call sites then this is an entry into the call + // graph. This is the first set that isn't considered universal. + Assumptions = AssumptionSet(CurrentAssumptions); + return indicateOptimisticFixpoint(); + } + + // Add this function's attributes to the set. + Changed |= Assumptions.getUnion(AssumptionSet(CurrentAssumptions)); + + return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; + } + + void trackStatistics() const override {} + + /// The assumptions already associated with this function. + const DenseSet CurrentAssumptions; +}; + +/// Assumption Info defined for call sites. +struct AAAssumptionInfoCallSite final : AAAssumptionInfoImpl { + AAAssumptionInfoCallSite(const IRPosition &IRP, Attributor &A) + : AAAssumptionInfoImpl(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + CallBase *AssociatedCall = + dyn_cast(&getIRPosition().getAssociatedValue()); + Assumptions = AssumptionSet(getAssumptions(*AssociatedCall)); + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override { + Function *F = getAssociatedFunction(); + const IRPosition &FnPos = IRPosition::function(*F); + auto *AssumptionAA = + A.lookupAAFor(FnPos, this, DepClassTy::REQUIRED); + + assert(AssumptionAA && "Attempting to manifest using an uninitialized AA!"); + + // Add in the assumptions derived for the parent function. + Assumptions.getUnion(AssumptionAA->getAssumptionSet()); + + CallBase *AssociatedCall = + dyn_cast(&getIRPosition().getAssociatedValue()); + + addAssumptions(*AssociatedCall, Assumptions.getAssumptions()); + + return !Assumptions.empty() ? ChangeStatus::CHANGED + : ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + const Function *F = getAssociatedFunction(); + const IRPosition &FnPos = IRPosition::function(*F); + auto &AssumptionAA = + A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); + return clampStateAndIndicateChange(getState(), AssumptionAA.getState()); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override {} +}; + } // namespace AACallGraphNode *AACallEdgeIterator::operator*() const { @@ -9706,6 +9836,7 @@ const char AACallEdges::ID = 0; const char AAFunctionReachability::ID = 0; const char AAPointerInfo::ID = 0; +const char AAAssumptionInfo::ID = 0; // Macro magic to create the static generator function for attributes that // follow the naming scheme. @@ -9808,6 +9939,7 @@ CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAReturnedValues) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAMemoryLocation) CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AACallEdges) +CREATE_FUNCTION_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAssumptionInfo) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANonNull) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoAlias) diff --git a/llvm/test/Transforms/Attributor/assumes_info.ll b/llvm/test/Transforms/Attributor/assumes_info.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Attributor/assumes_info.ll @@ -0,0 +1,81 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --check-attributes --check-globals +; RUN: opt -attributor -enable-new-pm=0 -attributor-manifest-internal -attributor-max-iterations-verify -attributor-annotate-decl-cs -attributor-max-iterations=2 -S < %s | FileCheck %s +; RUN: opt -aa-pipeline=basic-aa -passes=attributor -attributor-manifest-internal -attributor-max-iterations-verify -attributor-annotate-decl-cs -attributor-max-iterations=2 -S < %s | FileCheck %s +; RUN: opt -attributor-cgscc -enable-new-pm=0 -attributor-manifest-internal -attributor-annotate-decl-cs -S < %s | FileCheck %s +; RUN: opt -aa-pipeline=basic-aa -passes=attributor-cgscc -attributor-manifest-internal -attributor-annotate-decl-cs -S < %s | FileCheck %s +declare void @call() + +define dso_local void @entry(i1 %cond) #0 { +; CHECK-LABEL: define {{[^@]+}}@entry +; CHECK-SAME: (i1 [[COND:%.*]]) #[[ATTR0:[0-9]+]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @foo(i1 [[COND]]) +; CHECK-NEXT: call void @bar() +; CHECK-NEXT: call void @qux() +; CHECK-NEXT: ret void +; +entry: + call void @foo(i1 %cond) + call void @bar() + call void @qux() + ret void +} + +define internal void @foo(i1 %cond) #1 { +; CHECK-LABEL: define {{[^@]+}}@foo +; CHECK-SAME: (i1 [[COND:%.*]]) #[[ATTR1:[0-9]+]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @baz(i1 [[COND]]) +; CHECK-NEXT: ret void +; +entry: + call void @baz(i1 %cond) + ret void +} + +define internal void @bar() #2 { +; CHECK-LABEL: define {{[^@]+}}@bar +; CHECK-SAME: () #[[ATTR2:[0-9]+]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @baz(i1 noundef false) +; CHECK-NEXT: ret void +; +entry: + call void @baz(i1 0) + ret void +} + +define internal void @baz(i1 %Cond) { +entry: + %tobool = icmp ne i1 %Cond, 0 + br i1 %tobool, label %if.then, label %if.end + +if.then: + call void @baz(i1 0) + br label %if.end + +if.end: + call void @qux() + ret void +} + +define internal void @qux() { +; CHECK-LABEL: define {{[^@]+}}@qux +; CHECK-SAME: () #[[ATTR0]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @call() +; CHECK-NEXT: ret void +; +entry: + call void @call() + ret void +} + +attributes #0 = { "llvm.assume"="A" } +attributes #1 = { "llvm.assume"="B" } +attributes #2 = { "llvm.assume"="B,C" } +;. +; CHECK: attributes #[[ATTR0]] = { "llvm.assume"="A" } +; CHECK: attributes #[[ATTR1]] = { "llvm.assume"="B,A" } +; CHECK: attributes #[[ATTR2]] = { "llvm.assume"="B,C,A" } +;. diff --git a/llvm/test/Transforms/Attributor/depgraph.ll b/llvm/test/Transforms/Attributor/depgraph.ll --- a/llvm/test/Transforms/Attributor/depgraph.ll +++ b/llvm/test/Transforms/Attributor/depgraph.ll @@ -123,6 +123,8 @@ ; GRAPH-NEXT: updates [AAMemoryLocation] for CtxI ' %6 = call i32* @checkAndAdvance(i32* %5)' at position {cs: [@-1]} with state memory:argument ; GRAPH-NEXT: updates [AAMemoryLocation] for CtxI ' %6 = call i32* @checkAndAdvance(i32* %5)' at position {cs: [@-1]} with state memory:argument ; GRAPH-EMPTY: +; GRAPH-NEXT: [AAAssumptionInfo] for CtxI ' %2 = load i32, i32* %0, align 4' at position {fn:checkAndAdvance [checkAndAdvance@-1]} with state [] +; GRAPH-EMPTY: ; GRAPH-NEXT: [AAHeapToStack] for CtxI ' %2 = load i32, i32* %0, align 4' at position {fn:checkAndAdvance [checkAndAdvance@-1]} with state [H2S] Mallocs Good/Bad: 0/0 ; GRAPH-EMPTY: ; GRAPH-NEXT: [AAValueSimplify] for CtxI ' %2 = load i32, i32* %0, align 4' at position {fn_ret:checkAndAdvance [checkAndAdvance@-1]} with state not-simple