Index: llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -1000,6 +1000,124 @@ return true; } +static bool canTailPredicateInstruction(Instruction &I) { + if (dyn_cast(&I) || + dyn_cast(&I) || + dyn_cast(&I) || + dyn_cast(&I)) { + LLVM_DEBUG( + dbgs() << "tail-predication, not allowing instruction: "; + I.dump()); + return false; + } + return true; +} + +static bool getLoopControlInstructions(SmallSet &LoopControl, + Loop *L, ScalarEvolution &SE) { + PHINode *IndVar = L->getInductionVariable(SE); + if (!IndVar) + return false; + + LLVM_DEBUG(dbgs() << "found ind var: "; IndVar->dump()); + LoopControl.insert(IndVar); + + for (auto *U: IndVar->users()) { + Instruction *I = dyn_cast(U); + if (!I) + continue; + + // Now we simply look for the pattern (only the opcodes): + // + // %inc = add nuw nsw i32 %i, 1 + // %exitcond = icmp eq i32 %inc, %N + // br i1 %exitcond, label + // + // If we get this wrong somehow, we will probably find that the datatypes + // used in the loop are not uniform and the loop is rejected fo for + // tail-predication, which is a safe. + if (I->getOpcode() == Instruction::Add) { + LLVM_DEBUG(dbgs() << "found loop incr: "; U->dump()); + LoopControl.insert(I); + Instruction *Cmp = dyn_cast(*I->user_begin()); + if (dyn_cast(Cmp)) { + LLVM_DEBUG(dbgs() << "Found loop test"; Cmp->dump()); + LoopControl.insert(Cmp); + } + } + } + return true; +} + +static bool consecutiveLoadStores(SmallVector &LoadStores) { + // TODO + return true; +} + +// To set up a tail-predicated loop, we need to know the total number of +// elements processed by that loop. Thus, we need to determine the element +// size and: +// 1) it should be uniform for all operations in the vector loop, so we +// e.g. don't want any widening/narrowing operations. +// 2) it should be smaller than i64s because we don't have vector operations +// that work on i64s. +// 3) we don't want elements to be reversed or shuffled, to make sure the +// tail-predication masks/predicates the right lanes. +// +static bool canTailPredicateLoop(Loop *L, LoopInfo *LI, ScalarEvolution &SE) { + SmallSet LoopControl; + if (!getLoopControlInstructions(LoopControl, L, SE) || + LoopControl.size() != 3) + return false; + + LLVM_DEBUG(dbgs() << "tail-predication: check allowed instruction\n"); + + SmallSet Types; + SmallVector LoadStores; + for (BasicBlock *BB : L->blocks()) { + for (Instruction &I : BB->instructionsWithoutDebug()) { + if (dyn_cast(&I) || LoopControl.count(&I)) + continue; + if (!canTailPredicateInstruction(I)) + return false; + + Type *T = I.getType(); + if (T->isPointerTy()) + T = T->getPointerElementType(); + + if (T->getScalarSizeInBits() > 32) { + LLVM_DEBUG(dbgs() << "Unsupported Type: "; T->dump()); + return false; + } + + // Collect all types + if (!T->isVoidTy()) { + LLVM_DEBUG(dbgs() << "Adding type: "; T->dump()); + Types.insert(T); + } + + // And collect all loads and stores + if (isa(I) || isa(I)) + LoadStores.push_back(&I); + } + } + + if (Types.size() > 1) { + LLVM_DEBUG(dbgs() << "Need uniform element size, but found too many types: " + << Types.size() << "\n"); + return false; + } + + LLVM_DEBUG(dbgs() << "Number of loads/stores to analyse: " + << LoadStores.size() << "\n"); + + if (!consecutiveLoadStores(LoadStores)) + return false; + + LLVM_DEBUG(dbgs() << "tail-predication: all instructions allowed!\n"); + return true; +} + bool ARMTTIImpl::preferPredicateOverEpilogue(Loop *L, LoopInfo *LI, ScalarEvolution &SE, AssumptionCache &AC, @@ -1032,14 +1150,7 @@ return false; } - // TODO: to set up a tail-predicated loop, which works by setting up - // the total number of elements processed by the loop, we need to - // determine the element size here, and if it is uniform for all operations - // in the vector loop. This means we will reject narrowing/widening - // operations, and don't want to predicate the vector loop, which is - // the main prep step for tail-predicated loops. - - return false; + return canTailPredicateLoop(L, LI, SE); } Index: llvm/test/Transforms/LoopVectorize/ARM/prefer-tail-loop-folding.ll =================================================================== --- llvm/test/Transforms/LoopVectorize/ARM/prefer-tail-loop-folding.ll +++ llvm/test/Transforms/LoopVectorize/ARM/prefer-tail-loop-folding.ll @@ -47,3 +47,152 @@ %exitcond = icmp eq i64 %indvars.iv.next, 430 br i1 %exitcond, label %for.cond.cleanup, label %for.body } + +define dso_local void @prefer_folding(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32 %N) local_unnamed_addr #0 { +; CHECK-LABEL: prefer_folding( +; PREFER-FOLDING: vector.body: +; PREFER-FOLDING: call <4 x i32> @llvm.masked.load.v4i32.p0v4i32 +; PREFER-FOLDING: call <4 x i32> @llvm.masked.load.v4i32.p0v4i32 +; PREFER-FOLDING: call void @llvm.masked.store.v4i32.p0v4i32 +; PREFER-FOLDING: br i1 %{{.*}}, label %{{.*}}, label %vector.body +entry: + br label %for.body + +for.cond.cleanup: + ret void + +for.body: + %i.09 = phi i32 [ 0, %entry ], [ %add3, %for.body ] + %arrayidx = getelementptr inbounds i32, i32* %B, i32 %i.09 + %0 = load i32, i32* %arrayidx, align 4 + %arrayidx1 = getelementptr inbounds i32, i32* %C, i32 %i.09 + %1 = load i32, i32* %arrayidx1, align 4 + %add = add nsw i32 %1, %0 + %arrayidx2 = getelementptr inbounds i32, i32* %A, i32 %i.09 + store i32 %add, i32* %arrayidx2, align 4 + %add3 = add nuw nsw i32 %i.09, 1 + %exitcond = icmp eq i32 %add3, 431 + br i1 %exitcond, label %for.cond.cleanup, label %for.body +} + +define dso_local void @mixed_types(i16* noalias nocapture %A, i16* noalias nocapture readonly %B, i16* noalias nocapture readonly %C, i32 %N, i32* noalias nocapture %D, i32* noalias nocapture readonly %E, i32* noalias nocapture readonly %F) local_unnamed_addr #0 { +; CHECK-LABEL: mixed_types( +; PREFER-FOLDING: vector.body: +; PREFER-FOLDING-NOT: llvm.masked.load +; PREFER-FOLDING-NOT: llvm.masked.store +; PREFER-FOLDING: br i1 %{{.*}}, label %{{.*}}, label %vector.body +entry: + br label %for.body + +for.cond.cleanup: + ret void + +for.body: + %i.018 = phi i32 [ 0, %entry ], [ %add9, %for.body ] + %arrayidx = getelementptr inbounds i16, i16* %B, i32 %i.018 + %0 = load i16, i16* %arrayidx, align 2 + %arrayidx1 = getelementptr inbounds i16, i16* %C, i32 %i.018 + %1 = load i16, i16* %arrayidx1, align 2 + %add = add i16 %1, %0 + %arrayidx4 = getelementptr inbounds i16, i16* %A, i32 %i.018 + store i16 %add, i16* %arrayidx4, align 2 + %arrayidx5 = getelementptr inbounds i32, i32* %E, i32 %i.018 + %2 = load i32, i32* %arrayidx5, align 4 + %arrayidx6 = getelementptr inbounds i32, i32* %F, i32 %i.018 + %3 = load i32, i32* %arrayidx6, align 4 + %add7 = add nsw i32 %3, %2 + %arrayidx8 = getelementptr inbounds i32, i32* %D, i32 %i.018 + store i32 %add7, i32* %arrayidx8, align 4 + %add9 = add nuw nsw i32 %i.018, 1 + %exitcond = icmp eq i32 %add9, 431 + br i1 %exitcond, label %for.cond.cleanup, label %for.body +} + +define hidden void @unsupported_i64_type(i64* noalias nocapture %A, i64* noalias nocapture readonly %B, i64* noalias nocapture readonly %C, i32 %N) local_unnamed_addr #0 { +; CHECK-LABEL: unsupported_i64_type( +; PREFER-FOLDING-NOT: vector.body: +; PREFER-FOLDING-NOT: llvm.masked.load +; PREFER-FOLDING-NOT: llvm.masked.store +; PREFER-FOLDING: for.body: +entry: + br label %for.body + +for.cond.cleanup: + ret void + +for.body: + %i.09 = phi i32 [ 0, %entry ], [ %add3, %for.body ] + %arrayidx = getelementptr inbounds i64, i64* %B, i32 %i.09 + %0 = load i64, i64* %arrayidx, align 8 + %arrayidx1 = getelementptr inbounds i64, i64* %C, i32 %i.09 + %1 = load i64, i64* %arrayidx1, align 8 + %add = add nsw i64 %1, %0 + %arrayidx2 = getelementptr inbounds i64, i64* %A, i32 %i.09 + store i64 %add, i64* %arrayidx2, align 8 + %add3 = add nuw nsw i32 %i.09, 1 + %exitcond = icmp eq i32 %add3, 431 + br i1 %exitcond, label %for.cond.cleanup, label %for.body +} + +define dso_local void @non_uniform_elem_size(i32* noalias nocapture %A, i8* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32 %N) local_unnamed_addr #0 { +; CHECK-LABEL: non_uniform_elem_size( +; PREFER-FOLDING: vector.body: +; PREFER-FOLDING-NOT: llvm.masked.load +; PREFER-FOLDING-NOT: llvm.masked.store +; PREFER-FOLDING: br i1 %{{.*}}, label %{{.*}}, label %vector.body +entry: + br label %for.body + +for.cond.cleanup: + ret void + +for.body: + %i.09 = phi i32 [ 0, %entry ], [ %add3, %for.body ] + %arrayidx = getelementptr inbounds i8, i8* %B, i32 %i.09 + %0 = load i8, i8* %arrayidx, align 1 + %conv = zext i8 %0 to i32 + %arrayidx1 = getelementptr inbounds i32, i32* %C, i32 %i.09 + %1 = load i32, i32* %arrayidx1, align 4 + %add = add nsw i32 %1, %conv + %arrayidx2 = getelementptr inbounds i32, i32* %A, i32 %i.09 + store i32 %add, i32* %arrayidx2, align 4 + %add3 = add nuw nsw i32 %i.09, 1 + %exitcond = icmp eq i32 %add3, 431 + br i1 %exitcond, label %for.cond.cleanup, label %for.body +} + +define dso_local void @pragma_vect_predicate_disable(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32 %N) local_unnamed_addr #0 { +; CHECK-LABEL: pragma_vect_predicate_disable( +; +; FIXME: +; respect loop hint predicate.enable = false, +; and don't tail-fold here: +; +; PREFER-FOLDING: call <4 x i32> @llvm.masked.load.v4i32.p0v4i32 +; PREFER-FOLDING: call <4 x i32> @llvm.masked.load.v4i32.p0v4i32 +; PREFER-FOLDING: call void @llvm.masked.store.v4i32.p0v4i32 +; PREFER-FOLDING: br i1 %{{.*}}, label %{{.*}}, label %vector.body +entry: + br label %for.body + +for.cond.cleanup: + ret void + +for.body: + %i.09 = phi i32 [ 0, %entry ], [ %add3, %for.body ] + %arrayidx = getelementptr inbounds i32, i32* %B, i32 %i.09 + %0 = load i32, i32* %arrayidx, align 4 + %arrayidx1 = getelementptr inbounds i32, i32* %C, i32 %i.09 + %1 = load i32, i32* %arrayidx1, align 4 + %add = add nsw i32 %1, %0 + %arrayidx2 = getelementptr inbounds i32, i32* %A, i32 %i.09 + store i32 %add, i32* %arrayidx2, align 4 + %add3 = add nuw nsw i32 %i.09, 1 + %exitcond = icmp eq i32 %add3, 431 + br i1 %exitcond, label %for.cond.cleanup, label %for.body, !llvm.loop !7 +} + +attributes #0 = { nofree norecurse nounwind "target-features"="+armv8.1-m.main,+mve.fp" } + +!7 = distinct !{!7, !8} +!8 = !{!"llvm.loop.vectorize.predicate.enable", i1 false}