Index: llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp =================================================================== --- llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp +++ llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp @@ -226,9 +226,10 @@ MachineInstr *Dec = nullptr; MachineInstr *End = nullptr; MachineInstr *VCTP = nullptr; - SmallPtrSet SecondaryVCTPs; + MachineOperand TPNumElements; + SmallPtrSet SecondaryVCTPs; VPTBlock *CurrentBlock = nullptr; - SetVector CurrentPredicate; + SetVector CurrentPredicate; SmallVector VPTBlocks; SmallPtrSet ToRemove; SmallVector, 1> Reductions; @@ -239,7 +240,8 @@ LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI, ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI, const ARMBaseInstrInfo &TII) - : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI), TII(TII) { + : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI), TII(TII), + TPNumElements(MachineOperand::CreateImm(0)) { MF = ML.getHeader()->getParent(); if (auto *MBB = ML.getLoopPreheader()) Preheader = MBB; @@ -291,11 +293,10 @@ SmallVectorImpl &getVPTBlocks() { return VPTBlocks; } - // Return the loop iteration count, or the number of elements if we're tail - // predicating. - MachineOperand &getCount() { - return IsTailPredicationLegal() ? - VCTP->getOperand(1) : Start->getOperand(0); + // Return the operand for the loop start instruction. This will be the loop + // iteration count, or the number of elements if we're tail predicating. + MachineOperand &getLoopStartOperand() { + return IsTailPredicationLegal() ? TPNumElements : Start->getOperand(0); } unsigned getStartOpcode() const { @@ -453,7 +454,8 @@ // of the iteration count, to the loop start instruction. The number of // elements is provided to the vctp instruction, so we need to check that // we can use this register at InsertPt. - Register NumElements = VCTP->getOperand(1).getReg(); + TPNumElements = VCTP->getOperand(1); + Register NumElements = TPNumElements.getReg(); // If the register is defined within loop, then we can't perform TP. // TODO: Check whether this is just a mov of a register that would be @@ -466,9 +468,8 @@ // The element count register maybe defined after InsertPt, in which case we // need to try to move either InsertPt or the def so that the [w|d]lstp can // use the value. - // TODO: On failing to move an instruction, check if the count is provided by - // a mov and whether we can use the mov operand directly. MachineBasicBlock *InsertBB = StartInsertPt->getParent(); + if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) { if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) { if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) { @@ -482,9 +483,21 @@ StartInsertPt); LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef); } else { - LLVM_DEBUG(dbgs() << "ARM Loops: Unable to move element count to loop " - << "start instruction.\n"); - return false; + // If we'd fail to move an instruction and the element count is provided + // by a mov, use the mov operand if it will have the same value at the + // insertion point + MachineOperand Operand = ElemDef->getOperand(1); + if (isMovRegOpcode(ElemDef->getOpcode()) && + RDA.getUniqueReachingMIDef(ElemDef, Operand.getReg()) == + RDA.getUniqueReachingMIDef(StartInsertPt, Operand.getReg())) { + TPNumElements = Operand; + NumElements = TPNumElements.getReg(); + } else { + LLVM_DEBUG(dbgs() + << "ARM Loops: Unable to move element count to loop " + << "start instruction.\n"); + return false; + } } } } Index: llvm/test/CodeGen/Thumb2/LowOverheadLoops/mov-operand.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/Thumb2/LowOverheadLoops/mov-operand.ll @@ -0,0 +1,81 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=thumbv8.1m.main-arm-none-eabi -mattr=+mve.fp -verify-machineinstrs -tail-predication=enabled -o - %s | FileCheck %s +define arm_aapcs_vfpcc void @arm_var_f32_mve(float* %pSrc, i32 %blockSize, float* nocapture %pResult) { +; CHECK-LABEL: .LBB0_1: @ %do.body.i +; CHECK: dlstp.32 lr, r1 +; CHECK-NEXT: vadd.f32 s0, s3, s3 +; CHECK-NEXT: vcvt.f32.u32 s4, s4 +; CHECK-NEXT: vdiv.f32 s0, s0, s4 +; CHECK-NEXT: vmov r3, s0 +; CHECK-NEXT: vmov.i32 q0, #0x0 +; CHECK-NEXT: vdup.32 q1, r3 +; CHECK-NEXT: mov r3, r1 +; CHECK-NEXT: .LBB0_3: @ %do.body +; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: subs r3, #4 +; CHECK-NEXT: vldrw.u32 q2, [r0], #16 +; CHECK-NEXT: vsub.f32 q2, q2, q1 +; CHECK-NEXT: vfma.f32 q0, q2, q2 +; CHECK-NEXT: letp lr, .LBB0_3 +entry: + br label %do.body.i + +do.body.i: ; preds = %entry, %do.body.i + %blkCnt.0.i = phi i32 [ %sub.i, %do.body.i ], [ %blockSize, %entry ] + %sumVec.0.i = phi <4 x float> [ %3, %do.body.i ], [ zeroinitializer, %entry ] + %pSrc.addr.0.i = phi float* [ %add.ptr.i, %do.body.i ], [ %pSrc, %entry ] + %0 = tail call <4 x i1> @llvm.arm.mve.vctp32(i32 %blkCnt.0.i) + %1 = bitcast float* %pSrc.addr.0.i to <4 x float>* + %2 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>* %1, i32 4, <4 x i1> %0, <4 x float> zeroinitializer) + %3 = tail call fast <4 x float> @llvm.arm.mve.add.predicated.v4f32.v4i1(<4 x float> %sumVec.0.i, <4 x float> %2, <4 x i1> %0, <4 x float> %sumVec.0.i) + %sub.i = add nsw i32 %blkCnt.0.i, -4 + %add.ptr.i = getelementptr inbounds float, float* %pSrc.addr.0.i, i32 4 + %cmp.i = icmp sgt i32 %blkCnt.0.i, 4 + br i1 %cmp.i, label %do.body.i, label %arm_mean_f32_mve.exit + +arm_mean_f32_mve.exit: ; preds = %do.body.i + %4 = extractelement <4 x float> %3, i32 3 + %add2.i.i = fadd fast float %4, %4 + %conv.i = uitofp i32 %blockSize to float + %div.i = fdiv fast float %add2.i.i, %conv.i + %.splatinsert = insertelement <4 x float> undef, float %div.i, i32 0 + %.splat = shufflevector <4 x float> %.splatinsert, <4 x float> undef, <4 x i32> zeroinitializer + br label %do.body + +do.body: ; preds = %do.body, %arm_mean_f32_mve.exit + %blkCnt.0 = phi i32 [ %blockSize, %arm_mean_f32_mve.exit ], [ %sub, %do.body ] + %sumVec.0 = phi <4 x float> [ zeroinitializer, %arm_mean_f32_mve.exit ], [ %9, %do.body ] + %pSrc.addr.0 = phi float* [ %pSrc, %arm_mean_f32_mve.exit ], [ %add.ptr, %do.body ] + %5 = tail call <4 x i1> @llvm.arm.mve.vctp32(i32 %blkCnt.0) + %6 = bitcast float* %pSrc.addr.0 to <4 x float>* + %7 = tail call fast <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>* %6, i32 4, <4 x i1> %5, <4 x float> zeroinitializer) + %8 = tail call fast <4 x float> @llvm.arm.mve.sub.predicated.v4f32.v4i1(<4 x float> %7, <4 x float> %.splat, <4 x i1> %5, <4 x float> undef) + %9 = tail call fast <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float> %8, <4 x float> %8, <4 x float> %sumVec.0, <4 x i1> %5) + %sub = add nsw i32 %blkCnt.0, -4 + %add.ptr = getelementptr inbounds float, float* %pSrc.addr.0, i32 4 + %cmp1 = icmp sgt i32 %blkCnt.0, 4 + br i1 %cmp1, label %do.body, label %do.end + +do.end: ; preds = %do.body + %10 = extractelement <4 x float> %9, i32 3 + %add2.i = fadd fast float %10, %10 + %sub2 = add i32 %blockSize, -1 + %conv = uitofp i32 %sub2 to float + %div = fdiv fast float %add2.i, %conv + br label %cleanup + +cleanup: ; preds = %entry, %do.end + store float %div, float* %pResult, align 4 + ret void +} + +declare <4 x float> @llvm.arm.mve.sub.predicated.v4f32.v4i1(<4 x float>, <4 x float>, <4 x i1>, <4 x float>) + +declare <4 x float> @llvm.arm.mve.fma.predicated.v4f32.v4i1(<4 x float>, <4 x float>, <4 x float>, <4 x i1>) + +declare <4 x i1> @llvm.arm.mve.vctp32(i32) + +declare <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>*, i32 immarg, <4 x i1>, <4 x float>) + +declare <4 x float> @llvm.arm.mve.add.predicated.v4f32.v4i1(<4 x float>, <4 x float>, <4 x i1>, <4 x float>) +