diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -487,6 +487,27 @@ getSameOpcode({BaseOp1, Op1}).getOpcode(); } +/// \returns true if a compare instruction \p CI has similar "look" and +/// same predicate as \p BaseCI, "as is" or with its operands and predicate +/// swapped, false otherwise. +static bool isCmpSameOrSwapped(const CmpInst *BaseCI, const CmpInst *CI) { + assert(BaseCI->getOperand(0)->getType() == CI->getOperand(0)->getType() && + "Assessing comparisons of different types?"); + CmpInst::Predicate BasePred = BaseCI->getPredicate(); + CmpInst::Predicate Pred = CI->getPredicate(); + CmpInst::Predicate SwappedPred = CmpInst::getSwappedPredicate(Pred); + + Value *BaseOp0 = BaseCI->getOperand(0); + Value *BaseOp1 = BaseCI->getOperand(1); + Value *Op0 = CI->getOperand(0); + Value *Op1 = CI->getOperand(1); + + return (BasePred == Pred && + areCompatibleCmpOps(BaseOp0, BaseOp1, Op0, Op1)) || + (BasePred == SwappedPred && + areCompatibleCmpOps(BaseOp0, BaseOp1, Op1, Op0)); +} + /// \returns analysis of the Instructions in \p VL described in /// InstructionsState, the Opcode that we suppose the whole list /// could be vectorized even if its structure is diverse. @@ -534,52 +555,35 @@ continue; } } - } else if (IsCmpOp && isa(VL[Cnt])) { - auto *BaseInst = cast(VL[BaseIndex]); - auto *Inst = cast(VL[Cnt]); + } else if (auto *Inst = dyn_cast(VL[Cnt]); Inst && IsCmpOp) { + auto *BaseInst = cast(VL[BaseIndex]); Type *Ty0 = BaseInst->getOperand(0)->getType(); Type *Ty1 = Inst->getOperand(0)->getType(); if (Ty0 == Ty1) { - Value *BaseOp0 = BaseInst->getOperand(0); - Value *BaseOp1 = BaseInst->getOperand(1); - Value *Op0 = Inst->getOperand(0); - Value *Op1 = Inst->getOperand(1); - CmpInst::Predicate CurrentPred = - cast(VL[Cnt])->getPredicate(); - CmpInst::Predicate SwappedCurrentPred = - CmpInst::getSwappedPredicate(CurrentPred); + assert(InstOpcode == Opcode && "Expected same CmpInst opcode."); // Check for compatible operands. If the corresponding operands are not // compatible - need to perform alternate vectorization. - if (InstOpcode == Opcode) { - if (BasePred == CurrentPred && - areCompatibleCmpOps(BaseOp0, BaseOp1, Op0, Op1)) - continue; - if (BasePred == SwappedCurrentPred && - areCompatibleCmpOps(BaseOp0, BaseOp1, Op1, Op0)) - continue; - if (E == 2 && - (BasePred == CurrentPred || BasePred == SwappedCurrentPred)) - continue; - auto *AltInst = cast(VL[AltIndex]); - CmpInst::Predicate AltPred = AltInst->getPredicate(); - Value *AltOp0 = AltInst->getOperand(0); - Value *AltOp1 = AltInst->getOperand(1); - // Check if operands are compatible with alternate operands. - if (AltPred == CurrentPred && - areCompatibleCmpOps(AltOp0, AltOp1, Op0, Op1)) - continue; - if (AltPred == SwappedCurrentPred && - areCompatibleCmpOps(AltOp0, AltOp1, Op1, Op0)) + CmpInst::Predicate CurrentPred = Inst->getPredicate(); + CmpInst::Predicate SwappedCurrentPred = + CmpInst::getSwappedPredicate(CurrentPred); + + if (E == 2 && + (BasePred == CurrentPred || BasePred == SwappedCurrentPred)) + continue; + + if (isCmpSameOrSwapped(BaseInst, Inst)) + continue; + auto *AltInst = cast(VL[AltIndex]); + if (AltIndex != BaseIndex) { + if (isCmpSameOrSwapped(AltInst, Inst)) continue; - } - if (BaseIndex == AltIndex && BasePred != CurrentPred) { - assert(isValidForAlternation(Opcode) && - isValidForAlternation(InstOpcode) && - "Cast isn't safe for alternation, logic needs to be updated!"); + } else if (BasePred != CurrentPred) { + assert( + isValidForAlternation(InstOpcode) && + "CmpInst isn't safe for alternation, logic needs to be updated!"); AltIndex = Cnt; continue; } - auto *AltInst = cast(VL[AltIndex]); CmpInst::Predicate AltPred = AltInst->getPredicate(); if (BasePred == CurrentPred || BasePred == SwappedCurrentPred || AltPred == CurrentPred || AltPred == SwappedCurrentPred) @@ -4614,6 +4618,12 @@ return std::make_pair(Key, SubKey); } +/// Checks if the specified instruction \p I is an alternate operation for +/// the given \p MainOp and \p AltOp instructions. +static bool isAlternateInstruction(const Instruction *I, + const Instruction *MainOp, + const Instruction *AltOp); + void BoUpSLP::buildTree_rec(ArrayRef VL, unsigned Depth, const EdgeInfo &UserTreeIdx) { assert((allConstant(VL) || allSameType(VL)) && "Invalid types!"); @@ -5542,29 +5552,25 @@ })) { reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE, *this); } else { - CmpInst::Predicate P0 = CI->getPredicate(); - CmpInst::Predicate AltP0 = cast(S.AltOp)->getPredicate(); - assert(P0 != AltP0 && + auto *MainCI = cast(S.MainOp); + auto *AltCI = cast(S.AltOp); + CmpInst::Predicate MainP = MainCI->getPredicate(); + CmpInst::Predicate AltP = AltCI->getPredicate(); + assert(MainP != AltP && "Expected different main/alternate predicates."); - CmpInst::Predicate AltP0Swapped = CmpInst::getSwappedPredicate(AltP0); - Value *BaseOp0 = VL0->getOperand(0); - Value *BaseOp1 = VL0->getOperand(1); // Collect operands - commute if it uses the swapped predicate or // alternate operation. for (Value *V : VL) { auto *Cmp = cast(V); Value *LHS = Cmp->getOperand(0); Value *RHS = Cmp->getOperand(1); - CmpInst::Predicate CurrentPred = Cmp->getPredicate(); - if (P0 == AltP0Swapped) { - if (CI != Cmp && S.AltOp != Cmp && - ((P0 == CurrentPred && - !areCompatibleCmpOps(BaseOp0, BaseOp1, LHS, RHS)) || - (AltP0 == CurrentPred && - areCompatibleCmpOps(BaseOp0, BaseOp1, LHS, RHS)))) + + if (isAlternateInstruction(Cmp, MainCI, AltCI)) { + if (AltP == CmpInst::getSwappedPredicate(Cmp->getPredicate())) + std::swap(LHS, RHS); + } else { + if (MainP == CmpInst::getSwappedPredicate(Cmp->getPredicate())) std::swap(LHS, RHS); - } else if (P0 != CurrentPred && AltP0 != CurrentPred) { - std::swap(LHS, RHS); } Left.push_back(LHS); Right.push_back(RHS); @@ -5851,25 +5857,27 @@ } } -/// Checks if the specified instruction \p I is an alternate operation for the -/// given \p MainOp and \p AltOp instructions. static bool isAlternateInstruction(const Instruction *I, const Instruction *MainOp, const Instruction *AltOp) { - if (auto *CI0 = dyn_cast(MainOp)) { - auto *AltCI0 = cast(AltOp); + if (auto *MainCI = dyn_cast(MainOp)) { + auto *AltCI = cast(AltOp); + CmpInst::Predicate MainP = MainCI->getPredicate(); + CmpInst::Predicate AltP = AltCI->getPredicate(); + assert(MainP != AltP && "Expected different main/alternate predicates."); auto *CI = cast(I); - CmpInst::Predicate P0 = CI0->getPredicate(); - CmpInst::Predicate AltP0 = AltCI0->getPredicate(); - assert(P0 != AltP0 && "Expected different main/alternate predicates."); - CmpInst::Predicate AltP0Swapped = CmpInst::getSwappedPredicate(AltP0); - CmpInst::Predicate CurrentPred = CI->getPredicate(); - if (P0 == AltP0Swapped) - return I == AltCI0 || - (I != MainOp && - !areCompatibleCmpOps(CI0->getOperand(0), CI0->getOperand(1), - CI->getOperand(0), CI->getOperand(1))); - return AltP0 == CurrentPred || AltP0Swapped == CurrentPred; + if (isCmpSameOrSwapped(MainCI, CI)) + return false; + if (isCmpSameOrSwapped(AltCI, CI)) + return true; + CmpInst::Predicate P = CI->getPredicate(); + CmpInst::Predicate SwappedP = CmpInst::getSwappedPredicate(P); + + assert((MainP == P || AltP == P || MainP == SwappedP || AltP == SwappedP) && + "CmpInst expected to match either main or alternate predicate or " + "their swap."); + AltP; + return MainP != P && MainP != SwappedP; } return I->getOpcode() == AltOp->getOpcode(); } diff --git a/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions-logical.ll b/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions-logical.ll --- a/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions-logical.ll +++ b/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions-logical.ll @@ -86,15 +86,16 @@ define float @test_merge_anyof_v4sf(<4 x float> %t) { ; CHECK-LABEL: @test_merge_anyof_v4sf( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[T_FR:%.*]] = freeze <4 x float> [[T:%.*]] -; CHECK-NEXT: [[TMP0:%.*]] = fcmp olt <4 x float> [[T_FR]], zeroinitializer -; CHECK-NEXT: [[TMP1:%.*]] = fcmp ogt <4 x float> [[T_FR]], -; CHECK-NEXT: [[TMP2:%.*]] = or <4 x i1> [[TMP1]], [[TMP0]] -; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i1> [[TMP2]] to i4 -; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i4 [[TMP3]], 0 -; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x float> [[T_FR]], <4 x float> poison, <4 x i32> -; CHECK-NEXT: [[TMP4:%.*]] = fadd <4 x float> [[T_FR]], [[SHIFT]] -; CHECK-NEXT: [[ADD:%.*]] = extractelement <4 x float> [[TMP4]], i64 0 +; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x float> [[T:%.*]], <4 x float> poison, <8 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = fcmp ogt <8 x float> [[SHUFFLE]], +; CHECK-NEXT: [[TMP1:%.*]] = fcmp olt <8 x float> [[SHUFFLE]], +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i1> [[TMP0]], <8 x i1> [[TMP1]], <8 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = freeze <8 x i1> [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i1> [[TMP3]] to i8 +; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i8 [[TMP4]], 0 +; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x float> [[T]], <4 x float> poison, <4 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = fadd <4 x float> [[SHIFT]], [[T]] +; CHECK-NEXT: [[ADD:%.*]] = extractelement <4 x float> [[TMP5]], i64 0 ; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[DOTNOT]], float [[ADD]], float 0.000000e+00 ; CHECK-NEXT: ret float [[RETVAL_0]] ; @@ -400,14 +401,16 @@ define float @test_merge_anyof_v4si(<4 x i32> %t) { ; CHECK-LABEL: @test_merge_anyof_v4si( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[T_FR:%.*]] = freeze <4 x i32> [[T:%.*]] -; CHECK-NEXT: [[TMP0:%.*]] = add <4 x i32> [[T_FR]], -; CHECK-NEXT: [[TMP1:%.*]] = icmp ult <4 x i32> [[TMP0]], -; CHECK-NEXT: [[TMP2:%.*]] = bitcast <4 x i1> [[TMP1]] to i4 -; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i4 [[TMP2]], 0 -; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x i32> [[T_FR]], <4 x i32> poison, <4 x i32> -; CHECK-NEXT: [[TMP3:%.*]] = add nsw <4 x i32> [[T_FR]], [[SHIFT]] -; CHECK-NEXT: [[ADD:%.*]] = extractelement <4 x i32> [[TMP3]], i64 0 +; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x i32> [[T:%.*]], <4 x i32> poison, <8 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt <8 x i32> [[SHUFFLE]], +; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <8 x i32> [[SHUFFLE]], +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i1> [[TMP0]], <8 x i1> [[TMP1]], <8 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = freeze <8 x i1> [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i1> [[TMP3]] to i8 +; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i8 [[TMP4]], 0 +; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x i32> [[T]], <4 x i32> poison, <4 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = add nsw <4 x i32> [[SHIFT]], [[T]] +; CHECK-NEXT: [[ADD:%.*]] = extractelement <4 x i32> [[TMP5]], i64 0 ; CHECK-NEXT: [[CONV:%.*]] = sitofp i32 [[ADD]] to float ; CHECK-NEXT: [[RETVAL_0:%.*]] = select i1 [[DOTNOT]], float [[CONV]], float 0.000000e+00 ; CHECK-NEXT: ret float [[RETVAL_0]] diff --git a/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll b/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll --- a/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/alternate-cmp-swapped-pred.ll @@ -5,16 +5,15 @@ ; CHECK-LABEL: @test( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CALL:%.*]] = load i16, i16* undef, align 2 -; CHECK-NEXT: [[TMP0:%.*]] = insertelement <8 x i16> , i16 [[CALL37:%.*]], i32 4 +; CHECK-NEXT: [[TMP0:%.*]] = insertelement <8 x i16> , i16 [[CALL37:%.*]], i32 3 ; CHECK-NEXT: [[TMP1:%.*]] = insertelement <8 x i16> [[TMP0]], i16 [[CALL]], i32 0 -; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x i16> , i16 [[CALL37]], i32 3 -; CHECK-NEXT: [[TMP3:%.*]] = insertelement <8 x i16> [[TMP2]], i16 [[CALL37]], i32 6 -; CHECK-NEXT: [[TMP4:%.*]] = icmp slt <8 x i16> [[TMP1]], [[TMP3]] -; CHECK-NEXT: [[TMP5:%.*]] = icmp sgt <8 x i16> [[TMP1]], [[TMP3]] -; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <8 x i1> [[TMP4]], <8 x i1> [[TMP5]], <8 x i32> -; CHECK-NEXT: [[TMP7:%.*]] = zext <8 x i1> [[TMP6]] to <8 x i16> -; CHECK-NEXT: [[TMP8:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP7]]) -; CHECK-NEXT: [[OP_RDX:%.*]] = add i16 [[TMP8]], 0 +; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <8 x i16> [[TMP1]], <8 x i16> poison, <8 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <8 x i16> [[SHUFFLE]], zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = icmp sgt <8 x i16> [[SHUFFLE]], zeroinitializer +; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <8 x i1> [[TMP2]], <8 x i1> [[TMP3]], <8 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = zext <8 x i1> [[TMP4]] to <8 x i16> +; CHECK-NEXT: [[TMP6:%.*]] = call i16 @llvm.vector.reduce.add.v8i16(<8 x i16> [[TMP5]]) +; CHECK-NEXT: [[OP_RDX:%.*]] = add i16 [[TMP6]], 0 ; CHECK-NEXT: ret i16 [[OP_RDX]] ; entry: diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reduction-logical.ll b/llvm/test/Transforms/SLPVectorizer/X86/reduction-logical.ll --- a/llvm/test/Transforms/SLPVectorizer/X86/reduction-logical.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/reduction-logical.ll @@ -228,14 +228,13 @@ define i1 @logical_and_icmp_clamp(<4 x i32> %x) { ; CHECK-LABEL: @logical_and_icmp_clamp( -; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <4 x i32> [[X:%.*]], -; CHECK-NEXT: [[TMP2:%.*]] = icmp sgt <4 x i32> [[X]], -; CHECK-NEXT: [[TMP3:%.*]] = freeze <4 x i1> [[TMP2]] -; CHECK-NEXT: [[TMP4:%.*]] = call i1 @llvm.vector.reduce.and.v4i1(<4 x i1> [[TMP3]]) -; CHECK-NEXT: [[TMP5:%.*]] = freeze <4 x i1> [[TMP1]] -; CHECK-NEXT: [[TMP6:%.*]] = call i1 @llvm.vector.reduce.and.v4i1(<4 x i1> [[TMP5]]) -; CHECK-NEXT: [[OP_RDX:%.*]] = select i1 [[TMP4]], i1 [[TMP6]], i1 false -; CHECK-NEXT: ret i1 [[OP_RDX]] +; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> poison, <8 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = icmp sgt <8 x i32> [[SHUFFLE]], +; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <8 x i32> [[SHUFFLE]], +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <8 x i1> [[TMP1]], <8 x i1> [[TMP2]], <8 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = freeze <8 x i1> [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = call i1 @llvm.vector.reduce.and.v8i1(<8 x i1> [[TMP4]]) +; CHECK-NEXT: ret i1 [[TMP5]] ; %x0 = extractelement <4 x i32> %x, i32 0 %x1 = extractelement <4 x i32> %x, i32 1 @@ -261,16 +260,15 @@ define i1 @logical_and_icmp_clamp_extra_use_cmp(<4 x i32> %x) { ; CHECK-LABEL: @logical_and_icmp_clamp_extra_use_cmp( -; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <4 x i32> [[X:%.*]], -; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x i1> [[TMP1]], i32 2 -; CHECK-NEXT: call void @use1(i1 [[TMP2]]) -; CHECK-NEXT: [[TMP3:%.*]] = icmp sgt <4 x i32> [[X]], -; CHECK-NEXT: [[TMP4:%.*]] = freeze <4 x i1> [[TMP3]] -; CHECK-NEXT: [[TMP5:%.*]] = call i1 @llvm.vector.reduce.and.v4i1(<4 x i1> [[TMP4]]) -; CHECK-NEXT: [[TMP6:%.*]] = freeze <4 x i1> [[TMP1]] -; CHECK-NEXT: [[TMP7:%.*]] = call i1 @llvm.vector.reduce.and.v4i1(<4 x i1> [[TMP6]]) -; CHECK-NEXT: [[OP_RDX:%.*]] = select i1 [[TMP5]], i1 [[TMP7]], i1 false -; CHECK-NEXT: ret i1 [[OP_RDX]] +; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> poison, <8 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = icmp sgt <8 x i32> [[SHUFFLE]], +; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <8 x i32> [[SHUFFLE]], +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <8 x i1> [[TMP1]], <8 x i1> [[TMP2]], <8 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <8 x i1> [[TMP3]], i32 6 +; CHECK-NEXT: call void @use1(i1 [[TMP4]]) +; CHECK-NEXT: [[TMP5:%.*]] = freeze <8 x i1> [[TMP3]] +; CHECK-NEXT: [[TMP6:%.*]] = call i1 @llvm.vector.reduce.and.v8i1(<8 x i1> [[TMP5]]) +; CHECK-NEXT: ret i1 [[TMP6]] ; %x0 = extractelement <4 x i32> %x, i32 0 %x1 = extractelement <4 x i32> %x, i32 1