Index: include/llvm/IR/IntrinsicsARM.td =================================================================== --- include/llvm/IR/IntrinsicsARM.td +++ include/llvm/IR/IntrinsicsARM.td @@ -765,5 +765,11 @@ def int_arm_neon_udot : Neon_Dot_Intrinsic; def int_arm_neon_sdot : Neon_Dot_Intrinsic; +def int_arm_set_loop_elements : + Intrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty], []>; +def int_arm_loop_end : + Intrinsic<[llvm_i32_ty], + [llvm_i32_ty, llvm_i32_ty], []>; +def int_arm_get_active_mask_4 : Intrinsic<[llvm_v4i1_ty], [llvm_i32_ty], []>; } // end TargetPrefix Index: lib/Target/ARM/ARM.h =================================================================== --- lib/Target/ARM/ARM.h +++ lib/Target/ARM/ARM.h @@ -37,6 +37,8 @@ Pass *createARMParallelDSPPass(); +Pass *createARMHardwareLoopsPass(); +FunctionPass *createARMFinaliseHardwareLoopsPass(); FunctionPass *createARMISelDag(ARMBaseTargetMachine &TM, CodeGenOpt::Level OptLevel); FunctionPass *createA15SDOptimizerPass(); @@ -62,6 +64,7 @@ void initializeARMParallelDSPPass(PassRegistry &); +void initializeARMHardwareLoopsPass(PassRegistry &); void initializeARMLoadStoreOptPass(PassRegistry &); void initializeARMPreAllocLoadStoreOptPass(PassRegistry &); void initializeARMCodeGenPreparePass(PassRegistry &); Index: lib/Target/ARM/ARMFinalizeHardwareLoops.cpp =================================================================== --- /dev/null +++ lib/Target/ARM/ARMFinalizeHardwareLoops.cpp @@ -0,0 +1,256 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "ARM.h" +#include "ARMBaseInstrInfo.h" +#include "ARMBaseRegisterInfo.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineLoopInfo.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" + +using namespace llvm; + +#define DEBUG_TYPE "arm-finalise-hardware-loops" +#define ARM_FINALISE_HW_LOOPS_NAME "ARM hardware loop finalisation pass" + +namespace { + + class ARMFinaliseHWLoops : public MachineFunctionPass { + const ARMBaseInstrInfo *TII = nullptr; + + public: + static char ID; + + ARMFinaliseHWLoops() : MachineFunctionPass(ID) { } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired(); + MachineFunctionPass::getAnalysisUsage(AU); + } + + bool runOnMachineFunction(MachineFunction &MF) override; + + bool ProcessLoop(MachineLoop *ML); + + void Expand(MachineInstr *Start, MachineInstr *Dec, MachineInstr *End, + MachineInstr *ActiveMask, + SmallVectorImpl &Predicated); + + MachineFunctionProperties getRequiredProperties() const override { + return MachineFunctionProperties().set( + MachineFunctionProperties::Property::NoVRegs); + } + + StringRef getPassName() const override { + return ARM_FINALISE_HW_LOOPS_NAME; + } + }; +} + +char ARMFinaliseHWLoops::ID = 0; + +bool ARMFinaliseHWLoops::runOnMachineFunction(MachineFunction &MF) { + auto &MLI = getAnalysis(); + TII = + static_cast(MF.getSubtarget().getInstrInfo()); + LLVM_DEBUG(dbgs() << " ------- ARM HWLOOPS on " << MF.getName() << "\n"); + + bool Changed = false; + for (auto ML : MLI) { + if (!ML->getExitingBlock() || !ML->getHeader() || !ML->getLoopLatch()) + continue; + Changed |= ProcessLoop(ML); + } + return Changed; +} + +bool ARMFinaliseHWLoops::ProcessLoop(MachineLoop *ML) { + + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Processing " << *ML); + auto SearchForStart = [](MachineBasicBlock *MBB) -> MachineInstr* { + for (auto &MI : *MBB) { + if (MI.getOpcode() == ARM::t2LoopStart) + return &MI; + } + return nullptr; + }; + + MachineInstr *Start = nullptr; + + if (auto *Preheader = ML->getLoopPreheader()) { + Start = SearchForStart(Preheader); + if (!Start) { + if (Preheader->pred_size() == 1) { + MachineBasicBlock *PrePreheader = *Preheader->pred_begin(); + Start = SearchForStart(PrePreheader); + } + } + } + + if (!Start) + return false; + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Found Loop Start: " << *Start); + + auto IsLoopDec = [](MachineInstr &MI) { + return MI.getOpcode() == ARM::t2LoopDec; + }; + + auto IsLoopEnd = [](MachineInstr &MI) { + return MI.getOpcode() == ARM::t2LoopEnd; + }; + + auto IsActiveMask = [](MachineInstr &MI) { + return MI.getOpcode() == ARM::t2ActiveMask; + }; + + auto IsPredicated = [](MachineInstr &MI) { + switch (MI.getOpcode()) { + default: + break; + case ARM::VMSTR32: + case ARM::VMLDR32: + return true; + } + return false; + }; + + MachineInstr *Dec = nullptr; + MachineInstr *End = nullptr; + MachineInstr *ActiveMask = nullptr; + bool FoundPredicated = false; + bool IsProfitable = true; + SmallVector Predicated; + + for (auto *MBB : ML->getBlocks()) { + for (auto &MI : *MBB) { + // TODO: For scalar loops, check for any instructions that means a + // low-overhead loop wouldn't be profitable. Should we bail if LR has + // been spilt? We'd still need a register to control the loop count but + // the loop index may increase whereas LE(TP) decrement it... + // + // Not inserting a low-overhead loop for a vector loop is not really + // option here as we'd either: + // - Need to reconstruct a vector loop and a scalar epilogue. + // - Try to use VIDUP and create a VPT block to predicate the lanes, + // which would require using a Q register, all of which may be already + // allocated, for the VIDUP result. It looks like VIDUP wouldn't even be + // helpful for 16xi8 vectors because the instruction can only increment + // by a maximum of 8. + + if (IsLoopDec(MI)) + Dec = &MI; + else if (IsLoopEnd(MI)) + End = &MI; + else if (IsActiveMask(MI)) + ActiveMask = &MI; + else if (IsPredicated(MI)) { + FoundPredicated = true; + Predicated.push_back(&MI); + } + } + } + + // Check that we've found the necessary components + if (!Dec || !End || (FoundPredicated && !ActiveMask)) + return false; + + if (!IsProfitable) + return false; + + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Found Loop Dec: " << *Dec); + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Found Loop End: " << *End); + + // TODO: Verify that the cmp and br from the WLS either branch to the header + // or the exit block. + // TODO: Verify that the cmp and br from the LE either branch to the header + // or the exit block. + // TODO: Verify that all predicated instructions are using ActiveMask. + + Expand(Start, Dec, End, ActiveMask, Predicated); + return true; +} + +void ARMFinaliseHWLoops::Expand(MachineInstr *Start, MachineInstr *Dec, + MachineInstr *End, MachineInstr *ActiveMask, + SmallVectorImpl &Predicated) { + auto ExpandLoopStart = [this](MachineInstr *Start) { + MachineBasicBlock &MBB = *Start->getParent(); + MachineInstrBuilder MIB = BuildMI(MBB, Start, Start->getDebugLoc(), + TII->get(ARM::t2WLSTP)); + MIB.addDef(ARM::LR); + unsigned OpIdx = 0; + MIB.add(Start->getOperand(OpIdx++)); + MIB.add(Start->getOperand(OpIdx++)); + MIB.add(Start->getOperand(OpIdx++)); + MIB.add(predOps(ARMCC::AL)); + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Inserted WLSTP: " << *MIB << "\n"); + Start->eraseFromParent(); + }; + + auto ExpandLoad = [this](MachineInstr *MI) { + MachineBasicBlock &MBB = *MI->getParent(); + MachineInstrBuilder MIB = BuildMI(MBB, MI, MI->getDebugLoc(), + TII->get(ARM::t2VLDRW)); + unsigned OpIdx = 0; + MIB.add(MI->getOperand(OpIdx++)); + MIB.add(MI->getOperand(OpIdx++)); + MIB.add(predOps(ARMCC::AL)); + MI->eraseFromParent(); + }; + + auto ExpandStore = [this](MachineInstr *MI) { + MachineBasicBlock &MBB = *MI->getParent(); + MachineInstrBuilder MIB = BuildMI(MBB, MI, MI->getDebugLoc(), + TII->get(ARM::t2VSTRW)); + unsigned OpIdx = 0; + MIB.add(MI->getOperand(OpIdx++)); + MIB.add(MI->getOperand(OpIdx++)); + MIB.add(predOps(ARMCC::AL)); + MI->eraseFromParent(); + }; + + auto RemoveActiveMask = [](MachineInstr *MI) { + MI->eraseFromParent(); + }; + + // Combine the LoopDec and LoopEnd instructions into LE(TP). + auto ExpandLoopEnd = [this](MachineInstr *Dec, MachineInstr *End) { + // TODO: Check and handle the causes where LR is spilt between Dec and End. + MachineBasicBlock &MBB = *End->getParent(); + MachineInstrBuilder MIB = BuildMI(MBB, End, End->getDebugLoc(), + TII->get(ARM::t2LETP)); + MIB.addDef(ARM::LR); + unsigned OpIdx = 0; + MIB.add(End->getOperand(OpIdx++)); + MIB.add(End->getOperand(OpIdx++)); + MIB.add(predOps(ARMCC::AL)); + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Inserted LETP: " << *MIB << "\n"); + End->eraseFromParent(); + Dec->eraseFromParent(); + }; + + ExpandLoopStart(Start); + ExpandLoopEnd(Dec, End); + + if (ActiveMask) { + for (auto *MI : Predicated) { + if (MI->mayLoad()) + ExpandLoad(MI); + else if (MI->mayStore()) + ExpandStore(MI); + else + llvm_unreachable("unhandled predicated instruction"); + } + RemoveActiveMask(ActiveMask); + } +} + +FunctionPass *llvm::createARMFinaliseHardwareLoopsPass() { + return new ARMFinaliseHWLoops(); +} Index: lib/Target/ARM/ARMHardwareLoops.cpp =================================================================== --- /dev/null +++ lib/Target/ARM/ARMHardwareLoops.cpp @@ -0,0 +1,378 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "ARM.h" +#include "llvm/Pass.h" +#include "llvm/PassRegistry.h" +#include "llvm/PassSupport.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" + +#define DEBUG_TYPE "arm-hardware-loops" + +#define ARM_HW_LOOPS_NAME "ARM Hardware Loops" + +using namespace llvm; + +namespace { + + class HardwareLoop { + Loop *L = nullptr; + const SCEV *TotalEltSCEV = nullptr; + ConstantInt *Factor; + Instruction *Predicate = nullptr; + ScalarEvolution &SE; + bool IsScalar = false; + Module *M = nullptr; + IntegerType *Int32Ty = nullptr; + + public: + HardwareLoop() = delete; + + HardwareLoop(Loop *L, const SCEV *Elts, ConstantInt *Factor, + Instruction *Pred, ScalarEvolution &SE) : + L(L), TotalEltSCEV(Elts), Factor(Factor), Predicate(Pred), SE(SE) { + IsScalar = Factor->equalsInt(1); + M = L->getHeader()->getParent()->getParent(); + Int32Ty = Type::getInt32Ty(M->getContext()); + } + + void Insert(); + void HandleVector(); + }; + + class ARMHardwareLoops : public LoopPass { + ScalarEvolution *SE = nullptr; + DominatorTree *DT = nullptr; + + public: + static char ID; + + ARMHardwareLoops() : LoopPass(ID) { } + + bool doInitialization(Loop *L, LPPassManager &LPM) override { + return true; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + LoopPass::getAnalysisUsage(AU); + AU.addRequired(); + AU.addRequired(); + } + + bool runOnLoop(Loop *L, LPPassManager &) override; + }; +} + +static const SCEV* CalcTotalElts(ConstantInt *Factor, + const SCEV *TripCount, + ScalarEvolution &SE) { + if (Factor->equalsInt(1)) + return TripCount; + + const SCEV *FactorSCEV = SE.getSCEV(Factor); + IntegerType *Int32Ty = Factor->getType(); + + if (auto *Count = dyn_cast(TripCount)) { + const SCEV *Elts = SE.getMulExpr(TripCount, FactorSCEV); + unsigned Rem = Count->getAPInt().urem(Factor->getZExtValue()); + if (Rem == 0) + return Elts; + else + return SE.getAddExpr(Elts, SE.getSCEV(ConstantInt::get(Int32Ty, Rem))); + } + + auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr* { + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: VisitAdd " << *S << "\n"); + if (auto *Const = dyn_cast(S->getOperand(0))) { + if (Const->getAPInt() != -Factor->getValue()) + return nullptr; + } else + return nullptr; + return dyn_cast(S->getOperand(1)); + }; + + auto VisitMul = [&](const SCEVMulExpr *S) -> const SCEVUDivExpr* { + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: VisitMul " << *S << "\n"); + if (auto *Const = dyn_cast(S->getOperand(0))) { + if (Const->getValue() != Factor) + return nullptr; + } else + return nullptr; + return dyn_cast(S->getOperand(1)); + }; + + auto VisitDiv = [&](const SCEVUDivExpr *S) -> const SCEV* { + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: VisitDiv " << *S << "\n"); + if (auto *Const = dyn_cast(S->getRHS())) { + if (Const->getValue() != Factor) + return nullptr; + } else + return nullptr; + + if (auto *RoundUp = dyn_cast(S->getLHS())) { + if (auto *Const = dyn_cast(RoundUp->getOperand(0))) { + if (Const->getAPInt() != (Factor->getValue() - 1)) + return nullptr; + } else + return nullptr; + + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Elements: " + << *RoundUp->getOperand(1) << "\n"); + return RoundUp->getOperand(1); + } + return nullptr; + }; + + // (1 + ((-4 + (4 * ((3 + %N) /u 4))) /u 4)) + if (auto *TC = dyn_cast(TripCount)) + if (auto *Div = dyn_cast(TC->getOperand(1))) + if (auto *Add = dyn_cast(Div->getLHS())) + if (auto *Mul = VisitAdd(Add)) + if (auto *Div = VisitMul(Mul)) + if (auto *Elts = VisitDiv(Div)) + return Elts; + + return nullptr; +} + +bool ARMHardwareLoops::runOnLoop(Loop *L, LPPassManager &LPM) { + SE = &getAnalysis().getSE(); + DT = &getAnalysis().getDomTree(); + Function *F = L->getHeader()->getParent(); + Module *M = F->getParent(); + + LLVM_DEBUG(dbgs() << "---- ARM HWLOOPS on: " << F->getName() << "\n"); + + IntegerType *Int32Ty = Type::getInt32Ty(M->getContext()); + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Looking at loop: " << *L << "\n"); + // Definiton: loops that satisfy the following: + // - has a preheader. + // - inner most loop. + // - trip count can be calculated. + // - has a single exit block. + // - lcssa form? + auto ValidLoopStructure = [&](Loop *L) { + if (!L->getSubLoops().empty() || !L->getLoopPreheader() || + !L->getLoopPreheader()->getUniquePredecessor() || + !L->getExitBlock() || L->getNumBlocks() != 1 || + //L->isLCSSAForm(*DT) || + !SE->getBackedgeTakenCount(L)) + return false; + + const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); + if (isa(BackedgeTakenCount)) { + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Can't compute backedge count.\n"); + return false; + } + + const SCEV *TripCountSCEV = + SE->getAddExpr(BackedgeTakenCount, + SE->getOne(BackedgeTakenCount->getType())); + + return SE->getUnsignedRangeMax(TripCountSCEV).getBitWidth() <= 32; + }; + + if (!ValidLoopStructure(L)) { + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Invalid loop structure.\n"); + return false; + } + + auto CollectLoopStuff = [&](Loop *L) -> HardwareLoop* { + VectorType *VecTy = nullptr; + Instruction *Predicate = nullptr; + + // Inspect the instructions for vector operations. + for (auto *BB : L->getBlocks()) { + for (auto &I : *BB) { + if (!isa(I.getType())) + continue; + + auto *VTy = cast(I.getType()); + if (!VecTy) + VecTy = VTy; + else if (VecTy->getNumElements() != VTy->getNumElements()) + return nullptr; + + if (auto *Call = dyn_cast(&I)) { + if (Call->getIntrinsicID() == Intrinsic::masked_load || + Call->getIntrinsicID() == Intrinsic::masked_store) { + if (!Predicate) + Predicate = cast(Call->getOperand(2)); + else if (Predicate != cast(Call->getOperand(2))) + return nullptr; + } + } + } + } + + // Vector loops that have had their tail folded into the body will contain + // predicted load/store intrinsics with an lane masked derived from the + // following: + // + // vector.ph: + // %n.rnd.up = add i32 %N, 3 + // %n.vec = and i32 %n.rnd.up, -4 + // %trip.count.minus.1 = add i32 %N, -1 + // %broadcast.splatinsert11 = insertelement <4 x i32> undef, + // i32 %trip.count.minus.1, i32 0 + // %broadcast.splat12 = shufflevector <4 x i32> %broadcast.splatinsert11, + // <4 x i32> undef, <4 x i32> zeroinitializer + // br label %vector.body + // + // vector.body: + // %index = phi i32 [ 0, %vector.ph ], [ %index.next, %vector.body ] + // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0 + // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert, + // <4 x i32> undef, <4 x i32> zeroinitializer + // %induction = add <4 x i32> %broadcast.splat, + // %10 = icmp ule <4 x i32> %induction, %broadcast.splat12 + // + // Where %index is the induction variable that controls the loop trip count. + if (Predicate) { + LLVM_DEBUG(dbgs() << "ARM HWLOOP: Found predicate: " << *Predicate << "\n"); + } + + ConstantInt *Factor = VecTy ? + ConstantInt::get(Int32Ty, VecTy->getNumElements()) : + ConstantInt::get(Int32Ty, 1); + + const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); + const SCEV *TripCountSCEV = + SE->getAddExpr(BackedgeTakenCount, + SE->getOne(BackedgeTakenCount->getType())); + const SCEV *Elts = CalcTotalElts(Factor, TripCountSCEV, *SE); + + return new HardwareLoop(L, Elts, Factor, Predicate, *SE); + }; + + if (auto *HWLoop = CollectLoopStuff(L)) { + HWLoop->Insert(); + delete HWLoop; + } + + return true; +} + +void HardwareLoop::Insert() { + BasicBlock *Preheader = L->getLoopPreheader(); + BasicBlock *PrePreheader = Preheader->getUniquePredecessor(); + BasicBlock *Header = L->getHeader(); + BasicBlock *Exit = L->getExitBlock(); + const DataLayout &DL = M->getDataLayout(); + + LLVM_DEBUG(dbgs() << "ARM HWLOOP:\n" + << " - PrePreheader: " << PrePreheader->getName() << "\n" + << " - Preheader: " << Preheader->getName() << "\n" + << " - Header: " << Header->getName() << "\n" + << " - Exit: " << Exit->getName() << "\n" + << " - Elements: " << *TotalEltSCEV + << "\n"); + + auto InsertSetup = [this](BasicBlock *Preheader, Instruction *LoopGuard, + Value *NumElts, Value *Factor) { + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Loop Guard: " << *LoopGuard << "\n"); + IRBuilder<> Builder(SE.getContext()); + Builder.SetInsertPoint(LoopGuard); + + Function *Setup = + Intrinsic::getDeclaration(M, Intrinsic::arm_set_loop_elements); + Value *Ops[] = { NumElts, Factor }; + Instruction *Call = Builder.CreateCall(Setup, Ops); + + Value *Cmp = Builder.CreateICmpNE(Call, ConstantInt::get(Int32Ty, 0)); + LoopGuard->setOperand(0, Cmp); + + if (LoopGuard->getSuccessor(0) != Preheader) { + LoopGuard->setSuccessor(1, LoopGuard->getSuccessor(0)); + LoopGuard->setSuccessor(0, Preheader); + } + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Updated loop guard: " + << *LoopGuard << "\n"); + }; + + auto InsertDec = [this](BasicBlock *Header, Value *NumElts, Value *Factor) { + Instruction *BackBranch = Header->getTerminator(); + IRBuilder<> Builder(SE.getContext()); + Builder.SetInsertPoint(BackBranch); + + Function *Decrement = Intrinsic::getDeclaration(M, Intrinsic::arm_loop_end); + Value *Ops[] = { NumElts, Factor }; + Instruction *Call = Builder.CreateCall(Decrement, Ops); + Value *Cmp = Builder.CreateICmpSGT(Call, ConstantInt::get(Int32Ty, 0)); + + if (BackBranch->getSuccessor(0) != Header) { + BackBranch->setSuccessor(1, BackBranch->getSuccessor(0)); + BackBranch->setSuccessor(0, Header); + } + + // TODO: Try to remove the original compare chain: phi, add, cmp? + + BackBranch->setOperand(0, Cmp); + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Loop End: " << *Call << "\n"); + return Call; + }; + + auto InsertElts = [this](BasicBlock *Preheader, BasicBlock *Header, + Value *NumElts, Value *EltsRem) { + IRBuilder<> Builder(SE.getContext()); + Builder.SetInsertPoint(Header->getFirstNonPHI()); + PHINode *Index = Builder.CreatePHI(NumElts->getType(), 2); + Index->addIncoming(NumElts, Preheader); + Index->addIncoming(EltsRem, Header); + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Index PHI: " << *Index << "\n"); + return Index; + }; + + auto InsertActiveMask = [this](Value *Elts) { + IRBuilder<> Builder(SE.getContext()); + Builder.SetInsertPoint(Predicate); + Function *F = + Intrinsic::getDeclaration(M, Intrinsic::arm_get_active_mask_4); + Value *Ops[] = { Elts }; + Instruction *ActiveMask = Builder.CreateCall(F, Ops); + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Active Lane Mask: " + << *ActiveMask << "\n"); + Predicate->replaceAllUsesWith(ActiveMask); + }; + + SCEVExpander Expander(SE, DL, "HWLoopExpander"); + auto *LoopGuard = cast(PrePreheader->getTerminator()); + Value *TotalElts = Expander.expandCodeFor(TotalEltSCEV, + TotalEltSCEV->getType(), + LoopGuard); + InsertSetup(Preheader, LoopGuard, TotalElts, Factor); + Instruction *LoopEnd = InsertDec(Header, TotalElts, Factor); + PHINode *Elts = InsertElts(Preheader, Header, TotalElts, LoopEnd); + LoopEnd->setOperand(0, Elts); + LLVM_DEBUG(dbgs() << "ARM HWLOOPS: Updated Loop End: " << *LoopEnd << "\n"); + + if (!IsScalar && Predicate) + InsertActiveMask(Elts); +} + +INITIALIZE_PASS_BEGIN(ARMHardwareLoops, DEBUG_TYPE, ARM_HW_LOOPS_NAME, false, false) +INITIALIZE_PASS_END(ARMHardwareLoops, DEBUG_TYPE, ARM_HW_LOOPS_NAME, false, false) + +char ARMHardwareLoops::ID = 0; + +Pass *llvm::createARMHardwareLoopsPass() { + return new ARMHardwareLoops(); +} Index: lib/Target/ARM/ARMISelDAGToDAG.cpp =================================================================== --- lib/Target/ARM/ARMISelDAGToDAG.cpp +++ lib/Target/ARM/ARMISelDAGToDAG.cpp @@ -2985,7 +2985,65 @@ unsigned CC = (unsigned) cast(N2)->getZExtValue(); + // Handle loops. + if (InFlag.getOpcode() == ARMISD::CMP && + InFlag.getOperand(0).getOpcode() == ISD::INTRINSIC_W_CHAIN) { + SDValue Int = InFlag.getOperand(0); + uint64_t ID = cast(Int->getOperand(1))->getZExtValue(); + if (ID != Intrinsic::arm_loop_end) + return; + + // TODO: Check that the CMP is in the form we expect. + + SDValue Elements = Int.getOperand(2); + SDValue Size = CurDAG->getTargetConstant( + cast(Int.getOperand(3))->getZExtValue(), dl, MVT::i32); + + SDValue Args[] = { Elements, Size, Int.getOperand(0) }; + SDNode *LoopDec = + CurDAG->getMachineNode(ARM::t2LoopDec, dl, + CurDAG->getVTList(MVT::i32, MVT::Other), + Args); + ReplaceUses(Int.getNode(), LoopDec); + + SDValue EndArgs[] = { SDValue(LoopDec, 0), N1, Chain }; + SDNode *LoopEnd = + CurDAG->getMachineNode(ARM::t2LoopEnd, dl, MVT::Other, EndArgs); + + ReplaceUses(N, LoopEnd); + CurDAG->RemoveDeadNode(N); + CurDAG->RemoveDeadNode(InFlag.getNode()); + CurDAG->RemoveDeadNode(Int.getNode()); + return; + } + if (InFlag.getOpcode() == ARMISD::CMPZ) { + // Handle loops. + if (InFlag.getOperand(0).getOpcode() == ISD::INTRINSIC_W_CHAIN) { + SDValue Int = InFlag.getOperand(0); + uint64_t ID = cast(Int->getOperand(1))->getZExtValue(); + + if (ID == Intrinsic::arm_set_loop_elements) { + Chain = Int.getOperand(0); + SDValue Elements = Int.getOperand(2); + SDValue Size = CurDAG->getTargetConstant( + cast(Int.getOperand(3))->getZExtValue(), dl, MVT::i32); + SDValue Args[] = { Size, Elements, N1, Chain }; + SDNode *LoopStart = + CurDAG->getMachineNode(ARM::t2LoopStart, dl, MVT::Other, Args); + ReplaceUses(N, LoopStart); + + SDValue LR = + CurDAG->getCopyFromReg(SDValue(LoopStart, 0), dl, ARM::LR, + MVT::i32, SDValue(LoopStart, 0)); + ReplaceUses(Int, LR); + CurDAG->RemoveDeadNode(N); + CurDAG->RemoveDeadNode(InFlag.getNode()); + CurDAG->RemoveDeadNode(Int.getNode()); + return; + } + } + bool SwitchEQNEToPLMI; SelectCMPZ(InFlag.getNode(), SwitchEQNEToPLMI); InFlag = N->getOperand(4); Index: lib/Target/ARM/ARMISelLowering.cpp =================================================================== --- lib/Target/ARM/ARMISelLowering.cpp +++ lib/Target/ARM/ARMISelLowering.cpp @@ -526,6 +526,10 @@ setOperationAction(ISD::FMAXNUM, MVT::f16, Legal); } + const MVT pTypes[] = { MVT::v16i1, MVT::v8i1, MVT::v4i1 }; + for (auto VT : pTypes) + addRegisterClass(VT, &ARM::VCCRRegClass); + for (MVT VT : MVT::vector_valuetypes()) { for (MVT InnerVT : MVT::vector_valuetypes()) { setTruncStoreAction(VT, InnerVT, Expand); @@ -12684,6 +12688,7 @@ SDValue RHS = Cmp.getOperand(1); SDValue Chain = N->getOperand(0); SDValue BB = N->getOperand(1); + SDValue ARMcc = N->getOperand(2); ARMCC::CondCodes CC = (ARMCC::CondCodes)cast(ARMcc)->getZExtValue(); Index: lib/Target/ARM/ARMInstrThumb2.td =================================================================== --- lib/Target/ARM/ARMInstrThumb2.td +++ lib/Target/ARM/ARMInstrThumb2.td @@ -1235,6 +1235,89 @@ 4, IIC_iALUi, []>, Sched<[WriteALU, ReadALU]>; +let isBranch = 1, isTerminator = 1, hasSideEffects = 1 in { + def t2LoopStart : + t2PseudoInst<(outs), + (ins imm0_7:$size, rGPR:$elts, brtarget:$target), + 4, IIC_Br, []>, Sched<[WriteBr]>; + def t2WLSTP : + T2I<(outs GPRlr:$Rm), (ins imm0_7:$size, GPRlr:$elts, brtarget:$target), IIC_Br, + "wlstp.$size", "\t$Rm, $elts, $target", []>, Sched<[WriteBr]> { + bits<5> Rm; + bits<2> size; + bits<5> elts; + bits<12> target; + } +} + +def t2LoopDec : + t2PseudoInst<(outs GPRlr:$Rm), + (ins GPRlr:$Rn, imm0_7:$size), + 4, IIC_Br, + []>, + Sched<[WriteBr]>; + +let isBranch = 1, isTerminator = 1, hasSideEffects = 1 in { + def t2LoopEnd : + t2PseudoInst<(outs), + (ins GPRlr:$elts, brtarget:$target), + 4, IIC_Br, []>, Sched<[WriteBr]>; + def t2LETP : + T2I<(outs GPRlr:$Rm), (ins GPRlr:$elts, brtarget:$target), IIC_Br, + "letp", "\t$target", []>, Sched<[WriteBr]> { + bits<5> Rm; + bits<5> elts; + bits<12> target; + } +} + +def t2ActiveMask : + t2PseudoInst<(outs VCCR:$pred), + (ins rGPR:$elts), + 4, IIC_Br, + [(set VCCR:$pred, (int_arm_get_active_mask_4 rGPR:$elts))]>, + Sched<[WriteBr]>; + +def nonext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getExtensionType() == ISD::NON_EXTLOAD; +}]>; +def nontrunc_masked_store : + PatFrag<(ops node:$val, node:$ptr, node:$pred), + (masked_store node:$val, node:$ptr, node:$pred), [{ + return !cast(N)->isTruncatingStore(); +}]>; + +def VMLDR32 : t2PseudoInst<(outs QPR:$vec), + (ins t2addrmode_imm12:$addr, VCCR:$pred, i32imm:$imm), 4, + NoItinerary, []>, Sched<[WriteLd]>; +let mayLoad = 1 in +def t2VLDRW : T2I<(outs QPR:$Rm), + (ins rGPR:$addr), NoItinerary, + "vldrw", "\t$Rm, [$addr]", []>, Sched<[WriteLd]> { + bits<6> Rm; + bits<5> addr; +} + +def VMSTR32 : t2PseudoInst<(outs), + (ins QPR:$vec, t2addrmode_imm12:$addr, VCCR:$pred, i32imm:$imm), 4, + NoItinerary, []>, Sched<[WriteST]>; +let mayStore = 1 in +def t2VSTRW : T2I<(outs), + (ins QPR:$Rm, rGPR:$addr), NoItinerary, + "vstrw", "\t$Rm, [$addr]", []>, Sched<[WriteST]> { + bits<6> Rm; + bits<5> addr; +} + +def : Pat<(v4i32 (nonext_masked_load rGPR:$addr, (v4i1 VCCR:$pred), undef)), + (v4i32 (VMLDR32 rGPR:$addr, (i32 0), (v4i1 VCCR:$pred), (i32 2)))>; +def : Pat<(nontrunc_masked_store (v4i32 QPR:$vec), rGPR:$addr, (v4i1 VCCR:$pred)), + (VMSTR32 (v4i32 QPR:$vec), rGPR:$addr, (i32 0), (v4i1 VCCR:$pred), + (i32 2))>; + + //===----------------------------------------------------------------------===// // Load / store Instructions. Index: lib/Target/ARM/ARMRegisterInfo.td =================================================================== --- lib/Target/ARM/ARMRegisterInfo.td +++ lib/Target/ARM/ARMRegisterInfo.td @@ -254,6 +254,11 @@ let DiagnosticString = "operand must be a register sp"; } +def GPRlr : RegisterClass<"ARM", [i32], 32, (add LR)>; + +def VPR : ARMReg<32, "vpr">; +def VCCR : RegisterClass<"ARM", [i32, v16i1, v8i1, v4i1], 32, (add VPR)>; + // restricted GPR register class. Many Thumb2 instructions allow the full // register range for operands, but have undefined behaviours when PC // or SP (R13 or R15) are used. The ARM ISA refers to these operands Index: lib/Target/ARM/ARMTargetMachine.cpp =================================================================== --- lib/Target/ARM/ARMTargetMachine.cpp +++ lib/Target/ARM/ARMTargetMachine.cpp @@ -89,6 +89,7 @@ initializeARMLoadStoreOptPass(Registry); initializeARMPreAllocLoadStoreOptPass(Registry); initializeARMParallelDSPPass(Registry); + initializeARMHardwareLoopsPass(Registry); initializeARMCodeGenPreparePass(Registry); initializeARMConstantIslandsPass(Registry); initializeARMExecutionDomainFixPass(Registry); @@ -409,6 +410,9 @@ TargetPassConfig::addIRPasses(); + addPass(createARMHardwareLoopsPass()); + addPass(createDeadCodeEliminationPass()); + // Run the parallel DSP pass. if (getOptLevel() == CodeGenOpt::Aggressive) addPass(createARMParallelDSPPass()); @@ -493,6 +497,8 @@ addPass(createBreakFalseDeps()); } + addPass(createARMFinaliseHardwareLoopsPass()); + // Expand some pseudo instructions into multiple instructions to allow // proper scheduling. addPass(createARMExpandPseudoPass()); Index: lib/Target/ARM/ARMTargetTransformInfo.h =================================================================== --- lib/Target/ARM/ARMTargetTransformInfo.h +++ lib/Target/ARM/ARMTargetTransformInfo.h @@ -180,6 +180,10 @@ bool UseMaskForCond = false, bool UseMaskForGaps = false); + bool isLegalMaskedStore(Type *Ty) { return true; } + + bool isLegalMaskedLoad(Type *Ty) { return true; } + void getUnrollingPreferences(Loop *L, ScalarEvolution &SE, TTI::UnrollingPreferences &UP); Index: lib/Target/ARM/CMakeLists.txt =================================================================== --- lib/Target/ARM/CMakeLists.txt +++ lib/Target/ARM/CMakeLists.txt @@ -29,7 +29,9 @@ ARMConstantPoolValue.cpp ARMExpandPseudoInsts.cpp ARMFastISel.cpp + ARMFinalizeHardwareLoops.cpp ARMFrameLowering.cpp + ARMHardwareLoops.cpp ARMHazardRecognizer.cpp ARMInstructionSelector.cpp ARMISelDAGToDAG.cpp Index: test/CodeGen/Thumb2/mve-tailpred.ll =================================================================== --- /dev/null +++ test/CodeGen/Thumb2/mve-tailpred.ll @@ -0,0 +1,78 @@ +; RUN: opt -mtriple=thumbv8 -mcpu=cortex-a72 %s -arm-hardware-loops -dce -S -o - | FileCheck %s --check-prefix=OPT +; RUN: llc -mtriple=thumbv8 -mcpu=cortex-a72 %s -S -o - | FileCheck %s --check-prefix=LLC + +; CHECK-OPT-LABEL: mul_N +; CHECK-OPT: %0 = call i32 @llvm.arm.while.setup(i32 %N, i32 4) +; CHECK-OPT: br i1 %1, label %vector.ph, label %for.cond.cleanup + +; CHECK-OPT: vector.ph: +; CHECK-OPT: br label %vector.body + +; CHECK-OPT: vecctor.body: +; CHECK-OPT: %index = phi i32 [ 0, %vector.ph ], [ %index.next, %vector.body ] +; CHECK-OPT: %2 = phi i32 [ %N, %vector.ph ], [ %11, %vector.body ] +; CHECK-OPT: %3 = getelementptr inbounds i32, i32* %a, i32 %index +; CHECK-OPT: %4 = call <4 x i1> @llvm.arm.get.active.mask.4(i32 %2 +; CHECK-OPT: %wide.masked.load = tail call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %5, i32 4, <4 x i1> %4, <4 x i32> undef) +; CHECK-OPT: %wide.masked.load12 = tail call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %7, i32 4, <4 x i1> %4, <4 x i32> undef) +; CHECK-OPT: %8 = mul nsw <4 x i32> %wide.masked.load12, %wide.masked.load +; CHECK-OPT: tail call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %8, <4 x i32>* %10, i32 4, <4 x i1> %4) +; CHECK-OPT: %index.next = add i32 %index, 4 +; CHECK-OPT: %11 = call i32 @llvm.arm.loop.end(i32 %2, i32 4) +; CHECk-OPT: %12 = icmp sgt i32 %11, 0 +; CHECK-OPT: br i1 %12, label %vector.body, label %for.cond.cleanup + +; CHECK-LLC-LABEL: mul_N +; CHECK-LLC:: wlstp.#4 lr, r3, .LBB0_3 +; CHECK-LLC: .LBB0_2: +; CHECK-LLC: vldrw q8, [r0] +; CHECK-LLC: vldrw q9, [r1] +; CHECK-LLC: adds r0, #16 +; CHECK-LLC: adds r1, #16 +; CHECK-LLC: adds r3, #4 +; CHECK-LLC: vmul.i32 q8, q9, q8 +; CHECK-LLC: vstrw q8, [r2] +; CHECK-LLC: adds r2, #16 +; CHECK-LLC: letp .LBB0_2 +; CHECK-LLC: b .LBB0_3 + +define dso_local arm_aapcs_vfpcc void @mul_N(i32* noalias nocapture readonly %a, i32* noalias nocapture readonly %b, i32* noalias nocapture %c, i32 %N) { +entry: + %cmp8 = icmp eq i32 %N, 0 + br i1 %cmp8, label %for.cond.cleanup, label %vector.ph + +vector.ph: + %n.rnd.up = add i32 %N, 3 + %n.vec = and i32 %n.rnd.up, -4 + %trip.count.minus.1 = add i32 %N, -1 + %broadcast.splatinsert10 = insertelement <4 x i32> undef, i32 %trip.count.minus.1, i32 0 + %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10, <4 x i32> undef, <4 x i32> zeroinitializer + br label %vector.body + +vector.body: + %index = phi i32 [ 0, %vector.ph ], [ %index.next, %vector.body ] + %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0 + %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert, <4 x i32> undef, <4 x i32> zeroinitializer + %induction = add <4 x i32> %broadcast.splat, + %0 = getelementptr inbounds i32, i32* %a, i32 %index + %1 = icmp ule <4 x i32> %induction, %broadcast.splat11 + %2 = bitcast i32* %0 to <4 x i32>* + %wide.masked.load = tail call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %2, i32 4, <4 x i1> %1, <4 x i32> undef) + %3 = getelementptr inbounds i32, i32* %b, i32 %index + %4 = bitcast i32* %3 to <4 x i32>* + %wide.masked.load12 = tail call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %4, i32 4, <4 x i1> %1, <4 x i32> undef) + %5 = mul nsw <4 x i32> %wide.masked.load12, %wide.masked.load + %6 = getelementptr inbounds i32, i32* %c, i32 %index + %7 = bitcast i32* %6 to <4 x i32>* + tail call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %5, <4 x i32>* %7, i32 4, <4 x i1> %1) + %index.next = add i32 %index, 4 + %8 = icmp eq i32 %index.next, %n.vec + br i1 %8, label %for.cond.cleanup, label %vector.body + +for.cond.cleanup: + ret void +} + +declare <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>*, i32 immarg, <4 x i1>, <4 x i32>) + +declare void @llvm.masked.store.v4i32.p0v4i32(<4 x i32>, <4 x i32>*, i32 immarg, <4 x i1>)