diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -6253,21 +6253,37 @@ } static std::pair -getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy, +getVectorCallCosts(ArrayRef CIs, FixedVectorType *VecTy, TargetTransformInfo *TTI, TargetLibraryInfo *TLI) { + auto *CI = cast(CIs[0]); Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); // Calculate the cost of the scalar and vector calls. SmallVector VecTys; - for (Use &Arg : CI->args()) + SmallVector Args; + for (Use &Arg : CI->args()) { VecTys.push_back( FixedVectorType::get(Arg->getType(), VecTy->getNumElements())); + Args.push_back(Arg); + } + + bool CanUseArgs = true; + for (Value *C : CIs) { + auto *CI = cast(C); + if (!all_of(enumerate(CI->args()), [&Args](const auto &A) { + return !isa(A.value()) || A.value() == Args[A.index()]; + })) { + Args.clear(); + break; + } + } + FastMathFlags FMF; if (auto *FPCI = dyn_cast(CI)) FMF = FPCI->getFastMathFlags(); - SmallVector Arguments(CI->args()); - IntrinsicCostAttributes CostAttrs(ID, VecTy, Arguments, VecTys, FMF, - dyn_cast(CI)); + IntrinsicCostAttributes CostAttrs(ID, VecTy, Args, VecTys, FMF, + !Args.empty() ? dyn_cast(CI) + : nullptr); auto IntrinsicCost = TTI->getIntrinsicInstrCost(CostAttrs, TTI::TCK_RecipThroughput); @@ -7629,8 +7645,7 @@ CI->getFunctionType()->params(), CostKind); }; auto GetVectorCost = [=](InstructionCost CommonCost) { - auto *CI = cast(VL0); - auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI); + auto VecCallCosts = getVectorCallCosts(VL, VecTy, TTI, TLI); return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost; }; return GetCostDiff(GetScalarCost, GetVectorCost); diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/fshl.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/fshl.ll --- a/llvm/test/Transforms/SLPVectorizer/AArch64/fshl.ll +++ b/llvm/test/Transforms/SLPVectorizer/AArch64/fshl.ll @@ -6,21 +6,18 @@ ; CHECK-LABEL: define i64 @fshl ; CHECK-SAME: (i64 [[OR1:%.*]], i64 [[OR2:%.*]], i64 [[OR3:%.*]]) { ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = insertelement <2 x i64> poison, i64 [[OR2]], i32 0 -; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i64> [[TMP0]], i64 [[OR3]], i32 1 -; CHECK-NEXT: [[TMP2:%.*]] = call <2 x i64> @llvm.fshl.v2i64(<2 x i64> [[TMP1]], <2 x i64> zeroinitializer, <2 x i64> ) -; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x i64> , i64 [[OR1]], i32 0 -; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> , <2 x i32> -; CHECK-NEXT: [[TMP5:%.*]] = call <2 x i64> @llvm.fshl.v2i64(<2 x i64> [[TMP3]], <2 x i64> [[TMP4]], <2 x i64> ) -; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x i64> [[TMP3]], <2 x i64> , <2 x i32> -; CHECK-NEXT: [[TMP7:%.*]] = xor <2 x i64> [[TMP2]], [[TMP6]] -; CHECK-NEXT: [[TMP8:%.*]] = add <2 x i64> [[TMP7]], [[TMP3]] -; CHECK-NEXT: [[TMP9:%.*]] = xor <2 x i64> [[TMP5]], [[TMP8]] -; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x i64> [[TMP9]], i32 0 -; CHECK-NEXT: [[TMP11:%.*]] = extractelement <2 x i64> [[TMP8]], i32 1 -; CHECK-NEXT: [[ADD3:%.*]] = or i64 [[TMP10]], [[TMP11]] -; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x i64> [[TMP9]], i32 1 -; CHECK-NEXT: [[XOR5:%.*]] = xor i64 [[ADD3]], [[TMP12]] +; CHECK-NEXT: [[OR4:%.*]] = tail call i64 @llvm.fshl.i64(i64 [[OR2]], i64 0, i64 1) +; CHECK-NEXT: [[XOR1:%.*]] = xor i64 [[OR4]], 0 +; CHECK-NEXT: [[OR5:%.*]] = tail call i64 @llvm.fshl.i64(i64 [[OR3]], i64 0, i64 2) +; CHECK-NEXT: [[XOR2:%.*]] = xor i64 [[OR5]], [[OR1]] +; CHECK-NEXT: [[ADD1:%.*]] = add i64 [[XOR1]], [[OR1]] +; CHECK-NEXT: [[ADD2:%.*]] = add i64 0, [[XOR2]] +; CHECK-NEXT: [[OR6:%.*]] = tail call i64 @llvm.fshl.i64(i64 [[OR1]], i64 [[OR2]], i64 17) +; CHECK-NEXT: [[XOR3:%.*]] = xor i64 [[OR6]], [[ADD1]] +; CHECK-NEXT: [[OR7:%.*]] = tail call i64 @llvm.fshl.i64(i64 0, i64 0, i64 21) +; CHECK-NEXT: [[XOR4:%.*]] = xor i64 [[OR7]], [[ADD2]] +; CHECK-NEXT: [[ADD3:%.*]] = or i64 [[XOR3]], [[ADD2]] +; CHECK-NEXT: [[XOR5:%.*]] = xor i64 [[ADD3]], [[XOR4]] ; CHECK-NEXT: ret i64 [[XOR5]] ; entry: