Index: llvm/lib/Transforms/IPO/AttributorAttributes.cpp =================================================================== --- llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -413,8 +413,9 @@ /// Clamp the information known for all returned values of a function /// (identified by \p QueryingAA) into \p S. template -static void clampReturnedValueStates(Attributor &A, const AAType &QueryingAA, - StateType &S) { +static void clampReturnedValueStates( + Attributor &A, const AAType &QueryingAA, StateType &S, + const IRPosition::CallBaseContext *CBContext = nullptr) { LLVM_DEBUG(dbgs() << "[Attributor] Clamp return value states for " << QueryingAA << " into " << S << "\n"); @@ -431,7 +432,7 @@ // Callback for each possibly returned value. auto CheckReturnValue = [&](Value &RV) -> bool { - const IRPosition &RVPos = IRPosition::value(RV); + const IRPosition &RVPos = IRPosition::value(RV, CBContext); const AAType &AA = A.getAAFor(QueryingAA, RVPos); LLVM_DEBUG(dbgs() << "[Attributor] RV: " << RV << " AA: " << AA.getAsStr() << " @ " << RVPos << "\n"); @@ -453,7 +454,8 @@ /// Helper class for generic deduction: return value -> returned position. template + typename StateType = typename BaseType::StateType, + bool PropagateCallBaseContext = false> struct AAReturnedFromReturnedValues : public BaseType { AAReturnedFromReturnedValues(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} @@ -461,7 +463,9 @@ /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { StateType S(StateType::getBestState(this->getState())); - clampReturnedValueStates(A, *this, S); + clampReturnedValueStates( + A, *this, S, + PropagateCallBaseContext ? this->getCallBaseContext() : nullptr); // TODO: If we know we visited all returned values, thus no are assumed // dead, we can take the known information from the state T. return clampStateAndIndicateChange(this->getState(), S); @@ -515,17 +519,59 @@ S ^= *T; } -/// Helper class for generic deduction: call site argument -> argument position. +/// This function is the bridge between argument position and the call base +/// context. template +bool getArgumentStateFromCallBaseContext(Attributor &A, + BaseType &QueryingAttribute, + IRPosition &Pos, StateType &State) { + assert((Pos.getPositionKind() == IRPosition::IRP_ARGUMENT) && + "Expected a 'argument' position !"); + const CallBase *CBContext = Pos.getCallBaseContext(); + if (!CBContext) + return false; + int ArgNo = Pos.getArgNo(); + if (ArgNo < 0) + return false; + Value *Val = CBContext->getArgOperand(ArgNo); + assert(Val && "CBContext argument value must not be nullptr!"); + + const StateType &CBArgumentState = + (const StateType &)A + .getAAFor(QueryingAttribute, IRPosition::value(*Val)) + .getState(); + + LLVM_DEBUG(dbgs() << "[Attributor] Briding Call site context to argument" + << "Position:" << Pos << " Argument Value: " << Val + << "CB Arg state:" << CBArgumentState << "\n"); + + // NOTE: If we want to do call site grouping it should happen here. + State ^= CBArgumentState; + return true; +} + +/// Helper class for generic deduction: call site argument -> argument position. +template struct AAArgumentFromCallSiteArguments : public BaseType { AAArgumentFromCallSiteArguments(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - StateType S(StateType::getBestState(this->getState())); + StateType S = StateType::getBestState(this->getState()); + + if (BridgeCallBaseContext) { + bool Succsess = + getArgumentStateFromCallBaseContext( + A, *this, this->getIRPosition(), S); + if (Succsess) + return clampStateAndIndicateChange(this->getState(), S); + } clampCallSiteArgumentStates(A, *this, S); + // TODO: If we know we visited all incoming values, thus no are assumed // dead, we can take the known information from the state T. return clampStateAndIndicateChange(this->getState(), S); @@ -534,7 +580,8 @@ /// Helper class for generic replication: function returned -> cs returned. template + typename StateType = typename BaseType::StateType, + bool IntroduceCallBaseContext = false> struct AACallSiteReturnedFromReturned : public BaseType { AACallSiteReturnedFromReturned(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} @@ -552,7 +599,13 @@ if (!AssociatedFunction) return S.indicatePessimisticFixpoint(); - IRPosition FnPos = IRPosition::returned(*AssociatedFunction); + CallBase *CBContext = (CallBase *)&this->getAnchorValue(); + if (IntroduceCallBaseContext) + LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:" + << *CBContext << "\n"); + + IRPosition FnPos = IRPosition::returned( + *AssociatedFunction, IntroduceCallBaseContext ? CBContext : nullptr); const AAType &AA = A.getAAFor(*this, FnPos); return clampStateAndIndicateChange( S, static_cast(AA.getState())); @@ -6771,9 +6824,11 @@ struct AAValueConstantRangeArgument final : AAArgumentFromCallSiteArguments< - AAValueConstantRange, AAValueConstantRangeImpl, IntegerRangeState> { + AAValueConstantRange, AAValueConstantRangeImpl, IntegerRangeState, + true /* BridgeCallBaseContext */> { using Base = AAArgumentFromCallSiteArguments< - AAValueConstantRange, AAValueConstantRangeImpl, IntegerRangeState>; + AAValueConstantRange, AAValueConstantRangeImpl, IntegerRangeState, + true /* BridgeCallBaseContext */>; AAValueConstantRangeArgument(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} @@ -6794,9 +6849,14 @@ struct AAValueConstantRangeReturned : AAReturnedFromReturnedValues { - using Base = AAReturnedFromReturnedValues; + AAValueConstantRangeImpl, + AAValueConstantRangeImpl::StateType, + /* PropogateCallBaseContext */ true> { + using Base = + AAReturnedFromReturnedValues; AAValueConstantRangeReturned(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} @@ -6862,13 +6922,13 @@ if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) return false; - auto &LHSAA = - A.getAAFor(*this, IRPosition::value(*LHS)); + auto &LHSAA = A.getAAFor( + *this, IRPosition::value(*LHS, getCallBaseContext())); QuerriedAAs.push_back(&LHSAA); auto LHSAARange = LHSAA.getAssumedConstantRange(A, CtxI); - auto &RHSAA = - A.getAAFor(*this, IRPosition::value(*RHS)); + auto &RHSAA = A.getAAFor( + *this, IRPosition::value(*RHS, getCallBaseContext())); QuerriedAAs.push_back(&RHSAA); auto RHSAARange = RHSAA.getAssumedConstantRange(A, CtxI); @@ -6891,8 +6951,8 @@ if (!OpV.getType()->isIntegerTy()) return false; - auto &OpAA = - A.getAAFor(*this, IRPosition::value(OpV)); + auto &OpAA = A.getAAFor( + *this, IRPosition::value(OpV, getCallBaseContext())); QuerriedAAs.push_back(&OpAA); T.unionAssumed( OpAA.getAssumed().castOp(CastI->getOpcode(), getState().getBitWidth())); @@ -6909,11 +6969,11 @@ if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) return false; - auto &LHSAA = - A.getAAFor(*this, IRPosition::value(*LHS)); + auto &LHSAA = A.getAAFor( + *this, IRPosition::value(*LHS, getCallBaseContext())); QuerriedAAs.push_back(&LHSAA); - auto &RHSAA = - A.getAAFor(*this, IRPosition::value(*RHS)); + auto &RHSAA = A.getAAFor( + *this, IRPosition::value(*RHS, getCallBaseContext())); QuerriedAAs.push_back(&RHSAA); auto LHSAARange = LHSAA.getAssumedConstantRange(A, CtxI); @@ -7044,10 +7104,16 @@ struct AAValueConstantRangeCallSiteReturned : AACallSiteReturnedFromReturned { + AAValueConstantRangeImpl, + AAValueConstantRangeImpl::StateType, + /* IntroduceCallBaseContext */ true> { AAValueConstantRangeCallSiteReturned(const IRPosition &IRP, Attributor &A) : AACallSiteReturnedFromReturned(IRP, A) {} + AAValueConstantRangeImpl, + AAValueConstantRangeImpl::StateType, + /* IntroduceCallBaseContext */ true>(IRP, + A) { + } /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { Index: llvm/test/Transforms/Attributor/cb_range.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/Attributor/cb_range.ll @@ -0,0 +1,92 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --scrub-attributes +; RUN: opt -S -aa-pipeline=basic-aa -passes=attributor --attributor-enable-call-site-specific=true | FileCheck %s -check-prefix=CHECK + +define i32 @test_range(i32 %unknown) { +; CHECK-LABEL: define {{[^@]+}}@test_range +; CHECK-SAME: (i32 [[UNKNOWN:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = icmp sgt i32 [[UNKNOWN]], 100 +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i32 100, i32 0 +; CHECK-NEXT: ret i32 [[TMP2]] +; + %1 = icmp sgt i32 %unknown, 100 + %2 = select i1 %1, i32 100, i32 0 + ret i32 %2 +} + +define i32 @test1(i32 %unknown, i32 %b) { +; CHECK-LABEL: define {{[^@]+}}@test1 +; CHECK-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) #0, !range !0 +; CHECK-NEXT: [[TMP2:%.*]] = sub nsw i32 [[TMP1]], [[B]] +; CHECK-NEXT: ret i32 [[TMP2]] +; + %1 = call i32 @test_range(i32 %unknown) + %2 = sub nsw i32 %1, %b + ret i32 %2 +} + +define i32 @test2(i32 %unknown, i32 %b) { +; CHECK-LABEL: define {{[^@]+}}@test2 +; CHECK-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) #0, !range !0 +; CHECK-NEXT: [[TMP2:%.*]] = add nsw i32 [[TMP1]], [[B]] +; CHECK-NEXT: ret i32 [[TMP2]] +; + %1 = call i32 @test_range(i32 %unknown) + %2 = add nsw i32 %1, %b + ret i32 %2 +} + +; Positive checks + +define i32 @test1_pcheck(i32 %unknown) { +; CHECK-LABEL: define {{[^@]+}}@test1_pcheck +; CHECK-SAME: (i32 [[UNKNOWN:%.*]]) +; CHECK-NEXT: ret i32 1 +; + %1 = call i32 @test1(i32 %unknown, i32 20) + %2 = icmp sle i32 %1, 90 + %3 = zext i1 %2 to i32 + ret i32 %3 +} + +define i32 @test2_pcheck(i32 %unknown) { +; CHECK-LABEL: define {{[^@]+}}@test2_pcheck +; CHECK-SAME: (i32 [[UNKNOWN:%.*]]) +; CHECK-NEXT: ret i32 1 +; + %1 = call i32 @test2(i32 %unknown, i32 20) + %2 = icmp sge i32 %1, 20 + %3 = zext i1 %2 to i32 + ret i32 %3 +} + +; Negative checks + +define i32 @test1_ncheck(i32 %unknown) { +; CHECK-LABEL: define {{[^@]+}}@test1_ncheck +; CHECK-SAME: (i32 [[UNKNOWN:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @test1(i32 [[UNKNOWN]], i32 20) #0, !range !1 +; CHECK-NEXT: [[TMP2:%.*]] = icmp sle i32 [[TMP1]], 10 +; CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32 +; CHECK-NEXT: ret i32 [[TMP3]] +; + %1 = call i32 @test1(i32 %unknown, i32 20) + %2 = icmp sle i32 %1, 10 + %3 = zext i1 %2 to i32 + ret i32 %3 +} + +define i32 @test2_ncheck(i32 %unknown) { +; CHECK-LABEL: define {{[^@]+}}@test2_ncheck +; CHECK-SAME: (i32 [[UNKNOWN:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @test2(i32 [[UNKNOWN]], i32 20) #0, !range !2 +; CHECK-NEXT: [[TMP2:%.*]] = icmp sge i32 [[TMP1]], 30 +; CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32 +; CHECK-NEXT: ret i32 [[TMP3]] +; + %1 = call i32 @test2(i32 %unknown, i32 20) + %2 = icmp sge i32 %1, 30 + %3 = zext i1 %2 to i32 + ret i32 %3 +}