diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -38,19 +38,23 @@ "disable-vector-combine", cl::init(false), cl::Hidden, cl::desc("Disable all vector combine transforms")); -/// Compare the relative costs of extracts followed by scalar operation vs. -/// vector operation followed by extract: -/// opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C -/// Unless the vector op is much more expensive than the scalar op, this -/// eliminates an extract. +static cl::opt DisableBinopExtractShuffle( + "disable-binop-extract-shuffle", cl::init(false), cl::Hidden, + cl::desc("Disable binop extract to shuffle transforms")); + + +/// Compare the relative costs of 2 extracts followed by scalar operation vs. +/// vector operation(s) followed by extract. Return true if the existing +/// instructions are cheaper than a vector alternative. Otherwise, return false +/// and if one of the extracts should be transformed to a shufflevector, set +/// \p ConvertToShuffle to that extract instruction. static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1, unsigned Opcode, - const TargetTransformInfo &TTI) { + const TargetTransformInfo &TTI, + Instruction *&ConvertToShuffle) { assert(isa(Ext0->getOperand(1)) && - (cast(Ext0->getOperand(1))->getZExtValue() == - cast(Ext1->getOperand(1))->getZExtValue()) && - "Expected same constant extract index"); - + isa(Ext1->getOperand(1)) && + "Expected constant extract indexes"); Type *ScalarTy = Ext0->getType(); Type *VecTy = Ext0->getOperand(0)->getType(); int ScalarOpCost, VectorOpCost; @@ -69,31 +73,73 @@ CmpInst::makeCmpResultType(VecTy)); } - // Get cost estimate for the extract element. This cost will factor into + // Get cost estimates for the extract elements. These costs will factor into // both sequences. - unsigned ExtIndex = cast(Ext0->getOperand(1))->getZExtValue(); - int ExtractCost = TTI.getVectorInstrCost(Instruction::ExtractElement, - VecTy, ExtIndex); + unsigned Ext0Index = cast(Ext0->getOperand(1))->getZExtValue(); + unsigned Ext1Index = cast(Ext1->getOperand(1))->getZExtValue(); + + int Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, + VecTy, Ext0Index); + int Extract1Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, + VecTy, Ext1Index); + + // A more expensive extract will always be replaced by a splat shuffle. + // For example, if Ext0 is more expensive: + // opcode (extelt V0, Ext0), (ext V1, Ext1) --> + // extelt (opcode (splat V0, Ext0), V1), Ext1 + // TODO: Evaluate whether that always results in lowest cost. Alternatively, + // check the cost of creating a broadcast shuffle and shuffling both + // operands to element 0. + int CheapExtractCost = std::min(Extract0Cost, Extract1Cost); // Extra uses of the extracts mean that we include those costs in the // vector total because those instructions will not be eliminated. int OldCost, NewCost; - if (Ext0->getOperand(0) == Ext1->getOperand(0)) { - // Handle a special case. If the 2 operands are identical, adjust the + if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) { + // Handle a special case. If the 2 extracts are identical, adjust the // formulas to account for that. The extra use charge allows for either the // CSE'd pattern or an unoptimized form with identical values: // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2) : !Ext0->hasOneUse() || !Ext1->hasOneUse(); - OldCost = ExtractCost + ScalarOpCost; - NewCost = VectorOpCost + ExtractCost + HasUseTax * ExtractCost; + OldCost = CheapExtractCost + ScalarOpCost; + NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost; } else { // Handle the general case. Each extract is actually a different value: - // opcode (extelt V0, C), (extelt V1, C) --> extelt (opcode V0, V1), C - OldCost = 2 * ExtractCost + ScalarOpCost; - NewCost = VectorOpCost + ExtractCost + !Ext0->hasOneUse() * ExtractCost + - !Ext1->hasOneUse() * ExtractCost; + // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C + OldCost = Extract0Cost + Extract1Cost + ScalarOpCost; + NewCost = VectorOpCost + CheapExtractCost + + !Ext0->hasOneUse() * Extract0Cost + + !Ext1->hasOneUse() * Extract1Cost; } + + if (Ext0Index == Ext1Index) { + // If the extract indexes are identical, no shuffle is needed. + ConvertToShuffle = nullptr; + } else { + if (IsBinOp && DisableBinopExtractShuffle) + return true; + + // If we are extracting from 2 different indexes, then one operand must be + // shuffled before performing the vector operation. The shuffle mask is + // undefined except for 1 lane that is being translated to the remaining + // extraction lane. Therefore, it is a splat shuffle. Ex: + // ShufMask = { undef, undef, 0, undef } + // TODO: The cost model has an option for a "broadcast" shuffle + // (splat-from-element-0), but no option for a more general splat. + NewCost += + TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy); + + // The more expensive extract will be replaced by a shuffle. If the extracts + // have the same cost, replace the extract with the higher index. + if (Extract0Cost > Extract1Cost) + ConvertToShuffle = Ext0; + else if (Extract1Cost > Extract0Cost) + ConvertToShuffle = Ext1; + else + ConvertToShuffle = Ext0Index > Ext1Index ? Ext0 : Ext1; + } + // Aggressively form a vector op if the cost is equal because the transform // may enable further optimization. // Codegen can reverse this transform (scalarize) if it was not profitable. @@ -162,12 +208,33 @@ V0->getType() != V1->getType()) return false; - // TODO: Handle C0 != C1 by shuffling 1 of the operands. - if (C0 != C1) + Instruction *ConvertToShuffle; + if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI, ConvertToShuffle)) return false; - if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI)) - return false; + if (ConvertToShuffle) { + // The shuffle mask is undefined except for 1 lane that is being translated + // to the cheap extraction lane. Example: + // ShufMask = { 2, undef, undef, undef } + uint64_t SplatIndex = ConvertToShuffle == Ext0 ? C0 : C1; + uint64_t CheapExtIndex = ConvertToShuffle == Ext0 ? C1 : C0; + Type *VecTy = V0->getType(); + Type *I32Ty = IntegerType::getInt32Ty(I.getContext()); + UndefValue *Undef = UndefValue::get(I32Ty); + SmallVector ShufMask(VecTy->getVectorNumElements(), Undef); + ShufMask[CheapExtIndex] = ConstantInt::get(I32Ty, SplatIndex); + IRBuilder<> Builder(ConvertToShuffle); + + // extelt X, C --> extelt (splat X), C' + Value *Shuf = Builder.CreateShuffleVector(ConvertToShuffle->getOperand(0), + UndefValue::get(VecTy), + ConstantVector::get(ShufMask)); + Value *NewExt = Builder.CreateExtractElement(Shuf, CheapExtIndex); + if (ConvertToShuffle == Ext0) + Ext0 = cast(NewExt); + else + Ext1 = cast(NewExt); + } if (Pred != CmpInst::BAD_ICMP_PREDICATE) foldExtExtCmp(Ext0, Ext1, I, TTI); diff --git a/llvm/test/Transforms/VectorCombine/X86/extract-binop.ll b/llvm/test/Transforms/VectorCombine/X86/extract-binop.ll --- a/llvm/test/Transforms/VectorCombine/X86/extract-binop.ll +++ b/llvm/test/Transforms/VectorCombine/X86/extract-binop.ll @@ -251,14 +251,18 @@ ret i8 %r } -; TODO: Different extract indexes requires a shuffle. - define i8 @ext0_ext1_add(<16 x i8> %x, <16 x i8> %y) { -; CHECK-LABEL: @ext0_ext1_add( -; CHECK-NEXT: [[E0:%.*]] = extractelement <16 x i8> [[X:%.*]], i32 0 -; CHECK-NEXT: [[E1:%.*]] = extractelement <16 x i8> [[Y:%.*]], i32 1 -; CHECK-NEXT: [[R:%.*]] = add nuw i8 [[E0]], [[E1]] -; CHECK-NEXT: ret i8 [[R]] +; SSE-LABEL: @ext0_ext1_add( +; SSE-NEXT: [[E0:%.*]] = extractelement <16 x i8> [[X:%.*]], i32 0 +; SSE-NEXT: [[E1:%.*]] = extractelement <16 x i8> [[Y:%.*]], i32 1 +; SSE-NEXT: [[R:%.*]] = add nuw i8 [[E0]], [[E1]] +; SSE-NEXT: ret i8 [[R]] +; +; AVX-LABEL: @ext0_ext1_add( +; AVX-NEXT: [[TMP1:%.*]] = shufflevector <16 x i8> [[Y:%.*]], <16 x i8> undef, <16 x i32> +; AVX-NEXT: [[TMP2:%.*]] = add nuw <16 x i8> [[X:%.*]], [[TMP1]] +; AVX-NEXT: [[TMP3:%.*]] = extractelement <16 x i8> [[TMP2]], i32 0 +; AVX-NEXT: ret i8 [[TMP3]] ; %e0 = extractelement <16 x i8> %x, i32 0 %e1 = extractelement <16 x i8> %y, i32 1 @@ -267,11 +271,17 @@ } define i8 @ext5_ext0_add(<16 x i8> %x, <16 x i8> %y) { -; CHECK-LABEL: @ext5_ext0_add( -; CHECK-NEXT: [[E0:%.*]] = extractelement <16 x i8> [[X:%.*]], i32 5 -; CHECK-NEXT: [[E1:%.*]] = extractelement <16 x i8> [[Y:%.*]], i32 0 -; CHECK-NEXT: [[R:%.*]] = sub nsw i8 [[E0]], [[E1]] -; CHECK-NEXT: ret i8 [[R]] +; SSE-LABEL: @ext5_ext0_add( +; SSE-NEXT: [[E0:%.*]] = extractelement <16 x i8> [[X:%.*]], i32 5 +; SSE-NEXT: [[E1:%.*]] = extractelement <16 x i8> [[Y:%.*]], i32 0 +; SSE-NEXT: [[R:%.*]] = sub nsw i8 [[E0]], [[E1]] +; SSE-NEXT: ret i8 [[R]] +; +; AVX-LABEL: @ext5_ext0_add( +; AVX-NEXT: [[TMP1:%.*]] = shufflevector <16 x i8> [[X:%.*]], <16 x i8> undef, <16 x i32> +; AVX-NEXT: [[TMP2:%.*]] = sub nsw <16 x i8> [[TMP1]], [[Y:%.*]] +; AVX-NEXT: [[TMP3:%.*]] = extractelement <16 x i8> [[TMP2]], i64 0 +; AVX-NEXT: ret i8 [[TMP3]] ; %e0 = extractelement <16 x i8> %x, i32 5 %e1 = extractelement <16 x i8> %y, i32 0 @@ -280,11 +290,17 @@ } define i8 @ext1_ext6_add(<16 x i8> %x, <16 x i8> %y) { -; CHECK-LABEL: @ext1_ext6_add( -; CHECK-NEXT: [[E0:%.*]] = extractelement <16 x i8> [[X:%.*]], i32 1 -; CHECK-NEXT: [[E1:%.*]] = extractelement <16 x i8> [[Y:%.*]], i32 6 -; CHECK-NEXT: [[R:%.*]] = and i8 [[E0]], [[E1]] -; CHECK-NEXT: ret i8 [[R]] +; SSE-LABEL: @ext1_ext6_add( +; SSE-NEXT: [[E0:%.*]] = extractelement <16 x i8> [[X:%.*]], i32 1 +; SSE-NEXT: [[E1:%.*]] = extractelement <16 x i8> [[Y:%.*]], i32 6 +; SSE-NEXT: [[R:%.*]] = and i8 [[E0]], [[E1]] +; SSE-NEXT: ret i8 [[R]] +; +; AVX-LABEL: @ext1_ext6_add( +; AVX-NEXT: [[TMP1:%.*]] = shufflevector <16 x i8> [[Y:%.*]], <16 x i8> undef, <16 x i32> +; AVX-NEXT: [[TMP2:%.*]] = and <16 x i8> [[X:%.*]], [[TMP1]] +; AVX-NEXT: [[TMP3:%.*]] = extractelement <16 x i8> [[TMP2]], i32 1 +; AVX-NEXT: ret i8 [[TMP3]] ; %e0 = extractelement <16 x i8> %x, i32 1 %e1 = extractelement <16 x i8> %y, i32 6 @@ -294,10 +310,10 @@ define float @ext1_ext0_fmul(<4 x float> %x) { ; CHECK-LABEL: @ext1_ext0_fmul( -; CHECK-NEXT: [[E0:%.*]] = extractelement <4 x float> [[X:%.*]], i32 1 -; CHECK-NEXT: [[E1:%.*]] = extractelement <4 x float> [[X]], i32 0 -; CHECK-NEXT: [[R:%.*]] = fmul float [[E0]], [[E1]] -; CHECK-NEXT: ret float [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[X:%.*]], <4 x float> undef, <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = fmul <4 x float> [[TMP1]], [[X]] +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[TMP2]], i64 0 +; CHECK-NEXT: ret float [[TMP3]] ; %e0 = extractelement <4 x float> %x, i32 1 %e1 = extractelement <4 x float> %x, i32 0 @@ -309,9 +325,10 @@ ; CHECK-LABEL: @ext0_ext3_fmul_extra_use1( ; CHECK-NEXT: [[E0:%.*]] = extractelement <4 x float> [[X:%.*]], i32 0 ; CHECK-NEXT: call void @use_f32(float [[E0]]) -; CHECK-NEXT: [[E1:%.*]] = extractelement <4 x float> [[X]], i32 3 -; CHECK-NEXT: [[R:%.*]] = fmul nnan float [[E0]], [[E1]] -; CHECK-NEXT: ret float [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[X]], <4 x float> undef, <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = fmul nnan <4 x float> [[X]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[TMP2]], i32 0 +; CHECK-NEXT: ret float [[TMP3]] ; %e0 = extractelement <4 x float> %x, i32 0 call void @use_f32(float %e0) @@ -336,11 +353,17 @@ } define float @ext0_ext4_fmul_v8f32(<8 x float> %x) { -; CHECK-LABEL: @ext0_ext4_fmul_v8f32( -; CHECK-NEXT: [[E0:%.*]] = extractelement <8 x float> [[X:%.*]], i32 0 -; CHECK-NEXT: [[E1:%.*]] = extractelement <8 x float> [[X]], i32 4 -; CHECK-NEXT: [[R:%.*]] = fadd float [[E0]], [[E1]] -; CHECK-NEXT: ret float [[R]] +; SSE-LABEL: @ext0_ext4_fmul_v8f32( +; SSE-NEXT: [[E0:%.*]] = extractelement <8 x float> [[X:%.*]], i32 0 +; SSE-NEXT: [[E1:%.*]] = extractelement <8 x float> [[X]], i32 4 +; SSE-NEXT: [[R:%.*]] = fadd float [[E0]], [[E1]] +; SSE-NEXT: ret float [[R]] +; +; AVX-LABEL: @ext0_ext4_fmul_v8f32( +; AVX-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[X:%.*]], <8 x float> undef, <8 x i32> +; AVX-NEXT: [[TMP2:%.*]] = fadd <8 x float> [[X]], [[TMP1]] +; AVX-NEXT: [[TMP3:%.*]] = extractelement <8 x float> [[TMP2]], i32 0 +; AVX-NEXT: ret float [[TMP3]] ; %e0 = extractelement <8 x float> %x, i32 0 %e1 = extractelement <8 x float> %x, i32 4 @@ -349,11 +372,17 @@ } define float @ext7_ext4_fmul_v8f32(<8 x float> %x) { -; CHECK-LABEL: @ext7_ext4_fmul_v8f32( -; CHECK-NEXT: [[E0:%.*]] = extractelement <8 x float> [[X:%.*]], i32 7 -; CHECK-NEXT: [[E1:%.*]] = extractelement <8 x float> [[X]], i32 4 -; CHECK-NEXT: [[R:%.*]] = fadd float [[E0]], [[E1]] -; CHECK-NEXT: ret float [[R]] +; SSE-LABEL: @ext7_ext4_fmul_v8f32( +; SSE-NEXT: [[E0:%.*]] = extractelement <8 x float> [[X:%.*]], i32 7 +; SSE-NEXT: [[E1:%.*]] = extractelement <8 x float> [[X]], i32 4 +; SSE-NEXT: [[R:%.*]] = fadd float [[E0]], [[E1]] +; SSE-NEXT: ret float [[R]] +; +; AVX-LABEL: @ext7_ext4_fmul_v8f32( +; AVX-NEXT: [[TMP1:%.*]] = shufflevector <8 x float> [[X:%.*]], <8 x float> undef, <8 x i32> +; AVX-NEXT: [[TMP2:%.*]] = fadd <8 x float> [[TMP1]], [[X]] +; AVX-NEXT: [[TMP3:%.*]] = extractelement <8 x float> [[TMP2]], i64 4 +; AVX-NEXT: ret float [[TMP3]] ; %e0 = extractelement <8 x float> %x, i32 7 %e1 = extractelement <8 x float> %x, i32 4 diff --git a/llvm/test/Transforms/VectorCombine/X86/extract-cmp.ll b/llvm/test/Transforms/VectorCombine/X86/extract-cmp.ll --- a/llvm/test/Transforms/VectorCombine/X86/extract-cmp.ll +++ b/llvm/test/Transforms/VectorCombine/X86/extract-cmp.ll @@ -102,11 +102,17 @@ } define i1 @cmp01_v2f64(<2 x double> %x, <2 x double> %y) { -; CHECK-LABEL: @cmp01_v2f64( -; CHECK-NEXT: [[X0:%.*]] = extractelement <2 x double> [[X:%.*]], i32 0 -; CHECK-NEXT: [[Y1:%.*]] = extractelement <2 x double> [[Y:%.*]], i32 1 -; CHECK-NEXT: [[CMP:%.*]] = fcmp oge double [[X0]], [[Y1]] -; CHECK-NEXT: ret i1 [[CMP]] +; SSE-LABEL: @cmp01_v2f64( +; SSE-NEXT: [[X0:%.*]] = extractelement <2 x double> [[X:%.*]], i32 0 +; SSE-NEXT: [[Y1:%.*]] = extractelement <2 x double> [[Y:%.*]], i32 1 +; SSE-NEXT: [[CMP:%.*]] = fcmp oge double [[X0]], [[Y1]] +; SSE-NEXT: ret i1 [[CMP]] +; +; AVX-LABEL: @cmp01_v2f64( +; AVX-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[Y:%.*]], <2 x double> undef, <2 x i32> +; AVX-NEXT: [[TMP2:%.*]] = fcmp oge <2 x double> [[X:%.*]], [[TMP1]] +; AVX-NEXT: [[TMP3:%.*]] = extractelement <2 x i1> [[TMP2]], i32 0 +; AVX-NEXT: ret i1 [[TMP3]] ; %x0 = extractelement <2 x double> %x, i32 0 %y1 = extractelement <2 x double> %y, i32 1 @@ -115,11 +121,17 @@ } define i1 @cmp10_v2f64(<2 x double> %x, <2 x double> %y) { -; CHECK-LABEL: @cmp10_v2f64( -; CHECK-NEXT: [[X1:%.*]] = extractelement <2 x double> [[X:%.*]], i32 1 -; CHECK-NEXT: [[Y0:%.*]] = extractelement <2 x double> [[Y:%.*]], i32 0 -; CHECK-NEXT: [[CMP:%.*]] = fcmp ule double [[X1]], [[Y0]] -; CHECK-NEXT: ret i1 [[CMP]] +; SSE-LABEL: @cmp10_v2f64( +; SSE-NEXT: [[X1:%.*]] = extractelement <2 x double> [[X:%.*]], i32 1 +; SSE-NEXT: [[Y0:%.*]] = extractelement <2 x double> [[Y:%.*]], i32 0 +; SSE-NEXT: [[CMP:%.*]] = fcmp ule double [[X1]], [[Y0]] +; SSE-NEXT: ret i1 [[CMP]] +; +; AVX-LABEL: @cmp10_v2f64( +; AVX-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[X:%.*]], <2 x double> undef, <2 x i32> +; AVX-NEXT: [[TMP2:%.*]] = fcmp ule <2 x double> [[TMP1]], [[Y:%.*]] +; AVX-NEXT: [[TMP3:%.*]] = extractelement <2 x i1> [[TMP2]], i64 0 +; AVX-NEXT: ret i1 [[TMP3]] ; %x1 = extractelement <2 x double> %x, i32 1 %y0 = extractelement <2 x double> %y, i32 0 @@ -129,10 +141,10 @@ define i1 @cmp12_v4i32(<4 x i32> %x, <4 x i32> %y) { ; CHECK-LABEL: @cmp12_v4i32( -; CHECK-NEXT: [[X1:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 1 -; CHECK-NEXT: [[Y2:%.*]] = extractelement <4 x i32> [[Y:%.*]], i32 2 -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X1]], [[Y2]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[Y:%.*]], <4 x i32> undef, <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = icmp sgt <4 x i32> [[X:%.*]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i1> [[TMP2]], i32 1 +; CHECK-NEXT: ret i1 [[TMP3]] ; %x1 = extractelement <4 x i32> %x, i32 1 %y2 = extractelement <4 x i32> %y, i32 2