Index: llvm/lib/Target/ARM/MVETailPredication.cpp =================================================================== --- llvm/lib/Target/ARM/MVETailPredication.cpp +++ llvm/lib/Target/ARM/MVETailPredication.cpp @@ -123,7 +123,7 @@ private: /// Perform the relevant checks on the loop and convert if possible. - bool TryConvert(Value *TripCount); + bool TryConvert(Value *TripCount, Value *NumElements); /// Return whether this is a vectorized loop, that contains masked /// load/stores. @@ -131,7 +131,7 @@ /// Compute a value for the total number of elements that the predicated /// loop will process if it is a runtime value. - bool ComputeRuntimeElements(TripCountPattern &TCP); + bool ComputeRuntimeElements(TripCountPattern &TCP, Value *NumElements); /// Return whether this is the icmp that generates an i1 vector, based /// upon a loop counter and a limit that is defined outside the loop, @@ -222,6 +222,29 @@ return nullptr; }; + auto FindLoopNumElements = [](BasicBlock *BB) -> IntrinsicInst* { + for (auto &I : *BB) { + auto *Call = dyn_cast(&I); + if (!Call) + continue; + + Intrinsic::ID ID = Call->getIntrinsicID(); + if (ID == Intrinsic::set_loop_elements) + return cast(&I); + } + return nullptr; + }; + + // Search for the intrinsic that sets up the number of data elements + // processed by this loop. Always delete this intrinsic if it is found, + // because we can't and don't need to lower it. + Value *LoopNumElements = nullptr; + if (IntrinsicInst *I = FindLoopNumElements(Preheader)) { + LoopNumElements = I->getArgOperand(0); + LLVM_DEBUG(dbgs() << "ARM TP: Found numelements: "; I->dump()); + I->eraseFromParent(); + } + // Look for the hardware loop intrinsic that sets the iteration count. IntrinsicInst *Setup = FindLoopIterations(Preheader); @@ -252,7 +275,7 @@ LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n" << *Decrement << "\n"); - if (TryConvert(Setup->getArgOperand(0))) { + if (TryConvert(Setup->getArgOperand(0), LoopNumElements)) { if (ClonedVCTPInExitBlock) RematerializeIterCount(); return true; @@ -378,8 +401,9 @@ // the vector loop. Further checks are performed in function isTailPredicate(), // to verify 'induction' behaves as an induction variable. // -static bool ComputeConstElements(TripCountPattern &TCP) { - if (!dyn_cast(TCP.TripCount)) +static bool ComputeConstElements(TripCountPattern &TCP, Value *NumElements) { + if (!dyn_cast(TCP.TripCount) || + (NumElements && !dyn_cast(NumElements))) return false; ConstantInt *VF = ConstantInt::get( @@ -392,6 +416,13 @@ CC != ICmpInst::ICMP_ULT) return false; + // After setting Induction, if NumElements is set with the intrinsic, + // we don't need to look any further for it. + if (NumElements) { + TCP.NumElements = NumElements; + return true; + } + LLVM_DEBUG(dbgs() << "ARM TP: icmp with constants: "; TCP.Predicate->dump();); Value *ConstVec = TCP.Predicate->getOperand(1); @@ -449,7 +480,8 @@ return true; } -bool MVETailPredication::ComputeRuntimeElements(TripCountPattern &TCP) { +bool MVETailPredication::ComputeRuntimeElements(TripCountPattern &TCP, + Value *NumElements) { using namespace PatternMatch; const SCEV *TripCountSE = SE->getSCEV(TCP.TripCount); ConstantInt *VF = ConstantInt::get( @@ -464,6 +496,13 @@ Pred != ICmpInst::ICMP_ULE) return false; + // After setting Induction and Shuffle, if Numelements is set with the + // intrinsic, we don't need to look further for it. + if (NumElements) { + TCP.NumElements = NumElements; + return true; + } + LLVM_DEBUG(dbgs() << "Computing number of elements for vector trip count: "; TCP.TripCount->dump()); @@ -645,7 +684,7 @@ << "ARM TP: Inserted VCTP: " << *TailPredicate << "\n"); } -bool MVETailPredication::TryConvert(Value *TripCount) { +bool MVETailPredication::TryConvert(Value *TripCount, Value *NumElements) { if (!IsPredicatedVectorLoop()) { LLVM_DEBUG(dbgs() << "ARM TP: no masked instructions in loop.\n"); return false; @@ -675,8 +714,9 @@ // Step 1: using this icmp, now calculate the number of elements // processed by this loop. TripCountPattern TCP(Predicate, TripCount, getVectorType(I)); - if (!(ComputeConstElements(TCP) || ComputeRuntimeElements(TCP))) - continue; + if (!(ComputeConstElements(TCP, NumElements) || + ComputeRuntimeElements(TCP, NumElements))) + continue; LLVM_DEBUG(FoundScalarTC = true); Index: llvm/test/CodeGen/Thumb2/LowOverheadLoops/basic-tail-pred.ll =================================================================== --- llvm/test/CodeGen/Thumb2/LowOverheadLoops/basic-tail-pred.ll +++ llvm/test/CodeGen/Thumb2/LowOverheadLoops/basic-tail-pred.ll @@ -324,6 +324,107 @@ ret void } +define dso_local void @loopelems_value(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32 %N) local_unnamed_addr #0 { +; CHECK-LABEL: loopelems_value( +; CHECK-NOT: call void @llvm.set.loop.elements.i32(i32 %N) +; CHECK: [[ELEMS:%[^ ]+]] = phi i32 [ %N, %vector.ph ], [ [[REMAINING:%[^ ]+]], %vector.body ] +; CHECK: [[VCTP:%[^ ]+]] = call <4 x i1> @llvm.arm.mve.vctp32(i32 [[ELEMS]]) +; CHECK: [[REMAINING]] = sub i32 [[ELEMS]], 4 +; CHECK: [[LD0:%[^ ]+]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* {{.*}}, i32 4, <4 x i1> [[VCTP]], <4 x i32> undef) +; CHECK: [[LD1:%[^ ]+]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* {{.*}}, i32 4, <4 x i1> [[VCTP]], <4 x i32> undef) +; CHECK: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> {{.*}}, <4 x i32>* {{.*}}, i32 4, <4 x i1> %{{.*}}) +entry: + %cmp8 = icmp sgt i32 %N, 0 + %0 = add i32 %N, 3 + %1 = lshr i32 %0, 2 + %2 = shl nuw i32 %1, 2 + %3 = add i32 %2, -4 + %4 = lshr i32 %3, 2 + %5 = add nuw nsw i32 %4, 1 + br i1 %cmp8, label %vector.ph, label %for.cond.cleanup + +vector.ph: + call void @llvm.set.loop.elements.i32(i32 %N) + %trip.count.minus.1 = add i32 %N, -1 + call void @llvm.set.loop.iterations.i32(i32 %5) + br label %vector.body + +vector.body: + %lsr.iv17 = phi i32* [ %scevgep18, %vector.body ], [ %A, %vector.ph ] + %lsr.iv14 = phi i32* [ %scevgep15, %vector.body ], [ %C, %vector.ph ] + %lsr.iv = phi i32* [ %scevgep, %vector.body ], [ %B, %vector.ph ] + %index = phi i32 [ 0, %vector.ph ], [ %index.next, %vector.body ] + %6 = phi i32 [ %5, %vector.ph ], [ %11, %vector.body ] + %lsr.iv1719 = bitcast i32* %lsr.iv17 to <4 x i32>* + %lsr.iv1416 = bitcast i32* %lsr.iv14 to <4 x i32>* + %lsr.iv13 = bitcast i32* %lsr.iv to <4 x i32>* + %broadcast.splatinsert10 = insertelement <4 x i32> undef, i32 %index, i32 0 + %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10, <4 x i32> undef, <4 x i32> zeroinitializer + %induction = or <4 x i32> %broadcast.splat11, + %7 = insertelement <4 x i32> undef, i32 %trip.count.minus.1, i32 0 + %8 = shufflevector <4 x i32> %7, <4 x i32> undef, <4 x i32> zeroinitializer + %9 = icmp ule <4 x i32> %induction, %8 + %wide.masked.load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv13, i32 4, <4 x i1> %9, <4 x i32> undef) + %wide.masked.load12 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv1416, i32 4, <4 x i1> %9, <4 x i32> undef) + %10 = add nsw <4 x i32> %wide.masked.load12, %wide.masked.load + call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %10, <4 x i32>* %lsr.iv1719, i32 4, <4 x i1> %9) + %index.next = add i32 %index, 4 + %scevgep = getelementptr i32, i32* %lsr.iv, i32 4 + %scevgep15 = getelementptr i32, i32* %lsr.iv14, i32 4 + %scevgep18 = getelementptr i32, i32* %lsr.iv17, i32 4 + %11 = call i32 @llvm.loop.decrement.reg.i32.i32.i32(i32 %6, i32 1) + %12 = icmp ne i32 %11, 0 + br i1 %12, label %vector.body, label %for.cond.cleanup + +for.cond.cleanup: + ret void +} + +define dso_local void @loopelems_const(i32* noalias nocapture %A, i32* noalias nocapture readonly %B, i32* noalias nocapture readonly %C, i32 %N) local_unnamed_addr #0 { +; CHECK-LABEL: loopelems_const( +; CHECK-NOT: call void @llvm.set.loop.elements.i32(i32 32003) +; CHECK: %{{.*}} = phi i32 [ 8001, %entry ], [ %5, %vector.body ] +; CHECK: [[ELEMS:[^ ]+]] = phi i32 [ 32003, %entry ], [ [[REMAINING:%[^ ]+]], %vector.body ] + +; CHECK: [[VCTP:%[^ ]+]] = call <4 x i1> @llvm.arm.mve.vctp32(i32 [[ELEMS]]) +; CHECK: [[REMAINING]] = sub i32 [[ELEMS]], 4 +; CHECK: [[LD0:%[^ ]+]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* {{.*}}, i32 4, <4 x i1> [[VCTP]], <4 x i32> undef) +; CHECK: [[LD1:%[^ ]+]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* {{.*}}, i32 4, <4 x i1> [[VCTP]], <4 x i32> undef) +; CHECK: call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> {{.*}}, <4 x i32>* {{.*}}, i32 4, <4 x i1> %{{.*}}) +entry: + call void @llvm.set.loop.elements.i32(i32 32003) + call void @llvm.set.loop.iterations.i32(i32 8001) + br label %vector.body + +vector.body: + %lsr.iv14 = phi i32* [ %scevgep15, %vector.body ], [ %A, %entry ] + %lsr.iv11 = phi i32* [ %scevgep12, %vector.body ], [ %C, %entry ] + %lsr.iv = phi i32* [ %scevgep, %vector.body ], [ %B, %entry ] + %index = phi i32 [ 0, %entry ], [ %index.next, %vector.body ] + %0 = phi i32 [ 8001, %entry ], [ %3, %vector.body ] + %lsr.iv1416 = bitcast i32* %lsr.iv14 to <4 x i32>* + %lsr.iv1113 = bitcast i32* %lsr.iv11 to <4 x i32>* + %lsr.iv10 = bitcast i32* %lsr.iv to <4 x i32>* + %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0 + %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert, <4 x i32> undef, <4 x i32> zeroinitializer + %induction = or <4 x i32> %broadcast.splat, + %1 = icmp ult <4 x i32> %induction, + %wide.masked.load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv10, i32 4, <4 x i1> %1, <4 x i32> undef) + %wide.masked.load9 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %lsr.iv1113, i32 4, <4 x i1> %1, <4 x i32> undef) + %2 = add nsw <4 x i32> %wide.masked.load9, %wide.masked.load + call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %2, <4 x i32>* %lsr.iv1416, i32 4, <4 x i1> %1) + %index.next = add i32 %index, 4 + %scevgep = getelementptr i32, i32* %lsr.iv, i32 4 + %scevgep12 = getelementptr i32, i32* %lsr.iv11, i32 4 + %scevgep15 = getelementptr i32, i32* %lsr.iv14, i32 4 + %3 = call i32 @llvm.loop.decrement.reg.i32.i32.i32(i32 %0, i32 1) + %4 = icmp ne i32 %3, 0 + br i1 %4, label %vector.body, label %for.cond.cleanup + +for.cond.cleanup: + ret void +} + declare <16 x i8> @llvm.masked.load.v16i8.p0v16i8(<16 x i8>*, i32 immarg, <16 x i1>, <16 x i8>) declare void @llvm.masked.store.v16i8.p0v16i8(<16 x i8>, <16 x i8>*, i32 immarg, <16 x i1>) declare <8 x i16> @llvm.masked.load.v8i16.p0v8i16(<8 x i16>*, i32 immarg, <8 x i1>, <8 x i16>) @@ -334,4 +435,4 @@ declare void @llvm.masked.store.v4i32.p0v4i32(<4 x i32>, <4 x i32>*, i32 immarg, <4 x i1>) declare void @llvm.set.loop.iterations.i32(i32) declare i32 @llvm.loop.decrement.reg.i32.i32.i32(i32, i32) - +declare void @llvm.set.loop.elements.i32(i32)