Index: llvm/lib/Target/ARM/MVETailPredication.cpp =================================================================== --- llvm/lib/Target/ARM/MVETailPredication.cpp +++ llvm/lib/Target/ARM/MVETailPredication.cpp @@ -373,15 +373,15 @@ EnableTailPredication == TailPredication::ForceEnabledNoReductions || EnableTailPredication == TailPredication::ForceEnabled; - // 1) Check that the original scalar loop TripCount (TC) belongs to this loop. - // The scalar tripcount corresponds the number of elements processed by the - // loop, so we will refer to that from this point on. Value *ElemCount = ActiveLaneMask->getOperand(1); auto *EC= SE->getSCEV(ElemCount); auto *TC = SE->getSCEV(TripCount); int VectorWidth = VecTy->getNumElements(); ConstantInt *ConstElemCount = nullptr; + // 1) Smoke tests that the original scalar loop TripCount (TC) belongs to + // this loop. The scalar tripcount corresponds the number of elements + // processed by the loop, so we will refer to that from this point on. if (!SE->isLoopInvariant(EC, L)) { LLVM_DEBUG(dbgs() << "ARM TP: element count must be loop invariant.\n"); return false; @@ -405,6 +405,9 @@ // counting from 0. uint64_t TC2 = ConstElemCount->getZExtValue() + 1; + // If the tripcount values are inconsistent, we don't want to insert the + // VCTP and trigger tail-predication; it's better to keep intrinsic + // get.active.lane.mask and legalize this. if (TC1 != TC2) { LLVM_DEBUG(dbgs() << "ARM TP: inconsistent constant tripcount values: " << TC1 << " from set.loop.iterations, and " @@ -412,104 +415,53 @@ return false; } } else if (!ForceTailPredication) { - // Smoke tests if the element count is a runtime value. I.e., this isn't - // fully generic because that would require a full SCEV visitor here. It - // would require extracting the variable from the elementcount SCEV - // expression, and match this up with the tripcount SCEV expression. If - // this matches up, we know both expressions are bound by the same - // variable, and thus we know this tripcount belongs to this loop. The - // checks below will catch most cases though. - if (isa(EC) || isa(EC)) { - // If the element count is a simple AddExpr or SCEVUnknown, which is e.g. - // the case when the element count is just a variable %N, we can just see - // if it is an operand in the tripcount scev expression. - if (isa(TC) && !SE->hasOperand(TC, EC)) { - LLVM_DEBUG(dbgs() << "ARM TP: Can't verify the element counter\n"); - return false; - } - } else if (const SCEVAddRecExpr *AddRecExpr = dyn_cast(EC)) { - // For more complicated AddRecExpr, check that the corresponding loop and - // its loop hierarhy contains the trip count loop. - if (!AddRecExpr->getLoop()->contains(L)) { - LLVM_DEBUG(dbgs() << "ARM TP: Can't verify the element counter\n"); - return false; - } - } else { - LLVM_DEBUG(dbgs() << "ARM TP: Unsupported SCEV type, can't verify the " - "element counter\n"); + // 2) We need to prove that the sub expression that we create in the + // tail-predicated loop body, which calculates the remaining elements to be + // processed, is non-negative, i.e. it doesn't overflow: + // + // ((ElementCount + VectorWidth - 1) / VectorWidth) - TripCount >= 0 + // + // This is true if: + // + // TripCount == (ElementCount + VectorWidth - 1) / VectorWidth + // + // which what we will be using here. + + // ElementCount + (VW-1): + auto *ECPlusVWMinus1 = SE->getAddExpr(EC, + SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth - 1))); + + // Ceil = ElementCount + (VW-1) / VW + auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, + SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth))); + + LLVM_DEBUG( + dbgs() << "Analysing overflow behaviour for:\n"; + dbgs() << "ARM TP: - TripCount = "; TC->dump(); + dbgs() << "ARM TP: - ElemCount = "; EC->dump(); + dbgs() << "ARM TP: - VecWidth = " << VectorWidth << "\n"; + dbgs() << "ARM TP: - (ElemCount+VW-1) / VW = "; Ceil->dump(); + ); + + // As an example, almost all the tripcount expressions (produced by the + // vectoriser) look like this: + // + // TC = ((-4 + (4 * ((3 + %N) /u 4))) /u 4) + // + // and "ElementCount + (VW-1) / VW": + // + // Ceil = ((3 + %N) /u 4) + // + // To determine TC == Ceil, we look through the additional multiply/divide + // by 4 in the example above (VectorWidth) in TC and just check if Ceil is + // an operand in TC. If it isn't, we assume overflow can happen and bail: + // + if (!SE->hasOperand(SE->getBackedgeTakenCount(L), Ceil)) { + LLVM_DEBUG(dbgs() << "ARM TP: possible overflow in sub expression.\n"); return false; } } - // 2) Prove that the sub expression is non-negative, i.e. it doesn't overflow: - // - // (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount - // - // 2.1) First prove overflow can't happen in: - // - // ElementCount + (VectorWidth - 1) - // - // Because of a lack of context, it is difficult to get a useful bounds on - // this expression. But since ElementCount uses the same variables as the - // TripCount (TC), for which we can find meaningful value ranges, we use that - // instead and assert that: - // - // upperbound(TC) <= UINT_MAX - VectorWidth - // - unsigned SizeInBits = TripCount->getType()->getScalarSizeInBits(); - auto MaxMinusVW = APInt(SizeInBits, ~0) - APInt(SizeInBits, VectorWidth); - APInt UpperboundTC = SE->getUnsignedRangeMax(TC); - - if (UpperboundTC.ugt(MaxMinusVW) && !ForceTailPredication) { - LLVM_DEBUG(dbgs() << "ARM TP: Overflow possible in tripcount rounding:\n"; - dbgs() << "upperbound(TC) <= UINT_MAX - VectorWidth\n"; - dbgs() << UpperboundTC << " <= " << MaxMinusVW << " == false\n";); - return false; - } - - // 2.2) Make sure overflow doesn't happen in final expression: - // (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount, - // To do this, compare the full ranges of these subexpressions: - // - // Range(Ceil) <= Range(TC) - // - // where Ceil = ElementCount + (VW-1) / VW. If Ceil and TC are runtime - // values (and not constants), we have to compensate for the lowerbound value - // range to be off by 1. The reason is that the TC lives in the preheader in - // this form: - // - // %trip.count.minus = add nsw nuw i32 %N, -1 - // - // For the loop to be executed, %N has to be >= 1 and as a result the value - // range of %trip.count.minus has a lower bound of 0. Value %TC has this form: - // - // %5 = add nuw nsw i32 %4, 1 - // call void @llvm.set.loop.iterations.i32(i32 %5) - // - // where %5 is some expression using %N, which needs to have a lower bound of - // 1. Thus, if the ranges of Ceil and TC are not a single constant but a set, - // we first add 0 to TC such that we can do the <= comparison on both sets. - // - - // Tmp = ElementCount + (VW-1) - auto *ECPlusVWMinus1 = SE->getAddExpr(EC, - SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth - 1))); - // Ceil = ElementCount + (VW-1) / VW - auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, - SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth))); - - ConstantRange RangeCeil = SE->getUnsignedRange(Ceil) ; - ConstantRange RangeTC = SE->getUnsignedRange(TC) ; - if (!RangeTC.isSingleElement()) { - auto ZeroRange = - ConstantRange(APInt(TripCount->getType()->getScalarSizeInBits(), 0)); - RangeTC = RangeTC.unionWith(ZeroRange); - } - if (!RangeTC.contains(RangeCeil) && !ForceTailPredication) { - LLVM_DEBUG(dbgs() << "ARM TP: Overflow possible in sub\n"); - return false; - } - // 3) Find out if IV is an induction phi. Note that we can't use Loop // helpers here to get the induction variable, because the hardware loop is // no longer in loopsimplify form, and also the hwloop intrinsic uses a Index: llvm/test/CodeGen/Thumb2/LowOverheadLoops/tail-pred-basic.ll =================================================================== --- llvm/test/CodeGen/Thumb2/LowOverheadLoops/tail-pred-basic.ll +++ llvm/test/CodeGen/Thumb2/LowOverheadLoops/tail-pred-basic.ll @@ -478,96 +478,6 @@ ret void } -; CHECK-LABEL: wrong_tripcount_arg -; CHECK: vector.body: -; CHECK: call <4 x i1> @llvm.arm.mve.vctp32 -; CHECK-NOT: call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32 -; CHECK: vector.body35: -; CHECK: call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32 -; CHECK-NOT: call <4 x i1> @llvm.arm.mve.vctp32 -; CHECK: ret void -; -define dso_local void @wrong_tripcount_arg(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32* noalias nocapture %D, i32 %N1, i32 %N2) local_unnamed_addr #0 { -entry: - %cmp29 = icmp sgt i32 %N1, 0 - %0 = add i32 %N1, 3 - %1 = lshr i32 %0, 2 - %2 = shl nuw i32 %1, 2 - %3 = add i32 %2, -4 - %4 = lshr i32 %3, 2 - %5 = add nuw nsw i32 %4, 1 - br i1 %cmp29, label %vector.ph, label %for.cond4.preheader - -vector.ph: ; preds = %entry - call void @llvm.set.loop.iterations.i32(i32 %5) - br label %vector.body - -vector.body: ; preds = %vector.body, %vector.ph - %lsr.iv62 = phi i32* [ %scevgep63, %vector.body ], [ %D, %vector.ph ] - %lsr.iv59 = phi i32* [ %scevgep60, %vector.body ], [ %C, %vector.ph ] - %lsr.iv56 = phi i32* [ %scevgep57, %vector.body ], [ %B, %vector.ph ] - %index = phi i32 [ 0, %vector.ph ], [ %index.next, %vector.body ] - %6 = phi i32 [ %5, %vector.ph ], [ %8, %vector.body ] - %lsr.iv5658 = bitcast i32* %lsr.iv56 to <4 x i32>* - %lsr.iv5961 = bitcast i32* %lsr.iv59 to <4 x i32>* - %lsr.iv6264 = bitcast i32* %lsr.iv62 to <4 x i32>* - %active.lane.mask = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 %index, i32 %N1) - %wide.masked.load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv5658, i32 4, <4 x i1> %active.lane.mask, <4 x i32> undef) - %wide.masked.load32 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv5961, i32 4, <4 x i1> %active.lane.mask, <4 x i32> undef) - %7 = add nsw <4 x i32> %wide.masked.load32, %wide.masked.load - call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %7, <4 x i32>* %lsr.iv6264, i32 4, <4 x i1> %active.lane.mask) - %index.next = add i32 %index, 4 - %scevgep57 = getelementptr i32, i32* %lsr.iv56, i32 4 - %scevgep60 = getelementptr i32, i32* %lsr.iv59, i32 4 - %scevgep63 = getelementptr i32, i32* %lsr.iv62, i32 4 - %8 = call i32 @llvm.loop.decrement.reg.i32(i32 %6, i32 1) - %9 = icmp ne i32 %8, 0 - br i1 %9, label %vector.body, label %for.cond4.preheader - -for.cond4.preheader: ; preds = %vector.body, %entry - %cmp527 = icmp sgt i32 %N2, 0 - %10 = add i32 %N2, 3 - %11 = lshr i32 %10, 2 - %12 = shl nuw i32 %11, 2 - %13 = add i32 %12, -4 - %14 = lshr i32 %13, 2 - %15 = add nuw nsw i32 %14, 1 - br i1 %cmp527, label %vector.ph36, label %for.cond.cleanup6 - -vector.ph36: ; preds = %for.cond4.preheader - call void @llvm.set.loop.iterations.i32(i32 %15) - br label %vector.body35 - -vector.body35: ; preds = %vector.body35, %vector.ph36 - %lsr.iv53 = phi i32* [ %scevgep54, %vector.body35 ], [ %A, %vector.ph36 ] - %lsr.iv50 = phi i32* [ %scevgep51, %vector.body35 ], [ %C, %vector.ph36 ] - %lsr.iv = phi i32* [ %scevgep, %vector.body35 ], [ %B, %vector.ph36 ] - %index40 = phi i32 [ 0, %vector.ph36 ], [ %index.next41, %vector.body35 ] - %16 = phi i32 [ %15, %vector.ph36 ], [ %18, %vector.body35 ] - %lsr.iv49 = bitcast i32* %lsr.iv to <4 x i32>* - %lsr.iv5052 = bitcast i32* %lsr.iv50 to <4 x i32>* - %lsr.iv5355 = bitcast i32* %lsr.iv53 to <4 x i32>* - -; This has N1 as the tripcount / element count, which is the tripcount of the -; first loop and not this one: - %active.lane.mask46 = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 %index40, i32 %N1) - - %wide.masked.load47 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv49, i32 4, <4 x i1> %active.lane.mask46, <4 x i32> undef) - %wide.masked.load48 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv5052, i32 4, <4 x i1> %active.lane.mask46, <4 x i32> undef) - %17 = add nsw <4 x i32> %wide.masked.load48, %wide.masked.load47 - call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %17, <4 x i32>* %lsr.iv5355, i32 4, <4 x i1> %active.lane.mask46) - %index.next41 = add i32 %index40, 4 - %scevgep = getelementptr i32, i32* %lsr.iv, i32 4 - %scevgep51 = getelementptr i32, i32* %lsr.iv50, i32 4 - %scevgep54 = getelementptr i32, i32* %lsr.iv53, i32 4 - %18 = call i32 @llvm.loop.decrement.reg.i32(i32 %16, i32 1) - %19 = icmp ne i32 %18, 0 - br i1 %19, label %vector.body35, label %for.cond.cleanup6 - -for.cond.cleanup6: ; preds = %vector.body35, %for.cond4.preheader - ret void -} - ; CHECK-LABEL: tripcount_arg_not_invariant ; CHECK: call <4 x i1> @llvm.get.active.lane.mask ; CHECK-NOT: vctp Index: llvm/test/CodeGen/Thumb2/LowOverheadLoops/tail-pred-forced.ll =================================================================== --- llvm/test/CodeGen/Thumb2/LowOverheadLoops/tail-pred-forced.ll +++ /dev/null @@ -1,61 +0,0 @@ -; RUN: opt -mtriple=thumbv8.1m.main -mve-tail-predication -tail-predication=enabled -mattr=+mve,+lob %s -S -o - | FileCheck %s --check-prefixes=CHECK,ENABLED -; RUN: opt -mtriple=thumbv8.1m.main -mve-tail-predication -tail-predication=force-enabled -mattr=+mve,+lob %s -S -o - | FileCheck %s --check-prefixes=CHECK,FORCED - -; CHECK-LABEL: set_iterations_not_rounded_up -; -; ENABLED: call <4 x i1> @llvm.get.active.lane.mask -; ENABLED-NOT: vctp -; -; FORCED-NOT: call <4 x i1> @llvm.get.active.lane.mask -; FORCED: vctp -; -; CHECK: ret void -; -define dso_local void @set_iterations_not_rounded_up(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32 %N) local_unnamed_addr #0 { -entry: - %cmp8 = icmp sgt i32 %N, 0 - -; Here, v5 which is used in set.loop.iterations which is usually rounded up to -; a next multiple of the VF when emitted from the vectoriser, which means a -; bound can be put on this expression. Without this, we can't, and should flag -; this as potentially overflow behaviour. - - %v5 = add nuw nsw i32 %N, 1 - br i1 %cmp8, label %vector.ph, label %for.cond.cleanup - -vector.ph: ; preds = %entry - %trip.count.minus.1 = add i32 %N, -1 - call void @llvm.set.loop.iterations.i32(i32 %v5) - br label %vector.body - -vector.body: ; preds = %vector.body, %vector.ph - %lsr.iv17 = phi i32* [ %scevgep18, %vector.body ], [ %A, %vector.ph ] - %lsr.iv14 = phi i32* [ %scevgep15, %vector.body ], [ %C, %vector.ph ] - %lsr.iv = phi i32* [ %scevgep, %vector.body ], [ %B, %vector.ph ] - %index = phi i32 [ 0, %vector.ph ], [ %index.next, %vector.body ] - %v6 = phi i32 [ %v5, %vector.ph ], [ %v8, %vector.body ] - %lsr.iv13 = bitcast i32* %lsr.iv to <4 x i32>* - %lsr.iv1416 = bitcast i32* %lsr.iv14 to <4 x i32>* - %lsr.iv1719 = bitcast i32* %lsr.iv17 to <4 x i32>* - %active.lane.mask = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 %index, i32 %N) - %wide.masked.load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv13, i32 4, <4 x i1> %active.lane.mask, <4 x i32> undef) - %wide.masked.load12 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv1416, i32 4, <4 x i1> %active.lane.mask, <4 x i32> undef) - %v7 = add nsw <4 x i32> %wide.masked.load12, %wide.masked.load - call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %v7, <4 x i32>* %lsr.iv1719, i32 4, <4 x i1> %active.lane.mask) - %index.next = add i32 %index, 4 - %scevgep = getelementptr i32, i32* %lsr.iv, i32 4 - %scevgep15 = getelementptr i32, i32* %lsr.iv14, i32 4 - %scevgep18 = getelementptr i32, i32* %lsr.iv17, i32 4 - %v8 = call i32 @llvm.loop.decrement.reg.i32(i32 %v6, i32 1) - %v9 = icmp ne i32 %v8, 0 - br i1 %v9, label %vector.body, label %for.cond.cleanup - -for.cond.cleanup: ; preds = %vector.body, %entry - ret void -} - -declare <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>*, i32 immarg, <4 x i1>, <4 x i32>) -declare void @llvm.masked.store.v4i32.p0v4i32(<4 x i32>, <4 x i32>*, i32 immarg, <4 x i1>) -declare void @llvm.set.loop.iterations.i32(i32) -declare i32 @llvm.loop.decrement.reg.i32(i32, i32) -declare <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32, i32)