diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1897,6 +1897,34 @@ } break; } + case Intrinsic::vector_reduce_or: + case Intrinsic::vector_reduce_and: { + // Canonicalize logical or/and reductions: + // Or reduction for i1 is represented as: + // %val = bitcast to iReduxWidth + // %res = cmp ne iReduxWidth %val, 0 + // And reduction for i1 is represented as: + // %val = bitcast to iReduxWidth + // %res = cmp eq iReduxWidth %val, 11111 + Value *Arg = II->getArgOperand(0); + Type *RetTy = II->getType(); + if (RetTy == Builder.getInt1Ty()) + if (auto *FVTy = dyn_cast(Arg->getType())) { + Value *Res = Builder.CreateBitCast( + Arg, Builder.getIntNTy(FVTy->getNumElements())); + if (IID == Intrinsic::vector_reduce_and) { + Res = Builder.CreateICmpEQ( + Res, ConstantInt::getAllOnesValue(Res->getType())); + } else { + assert(IID == Intrinsic::vector_reduce_or && + "Expected or reduction."); + Res = Builder.CreateIsNotNull(Res); + } + replaceInstUsesWith(CI, Res); + return eraseInstFromFunction(CI); + } + break; + } default: { // Handle target specific intrinsics Optional V = targetInstCombineIntrinsic(*II); diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -6943,8 +6943,31 @@ RecurrenceDescriptor RdxDesc = Legal->getReductionVars()[cast(ReductionPhi)]; - unsigned BaseCost = TTI.getArithmeticReductionCost(RdxDesc.getOpcode(), - VectorTy, false, CostKind); + unsigned BaseCost; + RecurKind RdxKind = RdxDesc.getRecurrenceKind(); + RdxDesc.getRecurrenceType(); + if ((RdxKind == RecurKind::Or || RdxKind == RecurKind::And) && + isa(VectorTy) && + VectorTy->getElementType() == + IntegerType::getInt1Ty(VectorTy->getContext())) { + // Or reduction for i1 is represented as: + // %val = bitcast to iReduxWidth + // %res = cmp ne iReduxWidth %val, 0 + // And reduction for i1 is represented as: + // %val = bitcast to iReduxWidth + // %res = cmp eq iReduxWidth %val, 11111 + Type *ValTy = IntegerType::get(VectorTy->getContext(), + VectorTy->getElementCount().getFixedValue()); + BaseCost = TTI.getCastInstrCost(Instruction::BitCast, ValTy, VectorTy, + TTI::CastContextHint::None, + TTI::TCK_RecipThroughput) + + TTI.getCmpSelInstrCost(Instruction::ICmp, ValTy, + CmpInst::makeCmpResultType(ValTy)); + } else { + BaseCost = + TTI.getArithmeticReductionCost(RdxDesc.getOpcode(), VectorTy, + /*IsPairwiseForm=*/false, CostKind); + } // Get the operand that was not the reduction chain and match it to one of the // patterns, returning the better cost if it is found. diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -7070,8 +7070,25 @@ case RecurKind::FAdd: case RecurKind::FMul: { unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(RdxKind); - VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, - /*IsPairwiseForm=*/false); + if ((RdxKind == RecurKind::Or || RdxKind == RecurKind::And) && + ScalarTy == IntegerType::getInt1Ty(FirstReducedVal->getContext())) { + // Or reduction for i1 is represented as: + // %val = bitcast to iReduxWidth + // %res = cmp ne iReduxWidth %val, 0 + // And reduction for i1 is represented as: + // %val = bitcast to iReduxWidth + // %res = cmp eq iReduxWidth %val, 11111 + Type *ValTy = + IntegerType::get(FirstReducedVal->getContext(), ReduxWidth); + VectorCost = TTI->getCastInstrCost(Instruction::BitCast, ValTy, + VectorTy, TTI::CastContextHint::None, + TTI::TCK_RecipThroughput) + + TTI->getCmpSelInstrCost(Instruction::ICmp, ValTy, + CmpInst::makeCmpResultType(ValTy)); + } else { + VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, + /*IsPairwiseForm=*/false); + } ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy); break; } diff --git a/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll b/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll --- a/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll +++ b/llvm/test/Transforms/InstCombine/vector-logical-reductions.ll @@ -4,8 +4,9 @@ define float @reduction_logical_or(<4 x float> %x) { ; CHECK-LABEL: @reduction_logical_or( ; CHECK-NEXT: [[TMP1:%.*]] = fcmp ogt <4 x float> [[X:%.*]], -; CHECK-NEXT: [[TMP2:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP1]]) -; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP2]], float -1.000000e+00, float 1.000000e+00 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <4 x i1> [[TMP1]] to i4 +; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i4 [[TMP2]], 0 +; CHECK-NEXT: [[R:%.*]] = select i1 [[DOTNOT]], float 1.000000e+00, float -1.000000e+00 ; CHECK-NEXT: ret float [[R]] ; %1 = fcmp ogt <4 x float> %x, @@ -17,8 +18,9 @@ define float @reduction_logical_and(<4 x float> %x) { ; CHECK-LABEL: @reduction_logical_and( ; CHECK-NEXT: [[TMP1:%.*]] = fcmp ogt <4 x float> [[X:%.*]], -; CHECK-NEXT: [[TMP2:%.*]] = call i1 @llvm.vector.reduce.and.v4i1(<4 x i1> [[TMP1]]) -; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP2]], float -1.000000e+00, float 1.000000e+00 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <4 x i1> [[TMP1]] to i4 +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i4 [[TMP2]], -1 +; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP3]], float -1.000000e+00, float 1.000000e+00 ; CHECK-NEXT: ret float [[R]] ; %1 = fcmp ogt <4 x float> %x,