Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7079,10 +7079,14 @@ match(RedOp, m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) { if (match(Op0, m_ZExtOrSExt(m_Value())) && Op0->getOpcode() == Op1->getOpcode() && - Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() && !TheLoop->isLoopInvariant(Op0) && !TheLoop->isLoopInvariant(Op1)) { bool IsUnsigned = isa(Op0); - auto *ExtType = VectorType::get(Op0->getOperand(0)->getType(), VectorTy); + Type *Op0Ty = Op0->getOperand(0)->getType(); + Type *Op1Ty = Op1->getOperand(0)->getType(); + Type *LargestOpTy = + Op0Ty->getIntegerBitWidth() < Op1Ty->getIntegerBitWidth() ? Op1Ty + : Op0Ty; + auto *ExtType = VectorType::get(LargestOpTy, VectorTy); // Matched reduce(mul(ext, ext)) InstructionCost ExtCost = TTI.getCastInstrCost(Op0->getOpcode(), VectorTy, ExtType, Index: llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll =================================================================== --- llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll +++ llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll @@ -1133,33 +1133,33 @@ ret i8 %r.0.lcssa } -; 4x or 8x as different types +; 8x as different types define i32 @red_mla_ext_s8_s16_s32(i8* noalias nocapture readonly %A, i16* noalias nocapture readonly %B, i32 %n) #0 { ; CHECK-LABEL: @red_mla_ext_s8_s16_s32( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP9_NOT:%.*]] = icmp eq i32 [[N:%.*]], 0 ; CHECK-NEXT: br i1 [[CMP9_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_PH:%.*]] ; CHECK: vector.ph: -; CHECK-NEXT: [[N_RND_UP:%.*]] = add i32 [[N]], 3 -; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[N_RND_UP]], -4 +; CHECK-NEXT: [[N_RND_UP:%.*]] = add i32 [[N]], 7 +; CHECK-NEXT: [[N_VEC:%.*]] = and i32 [[N_RND_UP]], -8 ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: ; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]]) +; CHECK-NEXT: [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]]) ; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[A:%.*]], i32 [[INDEX]] -; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <4 x i8>* -; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0v4i8(<4 x i8>* [[TMP1]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison) -; CHECK-NEXT: [[TMP2:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD]] to <4 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <8 x i8>* +; CHECK-NEXT: [[WIDE_MASKED_LOAD:%.*]] = call <8 x i8> @llvm.masked.load.v8i8.p0v8i8(<8 x i8>* [[TMP1]], i32 1, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i8> poison) +; CHECK-NEXT: [[TMP2:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32> ; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i16, i16* [[B:%.*]], i32 [[INDEX]] -; CHECK-NEXT: [[TMP4:%.*]] = bitcast i16* [[TMP3]] to <4 x i16>* -; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i16> @llvm.masked.load.v4i16.p0v4i16(<4 x i16>* [[TMP4]], i32 2, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i16> poison) -; CHECK-NEXT: [[TMP5:%.*]] = sext <4 x i16> [[WIDE_MASKED_LOAD1]] to <4 x i32> -; CHECK-NEXT: [[TMP6:%.*]] = mul nsw <4 x i32> [[TMP5]], [[TMP2]] -; CHECK-NEXT: [[TMP7:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP6]], <4 x i32> zeroinitializer -; CHECK-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP7]]) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast i16* [[TMP3]] to <8 x i16>* +; CHECK-NEXT: [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0v8i16(<8 x i16>* [[TMP4]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison) +; CHECK-NEXT: [[TMP5:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = mul nsw <8 x i32> [[TMP5]], [[TMP2]] +; CHECK-NEXT: [[TMP7:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP6]], <8 x i32> zeroinitializer +; CHECK-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP7]]) ; CHECK-NEXT: [[TMP9]] = add i32 [[TMP8]], [[VEC_PHI]] -; CHECK-NEXT: [[INDEX_NEXT]] = add i32 [[INDEX]], 4 +; CHECK-NEXT: [[INDEX_NEXT]] = add i32 [[INDEX]], 8 ; CHECK-NEXT: [[TMP10:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]] ; CHECK-NEXT: br i1 [[TMP10]], label [[FOR_COND_CLEANUP]], label [[VECTOR_BODY]], !llvm.loop [[LOOP26:![0-9]+]] ; CHECK: for.cond.cleanup: