diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1506,6 +1506,13 @@ return true; } +// Checks if the value V is a multiplication of a constant and VScale, or if it +// is a shift left of VScale and a constant +static bool isConstantMultipliedVScale(Value *V) { + return match(V, m_VScale()) || match(V, m_c_Mul(m_Constant(), m_VScale())) || + match(V, m_Shl(m_VScale(), m_Constant())); +} + Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { if (!isa(Inst.getType())) return nullptr; @@ -1517,6 +1524,58 @@ assert(cast(RHS->getType())->getElementCount() == cast(Inst.getType())->getElementCount()); + // (A +/- splat(B)) +/- splat(C) -> A +/- splat(B +/- C) + // where B and C are splats of VScale multiplied by a number + if (Opcode == Instruction::Add || Opcode == Instruction::Sub) { + Value *A, *B, *C; + C = getSplatValue(RHS); + auto *Nested = dyn_cast(LHS); + if (!C || !isConstantMultipliedVScale(C)) { + C = getSplatValue(RHS); + Nested = dyn_cast(RHS); + } + + if (C && isConstantMultipliedVScale(C) && Nested) { + auto Opcode2 = Nested->getOpcode(); + if (Opcode2 == Instruction::Add || Opcode2 == Instruction::Sub) { + A = Nested->getOperand(0); + B = getSplatValue(Nested->getOperand(1)); + if (!B || !isConstantMultipliedVScale(B)) { + B = getSplatValue((Nested->getOperand(0))); + A = Nested->getOperand(1); + } + + if (B && isConstantMultipliedVScale(B)) { + // Deduce the new opcode based on the positivity of splat operations + auto IsBPositive = Opcode2 == Instruction::Add; + auto IsCPositive = Opcode == Instruction::Add; + Instruction::BinaryOps NewOpcode1; + Instruction::BinaryOps NewOpcode2; + if (IsBPositive && IsCPositive) { + NewOpcode1 = Instruction::Add; + NewOpcode2 = Instruction::Add; + } else if (IsBPositive && !IsCPositive) { + NewOpcode1 = Instruction::Add; + NewOpcode2 = Instruction::Sub; + } else if (!IsBPositive && IsCPositive) { + NewOpcode1 = Instruction::Sub; + NewOpcode2 = Instruction::Sub; + } else /* (!IsBPositive && !IsCPositive) */ { + NewOpcode1 = Instruction::Sub; + NewOpcode2 = Instruction::Add; + } + + // Combine the two splat operations, create a new vector splat and new + // binary operations + auto *NewOp = Builder.CreateBinOp(NewOpcode2, B, C); + auto EC = cast(Inst.getType())->getElementCount(); + auto *SplatNewOp = Builder.CreateVectorSplat(EC, NewOp); + return BinaryOperator::Create(NewOpcode1, A, SplatNewOp); + } + } + } + } + // If both operands of the binop are vector concatenations, then perform the // narrow binop on each pair of the source operands followed by concatenation // of the results. diff --git a/llvm/test/Transforms/InstCombine/vscale.ll b/llvm/test/Transforms/InstCombine/vscale.ll --- a/llvm/test/Transforms/InstCombine/vscale.ll +++ b/llvm/test/Transforms/InstCombine/vscale.ll @@ -40,6 +40,85 @@ ret i64 %ext } +define @test_add_add_splat_vscale( %A) { +; CHECK-LABEL: @test_add_add_splat_vscale( +; CHECK-NEXT: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP1:%.*]] = shl i64 [[VSCALE]], 1 +; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement poison, i64 [[TMP1]], i64 0 +; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector [[DOTSPLATINSERT]], poison, zeroinitializer +; CHECK-NEXT: [[RESULT:%.*]] = add [[DOTSPLAT]], [[A:%.*]] +; CHECK-NEXT: ret [[RESULT]] +; + %vscale = call i64 @llvm.vscale.i64() + %splatinsert = insertelement poison, i64 %vscale, i64 0 + %splat = shufflevector %splatinsert, poison, zeroinitializer + %nested.operaton = add %A, %splat + %result = add %nested.operaton, %splat + ret %result +} + +define @test_add_sub_splat_vscale( %A) { +; CHECK-LABEL: @test_add_sub_splat_vscale( +; CHECK-NEXT: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[TMP1:%.*]] = shl i64 [[VSCALE]], 1 +; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement poison, i64 [[TMP1]], i64 0 +; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector [[DOTSPLATINSERT]], poison, zeroinitializer +; CHECK-NEXT: [[RESULT:%.*]] = add [[DOTSPLAT]], [[A:%.*]] +; CHECK-NEXT: ret [[RESULT]] +; + %vscale = call i64 @llvm.vscale.i64() + %splatinsert = insertelement poison, i64 %vscale, i64 0 + %splat = shufflevector %splatinsert, poison, zeroinitializer + %1 = mul i64 %vscale, 3 + %splatinsert.3 = insertelement poison, i64 %1, i64 0 + %splat.3 = shufflevector %splatinsert.3, poison, zeroinitializer + %nested.operaton = add %A, %splat.3 + %result = sub %nested.operaton, %splat + ret %result +} + +define @test_sub_add_splat_vscale( %A) { +; CHECK-LABEL: @test_sub_add_splat_vscale( +; CHECK-NEXT: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[DOTSPLATINSERT_NEG:%.*]] = insertelement poison, i64 [[VSCALE]], i64 0 +; CHECK-NEXT: [[DOTSPLAT_NEG:%.*]] = shufflevector [[DOTSPLATINSERT_NEG]], poison, zeroinitializer +; CHECK-NEXT: [[RESULT:%.*]] = add [[DOTSPLAT_NEG]], [[A:%.*]] +; CHECK-NEXT: ret [[RESULT]] +; + %vscale = call i64 @llvm.vscale.i64() + %splatinsert = insertelement poison, i64 %vscale, i64 0 + %splat = shufflevector %splatinsert, poison, zeroinitializer + %1 = shl i64 %vscale, 1 + %splatinsert.2 = insertelement poison, i64 %1, i64 0 + %splat.2 = shufflevector %splatinsert.2, poison, zeroinitializer + %nested.operaton = sub %A, %splat + %result = add %nested.operaton, %splat.2 + ret %result +} + +define @test_sub_sub_splat_vscale( %A) { +; CHECK-LABEL: @test_sub_sub_splat_vscale( +; CHECK-NEXT: [[VSCALE:%.*]] = call i64 @llvm.vscale.i64() +; CHECK-NEXT: [[DOTNEG:%.*]] = mul i64 [[VSCALE]], -3 +; CHECK-NEXT: [[DOTSPLATINSERT_NEG:%.*]] = insertelement poison, i64 [[DOTNEG]], i64 0 +; CHECK-NEXT: [[DOTSPLAT_NEG:%.*]] = shufflevector [[DOTSPLATINSERT_NEG]], poison, zeroinitializer +; CHECK-NEXT: [[RESULT:%.*]] = add [[DOTSPLAT_NEG]], [[A:%.*]] +; CHECK-NEXT: ret [[RESULT]] +; + %vscale = call i64 @llvm.vscale.i64() + %splatinsert = insertelement poison, i64 %vscale, i64 0 + %splat = shufflevector %splatinsert, poison, zeroinitializer + %1 = shl i64 %vscale, 1 + %splatinsert.2 = insertelement poison, i64 %1, i64 0 + %splat.2 = shufflevector %splatinsert.2, poison, zeroinitializer + %nested.operaton = sub %A, %splat + %result = sub %nested.operaton, %splat.2 + ret %result +} + declare i32 @llvm.vscale.i32() +declare i64 @llvm.vscale.i64() + +declare void @use() attributes #0 = { vscale_range(1,16) } diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-gather-scatter.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-gather-scatter.ll --- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-gather-scatter.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-gather-scatter.ll @@ -322,9 +322,12 @@ ; CHECK-NEXT: [[TMP15:%.*]] = call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[TMP16:%.*]] = shl nuw nsw i64 [[TMP15]], 3 ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP16]] -; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[STEP_ADD]], [[DOTSPLAT]] -; CHECK-NEXT: [[TMP17:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -; CHECK-NEXT: br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP11:![0-9]+]] +; CHECK-NEXT: [[TMP17:%.*]] = shl nuw nsw i64 [[TMP5]], 3 +; CHECK-NEXT: [[DOTSPLATINSERT3:%.*]] = insertelement poison, i64 [[TMP17]], i64 0 +; CHECK-NEXT: [[DOTSPLAT4:%.*]] = shufflevector [[DOTSPLATINSERT3]], poison, zeroinitializer +; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[VEC_IND]], [[DOTSPLAT4]] +; CHECK-NEXT: [[TMP18:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP18]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP11:![0-9]+]] ; CHECK: middle.block: ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[N_MOD_VF]], 0 ; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]] @@ -335,9 +338,9 @@ ; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ] ; CHECK-NEXT: [[INDVARS_IV_STRIDE2:%.*]] = shl i64 [[INDVARS_IV]], 1 ; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds float, ptr [[B]], i64 [[INDVARS_IV_STRIDE2]] -; CHECK-NEXT: [[TMP18:%.*]] = load float, ptr [[ARRAYIDX]], align 4 +; CHECK-NEXT: [[TMP19:%.*]] = load float, ptr [[ARRAYIDX]], align 4 ; CHECK-NEXT: [[ARRAYIDX2:%.*]] = getelementptr inbounds float, ptr [[A]], i64 [[INDVARS_IV]] -; CHECK-NEXT: store float [[TMP18]], ptr [[ARRAYIDX2]], align 4 +; CHECK-NEXT: store float [[TMP19]], ptr [[ARRAYIDX2]], align 4 ; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1 ; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], [[N]] ; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP12:![0-9]+]] diff --git a/llvm/test/Transforms/LoopVectorize/scalable-inductions.ll b/llvm/test/Transforms/LoopVectorize/scalable-inductions.ll --- a/llvm/test/Transforms/LoopVectorize/scalable-inductions.ll +++ b/llvm/test/Transforms/LoopVectorize/scalable-inductions.ll @@ -46,9 +46,12 @@ ; CHECK-NEXT: [[TMP17:%.*]] = call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[TMP18:%.*]] = shl i64 [[TMP17]], 2 ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP18]] -; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[STEP_ADD]], [[DOTSPLAT]] -; CHECK-NEXT: [[TMP19:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -; CHECK-NEXT: br i1 [[TMP19]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] +; CHECK-NEXT: [[TMP19:%.*]] = shl i64 [[TMP5]], 2 +; CHECK-NEXT: [[DOTSPLATINSERT3:%.*]] = insertelement poison, i64 [[TMP19]], i64 0 +; CHECK-NEXT: [[DOTSPLAT4:%.*]] = shufflevector [[DOTSPLATINSERT3]], poison, zeroinitializer +; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[VEC_IND]], [[DOTSPLAT4]] +; CHECK-NEXT: [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] ; CHECK: middle.block: ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[N_MOD_VF]], 0 ; CHECK-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]] @@ -58,8 +61,8 @@ ; CHECK: for.body: ; CHECK-NEXT: [[I_08:%.*]] = phi i64 [ [[INC:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ] ; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[B]], i64 [[I_08]] -; CHECK-NEXT: [[TMP20:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[TMP20]], [[I_08]] +; CHECK-NEXT: [[TMP21:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 +; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[TMP21]], [[I_08]] ; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds i64, ptr [[A]], i64 [[I_08]] ; CHECK-NEXT: store i64 [[ADD]], ptr [[ARRAYIDX1]], align 8 ; CHECK-NEXT: [[INC]] = add nuw nsw i64 [[I_08]], 1 @@ -125,9 +128,12 @@ ; CHECK-NEXT: [[TMP14:%.*]] = call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[TMP15:%.*]] = shl i64 [[TMP14]], 1 ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP15]] -; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[STEP_ADD]], [[DOTSPLAT]] -; CHECK-NEXT: [[TMP16:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] -; CHECK-NEXT: br i1 [[TMP16]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]] +; CHECK-NEXT: [[TMP16:%.*]] = shl i64 [[TMP5]], 1 +; CHECK-NEXT: [[DOTSPLATINSERT3:%.*]] = insertelement poison, i64 [[TMP16]], i64 0 +; CHECK-NEXT: [[DOTSPLAT4:%.*]] = shufflevector [[DOTSPLATINSERT3]], poison, zeroinitializer +; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[VEC_IND]], [[DOTSPLAT4]] +; CHECK-NEXT: [[TMP17:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]] ; CHECK: middle.block: ; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[N_MOD_VF]], 0 ; CHECK-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]] @@ -137,8 +143,8 @@ ; CHECK: for.body: ; CHECK-NEXT: [[I_08:%.*]] = phi i64 [ [[INC:%.*]], [[FOR_BODY]] ], [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ] ; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i64, ptr [[B]], i64 [[I_08]] -; CHECK-NEXT: [[TMP17:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[TMP17]], [[I_08]] +; CHECK-NEXT: [[TMP18:%.*]] = load i64, ptr [[ARRAYIDX]], align 8 +; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[TMP18]], [[I_08]] ; CHECK-NEXT: [[ARRAYIDX1:%.*]] = getelementptr inbounds i64, ptr [[A]], i64 [[I_08]] ; CHECK-NEXT: store i64 [[ADD]], ptr [[ARRAYIDX1]], align 8 ; CHECK-NEXT: [[INC]] = add nuw nsw i64 [[I_08]], 1