diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -505,6 +505,7 @@ /// Canonicalize the position of binops relative to shufflevector. Instruction *foldVectorBinop(BinaryOperator &Inst); + Instruction *foldVScaleSplatAddSub(BinaryOperator &Inst); Instruction *foldVectorSelect(SelectInst &Sel); Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1575,6 +1575,70 @@ return true; } +// Combine two Add/Sub operations of the following structure: +// (A +/- splat(B)) +/- splat(C) -> A +/- splat(B +/- C) +// where B and C are splats of VScale multiplied by a number +Instruction *InstCombinerImpl::foldVScaleSplatAddSub(BinaryOperator &Inst) { + // Matches Value when it is either of: + // 1) VScale + // 2) A multiplication of a constant and VScale + // 3) A shift left of VScale on a constant value + auto m_ConstMultipliedVscale = + m_CombineOr(m_CombineOr(m_VScale(), m_c_Mul(m_Constant(), m_VScale())), + m_Shl(m_VScale(), m_Constant())); + + // Splat of the expression from above + auto m_SplatVscale = + m_Shuffle(m_InsertElt(m_Value(), m_ConstMultipliedVscale, m_ZeroInt()), + m_Value(), m_ZeroMask()); + + Instruction *SplatB, *SplatC; + Value *A, *B, *C; + BinaryOperator::BinaryOps NewOpcode1, NewOpcode2; + + if (match(&Inst, + m_c_Add(m_c_Add(m_Value(A), m_SplatVscale), m_SplatVscale)) && + match(&Inst, m_c_Add(m_c_Add(m_Specific(A), m_Instruction(SplatB)), + m_Instruction(SplatC)))) { + // (A + splat(B)) + splat(C) -> A + splat(C + B) + NewOpcode1 = Instruction::Add; + NewOpcode2 = Instruction::Add; + } else if (match(&Inst, + m_c_Add(m_Sub(m_Value(A), m_SplatVscale), m_SplatVscale)) && + match(&Inst, m_c_Add(m_Sub(m_Specific(A), m_Instruction(SplatB)), + m_Instruction(SplatC)))) { + // (A - splat(B)) + splat(C) -> A - splat(B - C) + NewOpcode1 = Instruction::Sub; + NewOpcode2 = Instruction::Sub; + } else if (match(&Inst, + m_Sub(m_c_Add(m_Value(A), m_SplatVscale), m_SplatVscale)) && + match(&Inst, m_Sub(m_c_Add(m_Specific(A), m_Instruction(SplatB)), + m_Instruction(SplatC)))) { + // (A + splat(B)) - splat(C) -> A + splat(B - C) + NewOpcode1 = Instruction::Sub; + NewOpcode2 = Instruction::Add; + } else if (match(&Inst, + m_Sub(m_Sub(m_Value(A), m_SplatVscale), m_SplatVscale)) && + match(&Inst, m_Sub(m_Sub(m_Specific(A), m_Instruction(SplatB)), + m_Instruction(SplatC)))) { + // (A - splat(B)) - splat(C) -> A - splat(B + C) + NewOpcode1 = Instruction::Add; + NewOpcode2 = Instruction::Sub; + } else { + return nullptr; + } + + B = getSplatValue(SplatB); + C = getSplatValue(SplatC); + + // Combine the two splat operations, create a new vector splat and new + // binary operations + auto *NewOp = Builder.CreateBinOp(NewOpcode1, B, C); + auto EC = cast(Inst.getType())->getElementCount(); + auto *SplatNewOp = Builder.CreateVectorSplat(EC, NewOp); + return BinaryOperator::Create(NewOpcode2, A, SplatNewOp); +} + Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { if (!isa(Inst.getType())) return nullptr; @@ -1586,6 +1650,9 @@ assert(cast(RHS->getType())->getElementCount() == cast(Inst.getType())->getElementCount()); + if (auto *I = foldVScaleSplatAddSub(Inst)) + return I; + // If both operands of the binop are vector concatenations, then perform the // narrow binop on each pair of the source operands followed by concatenation // of the results. diff --git a/llvm/test/Transforms/InstCombine/vscale.ll b/llvm/test/Transforms/InstCombine/vscale.ll --- a/llvm/test/Transforms/InstCombine/vscale.ll +++ b/llvm/test/Transforms/InstCombine/vscale.ll @@ -40,6 +40,85 @@ ret i64 %ext } +define @test_add_add_splat_vscale( %A) { +; CHECK-LABEL: @test_add_add_splat_vscale( +; CHECK-NEXT: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP1:%.*]] = shl i64 [[VSCALE]], 1 +; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement poison, i64 [[TMP1]], i64 0 +; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector [[DOTSPLATINSERT]], poison, zeroinitializer +; CHECK-NEXT: [[RESULT:%.*]] = add [[DOTSPLAT]], [[A:%.*]] +; CHECK-NEXT: ret [[RESULT]] +; + %vscale = call i64 @llvm.vscale.i64() + %splatinsert = insertelement poison, i64 %vscale, i64 0 + %splat = shufflevector %splatinsert, poison, zeroinitializer + %nested.operaton = add %A, %splat + %result = add %nested.operaton, %splat + ret %result +} + +define @test_add_sub_splat_vscale( %A) { +; CHECK-LABEL: @test_add_sub_splat_vscale( +; CHECK-NEXT: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP1:%.*]] = shl i64 [[VSCALE]], 1 +; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement poison, i64 [[TMP1]], i64 0 +; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector [[DOTSPLATINSERT]], poison, zeroinitializer +; CHECK-NEXT: [[RESULT:%.*]] = add [[DOTSPLAT]], [[A:%.*]] +; CHECK-NEXT: ret [[RESULT]] +; + %vscale = call i64 @llvm.vscale.i64() + %splatinsert = insertelement poison, i64 %vscale, i64 0 + %splat = shufflevector %splatinsert, poison, zeroinitializer + %1 = mul i64 %vscale, 3 + %splatinsert.3 = insertelement poison, i64 %1, i64 0 + %splat.3 = shufflevector %splatinsert.3, poison, zeroinitializer + %nested.operaton = add %A, %splat.3 + %result = sub %nested.operaton, %splat + ret %result +} + +define @test_sub_add_splat_vscale( %A) { +; CHECK-LABEL: @test_sub_add_splat_vscale( +; CHECK-NEXT: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[DOTSPLATINSERT_NEG:%.*]] = insertelement poison, i64 [[VSCALE]], i64 0 +; CHECK-NEXT: [[DOTSPLAT_NEG:%.*]] = shufflevector [[DOTSPLATINSERT_NEG]], poison, zeroinitializer +; CHECK-NEXT: [[RESULT:%.*]] = add [[DOTSPLAT_NEG]], [[A:%.*]] +; CHECK-NEXT: ret [[RESULT]] +; + %vscale = call i64 @llvm.vscale.i64() + %splatinsert = insertelement poison, i64 %vscale, i64 0 + %splat = shufflevector %splatinsert, poison, zeroinitializer + %1 = shl i64 %vscale, 1 + %splatinsert.2 = insertelement poison, i64 %1, i64 0 + %splat.2 = shufflevector %splatinsert.2, poison, zeroinitializer + %nested.operaton = sub %A, %splat + %result = add %nested.operaton, %splat.2 + ret %result +} + +define @test_sub_sub_splat_vscale( %A) { +; CHECK-LABEL: @test_sub_sub_splat_vscale( +; CHECK-NEXT: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[DOTNEG:%.*]] = mul i64 [[VSCALE]], -3 +; CHECK-NEXT: [[DOTSPLATINSERT_NEG:%.*]] = insertelement poison, i64 [[DOTNEG]], i64 0 +; CHECK-NEXT: [[DOTSPLAT_NEG:%.*]] = shufflevector [[DOTSPLATINSERT_NEG]], poison, zeroinitializer +; CHECK-NEXT: [[RESULT:%.*]] = add [[DOTSPLAT_NEG]], [[A:%.*]] +; CHECK-NEXT: ret [[RESULT]] +; + %vscale = call i64 @llvm.vscale.i64() + %splatinsert = insertelement poison, i64 %vscale, i64 0 + %splat = shufflevector %splatinsert, poison, zeroinitializer + %1 = shl i64 %vscale, 1 + %splatinsert.2 = insertelement poison, i64 %1, i64 0 + %splat.2 = shufflevector %splatinsert.2, poison, zeroinitializer + %nested.operaton = sub %A, %splat + %result = sub %nested.operaton, %splat.2 + ret %result +} + declare i32 @llvm.vscale.i32() +declare i64 @llvm.vscale.i64() + +declare void @use() attributes #0 = { vscale_range(1,16) } diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-gather-scatter.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-gather-scatter.ll --- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-gather-scatter.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-gather-scatter.ll @@ -317,9 +317,12 @@ ; CHECK-NEXT: [[TMP14:%.*]] = call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[TMP15:%.*]] = shl nuw nsw i64 [[TMP14]], 3 ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP15]] -; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[STEP_ADD]], [[DOTSPLAT]] -; CHECK-NEXT: [[TMP16:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -; CHECK-NEXT: br i1 [[TMP16]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP11:![0-9]+]] +; CHECK-NEXT: [[TMP16:%.*]] = shl nuw nsw i64 [[TMP4]], 3 +; CHECK-NEXT: [[DOTSPLATINSERT3:%.*]] = insertelement poison, i64 [[TMP16]], i64 0 +; CHECK-NEXT: [[DOTSPLAT4:%.*]] = shufflevector [[DOTSPLATINSERT3]], poison, zeroinitializer +; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[VEC_IND]], [[DOTSPLAT4]] +; CHECK-NEXT: [[TMP17:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP11:![0-9]+]] ; CHECK: middle.block: ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[N_VEC]], [[N]] ; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]] @@ -330,15 +333,16 @@ ; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ] ; CHECK-NEXT: [[INDVARS_IV_STRIDE2:%.*]] = shl i64 [[INDVARS_IV]], 1 ; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds float, ptr [[B]], i64 [[INDVARS_IV_STRIDE2]] -; CHECK-NEXT: [[TMP17:%.*]] = load float, ptr [[ARRAYIDX]], align 4 +; CHECK-NEXT: [[TMP18:%.*]] = load float, ptr [[ARRAYIDX]], align 4 ; CHECK-NEXT: [[ARRAYIDX2:%.*]] = getelementptr inbounds float, ptr [[A]], i64 [[INDVARS_IV]] -; CHECK-NEXT: store float [[TMP17]], ptr [[ARRAYIDX2]], align 4 +; CHECK-NEXT: store float [[TMP18]], ptr [[ARRAYIDX2]], align 4 ; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1 ; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], [[N]] ; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP12:![0-9]+]] ; CHECK: for.cond.cleanup: ; CHECK-NEXT: ret void ; + entry: br label %for.body diff --git a/llvm/test/Transforms/LoopVectorize/scalable-inductions.ll b/llvm/test/Transforms/LoopVectorize/scalable-inductions.ll --- a/llvm/test/Transforms/LoopVectorize/scalable-inductions.ll +++ b/llvm/test/Transforms/LoopVectorize/scalable-inductions.ll @@ -46,9 +46,12 @@ ; CHECK-NEXT: [[TMP17:%.*]] = call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[TMP18:%.*]] = shl i64 [[TMP17]], 2 ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP18]] -; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[STEP_ADD]], [[DOTSPLAT]] -; CHECK-NEXT: [[TMP19:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -; CHECK-NEXT: br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] +; CHECK-NEXT: [[TMP19:%.*]] = shl i64 [[TMP5]], 2 +; CHECK-NEXT: [[DOTSPLATINSERT3:%.*]] = insertelement poison, i64 [[TMP19]], i64 0 +; CHECK-NEXT: [[DOTSPLAT4:%.*]] = shufflevector [[DOTSPLATINSERT3]], poison, zeroinitializer +; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[VEC_IND]], [[DOTSPLAT4]] +; CHECK-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] ; CHECK: middle.block: ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[N_MOD_VF]], 0 ; CHECK-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]] @@ -58,8 +61,8 @@ ; CHECK: for.body: ; CHECK-NEXT: [[I_08:%.*]] = phi i64 [ [[INC:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ] ; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[B]], i64 [[I_08]] -; CHECK-NEXT: [[TMP20:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[TMP20]], [[I_08]] +; CHECK-NEXT: [[TMP21:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 +; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[TMP21]], [[I_08]] ; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds i64, ptr [[A]], i64 [[I_08]] ; CHECK-NEXT: store i64 [[ADD]], ptr [[ARRAYIDX1]], align 8 ; CHECK-NEXT: [[INC]] = add nuw nsw i64 [[I_08]], 1 @@ -125,9 +128,12 @@ ; CHECK-NEXT: [[TMP14:%.*]] = call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[TMP15:%.*]] = shl i64 [[TMP14]], 1 ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP15]] -; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[STEP_ADD]], [[DOTSPLAT]] -; CHECK-NEXT: [[TMP16:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -; CHECK-NEXT: br i1 [[TMP16]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]] +; CHECK-NEXT: [[TMP16:%.*]] = shl i64 [[TMP5]], 1 +; CHECK-NEXT: [[DOTSPLATINSERT3:%.*]] = insertelement poison, i64 [[TMP16]], i64 0 +; CHECK-NEXT: [[DOTSPLAT4:%.*]] = shufflevector [[DOTSPLATINSERT3]], poison, zeroinitializer +; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[VEC_IND]], [[DOTSPLAT4]] +; CHECK-NEXT: [[TMP17:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]] ; CHECK: middle.block: ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[N_MOD_VF]], 0 ; CHECK-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]] @@ -137,8 +143,8 @@ ; CHECK: for.body: ; CHECK-NEXT: [[I_08:%.*]] = phi i64 [ [[INC:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ] ; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[B]], i64 [[I_08]] -; CHECK-NEXT: [[TMP17:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[TMP17]], [[I_08]] +; CHECK-NEXT: [[TMP18:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 +; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[TMP18]], [[I_08]] ; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds i64, ptr [[A]], i64 [[I_08]] ; CHECK-NEXT: store i64 [[ADD]], ptr [[ARRAYIDX1]], align 8 ; CHECK-NEXT: [[INC]] = add nuw nsw i64 [[I_08]], 1