diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h --- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h +++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h @@ -312,6 +312,12 @@ // Returns true if the NoNaN attribute is set on the function. bool hasFunNoNaNAttr() const { return HasFunNoNaNAttr; } + /// Returns all assume calls in predicated blocks. They need to be dropped + /// when flattening the CFG. + const SmallPtrSetImpl &getConditionalAssumes() const { + return ConditionalAssumes; + } + private: /// Return true if the pre-header, exiting and latch blocks of \p Lp and all /// its nested loops are considered legal for vectorization. These legal @@ -468,6 +474,10 @@ /// While vectorizing these instructions we have to generate a /// call to the appropriate masked intrinsic SmallPtrSet MaskedOp; + + /// Assume instructions in predicated blocks must be dropped if the CFG gets + /// flattened. + SmallPtrSet ConditionalAssumes; }; } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -13,14 +13,16 @@ // pass. It should be easy to create an analysis pass around it if there // is a need (but D45420 needs to happen first). // -#include "llvm/Transforms/Vectorize/LoopVectorize.h" #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/Vectorize/LoopVectorize.h" using namespace llvm; +using namespace PatternMatch; #define LV_NAME "loop-vectorize" #define DEBUG_TYPE LV_NAME @@ -897,6 +899,14 @@ if (C->canTrap()) return false; } + + // We can predicate blocks with calls to assume, as long as we drop them in + // case we flatten the CFG via predication. + if (match(&I, m_Intrinsic())) { + ConditionalAssumes.insert(&I); + continue; + } + // We might be able to hoist the load. if (I.mayReadFromMemory()) { auto *LI = dyn_cast(&I); diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7174,6 +7174,13 @@ // visit each basic block after having visited its predecessor basic blocks. // --------------------------------------------------------------------------- + // Add assume instructions we need to drop to DeadInstructions, to prevent + // them from being added to the VPlan. + // TODO: We only need to drop assumes in blocks that get flattend. If the + // control flow is preserved, we should keep them. + auto &ConditionalAssumes = Legal->getConditionalAssumes(); + DeadInstructions.insert(ConditionalAssumes.begin(), ConditionalAssumes.end()); + // Create a dummy pre-entry VPBasicBlock to start building the VPlan. VPBasicBlock *VPBB = new VPBasicBlock("Pre-Entry"); auto Plan = std::make_unique(VPBB); diff --git a/llvm/test/Transforms/LoopVectorize/predicate-assume.ll b/llvm/test/Transforms/LoopVectorize/predicate-assume.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopVectorize/predicate-assume.ll @@ -0,0 +1,142 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; REQUIRES: asserts +; RUN: opt -loop-vectorize -force-vector-width=4 -debug -S %s 2>&1 | FileCheck %s + + +; Test case for PR43620. Make sure we can vectorize with predication in presence +; of assume calls. When generating code and flattening the CFG, we drop assume +; calls. + +; CHECK: digraph VPlan { +; CHECK: N0 [label = +; CHECK_NEXT: "for.body:\n" + +; CHECK_NEXT: "WIDEN-INDUCTION %indvars.iv = phi 0, %indvars.iv.next\l" + +; CHECK_NEXT: "WIDEN\l" + +; CHECK_NEXT: " %cmp1 = icmp %indvars.iv, 495616\l" +; CHECK_NEXT: ] +; CHECK_NEXT: N0 -> N1 [ label=""] +; CHECK_NEXT: N1 [label = +; CHECK_NEXT: "if.else:\n" + +; CHECK_NEXT: "WIDEN\l" + +; CHECK_NEXT: " %cmp2 = icmp %indvars.iv, 991232\l" + +; CHECK_NEXT: "REPLICATE call %cmp2, @llvm.assume\l" +; CHECK_NEXT: ] +; CHECK_NEXT: N1 -> N2 [ label=""] +; CHECK_NEXT: N2 [label = +; CHECK_NEXT: "if.end5:\n" + +; CHECK_NEXT: "EMIT %vp6568 = not %vp5040\l" + +; CHECK_NEXT: "BLEND %x.0 = 4.200000e+01/%vp6568 2.300000e+01/%vp5040\l" + +; CHECK_NEXT: "CLONE %arrayidx = getelementptr %a, %indvars.iv\l" + +; CHECK_NEXT: "WIDEN %1 = load %arrayidx\l" + +; CHECK_NEXT: "WIDEN\l" + +; CHECK_NEXT: " %mul = fmul %x.0, %1\l" + +; CHECK_NEXT: "CLONE %arrayidx7 = getelementptr %b, %indvars.iv\l" + +; CHECK_NEXT: "WIDEN store %mul, %arrayidx7\l" +; CHECK_NEXT: ] +; CHECK_NEXT:} + + +define void @foo(float* noalias nocapture readonly %a, float* noalias nocapture %b, i32 %n) { +; CHECK-LABEL: @foo( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP15:%.*]] = icmp eq i32 [[N:%.*]], 0 +; CHECK-NEXT: br i1 [[CMP15]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY_PREHEADER:%.*]] +; CHECK: for.body.preheader: +; CHECK-NEXT: [[TMP0:%.*]] = zext i32 [[N]] to i64 +; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 4 +; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]] +; CHECK: vector.ph: +; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], 4 +; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]] +; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] +; CHECK: vector.body: +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_IND:%.*]] = phi <4 x i64> [ , [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[INDEX]], 0 +; CHECK-NEXT: [[TMP2:%.*]] = add i64 [[INDEX]], 1 +; CHECK-NEXT: [[TMP3:%.*]] = add i64 [[INDEX]], 2 +; CHECK-NEXT: [[TMP4:%.*]] = add i64 [[INDEX]], 3 +; CHECK-NEXT: [[TMP5:%.*]] = icmp ult <4 x i64> [[VEC_IND]], +; CHECK-NEXT: [[TMP6:%.*]] = icmp ult <4 x i64> [[VEC_IND]], +; CHECK-NEXT: [[TMP7:%.*]] = xor <4 x i1> [[TMP5]], +; CHECK-NEXT: [[PREDPHI:%.*]] = select <4 x i1> [[TMP5]], <4 x float> , <4 x float> +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds float, float* [[A:%.*]], i64 [[TMP1]] +; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds float, float* [[TMP8]], i32 0 +; CHECK-NEXT: [[TMP10:%.*]] = bitcast float* [[TMP9]] to <4 x float>* +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x float>, <4 x float>* [[TMP10]], align 4 +; CHECK-NEXT: [[TMP11:%.*]] = fmul <4 x float> [[PREDPHI]], [[WIDE_LOAD]] +; CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds float, float* [[B:%.*]], i64 [[TMP1]] +; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds float, float* [[TMP12]], i32 0 +; CHECK-NEXT: [[TMP14:%.*]] = bitcast float* [[TMP13]] to <4 x float>* +; CHECK-NEXT: store <4 x float> [[TMP11]], <4 x float>* [[TMP14]], align 4 +; CHECK-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], 4 +; CHECK-NEXT: [[VEC_IND_NEXT]] = add <4 x i64> [[VEC_IND]], +; CHECK-NEXT: [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]] +; CHECK-NEXT: br i1 [[TMP15]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop !0 +; CHECK: middle.block: +; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP0]], [[N_VEC]] +; CHECK-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[SCALAR_PH]] +; CHECK: scalar.ph: +; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[FOR_BODY_PREHEADER]] ] +; CHECK-NEXT: br label [[FOR_BODY:%.*]] +; CHECK: for.cond.cleanup.loopexit: +; CHECK-NEXT: br label [[FOR_COND_CLEANUP]] +; CHECK: for.cond.cleanup: +; CHECK-NEXT: ret void +; CHECK: for.body: +; CHECK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[IF_END5:%.*]] ] +; CHECK-NEXT: [[CMP1:%.*]] = icmp ult i64 [[INDVARS_IV]], 495616 +; CHECK-NEXT: br i1 [[CMP1]], label [[IF_END5]], label [[IF_ELSE:%.*]] +; CHECK: if.else: +; CHECK-NEXT: [[CMP2:%.*]] = icmp ult i64 [[INDVARS_IV]], 991232 +; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP2]]) +; CHECK-NEXT: br label [[IF_END5]] +; CHECK: if.end5: +; CHECK-NEXT: [[X_0:%.*]] = phi float [ 4.200000e+01, [[IF_ELSE]] ], [ 2.300000e+01, [[FOR_BODY]] ] +; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds float, float* [[A]], i64 [[INDVARS_IV]] +; CHECK-NEXT: [[TMP16:%.*]] = load float, float* [[ARRAYIDX]], align 4 +; CHECK-NEXT: [[MUL:%.*]] = fmul float [[X_0]], [[TMP16]] +; CHECK-NEXT: [[ARRAYIDX7:%.*]] = getelementptr inbounds float, float* [[B]], i64 [[INDVARS_IV]] +; CHECK-NEXT: store float [[MUL]], float* [[ARRAYIDX7]], align 4 +; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], [[TMP0]] +; CHECK-NEXT: br i1 [[CMP]], label [[FOR_COND_CLEANUP_LOOPEXIT]], label [[FOR_BODY]], !llvm.loop !2 +; +entry: + %cmp15 = icmp eq i32 %n, 0 + br i1 %cmp15, label %for.cond.cleanup, label %for.body.preheader + +for.body.preheader: ; preds = %entry + %0 = zext i32 %n to i64 + br label %for.body + +for.cond.cleanup.loopexit: ; preds = %if.end5 + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry + ret void + +for.body: ; preds = %for.body.preheader, %if.end5 + %indvars.iv = phi i64 [ 0, %for.body.preheader ], [ %indvars.iv.next, %if.end5 ] + %cmp1 = icmp ult i64 %indvars.iv, 495616 + br i1 %cmp1, label %if.end5, label %if.else + +if.else: ; preds = %for.body + %cmp2 = icmp ult i64 %indvars.iv, 991232 + tail call void @llvm.assume(i1 %cmp2) + br label %if.end5 + +if.end5: ; preds = %for.body, %if.else + %x.0 = phi float [ 4.200000e+01, %if.else ], [ 2.300000e+01, %for.body ] + %arrayidx = getelementptr inbounds float, float* %a, i64 %indvars.iv + %1 = load float, float* %arrayidx, align 4 + %mul = fmul float %x.0, %1 + %arrayidx7 = getelementptr inbounds float, float* %b, i64 %indvars.iv + store float %mul, float* %arrayidx7, align 4 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %cmp = icmp eq i64 %indvars.iv.next, %0 + br i1 %cmp, label %for.cond.cleanup.loopexit, label %for.body +} + +; Function Attrs: nounwind willreturn +declare void @llvm.assume(i1)