diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1171,16 +1171,31 @@ /// \returns the score of placing \p V1 and \p V2 in consecutive lanes. /// Also, checks if \p V1 and \p V2 are compatible with instructions in \p /// MainAltOps. - static int getShallowScore(Value *V1, Value *V2, const DataLayout &DL, - ScalarEvolution &SE, int NumLanes, - ArrayRef MainAltOps, - const TargetTransformInfo *TTI) { + int getShallowScore(Value *V1, Value *V2, Instruction *U1, Instruction *U2, + const DataLayout &DL, ScalarEvolution &SE, int NumLanes, + ArrayRef MainAltOps) { if (V1 == V2) { if (isa(V1)) { + // Retruns true if the users of V1 and V2 won't need to be extracted. + auto AllUsersAreInternal = [NumLanes, U1, U2, this](Value *V1, + Value *V2) { + // Bail out if we have too many uses to save compilation time. + static constexpr unsigned VectorUsesLimit = 3; + unsigned Limit = VectorUsesLimit * NumLanes + 1; + if (V1->hasNUsesOrMore(Limit) || V2->hasNUsesOrMore(Limit)) + return false; + + auto AllUsersVectorized = [U1, U2, this](Value *V) { + return llvm::all_of(V->users(), [U1, U2, this](Value *U) { + return U == U1 || U == U2 || R.getTreeEntry(U) != nullptr; + }); + }; + return AllUsersVectorized(V1) && AllUsersVectorized(V2); + }; // A broadcast of a load can be cheaper on some targets. - // TODO: For now accept a broadcast load with no other internal uses. - if (TTI->isLegalBroadcastLoad(V1->getType(), NumLanes) && - V1->getNumUses() == (unsigned)NumLanes) + if (R.TTI->isLegalBroadcastLoad(V1->getType(), NumLanes) && + (V1->getNumUses() == (unsigned)NumLanes || + AllUsersAreInternal(V1, V2))) return VLOperands::ScoreSplatLoads; } return VLOperands::ScoreSplat; @@ -1358,12 +1373,13 @@ /// Look-ahead SLP: Auto-vectorization in the presence of commutative /// operations, CGO 2018 by Vasileios Porpodas, Rodrigo C. O. Rocha, /// Luís F. W. Góes - int getScoreAtLevelRec(Value *LHS, Value *RHS, int CurrLevel, int MaxLevel, + int getScoreAtLevelRec(Value *LHS, Value *RHS, Instruction *U1, + Instruction *U2, int CurrLevel, int MaxLevel, ArrayRef MainAltOps) { // Get the shallow score of V1 and V2. int ShallowScoreAtThisLevel = - getShallowScore(LHS, RHS, DL, SE, getNumLanes(), MainAltOps, R.TTI); + getShallowScore(LHS, RHS, U1, U2, DL, SE, getNumLanes(), MainAltOps); // If reached MaxLevel, // or if V1 and V2 are not instructions, @@ -1406,7 +1422,7 @@ // Recursively calculate the cost at each level int TmpScore = getScoreAtLevelRec(I1->getOperand(OpIdx1), I2->getOperand(OpIdx2), - CurrLevel + 1, MaxLevel, None); + I1, I2, CurrLevel + 1, MaxLevel, None); // Look for the best score. if (TmpScore > VLOperands::ScoreFail && TmpScore > MaxTmpScore) { MaxTmpScore = TmpScore; @@ -1436,8 +1452,10 @@ int getLookAheadScore(Value *LHS, Value *RHS, ArrayRef MainAltOps, int Lane, unsigned OpIdx, unsigned Idx, bool &IsUsed) { - int Score = - getScoreAtLevelRec(LHS, RHS, 1, LookAheadMaxDepth, MainAltOps); + // Keep track of the instruction stack as we recurse into the operands + // during the look-ahead score exploration. + int Score = getScoreAtLevelRec(LHS, RHS, /*U1=*/nullptr, /*U2=*/nullptr, + 1, LookAheadMaxDepth, MainAltOps); if (Score) { int SplatScore = getSplatScore(Lane, OpIdx, Idx); if (Score <= -SplatScore) { diff --git a/llvm/test/Transforms/SLPVectorizer/X86/lookahead.ll b/llvm/test/Transforms/SLPVectorizer/X86/lookahead.ll --- a/llvm/test/Transforms/SLPVectorizer/X86/lookahead.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/lookahead.ll @@ -706,21 +706,19 @@ ; CHECK-NEXT: [[TMP1:%.*]] = load <2 x double>, <2 x double>* [[TMP0]], align 8 ; CHECK-NEXT: [[GEP_2_0:%.*]] = getelementptr inbounds double, double* [[ARRAY2:%.*]], i64 0 ; CHECK-NEXT: [[GEP_2_1:%.*]] = getelementptr inbounds double, double* [[ARRAY2]], i64 1 -; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[GEP_2_0]] to <2 x double>* -; CHECK-NEXT: [[TMP3:%.*]] = load <2 x double>, <2 x double>* [[TMP2]], align 8 -; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <2 x double> [[TMP3]], <2 x double> poison, <2 x i32> -; CHECK-NEXT: [[TMP4:%.*]] = fmul <2 x double> [[TMP1]], [[SHUFFLE]] -; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x double> [[SHUFFLE]], i32 1 -; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x double> poison, double [[TMP5]], i32 0 -; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x double> [[SHUFFLE]], i32 0 -; CHECK-NEXT: [[TMP8:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP7]], i32 1 -; CHECK-NEXT: [[TMP9:%.*]] = fmul <2 x double> [[TMP1]], [[TMP8]] -; CHECK-NEXT: [[TMP10:%.*]] = fadd <2 x double> [[TMP4]], [[TMP9]] -; CHECK-NEXT: [[TMP11:%.*]] = insertelement <2 x double> [[TMP6]], double [[TMP5]], i32 1 -; CHECK-NEXT: [[TMP12:%.*]] = fsub <2 x double> [[TMP10]], [[TMP11]] -; CHECK-NEXT: [[TMP13:%.*]] = extractelement <2 x double> [[TMP12]], i32 0 -; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[TMP12]], i32 1 -; CHECK-NEXT: [[RES:%.*]] = fadd double [[TMP13]], [[TMP14]] +; CHECK-NEXT: [[LD_2_0:%.*]] = load double, double* [[GEP_2_0]], align 8 +; CHECK-NEXT: [[LD_2_1:%.*]] = load double, double* [[GEP_2_1]], align 8 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <2 x double> poison, double [[LD_2_0]], i32 0 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <2 x double> [[TMP2]], double [[LD_2_0]], i32 1 +; CHECK-NEXT: [[TMP4:%.*]] = fmul <2 x double> [[TMP1]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x double> poison, double [[LD_2_1]], i32 0 +; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x double> [[TMP5]], double [[LD_2_1]], i32 1 +; CHECK-NEXT: [[TMP7:%.*]] = fmul <2 x double> [[TMP1]], [[TMP6]] +; CHECK-NEXT: [[TMP8:%.*]] = fadd <2 x double> [[TMP4]], [[TMP7]] +; CHECK-NEXT: [[TMP9:%.*]] = fsub <2 x double> [[TMP8]], [[TMP3]] +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[TMP9]], i32 0 +; CHECK-NEXT: [[TMP11:%.*]] = extractelement <2 x double> [[TMP9]], i32 1 +; CHECK-NEXT: [[RES:%.*]] = fadd double [[TMP10]], [[TMP11]] ; CHECK-NEXT: ret double [[RES]] ; entry: