Index: llvm/lib/Transforms/Vectorize/VectorCombine.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -63,6 +63,9 @@ const TargetTransformInfo &TTI; const DominatorTree &DT; + ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0, + ExtractElementInst *Ext1, + unsigned PreferredExtractIndex) const; bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1, unsigned Opcode, ExtractElementInst *&ConvertToShuffle, @@ -74,6 +77,7 @@ bool foldExtractExtract(Instruction &I); bool foldBitcastShuf(Instruction &I); bool scalarizeBinopOrCmp(Instruction &I); + bool foldExtractedCmps(Instruction &I); }; static void replaceValue(Value &Old, Value &New) { @@ -81,6 +85,45 @@ New.takeName(&Old); } +/// Determine which, if any, of the inputs should be replaced by a shuffle. +ExtractElementInst *VectorCombine::getShuffleExtract( + ExtractElementInst *Ext0, ExtractElementInst *Ext1, + unsigned PreferredExtractIndex = InvalidIndex) const { + assert(isa(Ext0->getIndexOperand()) && + isa(Ext1->getIndexOperand()) && + "Expected constant extract indexes"); + + unsigned Index0 = cast(Ext0->getIndexOperand())->getZExtValue(); + unsigned Index1 = cast(Ext1->getIndexOperand())->getZExtValue(); + + // If the extract indexes are identical, no shuffle is needed. + if (Index0 == Index1) + return nullptr; + + Type *VecTy = Ext0->getVectorOperand()->getType(); + assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types"); + int Cost0 = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0); + int Cost1 = TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1); + + // We are extracting from 2 different indexes, so one operand must be shuffled + // before performing a vector operation and/or extract. The more expensive + // extract will be replaced by a shuffle. + if (Cost0 > Cost1) + return Ext0; + if (Cost1 > Cost0) + return Ext1; + + // If the costs are equal and there is a preferred extract index, shuffle the + // opposite operand. + if (PreferredExtractIndex == Index0) + return Ext1; + if (PreferredExtractIndex == Index1) + return Ext0; + + // Otherwise, replace the extract with the higher index. + return Index0 > Index1 ? Ext0 : Ext1; +} + /// Compare the relative costs of 2 extracts followed by scalar operation vs. /// vector operation(s) followed by extract. Return true if the existing /// instructions are cheaper than a vector alternative. Otherwise, return false @@ -152,10 +195,8 @@ !Ext1->hasOneUse() * Extract1Cost; } - if (Ext0Index == Ext1Index) { - // If the extract indexes are identical, no shuffle is needed. - ConvertToShuffle = nullptr; - } else { + ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex); + if (ConvertToShuffle) { if (IsBinOp && DisableBinopExtractShuffle) return true; @@ -168,20 +209,6 @@ // (splat-from-element-0), but no option for a more general splat. NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy); - - // The more expensive extract will be replaced by a shuffle. If the costs - // are equal and there is a preferred extract index, shuffle the opposite - // operand. Otherwise, replace the extract with the higher index. - if (Extract0Cost > Extract1Cost) - ConvertToShuffle = Ext0; - else if (Extract1Cost > Extract0Cost) - ConvertToShuffle = Ext1; - else if (PreferredExtractIndex == Ext0Index) - ConvertToShuffle = Ext1; - else if (PreferredExtractIndex == Ext1Index) - ConvertToShuffle = Ext0; - else - ConvertToShuffle = Ext0Index > Ext1Index ? Ext0 : Ext1; } // Aggressively form a vector op if the cost is equal because the transform @@ -491,6 +518,88 @@ return true; } +/// Try to combine a scalar binop + 2 scalar compares of extracted elements of +/// a vector into vector operations followed by extract. +bool VectorCombine::foldExtractedCmps(Instruction &I) { + // We are looking for a scalar binop of booleans. + // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1) + if (!I.isBinaryOp() || !I.getType()->isIntegerTy(1)) + return false; + + // The compare predicates should match, and each compare should have a + // constant operand. + // TODO: Relax the one-use constraints. + Value *B0 = I.getOperand(0), *B1 = I.getOperand(1); + Instruction *I0, *I1; + Constant *C0, *C1; + CmpInst::Predicate P0, P1; + if (!match(B0, m_OneUse(m_Cmp(P0, m_Instruction(I0), m_Constant(C0)))) || + !match(B1, m_OneUse(m_Cmp(P1, m_Instruction(I1), m_Constant(C1)))) || + P0 != P1) + return false; + + // The compare operands must be extracts of the same vector with constant + // extract indexes. + // TODO: Relax the one-use constraints. + Value *X; + uint64_t Index0, Index1; + if (!match(I0, m_OneUse(m_ExtractElt(m_Value(X), m_ConstantInt(Index0)))) || + !match(I1, m_OneUse(m_ExtractElt(m_Specific(X), m_ConstantInt(Index1))))) + return false; + + auto *Ext0 = cast(I0); + auto *Ext1 = cast(I1); + ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1); + if (!ConvertToShuf) + return false; + + // The original scalar pattern is: + // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1) + CmpInst::Predicate Pred = P0; + unsigned CmpOpcode = CmpInst::isFPPredicate(Pred) ? Instruction::FCmp + : Instruction::ICmp; + auto *VecTy = dyn_cast(X->getType()); + if (!VecTy) + return false; + + int OldCost = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0); + OldCost += TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1); + OldCost += TTI.getCmpSelInstrCost(CmpOpcode, I0->getType()) * 2; + OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType()); + + // The proposed vector pattern is: + // vcmp = cmp Pred X, VecC + // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0 + int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0; + int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1; + auto *CmpTy = cast(CmpInst::makeCmpResultType(X->getType())); + int NewCost = TTI.getCmpSelInstrCost(CmpOpcode, X->getType()); + NewCost += + TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy); + NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy); + NewCost += TTI.getVectorInstrCost(Ext0->getOpcode(), CmpTy, CheapIndex); + + // Aggressively form vector ops if the cost is equal because the transform + // may enable further optimization. + // Codegen can reverse this transform (scalarize) if it was not profitable. + if (OldCost < NewCost) + return false; + + // Create a vector constant from the 2 scalar constants. + SmallVector CmpC(VecTy->getNumElements(), + UndefValue::get(VecTy->getElementType())); + CmpC[Index0] = C0; + CmpC[Index1] = C1; + Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC)); + + Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder); + Value *VecLogic = Builder.CreateBinOp(cast(I).getOpcode(), + VCmp, Shuf); + Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex); + replaceValue(I, *NewExt); + return true; +} + /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. bool VectorCombine::run() { @@ -513,6 +622,7 @@ MadeChange |= foldExtractExtract(I); MadeChange |= foldBitcastShuf(I); MadeChange |= scalarizeBinopOrCmp(I); + MadeChange |= foldExtractedCmps(I); } } Index: llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll =================================================================== --- llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll +++ llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll @@ -294,16 +294,15 @@ ; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x double> undef, double [[MUL]], i32 0 ; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x double> [[TMP5]], <2 x double> undef, <2 x i32> zeroinitializer ; CHECK-NEXT: [[TMP7:%.*]] = fdiv <2 x double> [[TMP4]], [[TMP6]] -; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[TMP7]], i32 0 -; CHECK-NEXT: [[CMP:%.*]] = fcmp olt double [[TMP8]], 0x3EB0C6F7A0B5ED8D -; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x double> [[TMP7]], i32 1 -; CHECK-NEXT: [[CMP4:%.*]] = fcmp olt double [[TMP9]], 0x3EB0C6F7A0B5ED8D -; CHECK-NEXT: [[OR_COND:%.*]] = and i1 [[CMP]], [[CMP4]] +; CHECK-NEXT: [[TMP8:%.*]] = fcmp olt <2 x double> [[TMP7]], +; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <2 x i1> [[TMP8]], <2 x i1> undef, <2 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = and <2 x i1> [[TMP8]], [[SHIFT]] +; CHECK-NEXT: [[OR_COND:%.*]] = extractelement <2 x i1> [[TMP9]], i64 0 ; CHECK-NEXT: br i1 [[OR_COND]], label [[CLEANUP:%.*]], label [[LOR_LHS_FALSE:%.*]] ; CHECK: lor.lhs.false: ; CHECK-NEXT: [[TMP10:%.*]] = fcmp ule <2 x double> [[TMP7]], -; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <2 x i1> [[TMP10]], <2 x i1> undef, <2 x i32> -; CHECK-NEXT: [[TMP11:%.*]] = or <2 x i1> [[TMP10]], [[SHIFT]] +; CHECK-NEXT: [[SHIFT2:%.*]] = shufflevector <2 x i1> [[TMP10]], <2 x i1> undef, <2 x i32> +; CHECK-NEXT: [[TMP11:%.*]] = or <2 x i1> [[TMP10]], [[SHIFT2]] ; CHECK-NEXT: [[NOT_OR_COND1:%.*]] = extractelement <2 x i1> [[TMP11]], i32 0 ; CHECK-NEXT: ret i1 [[NOT_OR_COND1]] ; CHECK: cleanup: Index: llvm/test/Transforms/VectorCombine/X86/extract-cmp-binop.ll =================================================================== --- llvm/test/Transforms/VectorCombine/X86/extract-cmp-binop.ll +++ llvm/test/Transforms/VectorCombine/X86/extract-cmp-binop.ll @@ -3,13 +3,20 @@ ; RUN: opt < %s -vector-combine -S -mtriple=x86_64-- -mattr=avx2 | FileCheck %s --check-prefixes=CHECK,AVX define i1 @fcmp_and_v2f64(<2 x double> %a) { -; CHECK-LABEL: @fcmp_and_v2f64( -; CHECK-NEXT: [[E1:%.*]] = extractelement <2 x double> [[A:%.*]], i32 0 -; CHECK-NEXT: [[E2:%.*]] = extractelement <2 x double> [[A]], i32 1 -; CHECK-NEXT: [[CMP1:%.*]] = fcmp olt double [[E1]], 4.200000e+01 -; CHECK-NEXT: [[CMP2:%.*]] = fcmp olt double [[E2]], -8.000000e+00 -; CHECK-NEXT: [[R:%.*]] = and i1 [[CMP1]], [[CMP2]] -; CHECK-NEXT: ret i1 [[R]] +; SSE-LABEL: @fcmp_and_v2f64( +; SSE-NEXT: [[E1:%.*]] = extractelement <2 x double> [[A:%.*]], i32 0 +; SSE-NEXT: [[E2:%.*]] = extractelement <2 x double> [[A]], i32 1 +; SSE-NEXT: [[CMP1:%.*]] = fcmp olt double [[E1]], 4.200000e+01 +; SSE-NEXT: [[CMP2:%.*]] = fcmp olt double [[E2]], -8.000000e+00 +; SSE-NEXT: [[R:%.*]] = and i1 [[CMP1]], [[CMP2]] +; SSE-NEXT: ret i1 [[R]] +; +; AVX-LABEL: @fcmp_and_v2f64( +; AVX-NEXT: [[TMP1:%.*]] = fcmp olt <2 x double> [[A:%.*]], +; AVX-NEXT: [[SHIFT:%.*]] = shufflevector <2 x i1> [[TMP1]], <2 x i1> undef, <2 x i32> +; AVX-NEXT: [[TMP2:%.*]] = and <2 x i1> [[TMP1]], [[SHIFT]] +; AVX-NEXT: [[R:%.*]] = extractelement <2 x i1> [[TMP2]], i64 0 +; AVX-NEXT: ret i1 [[R]] ; %e1 = extractelement <2 x double> %a, i32 0 %e2 = extractelement <2 x double> %a, i32 1 @@ -20,13 +27,20 @@ } define i1 @fcmp_or_v4f64(<4 x double> %a) { -; CHECK-LABEL: @fcmp_or_v4f64( -; CHECK-NEXT: [[E1:%.*]] = extractelement <4 x double> [[A:%.*]], i32 0 -; CHECK-NEXT: [[E2:%.*]] = extractelement <4 x double> [[A]], i64 2 -; CHECK-NEXT: [[CMP1:%.*]] = fcmp olt double [[E1]], 4.200000e+01 -; CHECK-NEXT: [[CMP2:%.*]] = fcmp olt double [[E2]], -8.000000e+00 -; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP1]], [[CMP2]] -; CHECK-NEXT: ret i1 [[R]] +; SSE-LABEL: @fcmp_or_v4f64( +; SSE-NEXT: [[E1:%.*]] = extractelement <4 x double> [[A:%.*]], i32 0 +; SSE-NEXT: [[E2:%.*]] = extractelement <4 x double> [[A]], i64 2 +; SSE-NEXT: [[CMP1:%.*]] = fcmp olt double [[E1]], 4.200000e+01 +; SSE-NEXT: [[CMP2:%.*]] = fcmp olt double [[E2]], -8.000000e+00 +; SSE-NEXT: [[R:%.*]] = or i1 [[CMP1]], [[CMP2]] +; SSE-NEXT: ret i1 [[R]] +; +; AVX-LABEL: @fcmp_or_v4f64( +; AVX-NEXT: [[TMP1:%.*]] = fcmp olt <4 x double> [[A:%.*]], +; AVX-NEXT: [[SHIFT:%.*]] = shufflevector <4 x i1> [[TMP1]], <4 x i1> undef, <4 x i32> +; AVX-NEXT: [[TMP2:%.*]] = or <4 x i1> [[TMP1]], [[SHIFT]] +; AVX-NEXT: [[R:%.*]] = extractelement <4 x i1> [[TMP2]], i64 0 +; AVX-NEXT: ret i1 [[R]] ; %e1 = extractelement <4 x double> %a, i32 0 %e2 = extractelement <4 x double> %a, i64 2 @@ -38,11 +52,10 @@ define i1 @icmp_xor_v4i32(<4 x i32> %a) { ; CHECK-LABEL: @icmp_xor_v4i32( -; CHECK-NEXT: [[E1:%.*]] = extractelement <4 x i32> [[A:%.*]], i32 3 -; CHECK-NEXT: [[E2:%.*]] = extractelement <4 x i32> [[A]], i32 1 -; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i32 [[E1]], 42 -; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[E2]], -8 -; CHECK-NEXT: [[R:%.*]] = xor i1 [[CMP1]], [[CMP2]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp sgt <4 x i32> [[A:%.*]], +; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x i1> [[TMP1]], <4 x i1> undef, <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = xor <4 x i1> [[TMP1]], [[SHIFT]] +; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i1> [[TMP2]], i64 1 ; CHECK-NEXT: ret i1 [[R]] ; %e1 = extractelement <4 x i32> %a, i32 3 @@ -56,13 +69,20 @@ ; add is not canonical (should be xor), but that is ok. define i1 @icmp_add_v8i32(<8 x i32> %a) { -; CHECK-LABEL: @icmp_add_v8i32( -; CHECK-NEXT: [[E1:%.*]] = extractelement <8 x i32> [[A:%.*]], i32 7 -; CHECK-NEXT: [[E2:%.*]] = extractelement <8 x i32> [[A]], i32 2 -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[E1]], 42 -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[E2]], -8 -; CHECK-NEXT: [[R:%.*]] = add i1 [[CMP1]], [[CMP2]] -; CHECK-NEXT: ret i1 [[R]] +; SSE-LABEL: @icmp_add_v8i32( +; SSE-NEXT: [[E1:%.*]] = extractelement <8 x i32> [[A:%.*]], i32 7 +; SSE-NEXT: [[E2:%.*]] = extractelement <8 x i32> [[A]], i32 2 +; SSE-NEXT: [[CMP1:%.*]] = icmp eq i32 [[E1]], 42 +; SSE-NEXT: [[CMP2:%.*]] = icmp eq i32 [[E2]], -8 +; SSE-NEXT: [[R:%.*]] = add i1 [[CMP1]], [[CMP2]] +; SSE-NEXT: ret i1 [[R]] +; +; AVX-LABEL: @icmp_add_v8i32( +; AVX-NEXT: [[TMP1:%.*]] = icmp eq <8 x i32> [[A:%.*]], +; AVX-NEXT: [[SHIFT:%.*]] = shufflevector <8 x i1> [[TMP1]], <8 x i1> undef, <8 x i32> +; AVX-NEXT: [[TMP2:%.*]] = add <8 x i1> [[TMP1]], [[SHIFT]] +; AVX-NEXT: [[R:%.*]] = extractelement <8 x i1> [[TMP2]], i64 2 +; AVX-NEXT: ret i1 [[R]] ; %e1 = extractelement <8 x i32> %a, i32 7 %e2 = extractelement <8 x i32> %a, i32 2 @@ -72,6 +92,8 @@ ret i1 %r } +; Negative test - this could CSE/simplify. + define i1 @same_extract_index(<4 x i32> %a) { ; CHECK-LABEL: @same_extract_index( ; CHECK-NEXT: [[E1:%.*]] = extractelement <4 x i32> [[A:%.*]], i32 2 @@ -89,6 +111,8 @@ ret i1 %r } +; Negative test - need identical predicates. + define i1 @different_preds(<4 x i32> %a) { ; CHECK-LABEL: @different_preds( ; CHECK-NEXT: [[E1:%.*]] = extractelement <4 x i32> [[A:%.*]], i32 1 @@ -106,6 +130,8 @@ ret i1 %r } +; Negative test - need 1 source vector. + define i1 @different_source_vec(<4 x i32> %a, <4 x i32> %b) { ; CHECK-LABEL: @different_source_vec( ; CHECK-NEXT: [[E1:%.*]] = extractelement <4 x i32> [[A:%.*]], i32 1 @@ -123,6 +149,8 @@ ret i1 %r } +; Negative test - don't try this with scalable vectors. + define i1 @scalable( %a) { ; CHECK-LABEL: @scalable( ; CHECK-NEXT: [[E1:%.*]] = extractelement [[A:%.*]], i32 3