diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7006,22 +7006,41 @@ match(RedOp, m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) { if (match(Op0, m_ZExtOrSExt(m_Value())) && Op0->getOpcode() == Op1->getOpcode() && - Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() && !TheLoop->isLoopInvariant(Op0) && !TheLoop->isLoopInvariant(Op1)) { bool IsUnsigned = isa(Op0); - auto *ExtType = VectorType::get(Op0->getOperand(0)->getType(), VectorTy); - // Matched reduce(mul(ext, ext)) - InstructionCost ExtCost = - TTI.getCastInstrCost(Op0->getOpcode(), VectorTy, ExtType, - TTI::CastContextHint::None, CostKind, Op0); + Type *Op0Ty = Op0->getOperand(0)->getType(); + Type *Op1Ty = Op1->getOperand(0)->getType(); + Type *LargestOpTy = + Op0Ty->getIntegerBitWidth() < Op1Ty->getIntegerBitWidth() ? Op1Ty + : Op0Ty; + auto *ExtType = VectorType::get(LargestOpTy, VectorTy); + + // Matched reduce(mul(ext(A), ext(B))), where the two ext may be of + // different sizes. We take the largest type as the ext to reduce, and add + // the remaining cost as, for example reduce(mul(ext(ext(A)), ext(B))). + InstructionCost ExtCost0 = TTI.getCastInstrCost( + Op0->getOpcode(), VectorTy, VectorType::get(Op0Ty, VectorTy), + TTI::CastContextHint::None, CostKind, Op0); + InstructionCost ExtCost1 = TTI.getCastInstrCost( + Op1->getOpcode(), VectorTy, VectorType::get(Op1Ty, VectorTy), + TTI::CastContextHint::None, CostKind, Op1); InstructionCost MulCost = TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind); InstructionCost RedCost = TTI.getExtendedAddReductionCost( /*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind); + InstructionCost ExtraExtCost = 0; + if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) { + Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1; + ExtraExtCost = TTI.getCastInstrCost( + ExtraExtOp->getOpcode(), ExtType, + VectorType::get(ExtraExtOp->getOperand(0)->getType(), VectorTy), + TTI::CastContextHint::None, CostKind, ExtraExtOp); + } - if (RedCost.isValid() && RedCost < ExtCost * 2 + MulCost + BaseCost) + if (RedCost.isValid() && + (RedCost + ExtraExtCost) < (ExtCost0 + ExtCost1 + MulCost + BaseCost)) return I == RetI ? RedCost : 0; } else if (!match(I, m_ZExtOrSExt(m_Value()))) { // Matched reduce(mul()) diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll --- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll +++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll @@ -1133,33 +1133,33 @@ ret i8 %r.0.lcssa } -; 4x or 8x as different types +; 8x as different types define i32 @red_mla_ext_s8_s16_s32(i8* noalias nocapture readonly %A, i16* noalias nocapture readonly %B, i32 %n) #0 { ; CHECK-LABEL: @red_mla_ext_s8_s16_s32( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP9_NOT:%.*]] = icmp eq i32 [[N:%.*]], 0 ; CHECK-NEXT: br i1 [[CMP9_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_PH:%.*]] ; CHECK: vector.ph: -; CHECK-NEXT: [[N_RND_UP:%.*]] = add i32 [[N]], 3 -; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[N_RND_UP]], -4 +; CHECK-NEXT: [[N_RND_UP:%.*]] = add i32 [[N]], 7 +; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[N_RND_UP]], -8 ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: ; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]]) +; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]]) ; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[A:%.*]], i32 [[INDEX]] -; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <4 x i8>* -; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0v4i8(<4 x i8>* [[TMP1]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison) -; CHECK-NEXT: [[TMP2:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD]] to <4 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <8 x i8>* +; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <8 x i8> @llvm.masked.load.v8i8.p0v8i8(<8 x i8>* [[TMP1]], i32 1, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i8> poison) +; CHECK-NEXT: [[TMP2:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32> ; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i16, i16* [[B:%.*]], i32 [[INDEX]] -; CHECK-NEXT: [[TMP4:%.*]] = bitcast i16* [[TMP3]] to <4 x i16>* -; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i16> @llvm.masked.load.v4i16.p0v4i16(<4 x i16>* [[TMP4]], i32 2, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i16> poison) -; CHECK-NEXT: [[TMP5:%.*]] = sext <4 x i16> [[WIDE_MASKED_LOAD1]] to <4 x i32> -; CHECK-NEXT: [[TMP6:%.*]] = mul nsw <4 x i32> [[TMP5]], [[TMP2]] -; CHECK-NEXT: [[TMP7:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP6]], <4 x i32> zeroinitializer -; CHECK-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP7]]) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast i16* [[TMP3]] to <8 x i16>* +; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0v8i16(<8 x i16>* [[TMP4]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison) +; CHECK-NEXT: [[TMP5:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = mul nsw <8 x i32> [[TMP5]], [[TMP2]] +; CHECK-NEXT: [[TMP7:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP6]], <8 x i32> zeroinitializer +; CHECK-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP7]]) ; CHECK-NEXT: [[TMP9]] = add i32 [[TMP8]], [[VEC_PHI]] -; CHECK-NEXT: [[INDEX_NEXT]] = add i32 [[INDEX]], 4 +; CHECK-NEXT: [[INDEX_NEXT]] = add i32 [[INDEX]], 8 ; CHECK-NEXT: [[TMP10:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]] ; CHECK-NEXT: br i1 [[TMP10]], label [[FOR_COND_CLEANUP]], label [[VECTOR_BODY]], !llvm.loop [[LOOP26:![0-9]+]] ; CHECK: for.cond.cleanup: