diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1024,8 +1024,26 @@ RdxFlags.IsMaxOp = RdxKind == RecurKind::SMax || RdxKind == RecurKind::UMax || RdxKind == RecurKind::FMax; RdxFlags.IsSigned = RdxKind == RecurKind::SMax || RdxKind == RecurKind::SMin; - + // Special reductions for i1 or and and operations. No need to emit reductions + // here, just x != <0, 0, .., 0> for reduction or and x == <1, 1, .., 1> for + // reduction and. auto *SrcVecEltTy = cast(Src->getType())->getElementType(); + if ((RdxKind == RecurKind::And || RdxKind == RecurKind::Or) && + SrcVecEltTy == Builder.getInt1Ty()) { + Value *Res = Builder.CreateBitCast( + Src, Builder.getIntNTy(cast(Src->getType()) + ->getElementCount() + .getFixedValue())); + if (RdxKind == RecurKind::And) { + Res = Builder.CreateICmpEQ(Res, + ConstantInt::getAllOnesValue(Res->getType())); + } else { + assert(RdxKind == RecurKind::Or && "Expected or reduction."); + Res = Builder.CreateIsNotNull(Res); + } + return Res; + } + switch (RdxKind) { case RecurKind::Add: return Builder.CreateAddReduce(Src); diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -6773,8 +6773,30 @@ RecurrenceDescriptor RdxDesc = Legal->getReductionVars()[cast(ReductionPhi)]; - unsigned BaseCost = TTI.getArithmeticReductionCost(RdxDesc.getOpcode(), - VectorTy, false, CostKind); + unsigned BaseCost; + RecurKind RdxKind = RdxDesc.getRecurrenceKind(); + RdxDesc.getRecurrenceType(); + if ((RdxKind == RecurKind::Or || RdxKind == RecurKind::And) && + VectorTy->getElementType() == + IntegerType::getInt1Ty(VectorTy->getContext())) { + // Or reduction for i1 is represented as: + // %val = bitcast to iReduxWidth + // %res = cmp ne iReduxWidth %val, 0 + // And reduction for i1 is represented as: + // %val = bitcast to iReduxWidth + // %res = cmp eq iReduxWidth %val, 11111 + Type *ValTy = IntegerType::get(VectorTy->getContext(), + VectorTy->getElementCount().getFixedValue()); + BaseCost = TTI.getCastInstrCost(Instruction::BitCast, ValTy, VectorTy, + TTI::CastContextHint::None, + TTI::TCK_RecipThroughput) + + TTI.getCmpSelInstrCost(Instruction::ICmp, ValTy, + CmpInst::makeCmpResultType(ValTy)); + } else { + BaseCost = + TTI.getArithmeticReductionCost(RdxDesc.getOpcode(), VectorTy, + /*IsPairwiseForm=*/false, CostKind); + } // Get the operand that was not the reduction chain and match it to one of the // patterns, returning the better cost if it is found. diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -7027,8 +7027,25 @@ case RecurKind::FAdd: case RecurKind::FMul: { unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(RdxKind); - VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, - /*IsPairwiseForm=*/false); + if ((RdxKind == RecurKind::Or || RdxKind == RecurKind::And) && + ScalarTy == IntegerType::getInt1Ty(FirstReducedVal->getContext())) { + // Or reduction for i1 is represented as: + // %val = bitcast to iReduxWidth + // %res = cmp ne iReduxWidth %val, 0 + // And reduction for i1 is represented as: + // %val = bitcast to iReduxWidth + // %res = cmp eq iReduxWidth %val, 11111 + Type *ValTy = + IntegerType::get(FirstReducedVal->getContext(), ReduxWidth); + VectorCost = TTI->getCastInstrCost(Instruction::BitCast, ValTy, + VectorTy, TTI::CastContextHint::None, + TTI::TCK_RecipThroughput) + + TTI->getCmpSelInstrCost(Instruction::ICmp, ValTy, + CmpInst::makeCmpResultType(ValTy)); + } else { + VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, + /*IsPairwiseForm=*/false); + } ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy); break; } diff --git a/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll b/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll --- a/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll @@ -84,9 +84,10 @@ ; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[X:%.*]], i32 3 ; CHECK-NEXT: [[CMP3WRONG:%.*]] = fcmp olt float [[TMP1]], 4.200000e+01 ; CHECK-NEXT: [[TMP2:%.*]] = fcmp ogt <4 x float> [[X]], -; CHECK-NEXT: [[TMP3:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP2]]) -; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP4]], float -1.000000e+00, float 1.000000e+00 +; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i1> [[TMP2]] to i4 +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i4 [[TMP3]], 0 +; CHECK-NEXT: [[TMP5:%.*]] = or i1 [[TMP4]], [[CMP3WRONG]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP5]], float -1.000000e+00, float 1.000000e+00 ; CHECK-NEXT: ret float [[R]] ; %x0 = extractelement <4 x float> %x, i32 0 @@ -111,9 +112,10 @@ ; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[X:%.*]], i32 3 ; CHECK-NEXT: [[CMP3WRONG:%.*]] = fcmp olt float [[TMP1]], 4.200000e+01 ; CHECK-NEXT: [[TMP2:%.*]] = fcmp ogt <4 x float> [[X]], -; CHECK-NEXT: [[TMP3:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP2]]) -; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP4]], float -1.000000e+00, float 1.000000e+00 +; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i1> [[TMP2]] to i4 +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i4 [[TMP3]], 0 +; CHECK-NEXT: [[TMP5:%.*]] = or i1 [[TMP4]], [[CMP3WRONG]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP5]], float -1.000000e+00, float 1.000000e+00 ; CHECK-NEXT: ret float [[R]] ; %x0 = extractelement <4 x float> %x, i32 0 @@ -138,9 +140,10 @@ ; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 3 ; CHECK-NEXT: [[CMP3WRONG:%.*]] = icmp slt i32 [[TMP1]], 42 ; CHECK-NEXT: [[TMP2:%.*]] = icmp sgt <4 x i32> [[X]], -; CHECK-NEXT: [[TMP3:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP2]]) -; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[CMP3WRONG]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP4]], i32 -1, i32 1 +; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i1> [[TMP2]] to i4 +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i4 [[TMP3]], 0 +; CHECK-NEXT: [[TMP5:%.*]] = or i1 [[TMP4]], [[CMP3WRONG]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP5]], i32 -1, i32 1 ; CHECK-NEXT: ret i32 [[R]] ; %x0 = extractelement <4 x i32> %x, i32 0 @@ -169,9 +172,10 @@ ; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x i32> [[X:%.*]], i32 3 ; CHECK-NEXT: [[CMP3WRONG:%.*]] = icmp slt i32 [[TMP2]], [[TMP1]] ; CHECK-NEXT: [[TMP3:%.*]] = icmp sgt <4 x i32> [[X]], [[Y]] -; CHECK-NEXT: [[TMP4:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP3]]) -; CHECK-NEXT: [[TMP5:%.*]] = or i1 [[TMP4]], [[CMP3WRONG]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP5]], i32 -1, i32 1 +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i1> [[TMP3]] to i4 +; CHECK-NEXT: [[TMP5:%.*]] = icmp ne i4 [[TMP4]], 0 +; CHECK-NEXT: [[TMP6:%.*]] = or i1 [[TMP5]], [[CMP3WRONG]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP6]], i32 -1, i32 1 ; CHECK-NEXT: ret i32 [[R]] ; %x0 = extractelement <4 x i32> %x, i32 0