diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5882,6 +5882,54 @@ return nullptr; } +/// This function folds patterns produced by lowering of reduce idioms, such as +/// llvm.vector.reduce.and which are lowered into instruction chains. This code +/// attempts to generate fewer number of scalar comparisons instead of vector +/// comparisons when possible. +static Instruction *foldReductionIdiom(ICmpInst &I, + InstCombiner::BuilderTy &Builder, + const DataLayout &DL) { + if (I.getType()->isVectorTy()) + return nullptr; + ICmpInst::Predicate OuterPred, InnerPred; + Value *LHS, *RHS; + + // Match lowering of @llvm.vector.reduce.and. Turn + /// %vec_ne = icmp ne <8 x i8> %lhs, %rhs + /// %scalar_ne = bitcast <8 x i1> %vec_ne to i8 + /// %all_eq = icmp eq i8 %scalar_ne, 0 + /// + /// into + /// + /// %lhs.scalar = bitcast <8 x i8> %lhs to i64 + /// %rhs.scalar = bitcast <8 x i8> %rhs to i64 + /// %all_eq = icmp eq i64 %lhs.scalar, %rhs.scalar + if (!match(&I, m_ICmp(OuterPred, + m_OneUse(m_BitCast(m_OneUse( + m_ICmp(InnerPred, m_Value(LHS), m_Value(RHS))))), + m_Zero()))) + return nullptr; + auto *LHSTy = dyn_cast(LHS->getType()); + if (!LHSTy || !LHSTy->getElementType()->isIntegerTy()) + return nullptr; + unsigned NumBits = + LHSTy->getNumElements() * LHSTy->getElementType()->getIntegerBitWidth(); + // TODO: Relax this to "not wider than max legal integer type"? + if (!DL.isLegalInteger(NumBits)) + return nullptr; + + // TODO: Generalize to isEquality and support other patterns. + if (OuterPred == ICmpInst::ICMP_EQ && InnerPred == ICmpInst::ICMP_NE) { + auto *ScalarTy = Builder.getIntNTy(NumBits); + LHS = Builder.CreateBitCast(LHS, ScalarTy, LHS->getName() + ".scalar"); + RHS = Builder.CreateBitCast(RHS, ScalarTy, RHS->getName() + ".scalar"); + return ICmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, LHS, RHS, + I.getName()); + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { bool Changed = false; const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -6124,6 +6172,9 @@ if (Instruction *Res = foldICmpInvariantGroup(I)) return Res; + if (Instruction *Res = foldReductionIdiom(I, Builder, DL)) + return Res; + return Changed ? &I : nullptr; } diff --git a/llvm/test/Transforms/InstCombine/icmp-vec.ll b/llvm/test/Transforms/InstCombine/icmp-vec.ll --- a/llvm/test/Transforms/InstCombine/icmp-vec.ll +++ b/llvm/test/Transforms/InstCombine/icmp-vec.ll @@ -404,9 +404,9 @@ define i1 @eq_cast_eq-1(<2 x i4> %x, <2 x i4> %y) { ; CHECK-LABEL: @eq_cast_eq-1( -; CHECK-NEXT: [[IC:%.*]] = icmp ne <2 x i4> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i1> [[IC]] to i2 -; CHECK-NEXT: [[R:%.*]] = icmp eq i2 [[TMP1]], 0 +; CHECK-NEXT: [[X_SCALAR:%.*]] = bitcast <2 x i4> [[X:%.*]] to i8 +; CHECK-NEXT: [[Y_SCALAR:%.*]] = bitcast <2 x i4> [[Y:%.*]] to i8 +; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[X_SCALAR]], [[Y_SCALAR]] ; CHECK-NEXT: ret i1 [[R]] ; %ic = icmp eq <2 x i4> %x, %y diff --git a/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll --- a/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll +++ b/llvm/test/Transforms/InstCombine/reduction-and-sext-zext-i1.ll @@ -100,14 +100,12 @@ define i1 @reduce_and_pointer_cast(i8* %arg, i8* %arg1) { ; CHECK-LABEL: @reduce_and_pointer_cast( ; CHECK-NEXT: bb: -; CHECK-NEXT: [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>* -; CHECK-NEXT: [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>* -; CHECK-NEXT: [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8 -; CHECK-NEXT: [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8 -; CHECK-NEXT: [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]] -; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8 -; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[TMP0]], 0 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64* +; CHECK-NEXT: [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64* +; CHECK-NEXT: [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[LHS1]], [[RHS2]] +; CHECK-NEXT: ret i1 [[TMP2]] ; bb: %ptr1 = bitcast i8* %arg1 to <8 x i8>* @@ -144,14 +142,12 @@ define i1 @reduce_and_pointer_cast_ne(i8* %arg, i8* %arg1) { ; CHECK-LABEL: @reduce_and_pointer_cast_ne( ; CHECK-NEXT: bb: -; CHECK-NEXT: [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>* -; CHECK-NEXT: [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>* -; CHECK-NEXT: [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8 -; CHECK-NEXT: [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8 -; CHECK-NEXT: [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]] -; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8 -; CHECK-NEXT: [[TMP1:%.*]] = icmp ne i8 [[TMP0]], 0 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64* +; CHECK-NEXT: [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64* +; CHECK-NEXT: [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8 +; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i64 [[LHS1]], [[RHS2]] +; CHECK-NEXT: ret i1 [[TMP2]] ; bb: %ptr1 = bitcast i8* %arg1 to <8 x i8>* diff --git a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll --- a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll +++ b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll @@ -100,13 +100,11 @@ define i1 @reduce_or_pointer_cast(i8* %arg, i8* %arg1) { ; CHECK-LABEL: @reduce_or_pointer_cast( ; CHECK-NEXT: bb: -; CHECK-NEXT: [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>* -; CHECK-NEXT: [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>* -; CHECK-NEXT: [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8 -; CHECK-NEXT: [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8 -; CHECK-NEXT: [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]] -; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8 -; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i8 [[TMP0]], 0 +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64* +; CHECK-NEXT: [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64* +; CHECK-NEXT: [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8 +; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i64 [[LHS1]], [[RHS2]] ; CHECK-NEXT: ret i1 [[DOTNOT]] ; bb: