diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -615,6 +615,7 @@ /// Iteratively sink the scalarized operands of a predicated instruction into /// the block that was created for it. void sinkScalarOperands(Instruction *PredInst); + void sinkScalarOperands(VPlan &Plan); /// Shrinks vector element sizes to the smallest bitwidth they can be legally /// represented as. @@ -9025,6 +9026,8 @@ } } + VPlanTransforms::sinkScalarOperands(*Plan); + std::string PlanName; raw_string_ostream RSO(PlanName); ElementCount VF = Range.Start; diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -1262,6 +1262,8 @@ bool isPacked() const { return AlsoPack; } bool isPredicated() const { return IsPredicated; } + + void setIsPredicated(bool New = true) { IsPredicated = New; } }; /// A recipe for generating conditional branches on the bits of a mask. diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -26,6 +26,8 @@ Loop *OrigLoop, VPlanPtr &Plan, LoopVectorizationLegality::InductionList &Inductions, SmallPtrSetImpl &DeadInstructions, ScalarEvolution &SE); + + static bool sinkScalarOperands(VPlan &Plan); }; } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -99,3 +99,43 @@ } } } + +bool VPlanTransforms::sinkScalarOperands(VPlan &Plan) { + ReversePostOrderTraversal> + RPOT(VPBlockRecursiveTraversalWrapper(Plan.getEntry())); + + bool Changed = false; + for (VPBlockBase *Base : RPOT) { + auto *VPBB = dyn_cast(Base); + if (!VPBB) + continue; + + for (auto &Recipe : make_early_inc_range(*VPBB)) { + auto *RepR = dyn_cast(&Recipe); + if (!RepR || !RepR->isPredicated()) + continue; + SetVector WorkList(RepR->op_begin(), RepR->op_end()); + + while (!WorkList.empty()) { + auto *C = WorkList.pop_back_val(); + auto *Current = dyn_cast_or_null(C->Def); + if (!Current || Current->getParent() == RepR->getParent() || + cast(Current->getUnderlyingValue()) + ->mayHaveSideEffects()) + continue; + + if (any_of(Current->users(), [RepR](VPUser *U) { + auto *UI = dyn_cast(U); + return UI && UI->getParent() != RepR->getParent(); + })) + continue; + + Current->moveBefore(*RepR->getParent(), RepR->getParent()->begin()); + Current->setIsPredicated(); + WorkList.insert(Current->op_begin(), Current->op_end()); + Changed = true; + } + } + } + return Changed; +} diff --git a/llvm/test/Transforms/LoopVectorize/X86/small-size.ll b/llvm/test/Transforms/LoopVectorize/X86/small-size.ll --- a/llvm/test/Transforms/LoopVectorize/X86/small-size.ll +++ b/llvm/test/Transforms/LoopVectorize/X86/small-size.ll @@ -567,40 +567,40 @@ ; CHECK-NEXT: [[TMP17:%.*]] = extractelement <4 x i1> [[TMP1]], i32 0 ; CHECK-NEXT: br i1 [[TMP17]], label [[PRED_STORE_IF:%.*]], label [[PRED_STORE_CONTINUE:%.*]] ; CHECK: pred.store.if: +; CHECK-NEXT: [[NEXT_GEP7:%.*]] = getelementptr i32, i32* [[DST:%.*]], i64 [[INDEX]] ; CHECK-NEXT: [[TMP18:%.*]] = zext i16 [[TMP4]] to i32 ; CHECK-NEXT: [[TMP19:%.*]] = shl nuw nsw i32 [[TMP18]], 7 -; CHECK-NEXT: [[NEXT_GEP7:%.*]] = getelementptr i32, i32* [[DST:%.*]], i64 [[INDEX]] ; CHECK-NEXT: store i32 [[TMP19]], i32* [[NEXT_GEP7]], align 4 ; CHECK-NEXT: br label [[PRED_STORE_CONTINUE]] ; CHECK: pred.store.continue: ; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x i1> [[TMP1]], i32 1 ; CHECK-NEXT: br i1 [[TMP20]], label [[PRED_STORE_IF17:%.*]], label [[PRED_STORE_CONTINUE18:%.*]] ; CHECK: pred.store.if17: -; CHECK-NEXT: [[TMP21:%.*]] = zext i16 [[TMP8]] to i32 -; CHECK-NEXT: [[TMP22:%.*]] = shl nuw nsw i32 [[TMP21]], 7 -; CHECK-NEXT: [[TMP23:%.*]] = or i64 [[INDEX]], 1 -; CHECK-NEXT: [[NEXT_GEP8:%.*]] = getelementptr i32, i32* [[DST]], i64 [[TMP23]] -; CHECK-NEXT: store i32 [[TMP22]], i32* [[NEXT_GEP8]], align 4 +; CHECK-NEXT: [[TMP21:%.*]] = or i64 [[INDEX]], 1 +; CHECK-NEXT: [[NEXT_GEP8:%.*]] = getelementptr i32, i32* [[DST]], i64 [[TMP21]] +; CHECK-NEXT: [[TMP22:%.*]] = zext i16 [[TMP8]] to i32 +; CHECK-NEXT: [[TMP23:%.*]] = shl nuw nsw i32 [[TMP22]], 7 +; CHECK-NEXT: store i32 [[TMP23]], i32* [[NEXT_GEP8]], align 4 ; CHECK-NEXT: br label [[PRED_STORE_CONTINUE18]] ; CHECK: pred.store.continue18: ; CHECK-NEXT: [[TMP24:%.*]] = extractelement <4 x i1> [[TMP1]], i32 2 ; CHECK-NEXT: br i1 [[TMP24]], label [[PRED_STORE_IF19:%.*]], label [[PRED_STORE_CONTINUE20:%.*]] ; CHECK: pred.store.if19: -; CHECK-NEXT: [[TMP25:%.*]] = zext i16 [[TMP12]] to i32 -; CHECK-NEXT: [[TMP26:%.*]] = shl nuw nsw i32 [[TMP25]], 7 -; CHECK-NEXT: [[TMP27:%.*]] = or i64 [[INDEX]], 2 -; CHECK-NEXT: [[NEXT_GEP9:%.*]] = getelementptr i32, i32* [[DST]], i64 [[TMP27]] -; CHECK-NEXT: store i32 [[TMP26]], i32* [[NEXT_GEP9]], align 4 +; CHECK-NEXT: [[TMP25:%.*]] = or i64 [[INDEX]], 2 +; CHECK-NEXT: [[NEXT_GEP9:%.*]] = getelementptr i32, i32* [[DST]], i64 [[TMP25]] +; CHECK-NEXT: [[TMP26:%.*]] = zext i16 [[TMP12]] to i32 +; CHECK-NEXT: [[TMP27:%.*]] = shl nuw nsw i32 [[TMP26]], 7 +; CHECK-NEXT: store i32 [[TMP27]], i32* [[NEXT_GEP9]], align 4 ; CHECK-NEXT: br label [[PRED_STORE_CONTINUE20]] ; CHECK: pred.store.continue20: ; CHECK-NEXT: [[TMP28:%.*]] = extractelement <4 x i1> [[TMP1]], i32 3 ; CHECK-NEXT: br i1 [[TMP28]], label [[PRED_STORE_IF21:%.*]], label [[PRED_STORE_CONTINUE22]] ; CHECK: pred.store.if21: -; CHECK-NEXT: [[TMP29:%.*]] = zext i16 [[TMP16]] to i32 -; CHECK-NEXT: [[TMP30:%.*]] = shl nuw nsw i32 [[TMP29]], 7 -; CHECK-NEXT: [[TMP31:%.*]] = or i64 [[INDEX]], 3 -; CHECK-NEXT: [[NEXT_GEP10:%.*]] = getelementptr i32, i32* [[DST]], i64 [[TMP31]] -; CHECK-NEXT: store i32 [[TMP30]], i32* [[NEXT_GEP10]], align 4 +; CHECK-NEXT: [[TMP29:%.*]] = or i64 [[INDEX]], 3 +; CHECK-NEXT: [[NEXT_GEP10:%.*]] = getelementptr i32, i32* [[DST]], i64 [[TMP29]] +; CHECK-NEXT: [[TMP30:%.*]] = zext i16 [[TMP16]] to i32 +; CHECK-NEXT: [[TMP31:%.*]] = shl nuw nsw i32 [[TMP30]], 7 +; CHECK-NEXT: store i32 [[TMP31]], i32* [[NEXT_GEP10]], align 4 ; CHECK-NEXT: br label [[PRED_STORE_CONTINUE22]] ; CHECK: pred.store.continue22: ; CHECK-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], 4 diff --git a/llvm/test/Transforms/LoopVectorize/if-pred-stores.ll b/llvm/test/Transforms/LoopVectorize/if-pred-stores.ll --- a/llvm/test/Transforms/LoopVectorize/if-pred-stores.ll +++ b/llvm/test/Transforms/LoopVectorize/if-pred-stores.ll @@ -130,11 +130,11 @@ ; VEC-NEXT: [[TMP8:%.*]] = extractelement <2 x i1> [[TMP4]], i32 1 ; VEC-NEXT: br i1 [[TMP8]], label [[PRED_STORE_IF1:%.*]], label [[PRED_STORE_CONTINUE2]] ; VEC: pred.store.if1: -; VEC-NEXT: [[TMP9:%.*]] = extractelement <2 x i32> [[WIDE_LOAD]], i32 1 -; VEC-NEXT: [[TMP10:%.*]] = add nsw i32 [[TMP9]], 20 -; VEC-NEXT: [[TMP11:%.*]] = add i64 [[INDEX]], 1 -; VEC-NEXT: [[TMP12:%.*]] = getelementptr inbounds i32, i32* [[F]], i64 [[TMP11]] -; VEC-NEXT: store i32 [[TMP10]], i32* [[TMP12]], align 4 +; VEC-NEXT: [[TMP9:%.*]] = add i64 [[INDEX]], 1 +; VEC-NEXT: [[TMP10:%.*]] = getelementptr inbounds i32, i32* [[F]], i64 [[TMP9]] +; VEC-NEXT: [[TMP11:%.*]] = extractelement <2 x i32> [[WIDE_LOAD]], i32 1 +; VEC-NEXT: [[TMP12:%.*]] = add nsw i32 [[TMP11]], 20 +; VEC-NEXT: store i32 [[TMP12]], i32* [[TMP10]], align 4 ; VEC-NEXT: br label [[PRED_STORE_CONTINUE2]] ; VEC: pred.store.continue2: ; VEC-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], 2 @@ -565,12 +565,12 @@ ; VEC-NEXT: [[TMP9:%.*]] = extractelement <2 x i1> [[BROADCAST_SPLAT]], i32 1 ; VEC-NEXT: br i1 [[TMP9]], label [[PRED_STORE_IF2:%.*]], label [[PRED_STORE_CONTINUE3]] ; VEC: pred.store.if2: -; VEC-NEXT: [[TMP10:%.*]] = extractelement <2 x i8> [[WIDE_LOAD]], i32 1 -; VEC-NEXT: [[TMP11:%.*]] = zext i8 [[TMP10]] to i32 -; VEC-NEXT: [[TMP12:%.*]] = trunc i32 [[TMP11]] to i8 -; VEC-NEXT: [[TMP13:%.*]] = add i64 [[INDEX]], 1 -; VEC-NEXT: [[TMP14:%.*]] = getelementptr i8, i8* undef, i64 [[TMP13]] -; VEC-NEXT: store i8 [[TMP12]], i8* [[TMP14]], align 1 +; VEC-NEXT: [[TMP10:%.*]] = add i64 [[INDEX]], 1 +; VEC-NEXT: [[TMP11:%.*]] = getelementptr i8, i8* undef, i64 [[TMP10]] +; VEC-NEXT: [[TMP12:%.*]] = extractelement <2 x i8> [[WIDE_LOAD]], i32 1 +; VEC-NEXT: [[TMP13:%.*]] = zext i8 [[TMP12]] to i32 +; VEC-NEXT: [[TMP14:%.*]] = trunc i32 [[TMP13]] to i8 +; VEC-NEXT: store i8 [[TMP14]], i8* [[TMP11]], align 1 ; VEC-NEXT: br label [[PRED_STORE_CONTINUE3]] ; VEC: pred.store.continue3: ; VEC-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], 2 diff --git a/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll b/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll --- a/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll +++ b/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll @@ -14,7 +14,6 @@ ; CHECK-NEXT: loop: ; CHECK-NEXT: WIDEN-INDUCTION %indvars.iv = phi 0, %indvars.iv.next ; CHECK-NEXT: EMIT vp<%2> = icmp ule ir<%indvars.iv> vp<%0> -; CHECK-NEXT: REPLICATE ir<%gep.b> = getelementptr ir<@b>, ir<0>, ir<%indvars.iv> ; CHECK-NEXT: Successor(s): pred.load ; CHECK: pred.load: { @@ -24,6 +23,7 @@ ; CHECK-NEXT: CondBit: vp<%2> (loop) ; CHECK: pred.load.if: +; CHECK-NEXT: REPLICATE ir<%gep.b> = getelementptr ir<@b>, ir<0>, ir<%indvars.iv> ; CHECK-NEXT: REPLICATE ir<%lv.b> = load ir<%gep.b> ; CHECK-NEXT: Successor(s): pred.load.continue @@ -33,9 +33,6 @@ ; CHECK-NEXT: } ; CHECK: loop.0: -; CHECK-NEXT: REPLICATE ir<%add> = add vp<%5>, ir<10> -; CHECK-NEXT: REPLICATE ir<%mul> = mul ir<2>, ir<%add> -; CHECK-NEXT: REPLICATE ir<%gep.a> = getelementptr ir<@a>, ir<0>, ir<%indvars.iv> ; CHECK-NEXT: Successor(s): pred.store ; CHECK: pred.store: { @@ -45,6 +42,9 @@ ; CHECK-NEXT: CondBit: vp<%2> (loop) ; CHECK: pred.store.if: +; CHECK-NEXT: REPLICATE ir<%add> = add vp<%5>, ir<10> +; CHECK-NEXT: REPLICATE ir<%mul> = mul ir<2>, ir<%add> +; CHECK-NEXT: REPLICATE ir<%gep.a> = getelementptr ir<@a>, ir<0>, ir<%indvars.iv> ; CHECK-NEXT: REPLICATE store ir<%mul>, ir<%gep.a> ; CHECK-NEXT: Successor(s): pred.store.continue @@ -85,7 +85,6 @@ ; CHECK-NEXT: loop: ; CHECK-NEXT: WIDEN-INDUCTION %indvars.iv = phi 0, %indvars.iv.next ; CHECK-NEXT: EMIT vp<%2> = icmp ule ir<%indvars.iv> vp<%0> -; CHECK-NEXT: REPLICATE ir<%gep.b> = getelementptr ir<@b>, ir<0>, ir<%indvars.iv> ; CHECK-NEXT: Successor(s): pred.load ; CHECK: pred.load: { @@ -95,6 +94,7 @@ ; CHECK-NEXT: CondBit: vp<%2> (loop) ; CHECK: pred.load.if: +; CHECK-NEXT: REPLICATE ir<%gep.b> = getelementptr ir<@b>, ir<0>, ir<%indvars.iv> ; CHECK-NEXT: REPLICATE ir<%lv.b> = load ir<%gep.b> ; CHECK-NEXT: Successor(s): pred.load.continue @@ -104,9 +104,7 @@ ; CHECK-NEXT: } ; CHECK: loop.0: -; CHECK-NEXT: REPLICATE ir<%add> = add vp<%5>, ir<10> ; CHECK-NEXT: WIDEN ir<%mul> = mul ir<%indvars.iv>, ir<2> -; CHECK-NEXT: REPLICATE ir<%gep.a> = getelementptr ir<@a>, ir<0>, ir<%mul> ; CHECK-NEXT: Successor(s): pred.store ; CHECK: pred.store: { @@ -116,6 +114,8 @@ ; CHECK-NEXT: CondBit: vp<%2> (loop) ; CHECK: pred.store.if: +; CHECK-NEXT: REPLICATE ir<%add> = add vp<%5>, ir<10> +; CHECK-NEXT: REPLICATE ir<%gep.a> = getelementptr ir<@a>, ir<0>, ir<%mul> ; CHECK-NEXT: REPLICATE store ir<%add>, ir<%gep.a> ; CHECK-NEXT: Successor(s): pred.store.continue @@ -156,7 +156,6 @@ ; CHECK-NEXT: loop: ; CHECK-NEXT: WIDEN-INDUCTION %indvars.iv = phi 0, %indvars.iv.next ; CHECK-NEXT: EMIT vp<%2> = icmp ule ir<%indvars.iv> vp<%0> -; CHECK-NEXT: REPLICATE ir<%gep.b> = getelementptr ir<@b>, ir<0>, ir<%indvars.iv> ; CHECK-NEXT: Successor(s): pred.load ; CHECK: pred.load: { @@ -166,6 +165,7 @@ ; CHECK-NEXT: CondBit: vp<%2> (loop) ; CHECK: pred.load.if: +; CHECK-NEXT: REPLICATE ir<%gep.b> = getelementptr ir<@b>, ir<0>, ir<%indvars.iv> ; CHECK-NEXT: REPLICATE ir<%lv.b> = load ir<%gep.b> (S->V) ; CHECK-NEXT: Successor(s): pred.load.continue @@ -177,7 +177,6 @@ ; CHECK: loop.0: ; CHECK-NEXT: WIDEN ir<%add> = add vp<%5>, ir<10> ; CHECK-NEXT: WIDEN ir<%mul> = mul ir<%indvars.iv>, ir<%add> -; CHECK-NEXT: REPLICATE ir<%gep.a> = getelementptr ir<@a>, ir<0>, ir<%mul> ; CHECK-NEXT: Successor(s): pred.store ; CHECK: pred.store: { @@ -187,6 +186,7 @@ ; CHECK-NEXT: CondBit: vp<%2> (loop) ; CHECK: pred.store.if: +; CHECK-NEXT: REPLICATE ir<%gep.a> = getelementptr ir<@a>, ir<0>, ir<%mul> ; CHECK-NEXT: REPLICATE store ir<%add>, ir<%gep.a> ; CHECK-NEXT: Successor(s): pred.store.continue