diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -484,10 +484,42 @@ return 0; } +bool isPairwiseAdd(const Instruction *I) { + if (I->getOpcode() != Instruction::FAdd) + return false; + + assert(I->getNumOperands() == 2); + + unsigned SumIndices = 0; + + for (int i = 0; i < 2; i++) { + const auto *Ext = dyn_cast(I->getOperand(i)); + + if (!Ext || !isa(Ext->getOperand(1))) + return false; + + unsigned Index = cast(Ext->getOperand(1))->getZExtValue(); + + if (Index != 0 && Index != 1) + return false; + + SumIndices += Index; + } + + return SumIndices == 1; +} + int AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index, const Instruction *I) { assert(Val->isVectorTy() && "This must be a vector type"); + // The Extract is free if this is part of a pairwise add. + if (I && I->hasOneUse()) { + auto *SingleUser = cast(*I->user_begin()); + if (isPairwiseAdd(SingleUser)) + return 0; + } + if (Index != -1U) { // Legalize the type. std::pair LT = TLI->getTypeLegalizationCost(DL, Val); diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -232,9 +232,16 @@ unsigned Ext0Index = cast(Ext0->getOperand(1))->getZExtValue(); unsigned Ext1Index = cast(Ext1->getOperand(1))->getZExtValue(); - int Extract0Cost = + // Use instruction context to calculate costs for the current pattern + int OldExtract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, + VecTy, Ext0Index, Ext0); + int OldExtract1Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, + VecTy, Ext1Index, Ext1); + + // Get context-less costs for the ExtractElements in the replacement pattern + int NewExtract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext0Index); - int Extract1Cost = + int NewExtract1Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext1Index); // A more expensive extract will always be replaced by a splat shuffle. @@ -244,7 +251,8 @@ // TODO: Evaluate whether that always results in lowest cost. Alternatively, // check the cost of creating a broadcast shuffle and shuffling both // operands to element 0. - int CheapExtractCost = std::min(Extract0Cost, Extract1Cost); + int MinNewExtractCost = std::min(NewExtract0Cost, NewExtract1Cost); + int MinOldExtractCost = std::min(OldExtract0Cost, OldExtract1Cost); // Extra uses of the extracts mean that we include those costs in the // vector total because those instructions will not be eliminated. @@ -256,15 +264,15 @@ // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) : !Ext0->hasOneUse() || !Ext1->hasOneUse(); - OldCost = CheapExtractCost + ScalarOpCost; - NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost; + OldCost = MinOldExtractCost + ScalarOpCost; + NewCost = VectorOpCost + MinNewExtractCost + HasUseTax * MinNewExtractCost; } else { // Handle the general case. Each extract is actually a different value: // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C - OldCost = Extract0Cost + Extract1Cost + ScalarOpCost; - NewCost = VectorOpCost + CheapExtractCost + - !Ext0->hasOneUse() * Extract0Cost + - !Ext1->hasOneUse() * Extract1Cost; + OldCost = OldExtract0Cost + OldExtract1Cost + ScalarOpCost; + NewCost = VectorOpCost + MinNewExtractCost + + !Ext0->hasOneUse() * NewExtract0Cost + + !Ext1->hasOneUse() * NewExtract1Cost; } ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex); diff --git a/llvm/test/CodeGen/AArch64/combine-vectors-faddp.ll b/llvm/test/CodeGen/AArch64/combine-vectors-faddp.ll --- a/llvm/test/CodeGen/AArch64/combine-vectors-faddp.ll +++ b/llvm/test/CodeGen/AArch64/combine-vectors-faddp.ll @@ -4,10 +4,10 @@ define float @test_no_combine_for_faddp(<2 x float> %a) { ; CHECK-LABEL: @test_no_combine_for_faddp( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <2 x float> [[A:%.*]], <2 x float> undef, <2 x i32> -; CHECK-NEXT: [[TMP0:%.*]] = fadd <2 x float> [[A]], [[SHIFT]] -; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x float> [[TMP0]], i32 0 -; CHECK-NEXT: ret float [[TMP1]] +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x float> [[A:%.*]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x float> [[A]], i32 1 +; CHECK-NEXT: [[TMP2:%.*]] = fadd float [[TMP0]], [[TMP1]] +; CHECK-NEXT: ret float [[TMP2]] ; entry: %0 = extractelement <2 x float> %a, i32 0 @@ -19,10 +19,10 @@ define float @test_no_combine_for_faddp_swapped(<2 x float> %a) { ; CHECK-LABEL: @test_no_combine_for_faddp_swapped( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <2 x float> [[A:%.*]], <2 x float> undef, <2 x i32> -; CHECK-NEXT: [[TMP0:%.*]] = fadd <2 x float> [[SHIFT]], [[A]] -; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x float> [[TMP0]], i64 0 -; CHECK-NEXT: ret float [[TMP1]] +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x float> [[A:%.*]], i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x float> [[A]], i32 0 +; CHECK-NEXT: [[TMP2:%.*]] = fadd float [[TMP0]], [[TMP1]] +; CHECK-NEXT: ret float [[TMP2]] ; entry: %0 = extractelement <2 x float> %a, i32 1