diff --git a/llvm/lib/Target/ARM/MVETailPredication.cpp b/llvm/lib/Target/ARM/MVETailPredication.cpp --- a/llvm/lib/Target/ARM/MVETailPredication.cpp +++ b/llvm/lib/Target/ARM/MVETailPredication.cpp @@ -119,10 +119,10 @@ /// load/stores. bool IsPredicatedVectorLoop(); - /// Perform checks on the arguments of @llvm.get.active.lane.mask - /// intrinsic: check if the first is a loop induction variable, and for the - /// the second check that no overflow can occur in the expression that use - /// this backedge-taken count. + /// Perform several checks on the arguments of @llvm.get.active.lane.mask + /// intrinsic. E.g., check that the loop induction variable and the element + /// count are of the form we expect, and also perform overflow checks for + /// the new expressions that are created. bool IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, Value *TripCount, FixedVectorType *VecTy); @@ -373,10 +373,73 @@ EnableTailPredication == TailPredication::ForceEnabledNoReductions || EnableTailPredication == TailPredication::ForceEnabled; - // 1) TODO: Check that the TripCount (TC) belongs to this loop (originally). + // 1) Check that the original scalar loop TripCount (TC) belongs to this loop. // The scalar tripcount corresponds the number of elements processed by the // loop, so we will refer to that from this point on. - auto *ElemCountVal = ActiveLaneMask->getOperand(1); + Value *ElemCount = ActiveLaneMask->getOperand(1); + auto *EC= SE->getSCEV(ElemCount); + auto *TC = SE->getSCEV(TripCount); + int VectorWidth = VecTy->getNumElements(); + ConstantInt *ConstElemCount = nullptr; + + if (!SE->isLoopInvariant(EC, L)) { + LLVM_DEBUG(dbgs() << "ARM TP: element count must be loop invariant.\n"); + return false; + } + + if ((ConstElemCount = dyn_cast(ElemCount))) { + ConstantInt *TC = dyn_cast(TripCount); + if (!TC) { + LLVM_DEBUG(dbgs() << "ARM TP: Constant tripcount expected in " + "set.loop.iterations\n"); + return false; + } + + // Calculate 2 tripcount values and check that they are consistent with + // each other: + // i) The number of loop iterations extracted from the set.loop.iterations + // intrinsic, multipled by the vector width: + uint64_t TC1 = TC->getZExtValue() * VectorWidth; + + // ii) TC1 has to be equal to TC + 1, with the + 1 to compensate for start + // counting from 0. + uint64_t TC2 = ConstElemCount->getZExtValue() + 1; + + if (TC1 != TC2) { + LLVM_DEBUG(dbgs() << "ARM TP: inconsistent constant tripcount values: " + << TC1 << " from set.loop.iterations, and " + << TC2 << " from get.active.lane.mask\n"); + return false; + } + } else { + // Smoke tests if the element count is a runtime value. I.e., this isn't + // fully generic because that would require a full SCEV visitor here. It + // would require extracting the variable from the elementcount SCEV + // expression, and match this up with the tripcount SCEV expression. If + // this matches up, we know both expressions are bound by the same + // variable, and thus we know this tripcount belongs to this loop. The + // checks below will catch most cases though. + if (isa(EC) || isa(EC)) { + // If the element count is a simple AddExpr or SCEVUnknown, which is e.g. + // the case when the element count is just a variable %N, we can just see + // if it is an operand in the tripcount scev expression. + if (isa(TC) && !SE->hasOperand(TC, EC)) { + LLVM_DEBUG(dbgs() << "ARM TP: 1Can't verify the element counter\n"); + return false; + } + } else if (const SCEVAddRecExpr *AddRecExpr = dyn_cast(EC)) { + // For more complicated AddRecExpr, check that the corresponding loop and + // its loop hierarhy contains the trip count loop. + if (!AddRecExpr->getLoop()->contains(L)) { + LLVM_DEBUG(dbgs() << "ARM TP: 2Can't verify the element counter\n"); + return false; + } + } else { + LLVM_DEBUG(dbgs() << "ARM TP: Unsupported SCEV type, can't verify the " + "element counter\n"); + return false; + } + } // 2) Prove that the sub expression is non-negative, i.e. it doesn't overflow: // @@ -393,9 +456,7 @@ // // upperbound(TC) <= UINT_MAX - VectorWidth // - auto *TC = SE->getSCEV(TripCount); unsigned SizeInBits = TripCount->getType()->getScalarSizeInBits(); - int VectorWidth = VecTy->getNumElements(); auto Diff = APInt(SizeInBits, ~0) - APInt(SizeInBits, VectorWidth); uint64_t MaxMinusVW = Diff.getZExtValue(); // FIXME: since ranges can be negative we work with signed ranges here, but @@ -432,9 +493,9 @@ // 1. Thus, if the ranges of Ceil and TC are not a single constant but a set, // we first add 0 to TC such that we can do the <= comparison on both sets. // - auto *ElementCount = SE->getSCEV(ElemCountVal); + // Tmp = ElementCount + (VW-1) - auto *ECPlusVWMinus1 = SE->getAddExpr(ElementCount, + auto *ECPlusVWMinus1 = SE->getAddExpr(EC, SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth - 1))); // Ceil = ElementCount + (VW-1) / VW auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, diff --git a/llvm/test/CodeGen/Thumb2/LowOverheadLoops/basic-tail-pred.ll b/llvm/test/CodeGen/Thumb2/LowOverheadLoops/basic-tail-pred.ll --- a/llvm/test/CodeGen/Thumb2/LowOverheadLoops/basic-tail-pred.ll +++ b/llvm/test/CodeGen/Thumb2/LowOverheadLoops/basic-tail-pred.ll @@ -431,6 +431,195 @@ ret void } +; CHECK-LABEL: const_expected_in_set_loop +; CHECK: call <4 x i1> @llvm.get.active.lane.mask +; CHECK-NOT: vctp +; CHECK: ret void +; +define dso_local void @const_expected_in_set_loop(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32 %N) local_unnamed_addr #0 { +entry: + %cmp8 = icmp sgt i32 %N, 0 + %0 = add i32 %N, 3 + %1 = lshr i32 %0, 2 + %2 = shl nuw i32 %1, 2 + %3 = add i32 %2, -4 + %4 = lshr i32 %3, 2 + %5 = add nuw nsw i32 %4, 1 + br i1 %cmp8, label %vector.ph, label %for.cond.cleanup + +vector.ph: + call void @llvm.set.loop.iterations.i32(i32 %5) + br label %vector.body + +vector.body: ; preds = %vector.body, %vector.ph + %lsr.iv17 = phi i32* [ %scevgep18, %vector.body ], [ %A, %vector.ph ] + %lsr.iv14 = phi i32* [ %scevgep15, %vector.body ], [ %C, %vector.ph ] + %lsr.iv = phi i32* [ %scevgep, %vector.body ], [ %B, %vector.ph ] + %index = phi i32 [ 0, %vector.ph ], [ %index.next, %vector.body ] + %6 = phi i32 [ %5, %vector.ph ], [ %8, %vector.body ] + %lsr.iv13 = bitcast i32* %lsr.iv to <4 x i32>* + %lsr.iv1416 = bitcast i32* %lsr.iv14 to <4 x i32>* + %lsr.iv1719 = bitcast i32* %lsr.iv17 to <4 x i32>* + + %active.lane.mask = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 %index, i32 42) + + %wide.masked.load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv13, i32 4, <4 x i1> %active.lane.mask, <4 x i32> undef) + %wide.masked.load12 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv1416, i32 4, <4 x i1> %active.lane.mask, <4 x i32> undef) + %7 = add nsw <4 x i32> %wide.masked.load12, %wide.masked.load + call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %7, <4 x i32>* %lsr.iv1719, i32 4, <4 x i1> %active.lane.mask) + %index.next = add i32 %index, 4 + %scevgep = getelementptr i32, i32* %lsr.iv, i32 4 + %scevgep15 = getelementptr i32, i32* %lsr.iv14, i32 4 + %scevgep18 = getelementptr i32, i32* %lsr.iv17, i32 4 + %8 = call i32 @llvm.loop.decrement.reg.i32(i32 %6, i32 1) + %9 = icmp ne i32 %8, 0 + br i1 %9, label %vector.body, label %for.cond.cleanup + +for.cond.cleanup: ; preds = %vector.body, %entry + ret void +} + +; CHECK-LABEL: wrong_tripcount_arg +; CHECK: vector.body: +; CHECK: call <4 x i1> @llvm.arm.mve.vctp32 +; CHECK-NOT: call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32 +; CHECK: vector.body35: +; CHECK: call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32 +; CHECK-NOT: call <4 x i1> @llvm.arm.mve.vctp32 +; CHECK: ret void +; +define dso_local void @wrong_tripcount_arg(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32* noalias nocapture %D, i32 %N1, i32 %N2) local_unnamed_addr #0 { +entry: + %cmp29 = icmp sgt i32 %N1, 0 + %0 = add i32 %N1, 3 + %1 = lshr i32 %0, 2 + %2 = shl nuw i32 %1, 2 + %3 = add i32 %2, -4 + %4 = lshr i32 %3, 2 + %5 = add nuw nsw i32 %4, 1 + br i1 %cmp29, label %vector.ph, label %for.cond4.preheader + +vector.ph: ; preds = %entry + call void @llvm.set.loop.iterations.i32(i32 %5) + br label %vector.body + +vector.body: ; preds = %vector.body, %vector.ph + %lsr.iv62 = phi i32* [ %scevgep63, %vector.body ], [ %D, %vector.ph ] + %lsr.iv59 = phi i32* [ %scevgep60, %vector.body ], [ %C, %vector.ph ] + %lsr.iv56 = phi i32* [ %scevgep57, %vector.body ], [ %B, %vector.ph ] + %index = phi i32 [ 0, %vector.ph ], [ %index.next, %vector.body ] + %6 = phi i32 [ %5, %vector.ph ], [ %8, %vector.body ] + %lsr.iv5658 = bitcast i32* %lsr.iv56 to <4 x i32>* + %lsr.iv5961 = bitcast i32* %lsr.iv59 to <4 x i32>* + %lsr.iv6264 = bitcast i32* %lsr.iv62 to <4 x i32>* + %active.lane.mask = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 %index, i32 %N1) + %wide.masked.load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv5658, i32 4, <4 x i1> %active.lane.mask, <4 x i32> undef) + %wide.masked.load32 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv5961, i32 4, <4 x i1> %active.lane.mask, <4 x i32> undef) + %7 = add nsw <4 x i32> %wide.masked.load32, %wide.masked.load + call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %7, <4 x i32>* %lsr.iv6264, i32 4, <4 x i1> %active.lane.mask) + %index.next = add i32 %index, 4 + %scevgep57 = getelementptr i32, i32* %lsr.iv56, i32 4 + %scevgep60 = getelementptr i32, i32* %lsr.iv59, i32 4 + %scevgep63 = getelementptr i32, i32* %lsr.iv62, i32 4 + %8 = call i32 @llvm.loop.decrement.reg.i32(i32 %6, i32 1) + %9 = icmp ne i32 %8, 0 + br i1 %9, label %vector.body, label %for.cond4.preheader + +for.cond4.preheader: ; preds = %vector.body, %entry + %cmp527 = icmp sgt i32 %N2, 0 + %10 = add i32 %N2, 3 + %11 = lshr i32 %10, 2 + %12 = shl nuw i32 %11, 2 + %13 = add i32 %12, -4 + %14 = lshr i32 %13, 2 + %15 = add nuw nsw i32 %14, 1 + br i1 %cmp527, label %vector.ph36, label %for.cond.cleanup6 + +vector.ph36: ; preds = %for.cond4.preheader + call void @llvm.set.loop.iterations.i32(i32 %15) + br label %vector.body35 + +vector.body35: ; preds = %vector.body35, %vector.ph36 + %lsr.iv53 = phi i32* [ %scevgep54, %vector.body35 ], [ %A, %vector.ph36 ] + %lsr.iv50 = phi i32* [ %scevgep51, %vector.body35 ], [ %C, %vector.ph36 ] + %lsr.iv = phi i32* [ %scevgep, %vector.body35 ], [ %B, %vector.ph36 ] + %index40 = phi i32 [ 0, %vector.ph36 ], [ %index.next41, %vector.body35 ] + %16 = phi i32 [ %15, %vector.ph36 ], [ %18, %vector.body35 ] + %lsr.iv49 = bitcast i32* %lsr.iv to <4 x i32>* + %lsr.iv5052 = bitcast i32* %lsr.iv50 to <4 x i32>* + %lsr.iv5355 = bitcast i32* %lsr.iv53 to <4 x i32>* + +; This has N1 as the tripcount / element count, which is the tripcount of the +; first loop and not this one: + %active.lane.mask46 = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 %index40, i32 %N1) + + %wide.masked.load47 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv49, i32 4, <4 x i1> %active.lane.mask46, <4 x i32> undef) + %wide.masked.load48 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv5052, i32 4, <4 x i1> %active.lane.mask46, <4 x i32> undef) + %17 = add nsw <4 x i32> %wide.masked.load48, %wide.masked.load47 + call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %17, <4 x i32>* %lsr.iv5355, i32 4, <4 x i1> %active.lane.mask46) + %index.next41 = add i32 %index40, 4 + %scevgep = getelementptr i32, i32* %lsr.iv, i32 4 + %scevgep51 = getelementptr i32, i32* %lsr.iv50, i32 4 + %scevgep54 = getelementptr i32, i32* %lsr.iv53, i32 4 + %18 = call i32 @llvm.loop.decrement.reg.i32(i32 %16, i32 1) + %19 = icmp ne i32 %18, 0 + br i1 %19, label %vector.body35, label %for.cond.cleanup6 + +for.cond.cleanup6: ; preds = %vector.body35, %for.cond4.preheader + ret void +} + +; CHECK-LABEL: tripcount_arg_not_invariant +; CHECK: call <4 x i1> @llvm.get.active.lane.mask +; CHECK-NOT: vctp +; CHECK: ret void +; +define dso_local void @tripcount_arg_not_invariant(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32 %N) local_unnamed_addr #0 { +entry: + %cmp8 = icmp sgt i32 %N, 0 + %0 = add i32 %N, 3 + %1 = lshr i32 %0, 2 + %2 = shl nuw i32 %1, 2 + %3 = add i32 %2, -4 + %4 = lshr i32 %3, 2 + %5 = add nuw nsw i32 %4, 1 + br i1 %cmp8, label %vector.ph, label %for.cond.cleanup + +vector.ph: ; preds = %entry + %trip.count.minus.1 = add i32 %N, -1 + call void @llvm.set.loop.iterations.i32(i32 %5) + br label %vector.body + +vector.body: ; preds = %vector.body, %vector.ph + %lsr.iv17 = phi i32* [ %scevgep18, %vector.body ], [ %A, %vector.ph ] + %lsr.iv14 = phi i32* [ %scevgep15, %vector.body ], [ %C, %vector.ph ] + %lsr.iv = phi i32* [ %scevgep, %vector.body ], [ %B, %vector.ph ] + %index = phi i32 [ 0, %vector.ph ], [ %index.next, %vector.body ] + %6 = phi i32 [ %5, %vector.ph ], [ %8, %vector.body ] + + %lsr.iv13 = bitcast i32* %lsr.iv to <4 x i32>* + %lsr.iv1416 = bitcast i32* %lsr.iv14 to <4 x i32>* + %lsr.iv1719 = bitcast i32* %lsr.iv17 to <4 x i32>* + + %active.lane.mask = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 %index, i32 %index) + + %wide.masked.load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv13, i32 4, <4 x i1> %active.lane.mask, <4 x i32> undef) + %wide.masked.load12 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv1416, i32 4, <4 x i1> %active.lane.mask, <4 x i32> undef) + %7 = add nsw <4 x i32> %wide.masked.load12, %wide.masked.load + call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %7, <4 x i32>* %lsr.iv1719, i32 4, <4 x i1> %active.lane.mask) + %index.next = add i32 %index, 4 + %scevgep = getelementptr i32, i32* %lsr.iv, i32 4 + %scevgep15 = getelementptr i32, i32* %lsr.iv14, i32 4 + %scevgep18 = getelementptr i32, i32* %lsr.iv17, i32 4 + %8 = call i32 @llvm.loop.decrement.reg.i32(i32 %6, i32 1) + %9 = icmp ne i32 %8, 0 + ;br i1 %9, label %vector.body, label %for.cond.cleanup + br i1 %9, label %vector.body, label %vector.ph + +for.cond.cleanup: ; preds = %vector.body, %entry + ret void +} + declare <16 x i8> @llvm.masked.load.v16i8.p0v16i8(<16 x i8>*, i32 immarg, <16 x i1>, <16 x i8>) declare void @llvm.masked.store.v16i8.p0v16i8(<16 x i8>, <16 x i8>*, i32 immarg, <16 x i1>) declare <8 x i16> @llvm.masked.load.v8i16.p0v8i16(<8 x i16>*, i32 immarg, <8 x i1>, <8 x i16>) diff --git a/llvm/test/CodeGen/Thumb2/LowOverheadLoops/tail-pred-const.ll b/llvm/test/CodeGen/Thumb2/LowOverheadLoops/tail-pred-const.ll --- a/llvm/test/CodeGen/Thumb2/LowOverheadLoops/tail-pred-const.ll +++ b/llvm/test/CodeGen/Thumb2/LowOverheadLoops/tail-pred-const.ll @@ -265,13 +265,13 @@ ret void } -; CHECK-LABEL: @overflow_BTC_plus_1( +; CHECK-LABEL: @inconsistent_tripcounts( ; CHECK: vector.body: ; CHECK-NOT: @llvm.arm.mve.vctp32 ; CHECK: @llvm.get.active.lane.mask ; CHECK: ret void ; -define dso_local void @overflow_BTC_plus_1(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32* noalias nocapture readnone %D, i32 %N) local_unnamed_addr #0 { +define dso_local void @inconsistent_tripcounts(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32* noalias nocapture readnone %D, i32 %N) local_unnamed_addr #0 { entry: call void @llvm.set.loop.iterations.i32(i32 8001) br label %vector.body @@ -316,63 +316,7 @@ ; define dso_local void @overflow_in_sub(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32* noalias nocapture readnone %D, i32 %N) local_unnamed_addr #0 { entry: - call void @llvm.set.loop.iterations.i32(i32 8001) - br label %vector.body - -vector.body: - %lsr.iv14 = phi i32* [ %scevgep15, %vector.body ], [ %A, %entry ] - %lsr.iv11 = phi i32* [ %scevgep12, %vector.body ], [ %C, %entry ] - %lsr.iv = phi i32* [ %scevgep, %vector.body ], [ %B, %entry ] - %index = phi i32 [ 0, %entry ], [ %index.next, %vector.body ] - %0 = phi i32 [ 8001, %entry ], [ %3, %vector.body ] - %lsr.iv1416 = bitcast i32* %lsr.iv14 to <4 x i32>* - %lsr.iv1113 = bitcast i32* %lsr.iv11 to <4 x i32>* - %lsr.iv10 = bitcast i32* %lsr.iv to <4 x i32>* - %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0 - %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert, <4 x i32> undef, <4 x i32> zeroinitializer - %induction = add <4 x i32> %broadcast.splat, - -; Overflow in the substraction. This should hold: -; -; ceil(ElementCount / VectorWidth) >= TripCount -; -; But we have: -; -; ceil(3200 / 4) >= 8001 -; 8000 >= 8001 -; - %1 = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 %index, i32 31999) - - %wide.masked.load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv10, i32 4, <4 x i1> %1, <4 x i32> undef) - %wide.masked.load9 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv1113, i32 4, <4 x i1> %1, <4 x i32> undef) - %2 = add nsw <4 x i32> %wide.masked.load9, %wide.masked.load - call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %2, <4 x i32>* %lsr.iv1416, i32 4, <4 x i1> %1) - %index.next = add i32 %index, 4 - %scevgep = getelementptr i32, i32* %lsr.iv, i32 4 - %scevgep12 = getelementptr i32, i32* %lsr.iv11, i32 4 - %scevgep15 = getelementptr i32, i32* %lsr.iv14, i32 4 - %3 = call i32 @llvm.loop.decrement.reg.i32(i32 %0, i32 1) - %4 = icmp ne i32 %3, 0 - br i1 %4, label %vector.body, label %for.cond.cleanup - -for.cond.cleanup: - ret void -} - -; CHECK-LABEL: @overflow_in_rounding_tripcount( -; CHECK: vector.body: -; CHECK-NOT: @llvm.arm.mve.vctp32 -; CHECK: @llvm.get.active.lane.mask -; CHECK: ret void -; -define dso_local void @overflow_in_rounding_tripcount(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32* noalias nocapture readnone %D, i32 %N) local_unnamed_addr #0 { -entry: - -; TC = 4294967292 -; 4294967292 <= 4294967291 (MAX - vectorwidth) -; False -; - call void @llvm.set.loop.iterations.i32(i32 4294967291) + call void @llvm.set.loop.iterations.i32(i32 1073741824) br label %vector.body vector.body: