diff --git a/llvm/lib/Target/ARM/MVETailPredication.cpp b/llvm/lib/Target/ARM/MVETailPredication.cpp --- a/llvm/lib/Target/ARM/MVETailPredication.cpp +++ b/llvm/lib/Target/ARM/MVETailPredication.cpp @@ -8,8 +8,17 @@ // /// \file /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead -/// branches to help accelerate DSP applications. These two extensions can be -/// combined to provide implicit vector predication within a low-overhead loop. +/// branches to help accelerate DSP applications. These two extensions, +/// combined with a new form of predication called tail-predication, can be used +/// to provide implicit vector predication within a low-overhead loop. +/// This is implicit because the predicate of active/inactive lanes is +/// calculated by hardware, and thus does not need to be explicitly passed +/// to vector instructions. The instructions responsible for this are the +/// DLSTP and WLSTP instructions, which setup a tail-predicated loop and the +/// the total number of data elements processed by the loop. The loop-end +/// LETP instruction is responsible for decrementing and setting the remaining +/// elements to be processed and generating the mask of active lanes. +/// /// The HardwareLoops pass inserts intrinsics identifying loops that the /// backend will attempt to convert into a low-overhead loop. The vectorizer is /// responsible for generating a vectorized loop in which the lanes are @@ -21,10 +30,16 @@ /// - A loop containing multiple VCPT instructions, predicating multiple VPT /// blocks of instructions operating on different vector types. /// -/// This pass inserts the inserts the VCTP intrinsic to represent the effect of -/// tail predication. This will be picked up by the ARM Low-overhead loop pass, -/// which performs the final transformation to a DLSTP or WLSTP tail-predicated -/// loop. +/// This pass: +/// 1) Pattern matches the scalar iteration count produced by the vectoriser. +/// The scalar loop iteration count represents the number of elements to be +/// processed. +/// TODO: this could be emitted using an intrinsic, similar to the hardware +/// loop intrinsics, so that we don't need to pattern match this here. +/// 2) Inserts the VCTP intrinsic to represent the effect of +/// tail predication. This will be picked up by the ARM Low-overhead loop +/// pass, which performs the final transformation to a DLSTP or WLSTP +/// tail-predicated loop. #include "ARM.h" #include "ARMSubtarget.h" @@ -58,16 +73,17 @@ // Bookkeeping for pattern matching the loop trip count and the number of // elements processed by the loop. struct TripCountPattern { - // The Predicate used by the masked loads/stores, i.e. an icmp instruction - // which calculates active/inactive lanes + // An icmp instruction that calculates a predicate of active/inactive lanes + // used by the masked loads/stores. Instruction *Predicate = nullptr; - // The add instruction that increments the IV + // The add instruction that increments the IV. Value *TripCount = nullptr; // The number of elements processed by the vector loop. Value *NumElements = nullptr; + // Other instructions in the icmp chain that calculate the predicate. VectorType *VecTy = nullptr; Instruction *Shuffle = nullptr; Instruction *Induction = nullptr; @@ -117,8 +133,9 @@ /// loop will process if it is a runtime value. bool ComputeRuntimeElements(TripCountPattern &TCP); - /// Is the icmp that generates an i1 vector, based upon a loop counter - /// and a limit that is defined outside the loop. + /// Return whether this is the icmp that generates an i1 vector, based + /// upon a loop counter and a limit that is defined outside the loop, + /// that generates the active/inactive lanes required for tail-predication. bool isTailPredicate(TripCountPattern &TCP); /// Insert the intrinsic to represent the effect of tail predication. @@ -241,6 +258,7 @@ return true; } + LLVM_DEBUG(dbgs() << "ARM TP: Can't tail-predicate this loop.\n"); return false; } @@ -563,10 +581,10 @@ if (I->hasNUsesOrMore(1)) continue; - for (auto &U : I->operands()) { + for (auto &U : I->operands()) if (auto *OpI = dyn_cast(U)) MaybeDead.insert(OpI); - } + I->dropAllReferences(); Dead.insert(I); } @@ -638,30 +656,47 @@ SetVector Predicates; DenseMap NewPredicates; +#ifndef NDEBUG + // For debugging purposes, use this to indicate we have been able to + // pattern match the scalar loop trip count. + bool FoundScalarTC = false; +#endif + for (auto *I : MaskedInsts) { Intrinsic::ID ID = I->getIntrinsicID(); + // First, find the icmp used by this masked load/store. unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3; auto *Predicate = dyn_cast(I->getArgOperand(PredOp)); if (!Predicate || Predicates.count(Predicate)) continue; + // Step 1: using this icmp, now calculate the number of elements + // processed by this loop. TripCountPattern TCP(Predicate, TripCount, getVectorType(I)); - if (!(ComputeConstElements(TCP) || ComputeRuntimeElements(TCP))) continue; + LLVM_DEBUG(FoundScalarTC = true); + if (!isTailPredicate(TCP)) { - LLVM_DEBUG(dbgs() << "ARM TP: Not tail predicate: " << *Predicate << "\n"); + LLVM_DEBUG(dbgs() << "ARM TP: Not an icmp that generates tail predicate: " + << *Predicate << "\n"); continue; } - LLVM_DEBUG(dbgs() << "ARM TP: Found tail predicate: " << *Predicate << "\n"); + LLVM_DEBUG(dbgs() << "ARM TP: Found icmp generating tail predicate: " + << *Predicate << "\n"); Predicates.insert(Predicate); + + // Step 2: emit the VCTP intrinsic representing the effect of TP. InsertVCTPIntrinsic(TCP, NewPredicates); } - if (!NewPredicates.size()) + if (!NewPredicates.size()) { + LLVM_DEBUG(if (!FoundScalarTC) + dbgs() << "ARM TP: Can't determine loop itertion count\n"); return false; + } // Now clean up. ClonedVCTPInExitBlock = Cleanup(NewPredicates, Predicates, L);