diff --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h --- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h +++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h @@ -504,25 +504,6 @@ // This table shows the VPT instruction variants, i.e. the different // mask field encodings, see also B5.6. Predication/conditional execution in // the ArmARM. - - -inline static ARM::PredBlockMask getARMVPTBlockMask(unsigned NumInsts) { - switch (NumInsts) { - case 1: - return ARM::PredBlockMask::T; - case 2: - return ARM::PredBlockMask::TT; - case 3: - return ARM::PredBlockMask::TTT; - case 4: - return ARM::PredBlockMask::TTTT; - default: - break; - }; - llvm_unreachable("Unexpected number of instruction in a VPT block"); -} - - static inline bool isVPTOpcode(int Opc) { return Opc == ARM::MVE_VPTv16i8 || Opc == ARM::MVE_VPTv16u8 || Opc == ARM::MVE_VPTv16s8 || Opc == ARM::MVE_VPTv8i16 || diff --git a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp --- a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp +++ b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp @@ -191,6 +191,7 @@ SetVector CurrentPredicate; SmallVector VPTBlocks; SmallPtrSet ToRemove; + SmallPtrSet BlockMasksToRecompute; bool Revert = false; bool CannotTailPredicate = false; @@ -1183,11 +1184,9 @@ if (Block.HasNonUniformPredicate()) { PredicatedMI *Divergent = Block.getDivergent(); if (isVCTP(Divergent->MI)) { - // The vctp will be removed, so the size of the vpt block needs to be - // modified. - uint64_t Size = (uint64_t)getARMVPTBlockMask(Block.size() - 1); - Block.getVPST()->getOperand(0).setImm(Size); - LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n"); + // The vctp will be removed, so the block mask of the VPST/VPT will need + // to be recomputed. + LoLoop.BlockMasksToRecompute.insert(Block.getVPST()); } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { // The VPT block has a non-uniform predicate but it's entry is guarded // only by a vctp, which means we: @@ -1211,13 +1210,15 @@ ++Size; ++I; } + // Create a VPST with a null mask, we'll recompute it later. MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt, InsertAt->getDebugLoc(), TII->get(ARM::MVE_VPST)); - MIB.addImm((uint64_t)getARMVPTBlockMask(Size)); + MIB.addImm(0); LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST()); LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB); LoLoop.ToRemove.insert(Block.getVPST()); + LoLoop.BlockMasksToRecompute.insert(MIB.getInstr()); } } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { // A vpt block which is only predicated upon vctp and has no internal vpr @@ -1288,6 +1289,11 @@ LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I); I->eraseFromParent(); } + for (auto *I : LoLoop.BlockMasksToRecompute) { + LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I); + recomputeVPTBlockMask(*I); + LLVM_DEBUG(dbgs() << " ... done: " << *I); + } } PostOrderLoopTraversal DFS(LoLoop.ML, *MLI); diff --git a/llvm/lib/Target/ARM/MVEVPTBlockPass.cpp b/llvm/lib/Target/ARM/MVEVPTBlockPass.cpp --- a/llvm/lib/Target/ARM/MVEVPTBlockPass.cpp +++ b/llvm/lib/Target/ARM/MVEVPTBlockPass.cpp @@ -94,37 +94,6 @@ return &*CmpMI; } -static ARM::PredBlockMask ExpandBlockMask(ARM::PredBlockMask BlockMask, - ARMVCC::VPTCodes Kind) { - using PredBlockMask = ARM::PredBlockMask; - assert(Kind != ARMVCC::None && "Cannot expand mask with 'None'"); - assert(countTrailingZeros((unsigned)BlockMask) != 0 && - "Mask is already full"); - - auto ChooseMask = [&](PredBlockMask AddedThen, PredBlockMask AddedElse) { - return (Kind == ARMVCC::Then) ? AddedThen : AddedElse; - }; - - switch (BlockMask) { - case PredBlockMask::T: - return ChooseMask(PredBlockMask::TT, PredBlockMask::TE); - case PredBlockMask::TT: - return ChooseMask(PredBlockMask::TTT, PredBlockMask::TTE); - case PredBlockMask::TE: - return ChooseMask(PredBlockMask::TET, PredBlockMask::TEE); - case PredBlockMask::TTT: - return ChooseMask(PredBlockMask::TTTT, PredBlockMask::TTTE); - case PredBlockMask::TTE: - return ChooseMask(PredBlockMask::TTET, PredBlockMask::TTEE); - case PredBlockMask::TET: - return ChooseMask(PredBlockMask::TETT, PredBlockMask::TETE); - case PredBlockMask::TEE: - return ChooseMask(PredBlockMask::TEET, PredBlockMask::TEEE); - default: - llvm_unreachable("Unknown Mask"); - } -} - // Advances Iter past a block of predicated instructions. // Returns true if it successfully skipped the whole block of predicated // instructions. Returns false when it stopped early (due to MaxSteps), or if @@ -162,6 +131,22 @@ return false; } +// Creates a T, TT, TTT or TTTT BlockMask depending on BlockSize. +static ARM::PredBlockMask GetInitialBlockMask(unsigned BlockSize) { + switch (BlockSize) { + case 1: + return ARM::PredBlockMask::T; + case 2: + return ARM::PredBlockMask::TT; + case 3: + return ARM::PredBlockMask::TTT; + case 4: + return ARM::PredBlockMask::TTTT; + default: + llvm_unreachable("Invalid BlockSize!"); + } +} + // Given an iterator (Iter) that points at an instruction with a "Then" // predicate, tries to create the largest block of continuous predicated // instructions possible, and returns the VPT Block Mask of that block. @@ -190,7 +175,7 @@ }); // Generate the initial BlockMask - ARM::PredBlockMask BlockMask = getARMVPTBlockMask(BlockSize); + ARM::PredBlockMask BlockMask = GetInitialBlockMask(BlockSize); // Remove VPNOTs while there's still room in the block, so we can make the // largest block possible. @@ -232,7 +217,7 @@ // Change the predicate and update the mask Iter->getOperand(OpIdx).setImm(CurrentPredicate); - BlockMask = ExpandBlockMask(BlockMask, CurrentPredicate); + BlockMask = expandPredBlockMask(BlockMask, CurrentPredicate); LLVM_DEBUG(dbgs() << " adding : "; Iter->dump()); } diff --git a/llvm/lib/Target/ARM/Thumb2InstrInfo.h b/llvm/lib/Target/ARM/Thumb2InstrInfo.h --- a/llvm/lib/Target/ARM/Thumb2InstrInfo.h +++ b/llvm/lib/Target/ARM/Thumb2InstrInfo.h @@ -78,6 +78,13 @@ Register PredReg; return getVPTInstrPredicate(MI, PredReg); } + +// Recomputes the Block Mask of Instr, a VPT or VPST instruction. +// This rebuilds the block mask of the instruction depending on the predicates +// of the instructions following it. This should only be used after the +// MVEVPTBlockInsertion pass has run, and should be used whenever a predicated +// instruction is added to/removed from the block. +void recomputeVPTBlockMask(MachineInstr &Instr); } // namespace llvm #endif diff --git a/llvm/lib/Target/ARM/Thumb2InstrInfo.cpp b/llvm/lib/Target/ARM/Thumb2InstrInfo.cpp --- a/llvm/lib/Target/ARM/Thumb2InstrInfo.cpp +++ b/llvm/lib/Target/ARM/Thumb2InstrInfo.cpp @@ -737,3 +737,34 @@ PredReg = MI.getOperand(PIdx+1).getReg(); return (ARMVCC::VPTCodes)MI.getOperand(PIdx).getImm(); } + +void llvm::recomputeVPTBlockMask(MachineInstr &Instr) { + assert(isVPTOpcode(Instr.getOpcode()) && "Not a VPST or VPT Instruction!"); + + MachineOperand &MaskOp = Instr.getOperand(0); + assert(MaskOp.isImm() && "Operand 0 is not the block mask of the VPT/VPST?!"); + + MachineBasicBlock::iterator Iter = ++Instr.getIterator(), + End = Instr.getParent()->end(); + + // Verify that the instruction after the VPT/VPST is predicated (it should + // be), and skip it. + ARMVCC::VPTCodes Pred = getVPTInstrPredicate(*Iter); + assert( + Pred == ARMVCC::Then && + "VPT/VPST should be followed by an instruction with a 'then' predicate!"); + ++Iter; + + // Iterate over the predicated instructions, updating the BlockMask as we go. + ARM::PredBlockMask BlockMask = ARM::PredBlockMask::T; + while (Iter != End) { + ARMVCC::VPTCodes Pred = getVPTInstrPredicate(*Iter); + if (Pred == ARMVCC::None) + break; + BlockMask = expandPredBlockMask(BlockMask, Pred); + ++Iter; + } + + // Rewrite the BlockMask. + MaskOp.setImm((int64_t)(BlockMask)); +} diff --git a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h --- a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h +++ b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h @@ -121,6 +121,12 @@ }; } // namespace ARM +// Expands a PredBlockMask by adding an E or a T at the end, depending on Kind. +// e.g ExpandPredBlockMask(T, Then) = TT, ExpandPredBlockMask(TT, Else) = TTE, +// and so on. +ARM::PredBlockMask expandPredBlockMask(ARM::PredBlockMask BlockMask, + ARMVCC::VPTCodes Kind); + inline static const char *ARMVPTPredToString(ARMVCC::VPTCodes CC) { switch (CC) { case ARMVCC::None: return "none"; diff --git a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.cpp b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.cpp --- a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.cpp +++ b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.cpp @@ -15,6 +15,37 @@ using namespace llvm; namespace llvm { +ARM::PredBlockMask expandPredBlockMask(ARM::PredBlockMask BlockMask, + ARMVCC::VPTCodes Kind) { + using PredBlockMask = ARM::PredBlockMask; + assert(Kind != ARMVCC::None && "Cannot expand a mask with None!"); + assert(countTrailingZeros((unsigned)BlockMask) != 0 && + "Mask is already full"); + + auto ChooseMask = [&](PredBlockMask AddedThen, PredBlockMask AddedElse) { + return Kind == ARMVCC::Then ? AddedThen : AddedElse; + }; + + switch (BlockMask) { + case PredBlockMask::T: + return ChooseMask(PredBlockMask::TT, PredBlockMask::TE); + case PredBlockMask::TT: + return ChooseMask(PredBlockMask::TTT, PredBlockMask::TTE); + case PredBlockMask::TE: + return ChooseMask(PredBlockMask::TET, PredBlockMask::TEE); + case PredBlockMask::TTT: + return ChooseMask(PredBlockMask::TTTT, PredBlockMask::TTTE); + case PredBlockMask::TTE: + return ChooseMask(PredBlockMask::TTET, PredBlockMask::TTEE); + case PredBlockMask::TET: + return ChooseMask(PredBlockMask::TETT, PredBlockMask::TETE); + case PredBlockMask::TEE: + return ChooseMask(PredBlockMask::TEET, PredBlockMask::TEEE); + default: + llvm_unreachable("Unknown Mask"); + } +} + namespace ARMSysReg { // lookup system register using 12-bit SYSm value.