diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h --- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h +++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h @@ -312,6 +312,12 @@ // Returns true if the NoNaN attribute is set on the function. bool hasFunNoNaNAttr() const { return HasFunNoNaNAttr; } + /// Returns all assume calls in predicated blocks. They need to be dropped + /// when flattening the CFG. + const SmallPtrSetImpl &getConditionalAssumes() const { + return ConditionalAssumes; + } + private: /// Return true if the pre-header, exiting and latch blocks of \p Lp and all /// its nested loops are considered legal for vectorization. These legal @@ -468,6 +474,10 @@ /// While vectorizing these instructions we have to generate a /// call to the appropriate masked intrinsic SmallPtrSet MaskedOp; + + /// Assume instructions in predicated blocks must be dropped if the CFG gets + /// flattened. + SmallPtrSet ConditionalAssumes; }; } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -13,14 +13,16 @@ // pass. It should be easy to create an analysis pass around it if there // is a need (but D45420 needs to happen first). // -#include "llvm/Transforms/Vectorize/LoopVectorize.h" #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/Vectorize/LoopVectorize.h" using namespace llvm; +using namespace PatternMatch; #define LV_NAME "loop-vectorize" #define DEBUG_TYPE LV_NAME @@ -897,6 +899,14 @@ if (C->canTrap()) return false; } + + // We can predicate blocks with calls to assume, as long as we drop them in + // case we flatten the CFG via predication. + if (match(&I, m_Intrinsic())) { + ConditionalAssumes.insert(&I); + continue; + } + // We might be able to hoist the load. if (I.mayReadFromMemory()) { auto *LI = dyn_cast(&I); diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7176,6 +7176,13 @@ // visit each basic block after having visited its predecessor basic blocks. // --------------------------------------------------------------------------- + // Add assume instructions we need to drop to DeadInstructions, to prevent + // them from being added to the VPlan. + // TODO: We only need to drop assumes in blocks that get flattend. If the + // control flow is preserved, we should keep them. + auto &ConditionalAssumes = Legal->getConditionalAssumes(); + DeadInstructions.insert(ConditionalAssumes.begin(), ConditionalAssumes.end()); + // Create a dummy pre-entry VPBasicBlock to start building the VPlan. VPBasicBlock *VPBB = new VPBasicBlock("Pre-Entry"); auto Plan = std::make_unique(VPBB); diff --git a/llvm/test/Transforms/LoopVectorize/assume.ll b/llvm/test/Transforms/LoopVectorize/assume.ll --- a/llvm/test/Transforms/LoopVectorize/assume.ll +++ b/llvm/test/Transforms/LoopVectorize/assume.ll @@ -91,3 +91,51 @@ for.end: ; preds = %for.body ret void } + +; Test case for PR43620. Make sure we can vectorize with predication in presence +; of assume calls. For now, check that we drop all assumes in predicated blocks +; in the vector body. +define void @predicated_assume(float* noalias nocapture readonly %a, float* noalias nocapture %b, i32 %n) { + +; Check that the vector.body does not contain any assumes. + +; CHECK-LABEL: @predicated_assume( +; CHECK: vector.body: +; CHECK-NOT: call void @llvm.assume( +; CHECK: scalar.ph: +; +entry: + %cmp15 = icmp eq i32 %n, 0 + br i1 %cmp15, label %for.cond.cleanup, label %for.body.preheader + +for.body.preheader: ; preds = %entry + %0 = zext i32 %n to i64 + br label %for.body + +for.cond.cleanup.loopexit: ; preds = %if.end5 + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry + ret void + +for.body: ; preds = %for.body.preheader, %if.end5 + %indvars.iv = phi i64 [ 0, %for.body.preheader ], [ %indvars.iv.next, %if.end5 ] + %cmp1 = icmp ult i64 %indvars.iv, 495616 + br i1 %cmp1, label %if.end5, label %if.else + +if.else: ; preds = %for.body + %cmp2 = icmp ult i64 %indvars.iv, 991232 + tail call void @llvm.assume(i1 %cmp2) + br label %if.end5 + +if.end5: ; preds = %for.body, %if.else + %x.0 = phi float [ 4.200000e+01, %if.else ], [ 2.300000e+01, %for.body ] + %arrayidx = getelementptr inbounds float, float* %a, i64 %indvars.iv + %1 = load float, float* %arrayidx, align 4 + %mul = fmul float %x.0, %1 + %arrayidx7 = getelementptr inbounds float, float* %b, i64 %indvars.iv + store float %mul, float* %arrayidx7, align 4 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %cmp = icmp eq i64 %indvars.iv.next, %0 + br i1 %cmp, label %for.cond.cleanup.loopexit, label %for.body +}