Index: llvm/include/llvm/Transforms/IPO/Attributor.h =================================================================== --- llvm/include/llvm/Transforms/IPO/Attributor.h +++ llvm/include/llvm/Transforms/IPO/Attributor.h @@ -105,6 +105,7 @@ #include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/PassManager.h" namespace llvm { @@ -2147,6 +2148,106 @@ static const char ID; }; +struct IntegerRangeState : public AbstractState { + + uint32_t BitWidth; + ConstantRange Assumed; + ConstantRange Known; + IntegerRangeState(uint32_t BitWidth) + + : BitWidth(BitWidth), Assumed(ConstantRange::getEmpty(BitWidth)), + Known(ConstantRange::getFull(BitWidth)) {} + + bool isValidState() const override { + return BitWidth > 0 && !Assumed.isFullSet(); + } + bool isAtFixpoint() const override { return Assumed == Known; } + uint32_t getBitWidth() const { return BitWidth; } + ChangeStatus indicateOptimisticFixpoint() override { + Known = Assumed; + return ChangeStatus::CHANGED; + } + + ChangeStatus indicatePessimisticFixpoint() override { + Assumed = Known; + return ChangeStatus::CHANGED; + } + + ConstantRange getAssumed() const { return Assumed; } + + ConstantRange getKnown() const { return Known; } + + void unionAssumed(const ConstantRange &R) { + // Don't loose a known range. + Assumed = Assumed.unionWith(R).intersectWith(Known); + } + void unionAssumed(const IntegerRangeState &R) { + unionAssumed(R.getAssumed()); + } + void unionKnown(const ConstantRange &R) { + // Don't loose a known range. + Known = Known.unionWith(R); + Assumed = Assumed.unionWith(Known); + } + void unionKnown(const IntegerRangeState &R) { unionKnown(R.getKnown()); } + + void intersectKnown(const ConstantRange &R) { + Known = Known.intersectWith(R); + } + void intersectKnown(const IntegerRangeState &R) { + intersectKnown(R.getKnown()); + } + + bool operator==(const IntegerRangeState &R) const { + return getAssumed() == R.getAssumed() && getKnown() == R.getKnown(); + } + /// "Clamp" this state with \p R. The result is subtype dependent but it is + /// intended that only information assumed in both states will be assumed in + /// this one afterwards. + void operator^=(const IntegerRangeState &R) { + // NOTE: `^=` operator seems like `intersect` but in this case, we need to + // take `union`. + unionAssumed(R); + } + + void operator&=(const IntegerRangeState &R) { + // NOTE: `&=` operator seems like `intersect` but in this case, we need to + // take `union`. + unionKnown(R); + unionAssumed(R); + } +}; + +raw_ostream &operator<<(raw_ostream &OS, const IntegerRangeState &State); +struct AAValueConstantRange : public IntegerRangeState, + public AbstractAttribute, + public IRPosition { + AAValueConstantRange(const IRPosition &IRP) + : IntegerRangeState( + IRP.getAssociatedValue().getType()->getIntegerBitWidth()), + IRPosition(IRP) {} + + /// Return an IR position, see struct IRPosition. + const IRPosition &getIRPosition() const override { return *this; } + + /// See AbstractAttribute::getState(...). + IntegerRangeState &getState() override { return *this; } + const AbstractState &getState() const override { return *this; } + + /// Create an abstract attribute view for the position \p IRP. + static AAValueConstantRange &createForPosition(const IRPosition &IRP, + Attributor &A); + + /// Return an assumed range for value in \p U. + virtual ConstantRange getAssumedConstantRange(Use *U = nullptr) const = 0; + + /// Return an known range for value in \p U. + virtual ConstantRange getKnownConstantRange(Use *U = nullptr) const = 0; + + /// Unique ID (due to the unique address) + static const char ID; +}; + } // end namespace llvm #endif // LLVM_TRANSFORMS_IPO_FUNCTIONATTRS_H Index: llvm/lib/Transforms/IPO/Attributor.cpp =================================================================== --- llvm/lib/Transforms/IPO/Attributor.cpp +++ llvm/lib/Transforms/IPO/Attributor.cpp @@ -126,6 +126,7 @@ PIPE_OPERATOR(AAHeapToStack) PIPE_OPERATOR(AAReachability) PIPE_OPERATOR(AAMemoryBehavior) +PIPE_OPERATOR(AAValueConstantRange) #undef PIPE_OPERATOR } // namespace llvm @@ -2579,12 +2580,14 @@ CI->takeName(II); replaceAllInstructionUsesWith(*II, *CI); - // If this is a nounwind + mayreturn invoke we only remove the unwind edge. - // This is done by moving the invoke into a new and dead block and connecting - // the normal destination of the invoke with a branch that follows the call - // replacement we created above. + // If this is a nounwind + mayreturn invoke we only remove the + // unwind edge. This is done by moving the invoke into a new and + // dead block and connecting the normal destination of the invoke + // with a branch that follows the call replacement we created + // above. if (MayReturn) { - BasicBlock *NewDeadBB = SplitBlock(BB, II, nullptr, nullptr, nullptr, ".i2c"); + BasicBlock *NewDeadBB = + SplitBlock(BB, II, nullptr, nullptr, nullptr, ".i2c"); assert(isa(BB->getTerminator()) && BB->getTerminator()->getNumSuccessors() == 1 && BB->getTerminator()->getSuccessor(0) == NewDeadBB); @@ -2778,11 +2781,30 @@ A.getAAFor(AA, IRPosition::value(V)); Optional SimplifiedV = ValueSimplifyAA.getAssumedSimplifiedValue(A); UsedAssumedInformation |= !ValueSimplifyAA.isKnown(); + if (!SimplifiedV.hasValue()) return llvm::None; if (isa_and_nonnull(SimplifiedV.getValue())) return llvm::None; - return dyn_cast_or_null(SimplifiedV.getValue()); + if (auto *C = dyn_cast(SimplifiedV.getValue())) + return C; + + if (V.getType()->isIntegerTy()) { + const auto &ValueConstantRangeAA = + A.getAAFor(AA, IRPosition::value(V)); + + ConstantRange RangeV = ValueConstantRangeAA.getAssumedConstantRange(); + LLVM_DEBUG(dbgs() << "[Attributor][getAssumedConstant] Range " << RangeV + << "\n"); + if (auto *C = RangeV.getSingleElement()) { + LLVM_DEBUG(dbgs() << V << " is replaced by " << *C << "\n"); + return cast(ConstantInt::get(V.getType(), *C)); + } else if (RangeV.isEmptySet()) { + return llvm::None; + } + } + + return nullptr; } static bool @@ -4863,7 +4885,342 @@ if (UserI->mayWriteToMemory()) removeAssumedBits(NO_WRITES); } +/// ------------------ Value Constant Range Attribute +/// ---------------------------- + +struct AAValueConstantRangeImpl : AAValueConstantRange { + using StateType = IntegerRangeState; + AAValueConstantRangeImpl(const IRPosition &IRP) : AAValueConstantRange(IRP) {} + /// See AbstractAttribute::getAsStr(). + ConstantRange getAssumedConstantRange(Use *U = nullptr) const override { + // TODO: `U` is passed to make context-sensitivity query but it is not used + // now. + return getAssumed(); + } + ConstantRange getKnownConstantRange(Use *U = nullptr) const override { + // TODO: `U` is passed to make context-sensitivity query but it is not used + // now. + return getKnown(); + } + const std::string getAsStr() const override { + std::string Str; + llvm::raw_string_ostream OS(Str); + OS << "range(" << getBitWidth() << ")<"; + getKnownConstantRange().print(OS); + OS << " / "; + getAssumedConstantRange().print(OS); + OS << ">"; + return OS.str(); + } + + static MDNode *getMDNodeForConstantRange(Type *Ty, LLVMContext &Ctx, + ConstantRange AssumedConstantRange) { + Metadata *LowAndHigh[] = {ConstantAsMetadata::get(ConstantInt::get( + Ty, AssumedConstantRange.getLower())), + ConstantAsMetadata::get(ConstantInt::get( + Ty, AssumedConstantRange.getUpper()))}; + return MDNode::get(Ctx, LowAndHigh); + } + static bool isBetterRange(ConstantRange Assumed, MDNode *KnownRanges) { + + if (Assumed.isFullSet()) + return false; + + if (!KnownRanges) + return true; + + // If multiple ranges are annotated in IR, we give up to annotate assumed + // range for now. + + // FIMXE: If there exists a known range which containts assumed range, we + // can say assumed range is better. + if (KnownRanges->getNumOperands() > 2) + return false; + + ConstantInt *Lower = + mdconst::extract(KnownRanges->getOperand(0)); + ConstantInt *Upper = + mdconst::extract(KnownRanges->getOperand(1)); + + ConstantRange Known(Lower->getValue(), Upper->getValue()); + return Known.contains(Assumed) && Known != Assumed; + } + + static bool setIfBetterRange(Instruction *I, + ConstantRange AssumedConstantRange) { + auto *OldRangeMD = I->getMetadata(LLVMContext::MD_range); + if (isBetterRange(AssumedConstantRange, OldRangeMD)) { + if (!AssumedConstantRange.isEmptySet()) + I->setMetadata(LLVMContext::MD_range, + getMDNodeForConstantRange(I->getType(), I->getContext(), + AssumedConstantRange)); + else { + // TODO: Replace value to undef if range is empty(=undef). + } + return true; + } + return false; + } + + ChangeStatus manifest(Attributor &A) override { + ConstantRange AssumedConstantRange = getAssumedConstantRange(); + if (AssumedConstantRange.isFullSet()) + return ChangeStatus::UNCHANGED; + + auto &V = getAssociatedValue(); + if (isa(&V) || isa(&V)) + if (setIfBetterRange(cast(&V), AssumedConstantRange)) + return ChangeStatus::CHANGED; + + return ChangeStatus::UNCHANGED; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override {} + + void initialize(Attributor &A) override { + assert(getAssociatedValue().getType()->isIntegerTy() && + "AAValueConstantRange should be associcated with an integer value."); + } +}; + +struct AAValueConstantRangeArgument final : public AAValueConstantRangeImpl { + AAValueConstantRangeArgument(const IRPosition &IRP) + : AAValueConstantRangeImpl(IRP) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Use AAArgumentFromCallSiteArguments + + IntegerRangeState S(getBitWidth()); + clampCallSiteArgumentStates( + A, *this, S); + + // TODO: If we know we visited all returned values, thus no are assumed + // dead, we can take the known information from the state T. + return clampStateAndIndicateChange(this->getState(), S); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_ARG_ATTR(value_range) + } +}; + +struct AAValueConstantRangeReturned : AAValueConstantRangeImpl { + AAValueConstantRangeReturned(const IRPosition &IRP) + : AAValueConstantRangeImpl(IRP) {} + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Use AAReturnedFromReturnedValues + + // TODO: If we know we visited all returned values, thus no are assumed + // dead, we can take the known information from the state T. + + IntegerRangeState S(getBitWidth()); + + clampReturnedValueStates(A, *this, + S); + return clampStateAndIndicateChange(this->getState(), S); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FNRET_ATTR(value_range) + } +}; + +struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { + AAValueConstantRangeFloating(const IRPosition &IRP) + : AAValueConstantRangeImpl(IRP) {} + void initialize(Attributor &A) override { + Value &V = getAssociatedValue(); + if (auto *C = dyn_cast(&V)) { + unionAssumed(ConstantRange(C->getValue())); + indicateOptimisticFixpoint(); + return; + } + + if (isa(&V)) { + indicateOptimisticFixpoint(); + return; + } + if (LoadInst *LI = dyn_cast(&V)) + if (auto *RangeMD = LI->getMetadata(LLVMContext::MD_range)) + intersectKnown(getConstantRangeFromMetadata(*RangeMD)); + + LLVM_DEBUG(dbgs() << "[AAConstantRangeFloating] " << V + << " is initialized to " << getState()); + } + ChangeStatus updateImpl(Attributor &A) override { + auto VisitValueCB = [&](Value &V, IntegerRangeState &T, + bool Stripped) -> bool { + if (Instruction *I = dyn_cast(&V)) { + + if (auto *BinOp = dyn_cast(I)) { + Value *LHS = BinOp->getOperand(0); + Value *RHS = BinOp->getOperand(1); + + if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) + return false; + + auto &LHSAA = + A.getAAFor(*this, IRPosition::value(*LHS)); + auto LHSAARange = LHSAA.getAssumedConstantRange(); + + auto &RHSAA = + A.getAAFor(*this, IRPosition::value(*RHS)); + auto RHSAARange = RHSAA.getAssumedConstantRange(); + + // FIXME: Prohibit loop for now. + if (this == &LHSAA || this == &RHSAA) { + T.indicatePessimisticFixpoint(); + return false; + } + + auto AssumedRange = LHSAARange.binaryOp( + Instruction::BinaryOps(I->getOpcode()), RHSAARange); + + T.unionAssumed(AssumedRange); + + // If both lhs and rhs are fixed, we can say the current one is fixed. + if (LHSAA.isAtFixpoint() && RHSAA.isAtFixpoint() && !Stripped) + T.indicateOptimisticFixpoint(); + + return true; + } else if (auto *CmpI = dyn_cast(I)) { + Value *LHS = CmpI->getOperand(0); + Value *RHS = CmpI->getOperand(1); + + // Give up with other than integers. + if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) + return false; + + auto &LHSAA = + A.getAAFor(*this, IRPosition::value(*LHS)); + auto &RHSAA = + A.getAAFor(*this, IRPosition::value(*RHS)); + + auto LHSAARange = LHSAA.getAssumedConstantRange(); + auto RHSAARange = RHSAA.getAssumedConstantRange(); + + // If one of them is empty set, we can't decide. + if (LHSAARange.isEmptySet() || RHSAARange.isEmptySet()) + return true; + + bool MustTrue = false, MustFalse = false; + + // TODO: Make sure that the logic below is correct. + auto AllowedRegion = ConstantRange::makeAllowedICmpRegion( + CmpI->getPredicate(), RHSAARange); + + auto SatisfyingRegion = ConstantRange::makeSatisfyingICmpRegion( + CmpI->getPredicate(), RHSAARange); + + if (AllowedRegion.intersectWith(LHSAARange).isEmptySet()) + MustFalse = true; + + if (SatisfyingRegion.contains(LHSAARange)) + MustTrue = true; + + assert((!MustTrue || !MustFalse) && + "Either MustTrue or MustFalse should be false!"); + + if (MustTrue) + T.unionAssumed(ConstantRange(APInt(/* numBits */ 1, /* val */ 1))); + else if (MustFalse) + T.unionAssumed(ConstantRange(APInt(/* numBits */ 1, /* val */ 0))); + else + T.unionAssumed( + ConstantRange(/* BitWidth */ 1, /* isFullSet */ true)); + + LLVM_DEBUG(dbgs() << "[AAValueConstantRange] " << *CmpI << " " + << LHSAA << " " << RHSAA << "\n"); + + // If both lhs and rhs are fixed, we can say the current one is fixed. + if (LHSAA.isAtFixpoint() && RHSAA.isAtFixpoint() && !Stripped) + T.indicateOptimisticFixpoint(); + + return true; + } else { + // TODO: Add other instructions + T.indicatePessimisticFixpoint(); + return false; + } + } else { + // If the value is not instruction, we query AA to Attributor. + + const auto &AA = + A.getAAFor(*this, IRPosition::value(V)); + const IntegerRangeState &NS = + static_cast(AA.getState()); + T ^= NS; + return T.isValidState(); + } + }; + + IntegerRangeState T(getBitWidth()); + + if (!genericValueTraversal( + A, getIRPosition(), *this, T, VisitValueCB)) + return indicatePessimisticFixpoint(); + + return clampStateAndIndicateChange(getState(), T); + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FLOATING_ATTR(value_range) + } +}; + +struct AAValueConstantRangeFunction : AAValueConstantRangeImpl { + AAValueConstantRangeFunction(const IRPosition &IRP) + : AAValueConstantRangeImpl(IRP) {} + + /// See AbstractAttribute::initialize(...). + ChangeStatus updateImpl(Attributor &A) override { + llvm_unreachable("AAValueConstantRange(Function|CallSite)::updateImpl will " + "not be called"); + } + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override { + STATS_DECLTRACK_FN_ATTR(value_range) + } +}; + +struct AAValueConstantRangeCallSite : AAValueConstantRangeFunction { + AAValueConstantRangeCallSite(const IRPosition &IRP) + : AAValueConstantRangeFunction(IRP) {} + /// See AbstractAttribute::trackStatistics() + + void trackStatistics() const override { + STATS_DECLTRACK_CS_ATTR(value_range) + } +}; + +struct AAValueConstantRangeCallSiteReturned : AAValueConstantRangeReturned { + AAValueConstantRangeCallSiteReturned(const IRPosition &IRP) + : AAValueConstantRangeReturned(IRP) {} + + void initialize(Attributor &A) override { + if (CallInst *CI = dyn_cast(&getAssociatedValue())) + if (auto *RangeMD = CI->getMetadata(LLVMContext::MD_range)) + intersectKnown(getConstantRangeFromMetadata(*RangeMD)); + } + + void trackStatistics() const override { + STATS_DECLTRACK_CSRET_ATTR(value_range) + } +}; +struct AAValueConstantRangeCallSiteArgument : AAValueConstantRangeFloating { + AAValueConstantRangeCallSiteArgument(const IRPosition &IRP) + : AAValueConstantRangeFloating(IRP) {} + + void trackStatistics() const override { + STATS_DECLTRACK_CSARG_ATTR(value_range) + } +}; /// ---------------------------------------------------------------------------- /// Attributor /// ---------------------------------------------------------------------------- @@ -5717,13 +6074,23 @@ } template -raw_ostream &llvm:: -operator<<(raw_ostream &OS, - const IntegerStateBase &S) { +raw_ostream & +llvm::operator<<(raw_ostream &OS, + const IntegerStateBase &S) { return OS << "(" << S.getKnown() << "-" << S.getAssumed() << ")" << static_cast(S); } +raw_ostream &llvm::operator<<(raw_ostream &OS, const IntegerRangeState &S) { + OS << "range-state(" << S.getBitWidth() << ")<"; + S.getKnown().print(OS); + OS << " / "; + S.getAssumed().print(OS); + OS << "> && "; + + return OS << (!S.isValidState() ? "top" + : (S.isAtFixpoint() ? "fix" : "notfixed")); +} raw_ostream &llvm::operator<<(raw_ostream &OS, const AbstractState &S) { return OS << (!S.isValidState() ? "top" : (S.isAtFixpoint() ? "fix" : "")); } @@ -5837,6 +6204,7 @@ const char AAValueSimplify::ID = 0; const char AAHeapToStack::ID = 0; const char AAMemoryBehavior::ID = 0; +const char AAValueConstantRange::ID = 0; // Macro magic to create the static generator function for attributes that // follow the naming scheme. @@ -5942,6 +6310,7 @@ CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AADereferenceable) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAlign) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoCapture) +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueConstantRange) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueSimplify) CREATE_ALL_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAIsDead) Index: llvm/test/Transforms/Attributor/IPConstantProp/solve-after-each-resolving-undefs-for-function.ll =================================================================== --- llvm/test/Transforms/Attributor/IPConstantProp/solve-after-each-resolving-undefs-for-function.ll +++ llvm/test/Transforms/Attributor/IPConstantProp/solve-after-each-resolving-undefs-for-function.ll @@ -38,7 +38,7 @@ ; CHECK: ret1: ; CHECK-NEXT: ret i32 99 ; CHECK: ret2: -; CHECK-NEXT: ret i32 0 +; CHECK-NEXT: unreachable ; entry: br label %if.then @@ -59,7 +59,7 @@ ; CHECK-LABEL: define {{[^@]+}}@main ; CHECK-SAME: (i1 [[C:%.*]]) ; CHECK-NEXT: [[RES:%.*]] = call i32 @test1(i1 [[C]]) -; CHECK-NEXT: ret i32 [[RES]] +; CHECK-NEXT: ret i32 99 ; %res = call i32 @test1(i1 %c) ret i32 %res Index: llvm/test/Transforms/Attributor/value-simplify.ll =================================================================== --- llvm/test/Transforms/Attributor/value-simplify.ll +++ llvm/test/Transforms/Attributor/value-simplify.ll @@ -170,14 +170,13 @@ define internal i32 @ipccp3i(i32 %a) { ; CHECK-LABEL: define {{[^@]+}}@ipccp3i -; CHECK-SAME: (i32 [[A:%.*]]) #1 -; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[A:%.*]], 7 +; CHECK-SAME: (i32 returned [[A:%.*]]) +; CHECK-NEXT: [[C:%.*]] = icmp eq i32 7, 7 ; CHECK-NEXT: br i1 [[C]], label [[T:%.*]], label [[F:%.*]] ; CHECK: t: -; CHECK-NEXT: ret i32 [[A]] +; CHECK-NEXT: ret i32 7 ; CHECK: f: -; CHECK-NEXT: [[R:%.*]] = call i32 @ipccp3i(i32 5) #1 -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: unreachable ; %c = icmp eq i32 %a, 7 br i1 %c, label %t, label %f @@ -189,8 +188,8 @@ } define i32 @ipccp3() { -; CHECK-LABEL: define {{[^@]+}}@ipccp3() #1 -; CHECK-NEXT: [[R:%.*]] = call i32 @ipccp3i(i32 7) #1 +; CHECK-LABEL: define {{[^@]+}}@ipccp3() +; CHECK-NEXT: [[R:%.*]] = call i32 @ipccp3i(i32 7) ; CHECK-NEXT: ret i32 [[R]] ; %r = call i32 @ipccp3i(i32 7)