diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -413,8 +413,9 @@ /// Clamp the information known for all returned values of a function /// (identified by \p QueryingAA) into \p S. template -static void clampReturnedValueStates(Attributor &A, const AAType &QueryingAA, - StateType &S) { +static void clampReturnedValueStates( + Attributor &A, const AAType &QueryingAA, StateType &S, + const IRPosition::CallBaseContext *CBContext = nullptr) { LLVM_DEBUG(dbgs() << "[Attributor] Clamp return value states for " << QueryingAA << " into " << S << "\n"); @@ -431,7 +432,7 @@ // Callback for each possibly returned value. auto CheckReturnValue = [&](Value &RV) -> bool { - const IRPosition &RVPos = IRPosition::value(RV); + const IRPosition &RVPos = IRPosition::value(RV, CBContext); const AAType &AA = A.getAAFor(QueryingAA, RVPos); LLVM_DEBUG(dbgs() << "[Attributor] RV: " << RV << " AA: " << AA.getAsStr() << " @ " << RVPos << "\n"); @@ -453,7 +454,8 @@ /// Helper class for generic deduction: return value -> returned position. template + typename StateType = typename BaseType::StateType, + bool PropagateCallBaseContext = false> struct AAReturnedFromReturnedValues : public BaseType { AAReturnedFromReturnedValues(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} @@ -461,7 +463,9 @@ /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { StateType S(StateType::getBestState(this->getState())); - clampReturnedValueStates(A, *this, S); + clampReturnedValueStates( + A, *this, S, + PropagateCallBaseContext ? this->getCallBaseContext() : nullptr); // TODO: If we know we visited all returned values, thus no are assumed // dead, we can take the known information from the state T. return clampStateAndIndicateChange(this->getState(), S); @@ -515,17 +519,57 @@ S ^= *T; } -/// Helper class for generic deduction: call site argument -> argument position. +/// This function is the bridge between argument position and the call base +/// context. template +bool getArgumentStateFromCallBaseContext(Attributor &A, + BaseType &QueryingAttribute, + IRPosition &Pos, StateType &State) { + assert((Pos.getPositionKind() == IRPosition::IRP_ARGUMENT) && + "Expected an 'argument' position !"); + const CallBase *CBContext = Pos.getCallBaseContext(); + if (!CBContext) + return false; + + int ArgNo = Pos.getArgNo(); + assert(ArgNo > 0 && "Invalid Arg No!"); + + const auto &AA = A.getAAFor( + QueryingAttribute, IRPosition::callsite_argument(*CBContext, ArgNo)); + const StateType &CBArgumentState = + static_cast(AA.getState()); + + LLVM_DEBUG(dbgs() << "[Attributor] Briding Call site context to argument" + << "Position:" << Pos << "CB Arg state:" << CBArgumentState + << "\n"); + + // NOTE: If we want to do call site grouping it should happen here. + State ^= CBArgumentState; + return true; +} + +/// Helper class for generic deduction: call site argument -> argument position. +template struct AAArgumentFromCallSiteArguments : public BaseType { AAArgumentFromCallSiteArguments(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - StateType S(StateType::getBestState(this->getState())); + StateType S = StateType::getBestState(this->getState()); + + if (BridgeCallBaseContext) { + bool Success = + getArgumentStateFromCallBaseContext( + A, *this, this->getIRPosition(), S); + if (Success) + return clampStateAndIndicateChange(this->getState(), S); + } clampCallSiteArgumentStates(A, *this, S); + // TODO: If we know we visited all incoming values, thus no are assumed // dead, we can take the known information from the state T. return clampStateAndIndicateChange(this->getState(), S); @@ -534,7 +578,8 @@ /// Helper class for generic replication: function returned -> cs returned. template + typename StateType = typename BaseType::StateType, + bool IntroduceCallBaseContext = false> struct AACallSiteReturnedFromReturned : public BaseType { AACallSiteReturnedFromReturned(const IRPosition &IRP, Attributor &A) : BaseType(IRP, A) {} @@ -552,7 +597,13 @@ if (!AssociatedFunction) return S.indicatePessimisticFixpoint(); - IRPosition FnPos = IRPosition::returned(*AssociatedFunction); + CallBase &CBContext = static_cast(this->getAnchorValue()); + if (IntroduceCallBaseContext) + LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:" + << CBContext << "\n"); + + IRPosition FnPos = IRPosition::returned( + *AssociatedFunction, IntroduceCallBaseContext ? &CBContext : nullptr); const AAType &AA = A.getAAFor(*this, FnPos); return clampStateAndIndicateChange( S, static_cast(AA.getState())); @@ -6771,9 +6822,11 @@ struct AAValueConstantRangeArgument final : AAArgumentFromCallSiteArguments< - AAValueConstantRange, AAValueConstantRangeImpl, IntegerRangeState> { + AAValueConstantRange, AAValueConstantRangeImpl, IntegerRangeState, + true /* BridgeCallBaseContext */> { using Base = AAArgumentFromCallSiteArguments< - AAValueConstantRange, AAValueConstantRangeImpl, IntegerRangeState>; + AAValueConstantRange, AAValueConstantRangeImpl, IntegerRangeState, + true /* BridgeCallBaseContext */>; AAValueConstantRangeArgument(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} @@ -6794,9 +6847,14 @@ struct AAValueConstantRangeReturned : AAReturnedFromReturnedValues { - using Base = AAReturnedFromReturnedValues; + AAValueConstantRangeImpl, + AAValueConstantRangeImpl::StateType, + /* PropogateCallBaseContext */ true> { + using Base = + AAReturnedFromReturnedValues; AAValueConstantRangeReturned(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} @@ -6862,13 +6920,13 @@ if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) return false; - auto &LHSAA = - A.getAAFor(*this, IRPosition::value(*LHS)); + auto &LHSAA = A.getAAFor( + *this, IRPosition::value(*LHS, getCallBaseContext())); QuerriedAAs.push_back(&LHSAA); auto LHSAARange = LHSAA.getAssumedConstantRange(A, CtxI); - auto &RHSAA = - A.getAAFor(*this, IRPosition::value(*RHS)); + auto &RHSAA = A.getAAFor( + *this, IRPosition::value(*RHS, getCallBaseContext())); QuerriedAAs.push_back(&RHSAA); auto RHSAARange = RHSAA.getAssumedConstantRange(A, CtxI); @@ -6891,8 +6949,8 @@ if (!OpV.getType()->isIntegerTy()) return false; - auto &OpAA = - A.getAAFor(*this, IRPosition::value(OpV)); + auto &OpAA = A.getAAFor( + *this, IRPosition::value(OpV, getCallBaseContext())); QuerriedAAs.push_back(&OpAA); T.unionAssumed( OpAA.getAssumed().castOp(CastI->getOpcode(), getState().getBitWidth())); @@ -6909,11 +6967,11 @@ if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) return false; - auto &LHSAA = - A.getAAFor(*this, IRPosition::value(*LHS)); + auto &LHSAA = A.getAAFor( + *this, IRPosition::value(*LHS, getCallBaseContext())); QuerriedAAs.push_back(&LHSAA); - auto &RHSAA = - A.getAAFor(*this, IRPosition::value(*RHS)); + auto &RHSAA = A.getAAFor( + *this, IRPosition::value(*RHS, getCallBaseContext())); QuerriedAAs.push_back(&RHSAA); auto LHSAARange = LHSAA.getAssumedConstantRange(A, CtxI); @@ -7044,10 +7102,16 @@ struct AAValueConstantRangeCallSiteReturned : AACallSiteReturnedFromReturned { + AAValueConstantRangeImpl, + AAValueConstantRangeImpl::StateType, + /* IntroduceCallBaseContext */ true> { AAValueConstantRangeCallSiteReturned(const IRPosition &IRP, Attributor &A) : AACallSiteReturnedFromReturned(IRP, A) {} + AAValueConstantRangeImpl, + AAValueConstantRangeImpl::StateType, + /* IntroduceCallBaseContext */ true>(IRP, + A) { + } /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { diff --git a/llvm/test/Transforms/Attributor/cb_range.ll b/llvm/test/Transforms/Attributor/cb_range.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Attributor/cb_range.ll @@ -0,0 +1,145 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --scrub-attributes + +; call site specific analysis is disabled + +; RUN: opt -attributor -attributor-manifest-internal -attributor-max-iterations-verify -attributor-annotate-decl-cs -attributor-max-iterations=1 -attributor-enable-call-site-specific-deduction=false -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_CGSCC_NPM,NOT_CGSCC_OPM,NOT_TUNIT_NPM,IS__TUNIT____,IS________OPM,IS__TUNIT_OPM +; RUN: opt -aa-pipeline=basic-aa -passes=attributor -attributor-manifest-internal -attributor-max-iterations-verify -attributor-annotate-decl-cs -attributor-max-iterations=1 -attributor-enable-call-site-specific-deduction=false -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_CGSCC_OPM,NOT_CGSCC_NPM,NOT_TUNIT_OPM,IS__TUNIT____,IS________NPM,IS__TUNIT_NPM +; RUN: opt -attributor-cgscc -attributor-manifest-internal -attributor-annotate-decl-cs -attributor-enable-call-site-specific-deduction=false -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_TUNIT_NPM,NOT_TUNIT_OPM,NOT_CGSCC_NPM,IS__CGSCC____,IS________OPM,IS__CGSCC_OPM +; RUN: opt -aa-pipeline=basic-aa -passes=attributor-cgscc -attributor-manifest-internal -attributor-annotate-decl-cs -attributor-enable-call-site-specific-deduction=false -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_TUNIT_NPM,NOT_TUNIT_OPM,NOT_CGSCC_OPM,IS__CGSCC____,IS________NPM,IS__CGSCC_NPM + +; call site specific-deduction analysis is enabled + +; RUN: opt -attributor -attributor-manifest-internal -attributor-max-iterations-verify -attributor-annotate-decl-cs -attributor-max-iterations=1 -attributor-enable-call-site-specific-deduction=true -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_CGSCC_NPM,NOT_CGSCC_OPM,NOT_TUNIT_NPM,IS__TUNIT____,IS________OPM,IS__TUNIT_OPM,CHECK_ENABLED,NOT_CGSCC_NPM_ENABLED,NOT_CGSCC_OPM_ENABLED,NOT_TUNIT_NPM_ENABLED,IS__TUNIT_____ENABLED,IS________OPM_ENABLED,IS__TUNIT_OPM_ENABLED +; RUN: opt -aa-pipeline=basic-aa -passes=attributor -attributor-manifest-internal -attributor-max-iterations-verify -attributor-annotate-decl-cs -attributor-max-iterations=1 -attributor-enable-call-site-specific-deduction=true -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_CGSCC_OPM,NOT_CGSCC_NPM,NOT_TUNIT_OPM,IS__TUNIT____,IS________NPM,IS__TUNIT_NPM,CHECK_ENABLED,NOT_CGSCC_OPM_ENABLED,NOT_CGSCC_NPM_ENABLED,NOT_TUNIT_OPM_ENABLED,IS__TUNIT_____ENABLED,IS________NPM_ENABLED,IS__TUNIT_NPM_ENABLED +; RUN: opt -attributor-cgscc -attributor-manifest-internal -attributor-annotate-decl-cs -attributor-enable-call-site-specific-deduction=true -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_TUNIT_NPM,NOT_TUNIT_OPM,NOT_CGSCC_NPM,IS__CGSCC____,IS________OPM,IS__CGSCC_OPM,CHECK_ENABLED,NOT_TUNIT_NPM_ENABLED,NOT_TUNIT_OPM_ENABLED,NOT_CGSCC_NPM_ENABLED,IS__CGSCC_____ENABLED,IS________OPM_ENABLED,IS__CGSCC_OPM_ENABLED +; RUN: opt -aa-pipeline=basic-aa -passes=attributor-cgscc -attributor-manifest-internal -attributor-annotate-decl-cs -attributor-enable-call-site-specific-deduction=true -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_TUNIT_NPM,NOT_TUNIT_OPM,NOT_CGSCC_OPM,IS__CGSCC____,IS________NPM,IS__CGSCC_NPM,CHECK_ENABLED,NOT_TUNIT_NPM_ENABLED,NOT_TUNIT_OPM_ENABLED,NOT_CGSCC_OPM_ENABLED,IS__CGSCC_____ENABLED,IS________NPM_ENABLED,IS__CGSCC_NPM_ENABLED + +define i32 @test_range(i32 %unknown) { +; CHECK-LABEL: define {{[^@]+}}@test_range +; CHECK-SAME: (i32 [[UNKNOWN:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = icmp sgt i32 [[UNKNOWN]], 100 +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i32 100, i32 0 +; CHECK-NEXT: ret i32 [[TMP2]] +; + %1 = icmp sgt i32 %unknown, 100 + %2 = select i1 %1, i32 100, i32 0 + ret i32 %2 +} + +define i32 @test1(i32 %unknown, i32 %b) { +; IS__TUNIT____-LABEL: define {{[^@]+}}@test1 +; IS__TUNIT____-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) +; IS__TUNIT____-NEXT: [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) #0, !range !0 +; IS__TUNIT____-NEXT: [[TMP2:%.*]] = sub nsw i32 [[TMP1]], [[B]] +; IS__TUNIT____-NEXT: ret i32 [[TMP2]] +; +; IS__CGSCC____-LABEL: define {{[^@]+}}@test1 +; IS__CGSCC____-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) +; IS__CGSCC____-NEXT: [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) +; IS__CGSCC____-NEXT: [[TMP2:%.*]] = sub nsw i32 [[TMP1]], [[B]] +; IS__CGSCC____-NEXT: ret i32 [[TMP2]] +; + %1 = call i32 @test_range(i32 %unknown) + %2 = sub nsw i32 %1, %b + ret i32 %2 +} + +define i32 @test2(i32 %unknown, i32 %b) { +; IS__TUNIT____-LABEL: define {{[^@]+}}@test2 +; IS__TUNIT____-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) +; IS__TUNIT____-NEXT: [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) #0, !range !0 +; IS__TUNIT____-NEXT: [[TMP2:%.*]] = add nsw i32 [[TMP1]], [[B]] +; IS__TUNIT____-NEXT: ret i32 [[TMP2]] +; +; IS__CGSCC____-LABEL: define {{[^@]+}}@test2 +; IS__CGSCC____-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) +; IS__CGSCC____-NEXT: [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) +; IS__CGSCC____-NEXT: [[TMP2:%.*]] = add nsw i32 [[TMP1]], [[B]] +; IS__CGSCC____-NEXT: ret i32 [[TMP2]] +; + %1 = call i32 @test_range(i32 %unknown) + %2 = add nsw i32 %1, %b + ret i32 %2 +} + +; Positive checks + +define i32 @test1_pcheck(i32 %unknown) { +; IS__CGSCC____-LABEL: define {{[^@]+}}@test1_pcheck +; IS__CGSCC____-SAME: (i32 [[UNKNOWN:%.*]]) +; IS__CGSCC____-NEXT: [[TMP1:%.*]] = call i32 @test1(i32 [[UNKNOWN]], i32 20) +; IS__CGSCC____-NEXT: [[TMP2:%.*]] = icmp sle i32 [[TMP1]], 90 +; IS__CGSCC____-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32 +; IS__CGSCC____-NEXT: ret i32 [[TMP3]] +; +; IS__TUNIT_____ENABLED-LABEL: define {{[^@]+}}@test1_pcheck +; IS__TUNIT_____ENABLED-SAME: (i32 [[UNKNOWN:%.*]]) +; IS__TUNIT_____ENABLED-NEXT: ret i32 1 +; + %1 = call i32 @test1(i32 %unknown, i32 20) + %2 = icmp sle i32 %1, 90 + %3 = zext i1 %2 to i32 + ret i32 %3 +} + +define i32 @test2_pcheck(i32 %unknown) { +; IS__CGSCC____-LABEL: define {{[^@]+}}@test2_pcheck +; IS__CGSCC____-SAME: (i32 [[UNKNOWN:%.*]]) +; IS__CGSCC____-NEXT: [[TMP1:%.*]] = call i32 @test2(i32 [[UNKNOWN]], i32 20) +; IS__CGSCC____-NEXT: [[TMP2:%.*]] = icmp sge i32 [[TMP1]], 20 +; IS__CGSCC____-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32 +; IS__CGSCC____-NEXT: ret i32 [[TMP3]] +; +; IS__TUNIT_____ENABLED-LABEL: define {{[^@]+}}@test2_pcheck +; IS__TUNIT_____ENABLED-SAME: (i32 [[UNKNOWN:%.*]]) +; IS__TUNIT_____ENABLED-NEXT: ret i32 1 +; + %1 = call i32 @test2(i32 %unknown, i32 20) + %2 = icmp sge i32 %1, 20 + %3 = zext i1 %2 to i32 + ret i32 %3 +} + +; Negative checks + +define i32 @test1_ncheck(i32 %unknown) { +; IS__CGSCC____-LABEL: define {{[^@]+}}@test1_ncheck +; IS__CGSCC____-SAME: (i32 [[UNKNOWN:%.*]]) +; IS__CGSCC____-NEXT: [[TMP1:%.*]] = call i32 @test1(i32 [[UNKNOWN]], i32 20) +; IS__CGSCC____-NEXT: [[TMP2:%.*]] = icmp sle i32 [[TMP1]], 10 +; IS__CGSCC____-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32 +; IS__CGSCC____-NEXT: ret i32 [[TMP3]] +; +; IS__TUNIT_____ENABLED-LABEL: define {{[^@]+}}@test1_ncheck +; IS__TUNIT_____ENABLED-SAME: (i32 [[UNKNOWN:%.*]]) +; IS__TUNIT_____ENABLED-NEXT: [[TMP1:%.*]] = call i32 @test1(i32 [[UNKNOWN]], i32 20) #0, !range !1 +; IS__TUNIT_____ENABLED-NEXT: [[TMP2:%.*]] = icmp sle i32 [[TMP1]], 10 +; IS__TUNIT_____ENABLED-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32 +; IS__TUNIT_____ENABLED-NEXT: ret i32 [[TMP3]] +; + %1 = call i32 @test1(i32 %unknown, i32 20) + %2 = icmp sle i32 %1, 10 + %3 = zext i1 %2 to i32 + ret i32 %3 +} + +define i32 @test2_ncheck(i32 %unknown) { +; IS__CGSCC____-LABEL: define {{[^@]+}}@test2_ncheck +; IS__CGSCC____-SAME: (i32 [[UNKNOWN:%.*]]) +; IS__CGSCC____-NEXT: [[TMP1:%.*]] = call i32 @test2(i32 [[UNKNOWN]], i32 20) +; IS__CGSCC____-NEXT: [[TMP2:%.*]] = icmp sge i32 [[TMP1]], 30 +; IS__CGSCC____-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32 +; IS__CGSCC____-NEXT: ret i32 [[TMP3]] +; +; IS__TUNIT_____ENABLED-LABEL: define {{[^@]+}}@test2_ncheck +; IS__TUNIT_____ENABLED-SAME: (i32 [[UNKNOWN:%.*]]) +; IS__TUNIT_____ENABLED-NEXT: [[TMP1:%.*]] = call i32 @test2(i32 [[UNKNOWN]], i32 20) #0, !range !2 +; IS__TUNIT_____ENABLED-NEXT: [[TMP2:%.*]] = icmp sge i32 [[TMP1]], 30 +; IS__TUNIT_____ENABLED-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32 +; IS__TUNIT_____ENABLED-NEXT: ret i32 [[TMP3]] +; + %1 = call i32 @test2(i32 %unknown, i32 20) + %2 = icmp sge i32 %1, 30 + %3 = zext i1 %2 to i32 + ret i32 %3 +}