Index: llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5882,6 +5882,36 @@ return nullptr; } +static Instruction *foldReductionIdiom(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + if (I.getType()->isVectorTy() || I.getPredicate() != ICmpInst::ICMP_EQ) + return nullptr; + auto *BitCast = dyn_cast(I.getOperand(0)); + auto *Constant = dyn_cast(I.getOperand(1)); + if (!BitCast || !Constant || !BitCast->hasOneUse()) + return nullptr; + auto *BitCastOp = dyn_cast(BitCast->getOperand(0)); + if (!BitCastOp || !BitCastOp->hasOneUse()) + return nullptr; + auto *OpTy = dyn_cast(BitCastOp->getType()); + if (!OpTy) + return nullptr; + // And-reduce idiom. + if (Constant->isZero() && BitCastOp->getPredicate() == ICmpInst::ICMP_NE) { + Value *LHS = BitCastOp->getOperand(0); + Value *RHS = BitCastOp->getOperand(1); + auto *LHSTy = cast(LHS->getType()); + auto *ScalarTy = Builder.getIntNTy( + OpTy->getNumElements() * LHSTy->getElementType()->getIntegerBitWidth()); + LHS = Builder.CreateBitCast(LHS, ScalarTy, LHS->getName() + ".scalar"); + RHS = Builder.CreateBitCast(RHS, ScalarTy, RHS->getName() + ".scalar"); + return ICmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, LHS, RHS, + I.getName()); + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { bool Changed = false; const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -6124,6 +6154,9 @@ if (Instruction *Res = foldICmpInvariantGroup(I)) return Res; + if (Instruction *Res = foldReductionIdiom(I, Builder)) + return Res; + return Changed ? &I : nullptr; } Index: llvm/test/Transforms/InstCombine/icmp-vec.ll =================================================================== --- llvm/test/Transforms/InstCombine/icmp-vec.ll +++ llvm/test/Transforms/InstCombine/icmp-vec.ll @@ -402,9 +402,9 @@ define i1 @eq_cast_eq-1(<2 x i4> %x, <2 x i4> %y) { ; CHECK-LABEL: @eq_cast_eq-1( -; CHECK-NEXT: [[IC:%.*]] = icmp ne <2 x i4> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i1> [[IC]] to i2 -; CHECK-NEXT: [[R:%.*]] = icmp eq i2 [[TMP1]], 0 +; CHECK-NEXT: [[X_SCALAR:%.*]] = bitcast <2 x i4> [[X:%.*]] to i8 +; CHECK-NEXT: [[Y_SCALAR:%.*]] = bitcast <2 x i4> [[Y:%.*]] to i8 +; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[X_SCALAR]], [[Y_SCALAR]] ; CHECK-NEXT: ret i1 [[R]] ; %ic = icmp eq <2 x i4> %x, %y Index: llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll =================================================================== --- llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll +++ llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll @@ -98,14 +98,12 @@ define i1 @reduce_and_pointer_cast(i8* %arg, i8* %arg1) { ; CHECK-LABEL: @reduce_and_pointer_cast( ; CHECK-NEXT: bb: -; CHECK-NEXT: [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>* -; CHECK-NEXT: [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>* -; CHECK-NEXT: [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8 -; CHECK-NEXT: [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8 -; CHECK-NEXT: [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]] -; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8 -; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[TMP0]], 0 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64* +; CHECK-NEXT: [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64* +; CHECK-NEXT: [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[LHS1]], [[RHS2]] +; CHECK-NEXT: ret i1 [[TMP2]] ; bb: %ptr1 = bitcast i8* %arg1 to <8 x i8>* Index: llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll =================================================================== --- llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll +++ llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll @@ -98,13 +98,11 @@ define i1 @reduce_or_pointer_cast(i8* %arg, i8* %arg1) { ; CHECK-LABEL: @reduce_or_pointer_cast( ; CHECK-NEXT: bb: -; CHECK-NEXT: [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>* -; CHECK-NEXT: [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>* -; CHECK-NEXT: [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8 -; CHECK-NEXT: [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8 -; CHECK-NEXT: [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]] -; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8 -; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i8 [[TMP0]], 0 +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64* +; CHECK-NEXT: [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64* +; CHECK-NEXT: [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8 +; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i64 [[LHS1]], [[RHS2]] ; CHECK-NEXT: ret i1 [[DOTNOT]] ; bb: