diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -8012,10 +8012,20 @@ Value *RHS, const Twine &Name, bool UseSelect) { unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind); switch (Kind) { - case RecurKind::Add: - case RecurKind::Mul: case RecurKind::Or: + if (UseSelect && + LHS->getType() == CmpInst::makeCmpResultType(LHS->getType())) + return Builder.CreateSelect(LHS, Builder.getTrue(), RHS, Name); + return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS, + Name); case RecurKind::And: + if (UseSelect && + LHS->getType() == CmpInst::makeCmpResultType(LHS->getType())) + return Builder.CreateSelect(LHS, RHS, Builder.getFalse(), Name); + return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS, + Name); + case RecurKind::Add: + case RecurKind::Mul: case RecurKind::Xor: case RecurKind::FAdd: case RecurKind::FMul: @@ -8059,8 +8069,12 @@ static Value *createOp(IRBuilder<> &Builder, RecurKind RdxKind, Value *LHS, Value *RHS, const Twine &Name, const ReductionOpsListType &ReductionOps) { - bool UseSelect = ReductionOps.size() == 2; - assert((!UseSelect || isa(ReductionOps[1][0])) && + bool UseSelect = ReductionOps.size() == 2 || + // Logical or/and. + (ReductionOps.size() == 1 && + isa(ReductionOps.front().front())); + assert((!UseSelect || ReductionOps.size() != 2 || + isa(ReductionOps[1][0])) && "Expected cmp + select pairs for reduction"); Value *Op = createOp(Builder, RdxKind, LHS, RHS, Name, UseSelect); if (RecurrenceDescriptor::isIntMinMaxRecurrenceKind(RdxKind)) { @@ -8198,10 +8212,10 @@ /// Checks if the instruction is in basic block \p BB. /// For a cmp+sel min/max reduction check that both ops are in \p BB. static bool hasSameParent(Instruction *I, BasicBlock *BB) { - if (isCmpSelMinMax(I)) { + if (isCmpSelMinMax(I) || (isBoolLogicOp(I) && isa(I))) { auto *Sel = cast(I); - auto *Cmp = cast(Sel->getCondition()); - return Sel->getParent() == BB && Cmp->getParent() == BB; + auto *Cmp = dyn_cast(Sel->getCondition()); + return Sel->getParent() == BB && Cmp && Cmp->getParent() == BB; } return I->getParent() == BB; } diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reduction-logical.ll b/llvm/test/Transforms/SLPVectorizer/X86/reduction-logical.ll --- a/llvm/test/Transforms/SLPVectorizer/X86/reduction-logical.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/reduction-logical.ll @@ -480,7 +480,7 @@ ; CHECK-NEXT: [[S3:%.*]] = select i1 [[C:%.*]], i1 [[C]], i1 false ; CHECK-NEXT: [[TMP2:%.*]] = freeze <4 x i1> [[TMP1]] ; CHECK-NEXT: [[TMP3:%.*]] = call i1 @llvm.vector.reduce.and.v4i1(<4 x i1> [[TMP2]]) -; CHECK-NEXT: [[OP_EXTRA:%.*]] = and i1 [[TMP3]], [[S3]] +; CHECK-NEXT: [[OP_EXTRA:%.*]] = select i1 [[TMP3]], i1 [[S3]], i1 false ; CHECK-NEXT: ret i1 [[OP_EXTRA]] ; %x0 = extractelement <4 x i32> %x, i32 0 @@ -509,7 +509,7 @@ ; CHECK-NEXT: [[S3:%.*]] = select i1 [[C:%.*]], i1 true, i1 [[C]] ; CHECK-NEXT: [[TMP2:%.*]] = freeze <4 x i1> [[TMP1]] ; CHECK-NEXT: [[TMP3:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP2]]) -; CHECK-NEXT: [[OP_EXTRA:%.*]] = or i1 [[TMP3]], [[S3]] +; CHECK-NEXT: [[OP_EXTRA:%.*]] = select i1 [[TMP3]], i1 true, i1 [[S3]] ; CHECK-NEXT: ret i1 [[OP_EXTRA]] ; %x0 = extractelement <4 x i32> %x, i32 0