Index: .gitignore =================================================================== --- .gitignore +++ .gitignore @@ -20,6 +20,7 @@ .sw? #OS X specific files. .DS_store +._* # Nested build directory /build Index: CMakeLists.txt =================================================================== --- CMakeLists.txt +++ CMakeLists.txt @@ -603,6 +603,11 @@ LLVM_TARGETS_TO_BUILD "${LLVM_TARGETS_TO_BUILD}") list(REMOVE_DUPLICATES LLVM_TARGETS_TO_BUILD) +if ("AArch64" IN_LIST LLVM_TARGETS_TO_BUILD) + set(LINK_AARCH64_INTO_TOOLS ON) +else() + set(LINK_AARCH64_INTO_TOOLS OFF) +endif() # By default, we target the host, but this can be overridden at CMake # invocation time. set(LLVM_DEFAULT_TARGET_TRIPLE "${LLVM_HOST_TRIPLE}" CACHE STRING Index: docs/LangRef.rst =================================================================== --- docs/LangRef.rst +++ docs/LangRef.rst @@ -4406,6 +4406,9 @@ !1 = !DISubrange(count: 5, lowerBound: 1) ; array counting from 1 !2 = !DISubrange(count: -1) ; empty array. + !3 = !DIExpression(DW_OP_constu, 42) + !4 = !DISubrange(count !3, lowerBound: 1) ; array counting from 1 to 43 + ; Scopes used in rest of example !6 = !DIFile(filename: "vla.c", directory: "/path/to/file") !7 = distinct !DICompileUnit(language: DW_LANG_C99, ... Index: include/llvm-c/Core.h =================================================================== --- include/llvm-c/Core.h +++ include/llvm-c/Core.h @@ -154,7 +154,9 @@ LLVMVectorTypeKind, /**< SIMD 'packed' format, or other vector type */ LLVMMetadataTypeKind, /**< Metadata */ LLVMX86_MMXTypeKind, /**< X86 MMX */ - LLVMTokenTypeKind /**< Tokens */ + LLVMTokenTypeKind, /**< Tokens */ + LLVMSVEVecTypeKind, /**< SVE Vector */ + LLVMSVEPredTypeKind /**< SVE Predicate */ } LLVMTypeKind; typedef enum { @@ -273,6 +275,8 @@ LLVMInlineAsmValueKind, LLVMInstructionValueKind, + LLVMVScaleValueKind, + LLVMStepVectorValueKind, } LLVMValueKind; typedef enum { @@ -1411,7 +1415,9 @@ macro(GlobalObject) \ macro(Function) \ macro(GlobalVariable) \ + macro(StepVector) \ macro(UndefValue) \ + macro(VScale) \ macro(Instruction) \ macro(BinaryOperator) \ macro(CallInst) \ Index: include/llvm-c/DebugInfo.h =================================================================== --- include/llvm-c/DebugInfo.h +++ include/llvm-c/DebugInfo.h @@ -979,6 +979,7 @@ LLVMBool LocalToUnit, LLVMMetadataRef Expr, LLVMMetadataRef Decl, + LLVMDIFlags Flags, uint32_t AlignInBits); /** * Create a new temporary \c MDNode. Suitable for use in constructing cyclic @@ -1035,6 +1036,7 @@ LLVMMetadataRef Ty, LLVMBool LocalToUnit, LLVMMetadataRef Decl, + LLVMDIFlags Flags, uint32_t AlignInBits); /** Index: include/llvm/ADT/DenseMapInfo.h =================================================================== --- include/llvm/ADT/DenseMapInfo.h +++ include/llvm/ADT/DenseMapInfo.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace llvm { @@ -206,6 +207,45 @@ } }; + +// Provide DenseMapInfo for tree member tuples. +template +struct DenseMapInfo > { + typedef std::tuple Tuple; + typedef DenseMapInfo FirstInfo; + typedef DenseMapInfo SecondInfo; + typedef DenseMapInfo ThirdInfo; + + static inline Tuple getEmptyKey() { + return std::make_tuple(FirstInfo::getEmptyKey(), + SecondInfo::getEmptyKey(), + ThirdInfo::getEmptyKey()); + } + static inline Tuple getTombstoneKey() { + return std::make_tuple(FirstInfo::getTombstoneKey(), + SecondInfo::getTombstoneKey(), + ThirdInfo::getTombstoneKey()); + } + static unsigned getHashValue(const Tuple& TupleVal) { + uint64_t key = (uint64_t)FirstInfo::getHashValue(std::get<0>(TupleVal)) << 32 + | (uint64_t)SecondInfo::getHashValue(std::get<1>(TupleVal)); + // TODO: not sure what to do about the third member, + // for the current usage it does not offer anything useful + key += ~(key << 32); + key ^= (key >> 22); + key += ~(key << 13); + key ^= (key >> 8); + key += (key << 3); + key ^= (key >> 15); + key += ~(key << 27); + key ^= (key >> 31); + return (unsigned)key; + } + static bool isEqual(const Tuple &LHS, const Tuple &RHS) { + return LHS == RHS; + } +}; + // Provide DenseMapInfo for StringRefs. template <> struct DenseMapInfo { static inline StringRef getEmptyKey() { Index: include/llvm/Analysis/AliasSetTracker.h =================================================================== --- include/llvm/Analysis/AliasSetTracker.h +++ include/llvm/Analysis/AliasSetTracker.h @@ -43,6 +43,7 @@ class StoreInst; class VAArgInst; class Value; +class IntrinsicInst; class AliasSet : public ilist_node { friend class AliasSetTracker; @@ -219,6 +220,7 @@ iterator begin() const { return iterator(PtrList); } iterator end() const { return iterator(); } bool empty() const { return PtrList == nullptr; } + unsigned getRefCount() const { return RefCount; } // Unfortunately, ilist::size() is linear, so we have to add code to keep // track of the list's exact size. @@ -367,6 +369,7 @@ void add(LoadInst *LI); void add(StoreInst *SI); void add(VAArgInst *VAAI); + void add(IntrinsicInst *I, bool IsWrite); void add(AnyMemSetInst *MSI); void add(AnyMemTransferInst *MTI); void add(Instruction *I); // Dispatch to one of the other add methods... Index: include/llvm/Analysis/ConstantFolding.h =================================================================== --- include/llvm/Analysis/ConstantFolding.h +++ include/llvm/Analysis/ConstantFolding.h @@ -77,7 +77,8 @@ /// operands. If it fails, it returns a constant expression of the specified /// operands. Constant *ConstantFoldBinaryOpOperands(unsigned Opcode, Constant *LHS, - Constant *RHS, const DataLayout &DL); + Constant *RHS, const DataLayout &DL, + bool HasNUW = false, bool HasNSW=false); /// Attempt to constant fold a select instruction with the specified /// operands. The constant result is returned if successful; if not, null is Index: include/llvm/Analysis/InstructionSimplify.h =================================================================== --- include/llvm/Analysis/InstructionSimplify.h +++ include/llvm/Analysis/InstructionSimplify.h @@ -178,7 +178,7 @@ const SimplifyQuery &Q); /// Given operands for a ShuffleVectorInst, fold the result or return null. -Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, +Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Value *Mask, Type *RetTy, const SimplifyQuery &Q); //=== Helper functions for higher up the class hierarchy. Index: include/llvm/Analysis/Loads.h =================================================================== --- include/llvm/Analysis/Loads.h +++ include/llvm/Analysis/Loads.h @@ -98,13 +98,20 @@ unsigned *NumScanedInst = nullptr); /// Scan backwards to see if we have the value of the given pointer available -/// locally within a small number of instructions. +/// locally within a small number of instructions. This method comes in a +/// 'masked' and an 'unmasked' form, where the latter is the simplified form +/// and just calls the 'masked' form with an 'all true' mask, and an 'undef' +/// passthru value. /// /// You can use this function to scan across multiple blocks: after you call /// this function, if ScanFrom points at the beginning of the block, it's safe /// to continue scanning the predecessors. /// /// \param Ptr The pointer we want the load and store to originate from. +/// \param Mask The mask used for the load or store, or nullptr if the operation +/// is unmasked. +/// \param Passthru The passthru value specifying the expected 'false' lanes in +/// the vector. /// \param AccessTy The access type of the pointer. /// \param AtLeastAtomic Are we looking for at-least an atomic load/store ? In /// case it is false, we can return an atomic or non-atomic load or store. In @@ -120,11 +127,19 @@ /// location in memory, as opposed to the value operand of a store. /// /// \returns The found value, or nullptr if no value is found. +Value *FindAvailablePtrMaskedLoadStore(Value *Ptr, Value *Mask, Value *Passthru, + Type *AccessTy, bool AtLeastAtomic, + BasicBlock *ScanBB, + BasicBlock::iterator &ScanFrom, + unsigned MaxInstsToScan, + AliasAnalysis *AA, bool *IsLoad, + unsigned *NumScanedInst); + Value *FindAvailablePtrLoadStore(Value *Ptr, Type *AccessTy, bool AtLeastAtomic, BasicBlock *ScanBB, BasicBlock::iterator &ScanFrom, unsigned MaxInstsToScan, AliasAnalysis *AA, bool *IsLoad, unsigned *NumScanedInst); -} +} // namespace llvm #endif Index: include/llvm/Analysis/LoopAccessAnalysis.h =================================================================== --- include/llvm/Analysis/LoopAccessAnalysis.h +++ include/llvm/Analysis/LoopAccessAnalysis.h @@ -24,6 +24,7 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/Pass.h" #include "llvm/Support/raw_ostream.h" @@ -146,7 +147,8 @@ Instruction *getDestination(const LoopAccessInfo &LAI) const; /// Dependence types that don't prevent vectorization. - static bool isSafeForVectorization(DepType Type); + static bool isSafeForVectorization(DepType Type, + const TargetTransformInfo *TTI); /// Lexically forward dependence. bool isForward() const; @@ -162,9 +164,11 @@ const SmallVectorImpl &Instrs) const; }; - MemoryDepChecker(PredicatedScalarEvolution &PSE, const Loop *L) - : PSE(PSE), InnermostLoop(L), AccessIdx(0), MaxSafeRegisterWidth(-1U), - ShouldRetryWithRuntimeCheck(false), SafeForVectorization(true), + MemoryDepChecker(PredicatedScalarEvolution &PSE, const Loop *L, + const TargetTransformInfo *TTI) + : PSE(PSE), InnermostLoop(L), TTI(TTI), AccessIdx(0), + MaxSafeRegisterWidth(-1U), ShouldRetryWithRuntimeCheck(false), + RuntimeChecksFeasible(true), SafeForVectorization(true), RecordDependences(true) {} /// Register the location (instructions are given increasing numbers) @@ -185,7 +189,17 @@ ++AccessIdx; } - /// Check whether the dependencies between the accesses are safe. + /// Register the location (instructions are given increasing numbers) + /// of a write access. + void addAccess(MemSetInst *MSI) { + Value *Ptr = MSI->getRawDest(); + Accesses[MemAccessInfo(Ptr, false)].push_back(AccessIdx); + InstMap.push_back(MSI); + ++AccessIdx; + } + + /// Check whether the dependencies between the accesses are safe and whether + /// runtime checks are feasible. /// /// Only checks sets with elements in \p CheckDeps. bool areDepsSafe(DepCandidates &AccessSets, MemAccessInfoList &CheckDeps, @@ -199,13 +213,19 @@ /// the accesses safely with. uint64_t getMaxSafeDepDistBytes() { return MaxSafeDepDistBytes; } + /// The maximum number of bytes we can vectorize without introducing + /// issues for hardware store->load forwarding. + unsigned getMaxDepDistBytesWithSLF() { return MaxDepDistWithSLF; } + /// Return the number of elements that are safe to operate on /// simultaneously, multiplied by the size of the element in bits. uint64_t getMaxSafeRegisterWidth() const { return MaxSafeRegisterWidth; } /// In same cases when the dependency check fails we can still /// vectorize the loop with a dynamic array access check. - bool shouldRetryWithRuntimeCheck() { return ShouldRetryWithRuntimeCheck; } + bool shouldRetryWithRuntimeCheck() { + return RuntimeChecksFeasible && ShouldRetryWithRuntimeCheck; + } /// Returns the memory dependences. If null is returned we exceeded /// the MaxDependences threshold and this information is not @@ -237,6 +257,10 @@ SmallVector getInstructionsForAccess(Value *Ptr, bool isWrite) const; + const SmallVector &getUnsafeDependences() const { + return UnsafeDependences; + } + private: /// A wrapper around ScalarEvolution, used to add runtime SCEV checks, and /// applies dynamic knowledge to simplify SCEV expressions and convert them @@ -246,6 +270,7 @@ /// that a memory access is strided and doesn't wrap. PredicatedScalarEvolution &PSE; const Loop *InnermostLoop; + const TargetTransformInfo *TTI; /// Maps access locations (ptr, read/write) to program order. DenseMap > Accesses; @@ -259,6 +284,10 @@ // We can access this many bytes in parallel safely. uint64_t MaxSafeDepDistBytes; + // Maximum safe VF which should not introduce problems for h/w store->load + // forwarding. + unsigned MaxDepDistWithSLF; + /// Number of elements (from consecutive iterations) that are safe to /// operate on simultaneously, multiplied by the size of the element in bits. /// The size of the element is taken from the memory access that is most @@ -269,6 +298,10 @@ /// vectorize this loop with runtime checks. bool ShouldRetryWithRuntimeCheck; + /// If we see a non-unknown unsafe dependence, there is no point in generating + /// runtime checks. + bool RuntimeChecksFeasible; + /// No memory dependence was encountered that would inhibit /// vectorization. bool SafeForVectorization; @@ -282,6 +315,11 @@ /// RecordDependences is true. SmallVector Dependences; + /// Unsafe memory dependences collected during the analysis + // + // Used by for OptRemark generation + SmallVector UnsafeDependences; + /// Check whether there is a plausible dependence between the two /// accesses. /// @@ -337,7 +375,8 @@ AliasSetId(AliasSetId), Expr(Expr) {} }; - RuntimePointerChecking(ScalarEvolution *SE) : Need(false), SE(SE) {} + RuntimePointerChecking(ScalarEvolution *SE) : Need(false), Strided(false), + SE(SE) {} /// Reset the state of the pointer runtime information. void reset() { @@ -426,6 +465,9 @@ /// This flag indicates if we need to add the runtime check. bool Need; + /// This flag indicates if the pointer accesses are strided. + bool Strided; + /// Information about the pointers that may require checking. SmallVector Pointers; @@ -492,13 +534,27 @@ /// PSE must be emitted in order for the results of this analysis to be valid. class LoopAccessInfo { public: + /// Reasons why memory accesses cannot be vectorized (used for OptRemarks) + enum class FailureReason { + UnsafeDataDependence, + UnsafeDataDependenceTriedRT, + UnknownArrayBounds, + Unknown + }; +public: LoopAccessInfo(Loop *L, ScalarEvolution *SE, const TargetLibraryInfo *TLI, - AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI); + const TargetTransformInfo *TTI, AliasAnalysis *AA, + DominatorTree *DT, LoopInfo *LI); /// Return true we can analyze the memory accesses in the loop and there are /// no memory dependence cycles. bool canVectorizeMemory() const { return CanVecMem; } + /// Return reason describing why memory access cannot be vectorized. + // + // Used for the OptRemark generation. + FailureReason getFailureReason() const { return FailReason; } + const RuntimePointerChecking *getRuntimePointerChecking() const { return PtrRtChecking.get(); } @@ -517,7 +573,7 @@ /// Returns true if the value V is uniform within the loop. bool isUniform(Value *V) const; - uint64_t getMaxSafeDepDistBytes() const { return MaxSafeDepDistBytes; } + uint64_t getMaxSafeDepDistBytes() const; unsigned getNumStores() const { return NumStores; } unsigned getNumLoads() const { return NumLoads;} @@ -568,7 +624,15 @@ /// If the loop has any store to invariant address, then it returns true, /// else returns false. bool hasStoreToLoopInvariantAddress() const { - return StoreToLoopInvariantAddress; + return !InvariantStores.empty(); + } + + const SmallVectorImpl &getInvariantStores() const { + return InvariantStores; + } + + const SmallPtrSet &getUncomputablePtrs() const { + return UncomputablePtrs; } /// Used to add runtime SCEV checks. Simplifies SCEV expressions and converts @@ -612,18 +676,21 @@ std::unique_ptr DepChecker; Loop *TheLoop; + const TargetTransformInfo *TTI; unsigned NumLoads; unsigned NumStores; uint64_t MaxSafeDepDistBytes; + uint64_t MaxDepDistBytesWithSLF; /// Cache the result of analyzeLoop. bool CanVecMem; - /// Indicator for storing to uniform addresses. - /// If a loop has write to a loop invariant address then it should be true. - bool StoreToLoopInvariantAddress; + /// \brief List of stores to uniform addresses. + SmallVector InvariantStores; + /// \brief List of stores to uniform addresses. + SmallVector InvariantMemSets; /// The diagnostics report generated for the analysis. E.g. why we /// couldn't analyze the loop. @@ -635,6 +702,17 @@ /// Set of symbolic strides values. SmallPtrSet StrideSet; + + /// Allows analysis of uncounted loops (trip count undefined) + bool AllowUncountedLoops; + + /// Reason why memory accesses cannot be vectorized (used for OptRemarks) + FailureReason FailReason; + + /// Set of uncomputable pointers. + // + // Used when emitting OptRemarks + SmallPtrSet UncomputablePtrs; }; Value *stripIntegerCast(Value *V); @@ -725,6 +803,7 @@ // The used analysis passes. ScalarEvolution *SE; const TargetLibraryInfo *TLI; + const TargetTransformInfo *TTI; AliasAnalysis *AA; DominatorTree *DT; LoopInfo *LI; Index: include/llvm/Analysis/LoopInfo.h =================================================================== --- include/llvm/Analysis/LoopInfo.h +++ include/llvm/Analysis/LoopInfo.h @@ -577,7 +577,8 @@ /// it returns an unknown location. DebugLoc getStartLoc() const; - /// Return the source code span of the loop. + void getEarlyExitLocations(std::vector &Locs) const; + LocRange getLocRange() const; StringRef getName() const { @@ -594,6 +595,8 @@ friend class LoopBase; explicit Loop(BasicBlock *BB) : LoopBase(BB) {} ~Loop() = default; + + void getAttachedDebugLocations(std::vector &Locs) const; }; //===----------------------------------------------------------------------===// Index: include/llvm/Analysis/LoopVectorizationAnalysis.h =================================================================== --- /dev/null +++ include/llvm/Analysis/LoopVectorizationAnalysis.h @@ -0,0 +1,1164 @@ +//===- llvm/Analysis/LoopVectorizationAnalysis.h ----------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the interface to determine whether or not a loop can +// be safely vectorized +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_LOOPVECTORIZATIONANALYSIS_H +#define LLVM_ANALYSIS_LOOPVECTORIZATIONANALYSIS_H + +// TODO: Can we remove some of these and move more functions to the .cpp file? +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/LVCommon.h" +#include "llvm/Transforms/Utils/LoopUtils.h" + +namespace llvm { + +class Value; +class DataLayout; +class ScalarEvolution; +class Loop; +class SCEV; + +/// Information about vectorization costs +struct VectorizationFactor { + unsigned Width; // Vector width with best cost + unsigned Cost; // Cost of the loop with that width + bool isFixed; // Is the width an absolute value or a scale. +}; + +/// Type of exit +enum LoopExitKind { + EK_Unknown, ///< Basic case for lookups of unmapped exits + EK_None, ///< Block does not exit loop. + EK_Counted, ///< Loop exit with defined trip count (comparison + /// between induction and a loop-invariant value) + EK_LoadCompare ///< Loop exit due to a comparison between a + /// loaded value and a loop-invariant value +}; + +// Forward declarations +class SLVLoopVectorizationLegality; +class SLVLoopVectorizationCostModel; +class LoopVectorizationRequirements; +class SLVLoopVectorizeHints; + + +class AssumptionCache; +class DemandedBits; +class InductionDescriptor; +class RecurrenceDescriptor; +class TargetLibraryInfo; +class TargetTransformInfo; + + // TODO: Move. +/// Maximum vectorization interleave count. +static const unsigned MaxInterleaveFactor = 16; + + +/// Helpers. TODO: Move? non-staticize? +/// A helper function for converting Scalar types to vector types. +/// If the incoming type is void, we return void. If the VF is 1, we return +/// the scalar type. +inline Type* ToVectorTy(Type *Scalar, unsigned VF) { + if (Scalar->isVoidTy() || VF == 1) + return Scalar; + return VectorType::get(Scalar, VF); +} +inline Type* ToVectorTy(Type *Scalar, VectorizationFactor VF) { + if (Scalar->isVoidTy() || VF.Width == 1) + return Scalar; + return VectorType::get(Scalar, VF.Width, !VF.isFixed); +} + +inline Type *smallestIntegerVectorType(Type *T1, Type *T2) { + IntegerType *I1 = cast(T1->getVectorElementType()); + IntegerType *I2 = cast(T2->getVectorElementType()); + return I1->getBitWidth() < I2->getBitWidth() ? T1 : T2; +} +inline Type *largestIntegerVectorType(Type *T1, Type *T2) { + IntegerType *I1 = cast(T1->getVectorElementType()); + IntegerType *I2 = cast(T2->getVectorElementType()); + return I1->getBitWidth() > I2->getBitWidth() ? T1 : T2; +} + +/// A helper function that returns GEP instruction and knows to skip a +/// 'bitcast'. The 'bitcast' may be skipped if the source and the destination +/// pointee types of the 'bitcast' have the same size. +/// For example: +/// bitcast double** %var to i64* - can be skipped +/// bitcast double** %var to i8* - can not +static GetElementPtrInst *getGEPInstruction(Value *Ptr) { + + if (isa(Ptr)) + return cast(Ptr); + + if (isa(Ptr) && + isa(cast(Ptr)->getOperand(0))) { + Type *BitcastTy = Ptr->getType(); + Type *GEPTy = cast(Ptr)->getSrcTy(); + if (!isa(BitcastTy) || !isa(GEPTy)) + return nullptr; + Type *Pointee1Ty = cast(BitcastTy)->getPointerElementType(); + Type *Pointee2Ty = cast(GEPTy)->getPointerElementType(); + const DataLayout &DL = cast(Ptr)->getModule()->getDataLayout(); + if (DL.getTypeSizeInBits(Pointee1Ty) == DL.getTypeSizeInBits(Pointee2Ty)) + return cast(cast(Ptr)->getOperand(0)); + } + return nullptr; +} + + +/// SLVLoopVectorizationCostModel - estimates the expected speedups due to +/// vectorization. +/// In many cases vectorization is not profitable. This can happen because of +/// a number of reasons. In this class we mainly attempt to predict the +/// expected speedup/slowdowns due to the supported instruction set. We use the +/// TargetTransformInfo to query the different backends for the cost of +/// different operations. +class SLVLoopVectorizationCostModel { +public: + SLVLoopVectorizationCostModel(Loop *L, PredicatedScalarEvolution &PSE, + LoopInfo *LI, SLVLoopVectorizationLegality *Legal, + const TargetTransformInfo &TTI, + const TargetLibraryInfo *TLI, DemandedBits *DB, + AssumptionCache *AC, const Function *F, + const SLVLoopVectorizeHints *Hints) + : TheLoop(L), PSE(PSE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB), + AC(AC), TheFunction(F), Hints(Hints) {} + + /// \return The most profitable vectorization factor and the cost of that VF. + /// This method checks every power of two up to VF. If UserVF is not ZERO + /// then this vectorization factor will be selected if vectorization is + /// possible. + VectorizationFactor selectVectorizationFactor(bool OptForSize); + + /// \return The size (in bits) of the smallest and widest types in the code + /// that needs to be vectorized. We ignore values that remain scalar such as + /// 64 bit loop indices. + std::pair getSmallestAndWidestTypes(); + + /// \return The desired interleave count. + /// If interleave count has been specified by metadata it will be returned. + /// Otherwise, the interleave count is computed and returned. VF and LoopCost + /// are the selected vectorization factor and the cost of the selected VF. + unsigned selectInterleaveCount(bool OptForSize, VectorizationFactor VF, + unsigned LoopCost); + + /// \return The most profitable unroll factor. + /// This method finds the best unroll-factor based on register pressure and + /// other parameters. VF and LoopCost are the selected vectorization factor + /// and the cost of the selected VF. + unsigned computeInterleaveCount(bool OptForSize, VectorizationFactor VF, + unsigned LoopCost); + + /// \brief A struct that represents some properties of the register usage + /// of a loop. + struct RegisterUsage { + /// Holds the number of loop invariant values that are used in the loop. + unsigned LoopInvariantRegs; + /// Holds the maximum number of concurrent live intervals in the loop. + unsigned MaxLocalUsers; + /// Holds the number of instructions in the loop. + unsigned NumInstructions; + }; + + /// \return Returns information about the register usages of the loop for the + /// given vectorization factors. + SmallVector + calculateRegisterUsage(const SmallVector &VFs); + + /// Estimate the overhead of scalarizing a value. Insert and Extract are set + /// if the result needs to be inserted and/or extracted from vectors. + unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract, + const TargetTransformInfo &TTI); + + // Estimate cost of a call instruction CI if it were vectorized with factor VF. + // Return the cost of the instruction, including scalarization overhead if it's + // needed. The flag NeedToScalarize shows if the call needs to be scalarized - + // i.e. either vector version isn't available, or is too expensive. + unsigned getVectorCallCost(CallInst *CI, unsigned VF, + const TargetTransformInfo &TTI, + const TargetLibraryInfo *TLI, + bool &NeedToScalarize); + + // Estimate cost of an intrinsic call instruction CI if it were vectorized with + // factor VF. Return the cost of the instruction, including scalarization + // overhead if it's needed. + unsigned getVectorIntrinsicCost(CallInst *CI, unsigned VF, + const TargetTransformInfo &TTI, + const TargetLibraryInfo *TLI); + + /// Collect values we want to ignore in the cost model. + void collectValuesToIgnore(); + +private: + /// Returns the expected execution cost. The unit of the cost does + /// not matter because we use the 'cost' units to compare different + /// vector widths. The cost that is returned is *not* normalized by + /// the factor width. + unsigned expectedCost(VectorizationFactor VF); + + /// Returns the execution time cost of an instruction for a given vector + /// width. Vector width of one means scalar. + unsigned getInstructionCost(Instruction *I, VectorizationFactor VF); + + /// Returns whether the instruction is a load or store and will be a emitted + /// as a vector operation. + bool isConsecutiveLoadOrStore(Instruction *I); + +public: + /// Map of scalar integer values to the smallest bitwidth they can be legally + /// represented as. The vector equivalents of these values should be truncated + /// to this type. + MapVector MinBWs; + + /// The loop that we evaluate. + Loop *TheLoop; + /// Scev analysis. + PredicatedScalarEvolution &PSE; + /// Loop Info analysis. + LoopInfo *LI; + /// Vectorization legality. + SLVLoopVectorizationLegality *Legal; + /// Vector target information. + const TargetTransformInfo &TTI; + /// Target Library Info. + const TargetLibraryInfo *TLI; + /// Demanded bits analysis. + DemandedBits *DB; + /// Assumption cache. + AssumptionCache *AC; + const Function *TheFunction; + /// Loop Vectorize Hint. + const SLVLoopVectorizeHints *Hints; + /// Values to ignore in the cost model. + SmallPtrSet ValuesToIgnore; + /// Values to ignore in the cost model when VF > 1. + SmallPtrSet VecValuesToIgnore; +}; + +/// Utility class for getting and setting loop vectorizer hints in the form +/// of loop metadata. +/// This class keeps a number of loop annotations locally (as member variables) +/// and can, upon request, write them back as metadata on the loop. It will +/// initially scan the loop for existing metadata, and will update the local +/// values based on information in the loop. +/// We cannot write all values to metadata, as the mere presence of some info, +/// for example 'force', means a decision has been made. So, we need to be +/// careful NOT to add them if the user hasn't specifically asked so. +class SLVLoopVectorizeHints { + enum HintKind { + HK_WIDTH, + HK_UNROLL, + HK_FORCE + }; + + /// Hint - associates name and validation with the hint value. + struct Hint { + const char * Name; + unsigned Value; // This may have to change for non-numeric values. + HintKind Kind; + + Hint(const char * Name, unsigned Value, HintKind Kind) + : Name(Name), Value(Value), Kind(Kind) { } + + bool validate(unsigned Val) { + switch (Kind) { + case HK_WIDTH: + return isPowerOf2_32(Val) && Val <= VectorizerParams::MaxVectorWidth; + case HK_UNROLL: + return isPowerOf2_32(Val) && Val <= MaxInterleaveFactor; + case HK_FORCE: + return (Val <= 1); + } + return false; + } + }; + + /// Vectorization width. + Hint Width; + /// Vectorization interleave factor. + Hint Interleave; + /// Vectorization forced + Hint Force; + + /// Return the loop metadata prefix. + static StringRef Prefix() { return "llvm.loop."; } + +public: + enum ForceKind { + FK_Undefined = -1, ///< Not selected. + FK_Disabled = 0, ///< Forcing disabled. + FK_Enabled = 1, ///< Forcing enabled. + }; + + SLVLoopVectorizeHints(const Loop *L, bool DisableInterleaving, + OptimizationRemarkEmitter &ORE); + + /// Mark the loop L as already vectorized by setting the width to 1. + void setAlreadyVectorized() { + Width.Value = Interleave.Value = 1; + Hint Hints[] = {Width, Interleave}; + writeHintsToMetadata(Hints); + } + + bool allowVectorization(Function *F, Loop *L, bool AlwaysVectorize) const; + void emitRemarkWithHints() const; + + unsigned getWidth() const { return Width.Value; } + unsigned getInterleave() const { return Interleave.Value; } + enum ForceKind getForce() const { return (ForceKind)Force.Value; } + const char *vectorizeAnalysisPassName() const; + + bool allowReordering() const { + // When enabling loop hints are provided we allow the vectorizer to change + // the order of operations that is given by the scalar loop. This is not + // enabled by default because can be unsafe or inefficient. For example, + // reordering floating-point operations will change the way round-off + // error accumulates in the loop. + return getForce() == SLVLoopVectorizeHints::FK_Enabled || getWidth() > 1; + } + +private: + /// Find hints specified in the loop metadata and update local values. + void getHintsFromMetadata(); + + /// Checks string hint with one operand and set value if valid. + void setHint(StringRef Name, Metadata *Arg); + + /// Create a new hint from name / value pair. + MDNode *createHintMetadata(StringRef Name, unsigned V) const; + + /// Matches metadata with hint name. + bool matchesHintMetadataName(MDNode *Node, ArrayRef HintTypes); + + /// Sets current hints into loop metadata, keeping other values intact. + void writeHintsToMetadata(ArrayRef HintTypes); + + /// The loop these hints belong to. + const Loop *TheLoop; + + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter &ORE; +}; + +/// \brief This holds vectorization requirements that must be verified late in +/// the process. The requirements are set by legalize and costmodel. Once +/// vectorization has been determined to be possible and profitable the +/// requirements can be verified by looking for metadata or compiler options. +/// For example, some loops require FP commutativity which is only allowed if +/// vectorization is explicitly specified or if the fast-math compiler option +/// has been provided. +/// Late evaluation of these requirements allows helpful diagnostics to be +/// composed that tells the user what need to be done to vectorize the loop. For +/// example, by specifying #pragma clang loop vectorize or -ffast-math. Late +/// evaluation should be used only when diagnostics can generated that can be +/// followed by a non-expert user. +class LoopVectorizationRequirements { +public: + LoopVectorizationRequirements() + : NumRuntimePointerChecks(0), UnsafeAlgebraInst(nullptr) {} + + void addUnsafeAlgebraInst(Instruction *I) { + // First unsafe algebra instruction. + if (!UnsafeAlgebraInst) + UnsafeAlgebraInst = I; + } + + void addRuntimePointerChecks(unsigned Num) { NumRuntimePointerChecks = Num; } + + bool doesNotMeet(Function *F, Loop *L, const SLVLoopVectorizeHints &Hints); + +private: + unsigned NumRuntimePointerChecks; + Instruction *UnsafeAlgebraInst; +}; + +/// \brief The group of interleaved loads/stores sharing the same stride and +/// close to each other. +/// +/// Each member in this group has an index starting from 0, and the largest +/// index should be less than interleaved factor, which is equal to the absolute +/// value of the access's stride. +/// +/// E.g. An interleaved load group of factor 4: +/// for (unsigned i = 0; i < 1024; i+=4) { +/// a = A[i]; // Member of index 0 +/// b = A[i+1]; // Member of index 1 +/// d = A[i+3]; // Member of index 3 +/// ... +/// } +/// +/// An interleaved store group of factor 4: +/// for (unsigned i = 0; i < 1024; i+=4) { +/// ... +/// A[i] = a; // Member of index 0 +/// A[i+1] = b; // Member of index 1 +/// A[i+2] = c; // Member of index 2 +/// A[i+3] = d; // Member of index 3 +/// } +/// +/// Note: the interleaved load group could have gaps (missing members), but +/// the interleaved store group doesn't allow gaps. +class InterleaveGroup { +public: + InterleaveGroup(Instruction *Instr, int Stride, unsigned Align) + : Align(Align), SmallestKey(0), LargestKey(0), InsertPos(Instr) { + assert(Align && "The alignment should be non-zero"); + + Factor = std::abs(Stride); + assert(Factor > 1 && "Invalid interleave factor"); + + Reverse = Stride < 0; + Members[0] = Instr; + } + + bool isReverse() const { return Reverse; } + unsigned getFactor() const { return Factor; } + unsigned getAlignment() const { return Align; } + unsigned getNumMembers() const { return Members.size(); } + + /// \brief Try to insert a new member \p Instr with index \p Index and + /// alignment \p NewAlign. The index is related to the leader and it could be + /// negative if it is the new leader. + /// + /// \returns false if the instruction doesn't belong to the group. + bool insertMember(Instruction *Instr, int Index, unsigned NewAlign) { + assert(NewAlign && "The new member's alignment should be non-zero"); + + int Key = Index + SmallestKey; + + // Skip if there is already a member with the same index. + if (Members.count(Key)) + return false; + + if (Key > LargestKey) { + // The largest index is always less than the interleave factor. + if (Index >= static_cast(Factor)) + return false; + + LargestKey = Key; + } else if (Key < SmallestKey) { + // The largest index is always less than the interleave factor. + if (LargestKey - Key >= static_cast(Factor)) + return false; + + SmallestKey = Key; + } + + // It's always safe to select the minimum alignment. + Align = std::min(Align, NewAlign); + Members[Key] = Instr; + return true; + } + + /// \brief Get the member with the given index \p Index + /// + /// \returns nullptr if contains no such member. + Instruction *getMember(unsigned Index) const { + int Key = SmallestKey + Index; + if (!Members.count(Key)) + return nullptr; + + return Members.find(Key)->second; + } + + /// \brief Get the index for the given member. Unlike the key in the member + /// map, the index starts from 0. + unsigned getIndex(Instruction *Instr) const { + for (auto I : Members) + if (I.second == Instr) + return I.first - SmallestKey; + + llvm_unreachable("InterleaveGroup contains no such member"); + } + + Instruction *getInsertPos() const { return InsertPos; } + void setInsertPos(Instruction *Inst) { InsertPos = Inst; } + +private: + unsigned Factor; // Interleave Factor. + bool Reverse; + unsigned Align; + DenseMap Members; + int SmallestKey; + int LargestKey; + + // To avoid breaking dependences, vectorized instructions of an interleave + // group should be inserted at either the first load or the last store in + // program order. + // + // E.g. %even = load i32 // Insert Position + // %add = add i32 %even // Use of %even + // %odd = load i32 + // + // store i32 %even + // %odd = add i32 // Def of %odd + // store i32 %odd // Insert Position + Instruction *InsertPos; +}; + +/// \brief Drive the analysis of interleaved memory accesses in the loop. +/// +/// Use this class to analyze interleaved accesses only when we can vectorize +/// a loop. Otherwise it's meaningless to do analysis as the vectorization +/// on interleaved accesses is unsafe. +/// +/// The analysis collects interleave groups and records the relationships +/// between the member and the group in a map. +class InterleavedAccessInfo { +public: + InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, + DominatorTree *DT, LoopInfo *LI) + : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(nullptr), + RequiresScalarEpilogue(false) {} + + ~InterleavedAccessInfo() { + SmallSet DelSet; + // Avoid releasing a pointer twice. + for (auto &I : InterleaveGroupMap) + DelSet.insert(I.second); + for (auto *Ptr : DelSet) + delete Ptr; + } + + /// \brief Analyze the interleaved accesses and collect them in interleave + /// groups. Substitute symbolic strides using \p Strides. + void analyzeInterleaving(const ValueToValueMap &Strides); + + /// \brief Check if \p Instr belongs to any interleave group. + bool isInterleaved(Instruction *Instr) const { + return InterleaveGroupMap.count(Instr); + } + + /// \brief Return the maximum interleave factor of all interleaved groups. + unsigned getMaxInterleaveFactor() const { + unsigned MaxFactor = 1; + for (auto &Entry : InterleaveGroupMap) + MaxFactor = std::max(MaxFactor, Entry.second->getFactor()); + return MaxFactor; + } + + /// \brief Get the interleave group that \p Instr belongs to. + /// + /// \returns nullptr if doesn't have such group. + InterleaveGroup *getInterleaveGroup(Instruction *Instr) const { + if (InterleaveGroupMap.count(Instr)) + return InterleaveGroupMap.find(Instr)->second; + return nullptr; + } + + /// \brief Returns true if an interleaved group that may access memory + /// out-of-bounds requires a scalar epilogue iteration for correctness. + bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } + + /// \brief Initialize the LoopAccessInfo used for dependence checking. + void setLAI(const LoopAccessInfo *Info) { LAI = Info; } + +private: + /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. + /// Simplifies SCEV expressions in the context of existing SCEV assumptions. + /// The interleaved access analysis can also add new predicates (for example + /// by versioning strides of pointers). + PredicatedScalarEvolution &PSE; + Loop *TheLoop; + DominatorTree *DT; + LoopInfo *LI; + const LoopAccessInfo *LAI; + + /// True if the loop may contain non-reversed interleaved groups with + /// out-of-bounds accesses. We ensure we don't speculatively access memory + /// out-of-bounds by executing at least one scalar epilogue iteration. + bool RequiresScalarEpilogue; + + /// Holds the relationships between the members and the interleave group. + DenseMap InterleaveGroupMap; + + /// Holds dependences among the memory accesses in the loop. It maps a source + /// access to a set of dependent sink accesses. + DenseMap> Dependences; + + /// \brief The descriptor for a strided memory access. + struct StrideDescriptor { + StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, + unsigned Align) + : Stride(Stride), Scev(Scev), Size(Size), Align(Align) {} + + StrideDescriptor() = default; + + // The access's stride. It is negative for a reverse access. + int64_t Stride = 0; + const SCEV *Scev = nullptr; // The scalar expression of this access + uint64_t Size = 0; // The size of the memory object. + unsigned Align = 0; // The alignment of this access. + }; + + /// \brief A type for holding instructions and their stride descriptors. + typedef std::pair StrideEntry; + + /// \brief Create a new interleave group with the given instruction \p Instr, + /// stride \p Stride and alignment \p Align. + /// + /// \returns the newly created interleave group. + InterleaveGroup *createInterleaveGroup(Instruction *Instr, int Stride, + unsigned Align) { + assert(!InterleaveGroupMap.count(Instr) && + "Already in an interleaved access group"); + InterleaveGroupMap[Instr] = new InterleaveGroup(Instr, Stride, Align); + return InterleaveGroupMap[Instr]; + } + + /// \brief Release the group and remove all the relationships. + void releaseGroup(InterleaveGroup *Group) { + for (unsigned i = 0; i < Group->getFactor(); i++) + if (Instruction *Member = Group->getMember(i)) + InterleaveGroupMap.erase(Member); + + delete Group; + } + + /// \brief Collect all the accesses with a constant stride in program order. + void collectConstStrideAccesses( + MapVector &AccessStrideInfo, + const ValueToValueMap &Strides); + + /// \brief Returns true if \p Stride is allowed in an interleaved group. + static bool isStrided(int Stride) { + unsigned Factor = std::abs(Stride); + return Factor >= 2 && Factor <= MaxInterleaveGroupFactor; + } + + /// \brief Returns true if \p BB is a predicated block. + bool isPredicated(BasicBlock *BB) const { + return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); + } + + /// \brief Returns true if LoopAccessInfo can be used for dependence queries. + bool areDependencesValid() const { + return LAI && LAI->getDepChecker().getDependences(); + } + + /// \brief Returns true if memory accesses \p A and \p B can be reordered, if + /// necessary, when constructing interleaved groups. + /// + /// \p A must precede \p B in program order. We return false if reordering is + /// not necessary or is prevented because \p A and \p B may be dependent. + bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, + StrideEntry *B) const { + + // Code motion for interleaved accesses can potentially hoist strided loads + // and sink strided stores. The code below checks the legality of the + // following two conditions: + // + // 1. Potentially moving a strided load (B) before any store (A) that + // precedes B, or + // + // 2. Potentially moving a strided store (A) after any load or store (B) + // that A precedes. + // + // It's legal to reorder A and B if we know there isn't a dependence from A + // to B. Note that this determination is conservative since some + // dependences could potentially be reordered safely. + + // A is potentially the source of a dependence. + auto *Src = A->first; + auto SrcDes = A->second; + + // B is potentially the sink of a dependence. + auto *Sink = B->first; + auto SinkDes = B->second; + + // Code motion for interleaved accesses can't violate WAR dependences. + // Thus, reordering is legal if the source isn't a write. + if (!Src->mayWriteToMemory()) + return true; + + // At least one of the accesses must be strided. + if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride)) + return true; + + // If dependence information is not available from LoopAccessInfo, + // conservatively assume the instructions can't be reordered. + if (!areDependencesValid()) + return false; + + // If we know there is a dependence from source to sink, assume the + // instructions can't be reordered. Otherwise, reordering is legal. + return !Dependences.count(Src) || !Dependences.lookup(Src).count(Sink); + } + + /// \brief Collect the dependences from LoopAccessInfo. + /// + /// We process the dependences once during the interleaved access analysis to + /// enable constant-time dependence queries. + void collectDependences() { + if (!areDependencesValid()) + return; + auto *Deps = LAI->getDepChecker().getDependences(); + for (auto Dep : *Deps) + Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI)); + } +}; + +/// A class to describe an 'Escapee', which has either a single Escapee Merge +/// node, or alternatively a single Store Node. An Escapee can only be +/// constructed by the EscapeeFactory class, which collects the set of +/// Escapees in the loop. +class Escapee { +protected: + Escapee() {} + + template + Escapee(Instruction *M, ItTy S, ItTy E, bool IsReduction = false) : + Merge(M), IsReduction(IsReduction) { + assert((isa(M) || isa(M)) && + "MergeNode must be a PHI or Store"); + for (ItTy I = S; I != E; ++I) + Values.push_back(*I); + } + +public: + StoreInst *getStore() const { + return Store; + } + + void setStore(StoreInst *S) { + Store = S; + } + + Instruction* getMergeNode() const { + return Merge; + } + + iterator_range::iterator> getValues() { + return make_range(Values.begin(), Values.end()); + } + + Value *getValue(unsigned Idx) { + return Values[Idx]; + } + + unsigned getNumValues() { + return Values.size(); + } + + bool contains(const Value *V) { + return std::find(Values.begin(), Values.end(), V) != Values.end(); + } + + bool isReduction() { return IsReduction; } + + bool isReductionValue(const Value *V) { return IsReduction && contains(V); } + +private: + Instruction* Merge; + StoreInst* Store; + SmallVector Values; + +protected: + bool IsReduction; + + // Needed to set IsReduction of an Escapee object + // from within EscapeeFactory::CreateEscapee(). + friend class EscapeeFactory; +}; + +class EscapeeFactory : public Escapee { +private: + typedef std::tuple MergeValTuple; + // Convenience method to get a list of values from an escapee merge node + void getEscapeeValuesFromMergeNode(PHINode *MergeNode, + SmallVectorImpl &Values); + +public: + // Constructor + EscapeeFactory(LoopBlocksDFS *DFS, PostDominatorTree *PDT) : + DFS(DFS), PDT(PDT) {} + + // Manually add escapees + Escapee* CreateEscapee(PHINode *Recurrence, RecurrenceDescriptor &RD); + Escapee* CreateEscapee(PHINode *MergeNode); + + iterator_range::iterator> getEscapees() { + return make_range(Escapees.begin(), Escapees.end()); + } + + bool hasEscapees() { return Escapees.size() > 0; } + + // If returns true, Res is filled with the corresponding escapee. + bool canVectorizeEscapeeValue(Instruction *Val, Escapee *&Res); + +private: + // Escapee cache with quick lookup by merge node. + MapVector Escapees; + // DFS is needed to sort incoming edges + LoopBlocksDFS *DFS; + // PDT is needed to check position of the merge node. + PostDominatorTree *PDT; +}; + +/// SLVLoopVectorizationLegality checks if it is legal to vectorize a loop, and +/// to what vectorization factor. +/// This class does not look at the profitability of vectorization, only the +/// legality. This class has two main kinds of checks: +/// * Memory checks - The code in canVectorizeMemory checks if vectorization +/// will change the order of memory accesses in a way that will change the +/// correctness of the program. +/// * Scalars checks - The code in canVectorizeInstrs and canVectorizeMemory +/// checks for a number of different conditions, such as the availability of a +/// single induction variable, that all types are supported and vectorize-able, +/// etc. This code reflects the capabilities of InnerLoopVectorizer. +/// This class is also used by InnerLoopVectorizer for identifying +/// induction variable and the different reduction variables. +class SLVLoopVectorizationLegality { +public: + SLVLoopVectorizationLegality(Loop *L, PredicatedScalarEvolution &PSE, + DominatorTree *DT,PostDominatorTree *PDT, + TargetLibraryInfo *TLI, AliasAnalysis *AA, + Function *F,const TargetTransformInfo *TTI, + std::function *GetLAA, + LoopInfo *LI, OptimizationRemarkEmitter *ORE, + LoopVectorizationRequirements *R, + const SLVLoopVectorizeHints *H) + : NumPredStores(0), TheLoop(L), PSE(PSE), TLI(TLI), TheFunction(F), + TTI(TTI), DT(DT), PDT(PDT), AA(AA), GetLAA(GetLAA), LAI(nullptr), + ORE(ORE), InterleaveInfo(PSE, L, DT, LI), Induction(nullptr), + WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), + Hints(H), ScalarizedReduction(false), + AllowUncounted(true), IsUncounted(false) { + DFS = new LoopBlocksDFS(TheLoop); + DFS->perform(LI); + EF = new EscapeeFactory(DFS, PDT); + } + // TODO AllowUncounted(EnableUncountedLoops), IsUncounted(false) {} + + /// Returns true if the function has an attribute saying that + /// we can assume the absence of NaNs. + bool hasNoNaNAttr(void) const { return HasFunNoNaNAttr; } + + /// ReductionList contains the reduction descriptors for all + /// of the reductions that were found in the loop. + typedef DenseMap ReductionList; + + /// InductionList saves induction variables and maps them to the + /// induction descriptor. + typedef MapVector InductionList; + + typedef SmallVector ConditionExprs; + + // foobar + struct LoopExit { + LoopExit() + : Kind(EK_Unknown), ExitingBlock(nullptr), ExitBlock(nullptr) {} + LoopExit(LoopExitKind Kind, BasicBlock *Exiting, BasicBlock *Exit) + : Kind(Kind), ExitingBlock(Exiting), ExitBlock(Exit) {} + LoopExit(LoopExitKind Kind, BasicBlock *Exiting, BasicBlock *Exit, + ConditionExprs &SubExprs) + : Kind(Kind), ExitingBlock(Exiting), ExitBlock(Exit) { + Nodes.insert(Nodes.begin(), SubExprs.begin(), SubExprs.end()); + } + + LoopExitKind Kind; + BasicBlock *ExitingBlock; + BasicBlock *ExitBlock; + ConditionExprs Nodes; + }; + + /// Mapping of all exits to a known type + /// TODO: May need more info later. + typedef SmallVector ExitList; + + /// Returns true if it is legal to vectorize this loop. + /// This does not mean that it is profitable to vectorize this + /// loop, only that it is legal to do so. + bool canVectorize(); + + /// Returns the Induction variable. + PHINode *getInduction() { return Induction; } + + /// Returns the reduction variables found in the loop. + ReductionList *getReductionVars() { return &Reductions; } + + /// Returns the induction variables found in the loop. + InductionList *getInductionVars() { return &Inductions; } + + // Returns the EscapeeFactory + EscapeeFactory *getEF() { return EF; } + + /// Return all known exits with their type + /// TODO: Remove this in favour of the range func below? + ExitList *getLoopExits() { return &Exits; } + + // TODO: Should this be part of loopinfo? + iterator_range exits() { + return make_range(Exits.begin(), Exits.end()); + } + + LoopExitKind getBlockExitKind(BasicBlock *BB) { + for (auto &LE : Exits) + if (LE.ExitingBlock == BB) + return LE.Kind; + + // If we didn't find the block in the known exits, + // then it isn't an exiting block. + return EK_None; + } + + /// Returns the widest induction type. + Type *getWidestInductionType() { return WidestIndTy; } + + /// Returns True if V is an induction variable in this loop. + bool isInductionVariable(const Value *V); + + /// Return true if the block BB needs to be predicated in order for the loop + /// to be vectorized. + bool blockNeedsPredication(BasicBlock *BB); + + /// Check if this pointer is consecutive when vectorizing. This happens + /// when the last index of the GEP is the induction variable, or that the + /// pointer itself is an induction variable. + /// This check allows us to vectorize A[idx] into a wide load/store. + /// Returns: + /// 0 - Stride is unknown or non-consecutive. + /// 1 - Address is consecutive. + /// -1 - Address is consecutive, and decreasing. + int isConsecutivePtr(Value *Ptr); + + /// Returns true if the value V is uniform within the loop. + bool isUniform(Value *V); + + /// Returns true if this instruction will remain scalar after vectorization. + bool isUniformAfterVectorization(Instruction* I) { return Uniforms.count(I); } + + /// Returns the information that we collected about runtime memory check. + const RuntimePointerChecking *getRuntimePointerChecking() const { + return LAI->getRuntimePointerChecking(); + } + + const LoopAccessInfo *getLAI() const { + return LAI; + } + + /// \brief Check if \p Instr belongs to any interleaved access group. + bool isAccessInterleaved(Instruction *Instr) { + return InterleaveInfo.isInterleaved(Instr); + } + + /// \brief Get the interleaved access group that \p Instr belongs to. + const InterleaveGroup *getInterleavedAccessGroup(Instruction *Instr) { + return InterleaveInfo.getInterleaveGroup(Instr); + } + + unsigned getMaxSafeDepDistBytes() { return LAI->getMaxSafeDepDistBytes(); } + + bool hasStride(Value *V) { return StrideSet.count(V); } + bool mustCheckStrides() { return !StrideSet.empty(); } + SmallPtrSet::iterator strides_begin() { + return StrideSet.begin(); + } + SmallPtrSet::iterator strides_end() { return StrideSet.end(); } + + /// Returns true if the target machine supports masked store operation + /// for the given \p DataType and kind of access to \p Ptr. + bool isLegalMaskedStore(Type *DataType, Value *Ptr); + + /// Returns true if the target machine supports masked load operation + /// for the given \p DataType and kind of access to \p Ptr. + bool isLegalMaskedLoad(Type *DataType, Value *Ptr); + + /// Returns true if vector representation of the instruction \p I + /// requires mask. + bool isMaskRequired(const Instruction* I) { + return (MaskedOp.count(I) != 0); + } + /// Returns true if the loop requires masked operations for vectorisation to + /// be legal. + bool hasMaskedOperations() { + return MaskedOp.begin() != MaskedOp.end(); + } + unsigned getNumStores() const { + return LAI->getNumStores(); + } + unsigned getNumLoads() const { + return LAI->getNumLoads(); + } + unsigned getNumPredStores() const { + return NumPredStores; + } + + /// Returns true if reductions exist that the target cannot perform directly. + bool hasScalarizedReduction() const { + return !Reductions.empty() && ScalarizedReduction; + } + + bool isUncountedLoop() const { + return IsUncounted; + } + +private: + /// Check if a single basic block loop is vectorizable. + /// At this point we know that this is a loop with a constant trip count + /// and we only need to check individual instructions. + bool canVectorizeInstrs(); + + /// When we vectorize loops we may change the order in which + /// we read and write from memory. This method checks if it is + /// legal to vectorize the code, considering only memory constrains. + /// Returns true if the loop is vectorizable + bool canVectorizeMemory(); + + /// Return true if we can vectorize this loop using the IF-conversion + /// transformation. + bool canVectorizeWithIfConvert(); + + /// Collect the variables that need to stay uniform after vectorization. + void collectLoopUniforms(); + + /// Return true if all of the instructions in the block can be speculatively + /// executed. \p SafePtrs is a list of addresses that are known to be legal + /// and we know that we can read from them without segfault. + bool blockCanBePredicated(BasicBlock *BB, SmallPtrSetImpl &SafePtrs); + + /// Returns true if a loop variable within a loop that has an outside user can + /// be safely vectorized and the appropriate terminating value extracted + bool canVectorizeEscapee(Instruction *Esc); + + bool findConditionSubExprs(Value *V, ConditionExprs &SubExprs); + bool findConditionSubExprsRecurse(Value *V, ConditionExprs &SubExprs); + + /// Returns true if all possible exits from the loop can be used to form a + /// combined exit in the vectorized loop latch, with destructive operations + /// controlled with predication during the loop + bool canVectorizeExits(); + + /// \brief Collect memory access with loop invariant strides. + /// + /// Looks for accesses like "a[i * StrideA]" where "StrideA" is loop + /// invariant. + void collectStridedAccess(Value *LoadOrStoreInst); + + unsigned NumPredStores; + + /// The loop that we evaluate. + Loop *TheLoop; + /// A wrapper around ScalarEvolution used to add runtime SCEV checks. + /// Applies dynamic knowledge to simplify SCEV expressions in the context + /// of existing SCEV assumptions. The analysis will also add a minimal set + /// of new predicates if this is required to enable vectorization and + /// unrolling. + PredicatedScalarEvolution &PSE; + /// Target Library Info. + TargetLibraryInfo *TLI; + /// Parent function + Function *TheFunction; + /// Target Transform Info + const TargetTransformInfo *TTI; + /// Dominator Tree. + DominatorTree *DT; + /// Loop Info + //LoopInfo *LI; + /// PostDominator Tree. + PostDominatorTree *PDT; + // Depth first order graph of loop + LoopBlocksDFS *DFS; + // Alias analysis + AliasAnalysis *AA; + // LoopAccess analysis. + std::function *GetLAA; + // And the loop-accesses info corresponding to this loop. This pointer is + // null until canVectorizeMemory sets it up. + const LoopAccessInfo *LAI; + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; + + /// The interleave access information contains groups of interleaved accesses + /// with the same stride and close to each other. + InterleavedAccessInfo InterleaveInfo; + + // --- vectorization state --- // + + /// Holds the integer induction variable. This is the counter of the + /// loop. + PHINode *Induction; + /// Holds the reduction variables. + ReductionList Reductions; + /// Holds all of the induction variables that we found in the loop. + /// Notice that inductions don't need to start at zero and that induction + /// variables can be pointers. + InductionList Inductions; + /// Holds and creates all Escapees + EscapeeFactory *EF; + /// Holds a mapping of all loop exits discovered and their type + ExitList Exits; + /// Holds the widest induction type encountered. + Type *WidestIndTy; + + /// Allowed outside users. This holds the reduction + /// vars which can be accessed from outside the loop. + SmallPtrSet AllowedExit; + /// This set holds the variables which are known to be uniform after + /// vectorization. + SmallPtrSet Uniforms; + + /// Can we assume the absence of NaNs. + bool HasFunNoNaNAttr; + + /// Vectorization requirements that will go through late-evaluation. + LoopVectorizationRequirements *Requirements; + + /// Used to emit an analysis of any legality issues. + const SLVLoopVectorizeHints *Hints; + + ValueToValueMap Strides; + SmallPtrSet StrideSet; + + /// While vectorizing these instructions we have to generate a + /// call to the appropriate masked intrinsic + SmallPtrSet MaskedOp; + + /// Does any in-vector reduction require scalarization? + bool ScalarizedReduction; + + /// If enabled, we will vectorize (some) loops which do not have + /// a defined trip count that SCEV can determine. + bool AllowUncounted; + + /// If set, the current loop is uncounted + bool IsUncounted; +}; + +class LoopVectorizationAnalysis : public FunctionPass { +public: + static char ID; + + LoopVectorizationAnalysis() : FunctionPass(ID) { + initializeLoopVectorizationAnalysisPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; + + // TODO: cache, interface +}; +} // End llvm namespace + +#endif Index: include/llvm/Analysis/MemoryLocation.h =================================================================== --- include/llvm/Analysis/MemoryLocation.h +++ include/llvm/Analysis/MemoryLocation.h @@ -20,6 +20,7 @@ #include "llvm/ADT/DenseMapInfo.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/IntrinsicInst.h" namespace llvm { @@ -76,6 +77,7 @@ static MemoryLocation get(const VAArgInst *VI); static MemoryLocation get(const AtomicCmpXchgInst *CXI); static MemoryLocation get(const AtomicRMWInst *RMWI); + static MemoryLocation get(const MemSetInst *MSI); static MemoryLocation get(const Instruction *Inst) { return *MemoryLocation::getOrNone(Inst); } @@ -92,6 +94,8 @@ case Instruction::AtomicRMW: return get(cast(Inst)); default: + if (isa(Inst)) + return get(cast(Inst)); return None; } } Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -58,6 +58,8 @@ class Constant; class ConstantInt; class DataLayout; +class TargetLibraryInfo; +class TargetTransformInfo; class DominatorTree; class GEPOperator; class Instruction; @@ -191,7 +193,7 @@ protected: SCEVPredicateKind Kind; - ~SCEVPredicate() = default; + virtual ~SCEVPredicate() {} SCEVPredicate(const SCEVPredicate &) = default; SCEVPredicate &operator=(const SCEVPredicate &) = default; @@ -483,7 +485,7 @@ } ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, - DominatorTree &DT, LoopInfo &LI); + DominatorTree &DT, LoopInfo &LI, TargetTransformInfo &TTI); ScalarEvolution(ScalarEvolution &&Arg); ~ScalarEvolution(); @@ -1106,6 +1108,10 @@ /// The loop information for the function we are currently analyzing. LoopInfo &LI; + /// TTI - The target transform information for the target we are targeting. + /// + TargetTransformInfo &TTI; + /// This SCEV is used to represent unknown trip counts and things. std::unique_ptr CouldNotCompute; Index: include/llvm/Analysis/TargetLibraryInfo.h =================================================================== --- include/llvm/Analysis/TargetLibraryInfo.h +++ include/llvm/Analysis/TargetLibraryInfo.h @@ -18,6 +18,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" +#include namespace llvm { template class ArrayRef; @@ -28,7 +29,15 @@ struct VecDesc { StringRef ScalarFnName; StringRef VectorFnName; - unsigned VectorizationFactor; + VectorType::ElementCount VectorizationFactor; + bool Masked; + unsigned Priority; +}; + +/// Describes a vector function, by Name and Signature. +struct VectorFnInfo { + std::string Name; + FunctionType *Signature; }; enum LibFunc { @@ -52,6 +61,10 @@ static StringRef const StandardNames[NumLibFuncs]; bool ShouldExtI32Param, ShouldExtI32Return, ShouldSignExtI32Param; + /// Holds the mapping between a scalar function and its available vector + /// versions. + std::multimap VectorFunctionInfo; + enum AvailabilityState { StandardName = 3, // (memset to all ones) CustomName = 1, @@ -87,6 +100,7 @@ enum VectorLibrary { NoLibrary, // Don't use any vector library. Accelerate, // Use Accelerate framework. + SLEEF, // SIMD Library for Evaluating Elementary Functions. SVML // Intel short vector math library. }; @@ -145,12 +159,22 @@ /// Calls addVectorizableFunctions with a known preset of functions for the /// given vector library. - void addVectorizableFunctionsFromVecLib(enum VectorLibrary VecLib); + void addVectorizableFunctionsFromVecLib(enum VectorLibrary VecLib, + const Triple &T); + /// Maintain the original interface until the clang patches have landed. + void addVectorizableFunctionsFromVecLib(enum VectorLibrary VecLib) { + addVectorizableFunctionsFromVecLib(VecLib, Triple()); + } + + /// Adds vector routines that are available as global external function + /// pointers in the module \param M. + void addOpenMPVectorFunctions(Module *M); /// Return true if the function F has a vector equivalent with vectorization /// factor VF. - bool isFunctionVectorizable(StringRef F, unsigned VF) const { - return !getVectorizedFunction(F, VF).empty(); + bool isFunctionVectorizable(StringRef F, VectorType::ElementCount VF, + bool Masked, FunctionType *Sign) const { + return !getVectorizedFunction(F, VF, Masked, Sign).empty(); } /// Return true if the function F has a vector equivalent with any @@ -159,11 +183,13 @@ /// Return the name of the equivalent of F, vectorized with factor VF. If no /// such mapping exists, return the empty string. - StringRef getVectorizedFunction(StringRef F, unsigned VF) const; + std::string getVectorizedFunction(StringRef F, VectorType::ElementCount VF, + bool Masked, + FunctionType *Sign = nullptr) const; /// Return true if the function F has a scalar equivalent, and set VF to be /// the vectorization factor. - bool isFunctionScalarizable(StringRef F, unsigned &VF) const { + bool isFunctionScalarizable(StringRef F, VectorType::ElementCount &VF) const { return !getScalarizedFunction(F, VF).empty(); } @@ -171,7 +197,8 @@ /// exists, return the empty string. /// /// Set VF to the vectorization factor. - StringRef getScalarizedFunction(StringRef F, unsigned &VF) const; + StringRef getScalarizedFunction(StringRef F, + VectorType::ElementCount &VF) const; /// Set to true iff i32 parameters to library functions should have signext /// or zeroext attributes if they correspond to C-level int or unsigned int, @@ -196,6 +223,26 @@ /// Returns the size of the wchar_t type in bytes or 0 if the size is unknown. /// This queries the 'wchar_size' metadata. unsigned getWCharSize(const Module &M) const; + + /// Returns size of the default wchar_t type on target \p T. This is mostly + /// intended to verify that the size in the frontend matches LLVM. All other + /// queries should use getWCharSize() instead. + static unsigned getTargetWCharSize(const Triple &T); + + /// Demangle a scalar/vector mangled name + static std::pair demangle(const std::string In); + + /// Checks the validity of a mangled scalar/vector name. + static bool isMangledName(std::string Name); + + /// Check the validity of \param Ty to be used as a vector function + /// signature. If so, it sets \param FTy to the vector function signature. + static bool isValidSignature(Type *Ty, FunctionType *&FTY); + + /// Mangles \param VecName and \param ScalarName using a prefix, a middlefix + /// and a suffix. + static std::string mangle(const std::string VecName, + const std::string ScalarName); }; /// Provides information about what library functions are available for @@ -247,14 +294,16 @@ bool has(LibFunc F) const { return Impl->getState(F) != TargetLibraryInfoImpl::Unavailable; } - bool isFunctionVectorizable(StringRef F, unsigned VF) const { - return Impl->isFunctionVectorizable(F, VF); + bool isFunctionVectorizable(StringRef F, VectorType::ElementCount VF, + bool Masked, FunctionType *Sign) const { + return Impl->isFunctionVectorizable(F, VF, Masked, Sign); } bool isFunctionVectorizable(StringRef F) const { return Impl->isFunctionVectorizable(F); } - StringRef getVectorizedFunction(StringRef F, unsigned VF) const { - return Impl->getVectorizedFunction(F, VF); + std::string getVectorizedFunction(StringRef F, VectorType::ElementCount VF, + bool Masked, FunctionType *Sign) const { + return Impl->getVectorizedFunction(F, VF, Masked, Sign); } /// Tests if the function is both available and a candidate for optimized code @@ -336,6 +385,15 @@ FunctionAnalysisManager::Invalidator &) { return false; } + + /// Check whether the vector pcs should be used + void setCallingConv(CallInst *CI) const { + StringRef FName = CI->getCalledFunction()->getName(); + if (FName.startswith("_ZGVn")) { + CI->setCallingConv(CallingConv::AArch64_VectorCall); + } + } + }; /// Analysis pass providing the \c TargetLibraryInfo. Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -73,6 +73,117 @@ } }; +/// \brief Information about the memory access pattern for a given +/// load/store instruction. +class MemAccessInfo { +public: + enum MemAccessInfoType { + UNIFORM = 0, + STRIDED = 1, + NONSTRIDED = 2 + }; + +private: + // The type of access (uniform, strided, gather/scatter) + MemAccessInfoType Access; + + // Predicated access + bool IsMasked; + + struct StridedInfoStruct { + int Stride; + bool IsReversed; + }; + + struct NonStridedInfoStruct { + Type *IndexType; + bool IsSignedIndex; + }; + + // Access-specific information + union { + struct StridedInfoStruct StridedInfo; + struct NonStridedInfoStruct NonStridedInfo; + }; + + // Privater constructors + MemAccessInfo () : Access(UNIFORM), IsMasked(false) {} + + MemAccessInfo (int stride, bool isReversed, bool isMasked) : + Access(STRIDED), IsMasked(isMasked) { + StridedInfo.Stride = stride; + StridedInfo.IsReversed = isReversed; + } + + MemAccessInfo (Type *idxType, bool isSignedIndex, bool isMasked) : + Access(NONSTRIDED), IsMasked(isMasked) { + NonStridedInfo.IsSignedIndex = isSignedIndex; + NonStridedInfo.IndexType = idxType; + } + +public: + // Static methods to create a MemAccessInfo + static MemAccessInfo getUniformInfo() { + return MemAccessInfo(); + } + + static MemAccessInfo getStridedInfo(int stride, bool isReversed, + bool isMasked) { + return MemAccessInfo(stride, isReversed, isMasked); + } + + static MemAccessInfo getNonStridedInfo(Type *idxType, bool isMasked, + bool isSignedIdx = true) { + return MemAccessInfo(idxType, isSignedIdx, isMasked); + } + + // Accessor methods + MemAccessInfoType getAccessType() const { + return Access; + } + + bool isStrided() const { + return Access == STRIDED; + } + + bool isUniform() const { + return Access == UNIFORM; + } + + bool isNonStrided() const { + return Access == NONSTRIDED; + } + + bool isMasked() const { + assert(!(IsMasked && isUniform()) && + "Uniform access cannot be predicated"); + return IsMasked; + } + + bool isReversed() const { + assert(isStrided() && + "Cannot get reversed stride from non-strided access"); + return StridedInfo.IsReversed; + } + + int getStride() const { + assert(isStrided() && + "Cannot get stride from non-strided access"); + return StridedInfo.Stride; + } + + Type *getIndexType() const { + assert(isNonStrided() && + "Cannot get index type of non-gather/scatter access"); + return NonStridedInfo.IndexType; + } + + bool isSignedIndex() const { + assert(isNonStrided() && "Expecting a non-strided access"); + return NonStridedInfo.IsSignedIndex; + } +}; + /// This pass provides access to the codegen interfaces that are needed /// for IR-level transformations. class TargetTransformInfo { @@ -566,7 +677,7 @@ unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) const; unsigned getOperandsScalarizationOverhead(ArrayRef Args, - unsigned VF) const; + VectorType::ElementCount VF) const; /// If target has efficient vector element load/store instructions, it can /// return true here so that insertion/extraction costs are not added to @@ -724,6 +835,13 @@ /// This is currently measured in number of instructions. unsigned getPrefetchDistance() const; + /// \return The width of the largest possible register supported by the target + /// architecture. This is an upper bound rather than its actual width. The + /// lower bound (returned by getRegisterBitWidth) is the more common question + /// to ask but for cases when a transform is only safe when the register is + /// smaller than X, this function should be used. + unsigned getRegisterBitWidthUpperBound(bool Vector) const; + /// \return Some HW prefetchers can handle accesses up to a certain constant /// stride. This is the minimum stride in bytes where it makes sense to start /// adding SW prefetches. The default is 1, i.e. prefetch with any stride. @@ -795,6 +913,12 @@ int getMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, unsigned AddressSpace, const Instruction *I = nullptr) const; + // \return The cost of Load and Store instructions for + // a memory access pattern described by Info. + unsigned getVectorMemoryOpCost(unsigned Opcode, Type *Src, Value *Ptr, + unsigned Alignment, unsigned AddressSpace, + const MemAccessInfo &Info, Instruction *I) const; + /// \return The cost of masked Load and Store instructions. int getMaskedMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, unsigned AddressSpace) const; @@ -843,8 +967,8 @@ /// Three cases are handled: 1. scalar instruction 2. vector instruction /// 3. scalar instruction which is to be vectorized with VF. int getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, - ArrayRef Args, FastMathFlags FMF, - unsigned VF = 1) const; + ArrayRef Args, FastMathFlags FMF, + VectorType::ElementCount VF = VectorType::SingleElement()) const; /// \returns The cost of Intrinsic instructions. Types analysis only. /// If ScalarizationCostPassed is UINT_MAX, the cost of scalarizing the @@ -894,6 +1018,15 @@ Value *getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst, Type *ExpectedType) const; + /// \returns True if the target has instructions for a load/store with + /// an access pattern described by Info. + /// \param Ty is the result (load) or operand (store) type of the + /// memory operation. + /// \param Info is a struct that describes the memory access + /// (strided,gather,uniform) + bool hasVectorMemoryOp(unsigned Opcode, Type *Ty, + const MemAccessInfo &Info) const; + /// \returns The type to use in a loop expansion of a memcpy call. Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length, unsigned SrcAlign, unsigned DestAlign) const; @@ -915,6 +1048,17 @@ bool areInlineCompatible(const Function *Caller, const Function *Callee) const; + /// \returns True if the target can efficiently support vectorized non-unit + // strides. + bool canVectorizeNonUnitStrides(bool forceFixedWidth = false) const; + + /// \returns True if the target can efficiently handle store->loads + /// even when forwarding is prevented by the address not being a multiple + /// of the VF. For vector architectures with predication and unaligned + /// memory access, the benefit of vectorization may still outweigh + /// the cost for lack of store-load forwarding. + bool vectorizePreventedSLForwarding(void) const; + /// The type of load/store indexing. enum MemIndexedMode { MIM_Unindexed, ///< No indexing. @@ -964,10 +1108,12 @@ /// Flags describing the kind of vector reduction. struct ReductionFlags { - ReductionFlags() : IsMaxOp(false), IsSigned(false), NoNaN(false) {} + ReductionFlags() + : IsMaxOp(false), IsSigned(false), NoNaN(false), IsOrdered(false) {} bool IsMaxOp; ///< If the op a min/max kind, true if it's a max operation. bool IsSigned; ///< Whether the operation is a signed int reduction. bool NoNaN; ///< If op is an fp min/max, whether NaNs may be present. + bool IsOrdered; ///< True if the reduction is an ordered/strict reduction. }; /// \returns True if the target wants to handle the given reduction idiom in @@ -978,6 +1124,11 @@ /// \returns True if the target wants to expand the given reduction intrinsic /// into a shuffle sequence. bool shouldExpandReduction(const IntrinsicInst *II) const; + + /// \returns True if the target can handle the reduction. + bool canReduceInVector(unsigned Opcode, Type *ScalarTy, + ReductionFlags Flags) const; + /// @} private: @@ -1061,8 +1212,8 @@ virtual bool useColdCCForColdCall(Function &F) = 0; virtual unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) = 0; - virtual unsigned getOperandsScalarizationOverhead(ArrayRef Args, - unsigned VF) = 0; + virtual unsigned getOperandsScalarizationOverhead( + ArrayRef Args, VectorType::ElementCount VF) = 0; virtual bool supportsEfficientVectorElementLoadStore() = 0; virtual bool enableAggressiveInterleaving(bool LoopHasReductions) = 0; virtual const MemCmpExpansionOptions *enableMemCmpExpansion( @@ -1096,6 +1247,7 @@ virtual llvm::Optional getCacheSize(CacheLevel Level) = 0; virtual llvm::Optional getCacheAssociativity(CacheLevel Level) = 0; virtual unsigned getPrefetchDistance() = 0; + virtual unsigned getRegisterBitWidthUpperBound(bool Vector) = 0; virtual unsigned getMinPrefetchStride() = 0; virtual unsigned getMaxPrefetchIterationsAhead() = 0; virtual unsigned getMaxInterleaveFactor(unsigned VF) = 0; @@ -1118,6 +1270,11 @@ unsigned Index) = 0; virtual int getMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, unsigned AddressSpace, const Instruction *I) = 0; + virtual unsigned getVectorMemoryOpCost(unsigned Opcode, Type *Src, Value *Ptr, + unsigned Alignment, + unsigned AddressSpace, + const MemAccessInfo &Info, + Instruction *I) = 0; virtual int getMaskedMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, unsigned AddressSpace) = 0; @@ -1137,7 +1294,8 @@ ArrayRef Tys, FastMathFlags FMF, unsigned ScalarizationCostPassed) = 0; virtual int getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, - ArrayRef Args, FastMathFlags FMF, unsigned VF) = 0; + ArrayRef Args, FastMathFlags FMF, + VectorType::ElementCount VF) = 0; virtual int getCallInstrCost(Function *F, Type *RetTy, ArrayRef Tys) = 0; virtual unsigned getNumberOfParts(Type *Tp) = 0; @@ -1149,6 +1307,10 @@ virtual unsigned getAtomicMemIntrinsicMaxElementSize() const = 0; virtual Value *getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst, Type *ExpectedType) = 0; + + virtual bool hasVectorMemoryOp(unsigned Opcode, Type *Ty, + const MemAccessInfo &Info) = 0; + virtual Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length, unsigned SrcAlign, unsigned DestAlign) const = 0; @@ -1157,6 +1319,10 @@ unsigned RemainingBytes, unsigned SrcAlign, unsigned DestAlign) const = 0; virtual bool areInlineCompatible(const Function *Caller, const Function *Callee) const = 0; + + virtual bool canVectorizeNonUnitStrides(bool forceFixedWidth) const = 0; + virtual bool vectorizePreventedSLForwarding() const = 0; + virtual bool isIndexedLoadLegal(MemIndexedMode Mode, Type *Ty) const = 0; virtual bool isIndexedStoreLegal(MemIndexedMode Mode,Type *Ty) const = 0; virtual unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const = 0; @@ -1177,6 +1343,9 @@ virtual bool useReductionIntrinsic(unsigned Opcode, Type *Ty, ReductionFlags) const = 0; virtual bool shouldExpandReduction(const IntrinsicInst *II) const = 0; + virtual bool canReduceInVector(unsigned Opcode, Type *ScalarTy, + ReductionFlags Flags) const = 0; + virtual int getInstructionLatency(const Instruction *I) = 0; }; @@ -1324,7 +1493,7 @@ return Impl.getScalarizationOverhead(Ty, Insert, Extract); } unsigned getOperandsScalarizationOverhead(ArrayRef Args, - unsigned VF) override { + VectorType::ElementCount VF) override { return Impl.getOperandsScalarizationOverhead(Args, VF); } @@ -1383,6 +1552,10 @@ unsigned getRegisterBitWidth(bool Vector) const override { return Impl.getRegisterBitWidth(Vector); } + unsigned getRegisterBitWidthUpperBound(bool Vector) override { + return Impl.getRegisterBitWidthUpperBound(Vector); + } + unsigned getMinVectorRegisterBitWidth() override { return Impl.getMinVectorRegisterBitWidth(); } @@ -1455,6 +1628,13 @@ unsigned AddressSpace, const Instruction *I) override { return Impl.getMemoryOpCost(Opcode, Src, Alignment, AddressSpace, I); } + unsigned getVectorMemoryOpCost(unsigned Opcode, Type *Src, Value *Ptr, + unsigned Alignment, unsigned AddressSpace, + const MemAccessInfo &Info, + Instruction *I) override { + return Impl.getVectorMemoryOpCost(Opcode, Src, Ptr, Alignment, AddressSpace, + Info, I); + } int getMaskedMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, unsigned AddressSpace) override { return Impl.getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace); @@ -1485,7 +1665,8 @@ ScalarizationCostPassed); } int getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, - ArrayRef Args, FastMathFlags FMF, unsigned VF) override { + ArrayRef Args, FastMathFlags FMF, + VectorType::ElementCount VF) override { return Impl.getIntrinsicInstrCost(ID, RetTy, Args, FMF, VF); } int getCallInstrCost(Function *F, Type *RetTy, @@ -1513,6 +1694,12 @@ Type *ExpectedType) override { return Impl.getOrCreateResultFromMemIntrinsic(Inst, ExpectedType); } + + bool hasVectorMemoryOp(unsigned Opcode, Type *Ty, + const MemAccessInfo &Info) override { + return Impl.hasVectorMemoryOp(Opcode, Ty, Info); + } + Type *getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length, unsigned SrcAlign, unsigned DestAlign) const override { @@ -1530,6 +1717,15 @@ const Function *Callee) const override { return Impl.areInlineCompatible(Caller, Callee); } + + bool canVectorizeNonUnitStrides(bool forceFixedWidth) const override { + return Impl.canVectorizeNonUnitStrides(forceFixedWidth); + } + + bool vectorizePreventedSLForwarding() const override { + return Impl.vectorizePreventedSLForwarding(); + } + bool isIndexedLoadLegal(MemIndexedMode Mode, Type *Ty) const override { return Impl.isIndexedLoadLegal(Mode, Ty, getDataLayout()); } @@ -1574,6 +1770,10 @@ bool shouldExpandReduction(const IntrinsicInst *II) const override { return Impl.shouldExpandReduction(II); } + bool canReduceInVector(unsigned Opcode, Type *ScalarTy, + ReductionFlags Flags) const override { + return Impl.canReduceInVector(Opcode, ScalarTy, Flags); + } int getInstructionLatency(const Instruction *I) override { return Impl.getInstructionLatency(I); } Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -298,7 +298,9 @@ } unsigned getOperandsScalarizationOverhead(ArrayRef Args, - unsigned VF) { return 0; } + VectorType::ElementCount VF) { + return 0; + } bool supportsEfficientVectorElementLoadStore() { return false; } @@ -390,6 +392,8 @@ unsigned getPrefetchDistance() { return 0; } + unsigned getRegisterBitWidthUpperBound(bool Vector) { return 32; } + unsigned getMinPrefetchStride() { return 1; } unsigned getMaxPrefetchIterationsAhead() { return UINT_MAX; } @@ -434,6 +438,12 @@ return 1; } + unsigned getVectorMemoryOpCost(unsigned Opcode, Type *Src, Value *Ptr, + unsigned Alignment, unsigned AddressSpace, + const MemAccessInfo &Info, Instruction *I) { + return 1; + } + unsigned getMaskedMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, unsigned AddressSpace) { return 1; @@ -459,7 +469,8 @@ return 1; } unsigned getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, - ArrayRef Args, FastMathFlags FMF, unsigned VF) { + ArrayRef Args, FastMathFlags FMF, + VectorType::ElementCount VF) { return 1; } @@ -474,7 +485,10 @@ return 0; } - unsigned getArithmeticReductionCost(unsigned, Type *, bool) { return 1; } + unsigned getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwiseForm) { + return 1; + } unsigned getMinMaxReductionCost(Type *, Type *, bool, bool) { return 1; } @@ -520,6 +534,14 @@ Callee->getFnAttribute("target-features")); } + bool canVectorizeNonUnitStrides(bool forceFixedWidth = false) const { + return false; + } + + bool vectorizePreventedSLForwarding(void) const { + return false; + } + bool isIndexedLoadLegal(TTI::MemIndexedMode Mode, Type *Ty, const DataLayout &DL) const { return false; @@ -562,12 +584,18 @@ bool useReductionIntrinsic(unsigned Opcode, Type *Ty, TTI::ReductionFlags Flags) const { + if (auto *VT = dyn_cast(Ty)) + return VT->isScalable(); return false; } bool shouldExpandReduction(const IntrinsicInst *II) const { return true; } + bool canReduceInVector(unsigned Opcode, Type *ScalarTy, + TTI::ReductionFlags Flags) const { + return !Flags.IsOrdered; + } protected: // Obtain the minimum required size to hold the value (without the sign) @@ -816,6 +844,31 @@ U->getNumOperands() == 1 ? U->getOperand(0)->getType() : nullptr); } + bool hasVectorMemoryOp(unsigned Opcode, Type *Ty, const MemAccessInfo &Info) { + bool IsConsecutive = Info.getAccessType() == MemAccessInfo::STRIDED && + std::abs(Info.getStride()) == 1; + if (Info.isUniform()) + return true; + if (Info.isMasked() && Opcode == Instruction::Load) + return static_cast(this) + ->isLegalMaskedLoad(Ty->getVectorElementType()/*, IsConsecutive*/); + if (Info.isMasked() && Opcode == Instruction::Store) + return static_cast(this) + ->isLegalMaskedStore(Ty->getVectorElementType()/*, IsConsecutive*/); + return IsConsecutive; + } + + unsigned getVectorMemoryOpCost(unsigned Opcode, Type *Src, Value *Ptr, + unsigned Alignment, unsigned AddressSpace, + const MemAccessInfo &Info, Instruction *I) { + if (Info.isMasked()) + return static_cast(this) + ->getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace); + else + return static_cast(this) + ->getMemoryOpCost(Opcode, Src, Alignment, AddressSpace, I); + } + int getInstructionLatency(const Instruction *I) { SmallVector Operands(I->value_op_begin(), I->value_op_end()); Index: include/llvm/Analysis/VectorUtils.h =================================================================== --- include/llvm/Analysis/VectorUtils.h +++ include/llvm/Analysis/VectorUtils.h @@ -43,11 +43,17 @@ /// ctlz,cttz and powi special intrinsics whose argument is scalar. bool hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, unsigned ScalarOpdIdx); +/// Returns true if the given vector intrinsic is maskable, and the +/// position of the mask parameter if available. +std::pair isMaskedVectorIntrinsic(Intrinsic::ID ID); + /// Returns intrinsic ID for call. /// For the input call instruction it finds mapping intrinsic and returns /// its intrinsic ID, in case it does not found it return not_intrinsic. +/// If UseMask is true, then find a masking vectorized function if available. Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI, - const TargetLibraryInfo *TLI); + const TargetLibraryInfo *TLI, + bool UseMask = false); /// Find the operand of the GEP that should be checked for consecutive /// stores. This ignores trailing indices that have no effect on the final @@ -115,6 +121,23 @@ DemandedBits &DB, const TargetTransformInfo *TTI=nullptr); +// Temporary helper function to implement what was 'test any_true' using +// vector reductions. Only targets AArch64 SVE. +Value *getAnyTrueReduction(IRBuilder<> &Builder, Value *Src, + const Twine &Name = ""); +// Temporary helper function to implement what was 'test all_true' using +// vector reductions. Only targets AArch64 SVE. +Value *getAllTrueReduction(IRBuilder<> &Builder, Value *Src, + const Twine &Name = ""); +// Temporary helper function to implement what was 'test all_false' using +// vector reductions. Only targets AArch64 SVE. +Value *getAllFalseReduction(IRBuilder<> &Builder, Value *Src, + const Twine &Name = ""); +// Helper function to implement what was 'test 'last_true' using +// extractelement. +Value *getLastTrueVector(IRBuilder<> &Builder, Value *Src, + const Twine &Name = ""); + /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath, /// MD_nontemporal]. For K in Kinds, we get the MDNode for K from each of the /// elements of VL, compute their "intersection" (i.e., the most generic Index: include/llvm/BinaryFormat/Dwarf.def =================================================================== --- include/llvm/BinaryFormat/Dwarf.def +++ include/llvm/BinaryFormat/Dwarf.def @@ -729,6 +729,8 @@ HANDLE_DW_CC(0xc9, LLVM_PreserveMost) HANDLE_DW_CC(0xca, LLVM_PreserveAll) HANDLE_DW_CC(0xcb, LLVM_X86RegCall) +// FIXME: Is this an internal number or something defined in a spec? +HANDLE_DW_CC(0xcc, LLVM_AAVPCS) // From GCC source code (include/dwarf2.h): This DW_CC_ value is not currently // generated by any toolchain. It is used internally to GDB to indicate OpenCL C // functions that have been compiled with the IBM XL C for OpenCL compiler and use Index: include/llvm/Bitcode/LLVMBitCodes.h =================================================================== --- include/llvm/Bitcode/LLVMBitCodes.h +++ include/llvm/Bitcode/LLVMBitCodes.h @@ -167,7 +167,10 @@ TYPE_CODE_FUNCTION = 21, // FUNCTION: [vararg, retty, paramty x N] - TYPE_CODE_TOKEN = 22 // TOKEN + TYPE_CODE_TOKEN = 22, // TOKEN + + TYPE_CODE_SVE_VEC = 23, // SVE hack + TYPE_CODE_SVE_PRED = 24 }; enum OperandBundleTagCode { @@ -311,6 +314,10 @@ METADATA_INDEX_OFFSET = 38, // [offset] METADATA_INDEX = 39, // [bitpos] METADATA_LABEL = 40, // [distinct, scope, name, file, line] + METADATA_STRING_TYPE = 41, // [distinct, name, size, align, ...] + METADATA_FORTRAN_ARRAY_TYPE = 42, // [distinct, name, [bounds ...], ...] + METADATA_FORTRAN_SUBRANGE = 43, // [distinct, lbound, lbnde, ubound, ubnde] + METADATA_COMMON_BLOCK = 44, // [distinct, scope, name, variable,...] }; // The constants block (CONSTANTS_BLOCK_ID) describes emission for each @@ -342,6 +349,8 @@ CST_CODE_INLINEASM = 23, // INLINEASM: [sideeffect|alignstack| // asmdialect,asmstr,conststr] CST_CODE_CE_GEP_WITH_INRANGE_INDEX = 24, // [opty, flags, n x operands] + CST_CODE_VSCALE = 25, // VSCALE + CST_CODE_STEPVEC = 26 // STEPVEC }; /// CastOpcodes - These are values used in the bitcode files to encode which Index: include/llvm/CodeGen/Analysis.h =================================================================== --- include/llvm/CodeGen/Analysis.h +++ include/llvm/CodeGen/Analysis.h @@ -62,6 +62,11 @@ return ComputeLinearIndex(Ty, Indices.begin(), Indices.end(), CurIndex); } +struct FieldOffsets { + uint64_t UnscaledBytes; + uint64_t ScaledBytes; +}; + /// ComputeValueVTs - Given an LLVM IR type, compute a sequence of /// EVTs that represent all the individual underlying /// non-aggregate types that comprise it. @@ -71,8 +76,8 @@ /// void ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL, Type *Ty, SmallVectorImpl &ValueVTs, - SmallVectorImpl *Offsets = nullptr, - uint64_t StartingOffset = 0); + SmallVectorImpl *Offsets = nullptr, + FieldOffsets StartingOffset = {0,0}); /// ExtractTypeInfo - Returns the type info, possibly bitcast, encoded in V. GlobalValue *ExtractTypeInfo(Value *V); Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -441,6 +441,8 @@ unsigned getRegisterBitWidth(bool Vector) const { return 32; } + unsigned getRegisterBitWidthUpperBound(bool Vector) { return 32; } + /// Estimate the overhead of scalarizing an instruction. Insert and Extract /// are set if the result needs to be inserted and/or extracted from vectors. unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) { @@ -463,16 +465,16 @@ /// non-constant operands. The types of the arguments are ordinarily /// scalar, in which case the costs are multiplied with VF. unsigned getOperandsScalarizationOverhead(ArrayRef Args, - unsigned VF) { + VectorType::ElementCount VF) { unsigned Cost = 0; SmallPtrSet UniqueOperands; for (const Value *A : Args) { if (!isa(A) && UniqueOperands.insert(A).second) { - Type *VecTy = nullptr; - if (A->getType()->isVectorTy()) { - VecTy = A->getType(); + auto VecTy = dyn_cast(A->getType()); + if (VecTy) { // If A is a vector operand, VF should be 1 or correspond to A. - assert((VF == 1 || VF == VecTy->getVectorNumElements()) && + assert((VF == VectorType::SingleElement() || + VF == VecTy->getElementCount()) && "Vector argument does not match VF"); } else @@ -485,15 +487,13 @@ return Cost; } - unsigned getScalarizationOverhead(Type *VecTy, ArrayRef Args) { - assert(VecTy->isVectorTy()); - + unsigned getScalarizationOverhead(Type *Ty, ArrayRef Args) { + VectorType* VecTy = cast(Ty); unsigned Cost = 0; Cost += getScalarizationOverhead(VecTy, true, false); if (!Args.empty()) - Cost += getOperandsScalarizationOverhead(Args, - VecTy->getVectorNumElements()); + Cost += getOperandsScalarizationOverhead(Args, VecTy->getElementCount()); else // When no information on arguments is provided, we add the cost // associated with one argument as a heuristic. @@ -897,10 +897,11 @@ /// Get intrinsic cost based on arguments. unsigned getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, - ArrayRef Args, FastMathFlags FMF, - unsigned VF = 1) { + ArrayRef Args, FastMathFlags FMF, + VectorType::ElementCount VF = VectorType::SingleElement()) { unsigned RetVF = (RetTy->isVectorTy() ? RetTy->getVectorNumElements() : 1); - assert((RetVF == 1 || VF == 1) && "VF > 1 and RetVF is a vector type"); + assert((RetVF == 1 || VF == VectorType::SingleElement()) && + "VF > 1 and RetVF is a vector type"); switch (IID) { default: { @@ -908,18 +909,19 @@ SmallVector Types; for (Value *Op : Args) { Type *OpTy = Op->getType(); - assert(VF == 1 || !OpTy->isVectorTy()); - Types.push_back(VF == 1 ? OpTy : VectorType::get(OpTy, VF)); + assert(VF == VectorType::SingleElement() || !OpTy->isVectorTy()); + Types.push_back(VF == VectorType::SingleElement() + ? OpTy : VectorType::get(OpTy, VF)); } - if (VF > 1 && !RetTy->isVoidTy()) + if (VF != VectorType::SingleElement() && !RetTy->isVoidTy()) RetTy = VectorType::get(RetTy, VF); // Compute the scalarization overhead based on Args for a vector // intrinsic. A vectorizer will pass a scalar RetTy and VF > 1, while // CostModel will pass a vector RetTy and VF is 1. unsigned ScalarizationCost = std::numeric_limits::max(); - if (RetVF > 1 || VF > 1) { + if (RetVF > 1 || VF != VectorType::SingleElement()) { ScalarizationCost = 0; if (!RetTy->isVoidTy()) ScalarizationCost += getScalarizationOverhead(RetTy, true, false); @@ -930,7 +932,7 @@ getIntrinsicInstrCost(IID, RetTy, Types, FMF, ScalarizationCost); } case Intrinsic::masked_scatter: { - assert(VF == 1 && "Can't vectorize types here."); + assert(VF == VectorType::SingleElement() && "Can't vectorize type."); Value *Mask = Args[3]; bool VarMask = !isa(Mask); unsigned Alignment = cast(Args[2])->getZExtValue(); @@ -941,7 +943,7 @@ Alignment); } case Intrinsic::masked_gather: { - assert(VF == 1 && "Can't vectorize types here."); + assert(VF == VectorType::SingleElement() && "Can't vectorize type."); Value *Mask = Args[2]; bool VarMask = !isa(Mask); unsigned Alignment = cast(Args[1])->getZExtValue(); @@ -1076,6 +1078,9 @@ case Intrinsic::fmuladd: ISDs.push_back(ISD::FMA); break; + case Intrinsic::bswap: + ISDs.push_back(ISD::BSWAP); + break; // FIXME: We should return 0 whenever getIntrinsicCost == TCC_Free. case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: Index: include/llvm/CodeGen/ISDOpcodes.h =================================================================== --- include/llvm/CodeGen/ISDOpcodes.h +++ include/llvm/CodeGen/ISDOpcodes.h @@ -328,6 +328,18 @@ /// vector_length(VECTOR2) must be valid VECTOR1 indices. INSERT_SUBVECTOR, + /// WARNING: We have intentionally changed the meaning of this opcode. + /// IDX is now treated as a multiple of N for inputs of type NxMx??. + /// No change occurs for fixed width vectors as N=1 but for scalable + /// vectors it means we can half a vector using the indices 0 and M/2. + /// NOTE: The related change is also applied to INSERT_SUBVECTOR. + /// + /// TODO: We shall maintain this for the foreseeable future as doing + /// otherwise requires work with the result being no less contentious. + /// Two potential routes would be to introduce EXTRACT_HI/EXTRACT_LO opcodes + /// or maintain the original behaviour and correctly calculate the index, + /// along with extending optimisation to remove extracts of concat vectors. + /// /// EXTRACT_SUBVECTOR(VECTOR, IDX) - Returns a subvector from VECTOR (an /// vector value) starting with the element number IDX, which must be a /// constant multiple of the result vector length. @@ -342,6 +354,16 @@ /// in terms of the element size of VEC1/VEC2, not in terms of bytes. VECTOR_SHUFFLE, + /// VECTOR_SHUFFLE_VAR(VEC1, VEC2, VEC3) - like VECTOR_SHUFFLE, + /// except that the mask is represented as an SDNode and any out-of-range + /// mask element produces undefined (potentially faulting) behavior. + /// The mask elements can have any integer type. + VECTOR_SHUFFLE_VAR, + + /// SERIES_VECTOR(INITIAL, STEP) - Creates a vector, with the first lane + /// containing INITIAL and each subsequent lane incremented by STEP + SERIES_VECTOR, + /// SCALAR_TO_VECTOR(VAL) - This represents the operation of loading a /// scalar value into element 0 of the resultant vector type. The top /// elements 1 to N-1 of the N-element vector are undefined. The type @@ -350,6 +372,9 @@ /// than the vector element type, and is implicitly truncated to it. SCALAR_TO_VECTOR, + /// SPLAT_VECTOR(VAL) - Duplicates the value across all lanes of a vector + SPLAT_VECTOR, + /// MULHU/MULHS - Multiply high - Multiply two integers of type iN, /// producing an unsigned/signed value of type i[2*N], then return the top /// part. @@ -567,6 +592,11 @@ /// when a single input is NaN, NaN is returned. FMINNAN, FMAXNAN, + /// FMINIMUM/FMAXIMUM - NaN-propagating minimum/maximum that also treat -0.0 + /// as less than 0.0. While FMINNUM_IEEE/FMAXNUM_IEEE follow IEEE 754-2008 + /// semantics, FMINIMUM/FMAXIMUM follow IEEE 754-2018 draft semantics. + FMINIMUM, FMAXIMUM, + /// FSINCOS - Compute both fsin and fcos as a single operation. FSINCOS, @@ -812,6 +842,12 @@ /// known nonzero constant. The only operand here is the chain. GET_DYNAMIC_AREA_OFFSET, + /// VSCALE(IMM) - Returns the runtime scaling factor used to calculate the + /// number of elements within a scalable vector. IMM is a constant integer + /// multiplier that is applied to the runtime value and is usual some + /// multiple of MVT.getVectorNumElements(). + VSCALE, + /// Generic reduction nodes. These nodes represent horizontal vector /// reduction operations, producing a scalar result. /// The STRICT variants perform reductions in sequential order. The first @@ -820,11 +856,14 @@ VECREDUCE_STRICT_FADD, VECREDUCE_STRICT_FMUL, /// These reductions are non-strict, and have a single vector operand. VECREDUCE_FADD, VECREDUCE_FMUL, + /// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants. + VECREDUCE_FMAX, VECREDUCE_FMIN, + /// Integer reductions may have a result type larger than the vector element + /// type. However, the reduction is performed using the vector element type + /// and the value in the top bits is unspecified. VECREDUCE_ADD, VECREDUCE_MUL, VECREDUCE_AND, VECREDUCE_OR, VECREDUCE_XOR, VECREDUCE_SMAX, VECREDUCE_SMIN, VECREDUCE_UMAX, VECREDUCE_UMIN, - /// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants. - VECREDUCE_FMAX, VECREDUCE_FMIN, /// BUILTIN_OP_END - This must be the last enum value in this list. /// The target-specific pre-isel opcode values start here. @@ -876,6 +915,23 @@ static const int LAST_INDEXED_MODE = POST_DEC + 1; //===--------------------------------------------------------------------===// + /// MemIndexType enum - This enum defines how to interpret MGATHER/SCATTER's + /// index parameter when calculating addresses. + /// + /// SIGNED_SCALED Addr = Base + ((signed)Index * sizeof(element)) + /// SIGNED_UNSCALED Addr = Base + (signed)Index + /// UNSIGNED_SCALED Addr = Base + ((unsigned)Index * sizeof(element)) + /// UNSIGNED_UNSCALED Addr = Base + (unsigned)Index + enum MemIndexType { + SIGNED_SCALED = 0, + SIGNED_UNSCALED, + UNSIGNED_SCALED, + UNSIGNED_UNSCALED + }; + + static const int LAST_MEM_INDEX_TYPE = UNSIGNED_UNSCALED + 1; + + //===--------------------------------------------------------------------===// /// LoadExtType enum - This enum defines the three variants of LOADEXT /// (load with extension). /// Index: include/llvm/CodeGen/MIRYamlMapping.h =================================================================== --- include/llvm/CodeGen/MIRYamlMapping.h +++ include/llvm/CodeGen/MIRYamlMapping.h @@ -252,7 +252,7 @@ if (Object.Type != MachineStackObject::VariableSized) YamlIO.mapRequired("size", Object.Size); YamlIO.mapOptional("alignment", Object.Alignment, (unsigned)0); - YamlIO.mapOptional("stack-id", Object.StackID); + YamlIO.mapOptional("stack-id", Object.StackID, (uint8_t)0); YamlIO.mapOptional("callee-saved-register", Object.CalleeSavedRegister, StringValue()); // Don't print it out when it's empty. YamlIO.mapOptional("callee-saved-restored", Object.CalleeSavedRestored, Index: include/llvm/CodeGen/MachineFrameInfo.h =================================================================== --- include/llvm/CodeGen/MachineFrameInfo.h +++ include/llvm/CodeGen/MachineFrameInfo.h @@ -16,6 +16,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Support/DataTypes.h" +#include "llvm/CodeGen/MachineBasicBlock.h" #include #include @@ -450,7 +451,8 @@ assert(unsigned(ObjectIdx+NumFixedObjects) < Objects.size() && "Invalid Object Idx!"); Objects[ObjectIdx+NumFixedObjects].Alignment = Align; - ensureMaxAlignment(Align); + if (getStackID(ObjectIdx) == 0) + ensureMaxAlignment(Align); } /// Return the underlying Alloca of the specified @@ -703,8 +705,10 @@ /// Remove or mark dead a statically sized stack object. void RemoveStackObject(int ObjectIdx) { + StackObject& O = Objects[ObjectIdx+NumFixedObjects]; + // Mark it dead. - Objects[ObjectIdx+NumFixedObjects].Size = ~0ULL; + O.Size = ~0ULL; } /// Notify the MachineFrameInfo object that a variable sized object has been Index: include/llvm/CodeGen/MachineInstr.h =================================================================== --- include/llvm/CodeGen/MachineInstr.h +++ include/llvm/CodeGen/MachineInstr.h @@ -1066,9 +1066,6 @@ /// /// If GroupNo is not NULL, it will receive the number of the operand group /// containing OpIdx. - /// - /// The flag operand is an immediate that can be decoded with methods like - /// InlineAsm::hasRegClassConstraint(). int findInlineAsmFlagIdx(unsigned OpIdx, unsigned *GroupNo = nullptr) const; /// Compute the static register class constraint for operand OpIdx. Index: include/llvm/CodeGen/MachineScheduler.h =================================================================== --- include/llvm/CodeGen/MachineScheduler.h +++ include/llvm/CodeGen/MachineScheduler.h @@ -175,6 +175,11 @@ /// of the same vreg. \sa MachineSchedStrategy::shouldTrackLaneMasks() bool ShouldTrackLaneMasks = false; + /// When trying to see whether latency determines ordering or not, this flag + /// determines whether we always consider reducing the path height + /// (bottom up) or path depth (top down). + bool AlwaysReduceLatencyHeight = false; + // Allow the scheduler to force top-down or bottom-up scheduling. If neither // is true, the scheduler runs in both directions and converges. bool OnlyTopDown = false; @@ -214,6 +219,11 @@ /// This has to be enabled in combination with shouldTrackPressure(). virtual bool shouldTrackLaneMasks() const { return false; } + /// When trying to see whether latency determines ordering or not, this + /// function returns true if we should always consider reducing the path + /// height (bottom up) or path depth (top down). + virtual bool alwaysReduceLatencyHeight() const { return false; } + // If this method returns true, handling of the scheduling regions // themselves (in case of a scheduling boundary in MBB) will be done // beginning with the topmost region of MBB. @@ -852,6 +862,8 @@ // Critical resource consumption of the best candidate. SchedResourceDelta ResDelta; + bool PressureExceedsLimit; + SchedCandidate() { reset(CandPolicy()); } SchedCandidate(const CandPolicy &Policy) { reset(Policy); } @@ -862,6 +874,7 @@ AtTop = false; RPDelta = RegPressureDelta(); ResDelta = SchedResourceDelta(); + PressureExceedsLimit = false; } bool isValid() const { return SU; } @@ -874,6 +887,7 @@ AtTop = Best.AtTop; RPDelta = Best.RPDelta; ResDelta = Best.ResDelta; + PressureExceedsLimit = Best.PressureExceedsLimit; } void initResourceDelta(const ScheduleDAGMI *DAG, @@ -908,7 +922,8 @@ GenericSchedulerBase::CandReason Reason); bool tryLatency(GenericSchedulerBase::SchedCandidate &TryCand, GenericSchedulerBase::SchedCandidate &Cand, - SchedBoundary &Zone); + SchedBoundary &Zone, + bool AlwaysReduceLatencyHeight = false); bool tryPressure(const PressureChange &TryP, const PressureChange &CandP, GenericSchedulerBase::SchedCandidate &TryCand, @@ -941,6 +956,10 @@ return RegionPolicy.ShouldTrackLaneMasks; } + bool alwaysReduceLatencyHeight() const override { + return RegionPolicy.AlwaysReduceLatencyHeight; + } + void initialize(ScheduleDAGMI *dag) override; SUnit *pickNode(bool &IsTopNode) override; Index: include/llvm/CodeGen/Passes.h =================================================================== --- include/llvm/CodeGen/Passes.h +++ include/llvm/CodeGen/Passes.h @@ -374,11 +374,28 @@ /// integrity. ModulePass *createForwardControlFlowIntegrityPass(); + /// ContiguousLoadStore Pass - This pass identifies structured load/store + /// instructions where each element has the same operation applied to it, + /// in which case they can be replaced with contiguous load/stores + /// + FunctionPass *createContiguousLoadStorePass(); + /// InterleavedAccess Pass - This pass identifies and matches interleaved /// memory accesses to target specific intrinsics. /// FunctionPass *createInterleavedAccessPass(); + /// InterleavedGatherScatterStoreSink Pass - This pass makes preparations for + /// the InterleavedGatherScatter pass, by sinking as many stores as possible + /// to the end of basic blocks + /// + FunctionPass *createInterleavedGatherScatterStoreSinkPass(); + + /// InterleavedGatherScatter Pass - This pass identifies and matches + /// interleaved gathers and scatters to target specific intrinsics. + /// + FunctionPass *createInterleavedGatherScatterPass(); + /// LowerEmuTLS - This pass generates __emutls_[vt].xyz variables for all /// TLS variables for the emulated TLS model. /// Index: include/llvm/CodeGen/RegisterPressure.h =================================================================== --- include/llvm/CodeGen/RegisterPressure.h +++ include/llvm/CodeGen/RegisterPressure.h @@ -483,6 +483,12 @@ ArrayRef CriticalPSets, ArrayRef MaxPressureLimit); + // Get the highest current pressure of any register set as a percentage of + // the corresponding register pressure limit. This is useful for the + // scheduler when trying to decide whether to prioritise latency or register + // pressure. + unsigned getHighestUpwardPressureFactor(const PressureDiff *PDiff) const; + void getUpwardPressureDelta(const MachineInstr *MI, /*const*/ PressureDiff &PDiff, RegPressureDelta &Delta, Index: include/llvm/CodeGen/RegisterScavenging.h =================================================================== --- include/llvm/CodeGen/RegisterScavenging.h +++ include/llvm/CodeGen/RegisterScavenging.h @@ -155,11 +155,16 @@ /// Make a register of the specific register class /// available and do the appropriate bookkeeping. SPAdj is the stack /// adjustment due to call frame, it's passed along to eliminateFrameIndex(). + /// SRLiveRangeEndsHere can help to signal the register scavenger that it may + /// reuse the destination register of 'I' if there is no overlap due to + /// early clobber. /// Returns the scavenged register. /// This is deprecated as it depends on the quality of the kill flags being /// present; Use scavengeRegisterBackwards() instead! unsigned scavengeRegister(const TargetRegisterClass *RC, - MachineBasicBlock::iterator I, int SPAdj); + MachineBasicBlock::iterator I, int SPAdj, + bool SRLiveRangeEndsHere = false); + unsigned scavengeRegister(const TargetRegisterClass *RegClass, int SPAdj) { return scavengeRegister(RegClass, MBBI, SPAdj); } Index: include/llvm/CodeGen/SelectionDAG.h =================================================================== --- include/llvm/CodeGen/SelectionDAG.h +++ include/llvm/CodeGen/SelectionDAG.h @@ -754,6 +754,10 @@ return getNode(ISD::UNDEF, SDLoc(), VT); } + + return getNode(ISD::SPLAT_VECTOR, DL, VT, Op); + + SmallVector Ops(VT.getVectorNumElements(), Op); return getNode(ISD::BUILD_VECTOR, DL, VT, Ops); } @@ -866,6 +870,12 @@ return getNode(ISD::UNDEF, SDLoc(), VT); } + /// Return the runtime scaling factor applicable to scalable vectors that is + /// itself scaled by 'MulImm'. + SDValue getVScale(const SDLoc &DL, EVT VT, int64_t MulImm=1) { + return getNode(ISD::VSCALE, DL, VT, getConstant(MulImm, DL, VT)); + } + /// Return a GLOBAL_OFFSET_TABLE node. This does not have a useful SDLoc. SDValue getGLOBAL_OFFSET_TABLE(EVT VT) { return getNode(ISD::GLOBAL_OFFSET_TABLE, SDLoc(), VT); @@ -954,8 +964,13 @@ "Cannot compare scalars to vectors"); assert(LHS.getValueType().isVector() == VT.isVector() && "Cannot compare scalars to vectors"); + assert(LHS.getValueType() == RHS.getValueType() && + "Cannot compare different types"); + assert(LHS.getValueType().isScalableVector() == + VT.isScalableVector() && + "Cannot compare different types"); assert(Cond != ISD::SETCC_INVALID && - "Cannot create a setCC of an invalid node."); + "Cannot create a setCC of an invalid node."); return getNode(ISD::SETCC, DL, VT, LHS, RHS, getCondCode(Cond)); } @@ -1105,9 +1120,11 @@ MachineMemOperand *MMO, bool IsTruncating = false, bool IsCompressing = false); SDValue getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl, - ArrayRef Ops, MachineMemOperand *MMO); + ArrayRef Ops, MachineMemOperand *MMO, + ISD::LoadExtType ExtTy, ISD::MemIndexType IndexType); SDValue getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl, - ArrayRef Ops, MachineMemOperand *MMO); + ArrayRef Ops, MachineMemOperand *MMO, + bool IsTrunc, ISD::MemIndexType IndexType); /// Return (create a new or find existing) a target-specific node. /// TargetMemSDNode should be derived class from MemSDNode. @@ -1402,6 +1419,12 @@ ArrayRef Ops, const SDNodeFlags Flags = SDNodeFlags()); + SDValue FoldSeriesVectorBinOp(unsigned Opcode, SDLoc DL, EVT VT, SDValue N1, + SDValue N2, const SDNodeFlags Flags); + + SDValue FoldSplatVectorBinOp(unsigned Opcode, SDLoc DL, EVT VT, SDValue N1, + SDValue N2, const SDNodeFlags Flags); + /// Constant fold a setcc to true or false. SDValue FoldSetCC(EVT VT, SDValue N1, SDValue N2, ISD::CondCode Cond, const SDLoc &dl); @@ -1518,6 +1541,10 @@ bool areNonVolatileConsecutiveLoads(LoadSDNode *LD, LoadSDNode *Base, unsigned Bytes, int Dist) const; + /// Returns true if Op is a splat of an integer constant. The splatted value + /// is returned via SplatValue when not null. + bool isConstantIntSplat(SDValue Op, APInt* SplatValue); + /// Infer alignment of a load / store address. Return 0 if /// it cannot be inferred. unsigned InferPtrAlignment(SDValue Ptr) const; Index: include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- include/llvm/CodeGen/SelectionDAGNodes.h +++ include/llvm/CodeGen/SelectionDAGNodes.h @@ -86,7 +86,10 @@ /// If N is a BUILD_VECTOR node whose elements are all the same constant or /// undefined, return true and return the constant value in \p SplatValue. - bool isConstantSplatVector(const SDNode *N, APInt &SplatValue); + /// This sets \p SplatValue to the smallest possible splat unless AllowShrink + /// is set to false. + bool isConstantSplatVector(const SDNode *N, APInt &SplatValue, + bool AllowShrink = true); /// Return true if the specified node is a BUILD_VECTOR where all of the /// elements are ~0 or undef. @@ -525,16 +528,22 @@ class LSBaseSDNodeBitfields { friend class LSBaseSDNode; + friend class MaskedGatherScatterSDNode; uint16_t : NumMemSDNodeBits; - uint16_t AddressingMode : 3; // enum ISD::MemIndexedMode + // This storage is shared between disparate class hierarchies to hold an + // enumeration specific to the class hierarchy in use. + // LSBaseSDNode => enum ISD::MemIndexedMode + // MaskedGatherScatterSDNode => enum ISD::MemIndexType + uint16_t AddressingMode : 3; }; enum { NumLSBaseSDNodeBits = NumMemSDNodeBits + 3 }; class LoadSDNodeBitfields { friend class LoadSDNode; friend class MaskedLoadSDNode; + friend class MaskedGatherSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -545,6 +554,7 @@ class StoreSDNodeBitfields { friend class StoreSDNode; friend class MaskedStoreSDNode; + friend class MaskedScatterSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -1446,7 +1456,10 @@ friend class SelectionDAG; ShuffleVectorSDNode(EVT VT, unsigned Order, const DebugLoc &dl, const int *M) - : SDNode(ISD::VECTOR_SHUFFLE, Order, dl, getSDVTList(VT)), Mask(M) {} + : SDNode(ISD::VECTOR_SHUFFLE, Order, dl, getSDVTList(VT)), Mask(M) { + assert(!VT.isScalableVector() && + "ISD::VECTOR_SHUFFLE does not support scalable vectors!"); + } public: ArrayRef getMask() const { @@ -2191,8 +2204,24 @@ MaskedGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, const DebugLoc &dl, SDVTList VTs, EVT MemVT, - MachineMemOperand *MMO) - : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {} + MachineMemOperand *MMO, ISD::MemIndexType IndexType) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) { + LSBaseSDNodeBits.AddressingMode = IndexType; + assert(getIndexType() == IndexType && "Value truncated"); + } + + /// How is Index applied to BasePtr when computing addresses. + ISD::MemIndexType getIndexType() const { + return static_cast(LSBaseSDNodeBits.AddressingMode); + } + bool isIndexScaled() const { + return (getIndexType() == ISD::SIGNED_SCALED) || + (getIndexType() == ISD::UNSIGNED_SCALED); + } + bool isIndexSigned() const { + return (getIndexType() == ISD::SIGNED_SCALED) || + (getIndexType() == ISD::SIGNED_UNSCALED); + } // In the both nodes address is Op1, mask is Op2: // MaskedGatherSDNode (Chain, passthru, mask, base, index, scale) @@ -2215,10 +2244,18 @@ class MaskedGatherSDNode : public MaskedGatherScatterSDNode { public: friend class SelectionDAG; - MaskedGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, - EVT MemVT, MachineMemOperand *MMO) - : MaskedGatherScatterSDNode(ISD::MGATHER, Order, dl, VTs, MemVT, MMO) {} + ISD::LoadExtType ETy, EVT MemVT, MachineMemOperand *MMO, + ISD::MemIndexType IndexType) + : MaskedGatherScatterSDNode(ISD::MGATHER, Order, dl, VTs, MemVT, MMO, + IndexType) { + LoadSDNodeBits.ExtTy = ETy; + } + + ISD::LoadExtType getExtensionType() const { + return ISD::LoadExtType(LoadSDNodeBits.ExtTy); + } + const SDValue &getSrc0() const { return getValue(); } static bool classof(const SDNode *N) { return N->getOpcode() == ISD::MGATHER; @@ -2232,8 +2269,17 @@ friend class SelectionDAG; MaskedScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, - EVT MemVT, MachineMemOperand *MMO) - : MaskedGatherScatterSDNode(ISD::MSCATTER, Order, dl, VTs, MemVT, MMO) {} + bool isTrunc, EVT MemVT, MachineMemOperand *MMO, + ISD::MemIndexType IndexType) + : MaskedGatherScatterSDNode(ISD::MSCATTER, Order, dl, VTs, MemVT, MMO, + IndexType) { + StoreSDNodeBits.IsTruncating = isTrunc; + } + + /// Return true if the op does a truncation before store. + /// For integers this is the same as doing a TRUNCATE and storing the result. + /// For floats, it is the same as doing an FP_ROUND and storing the result. + bool isTruncatingStore() const { return StoreSDNodeBits.IsTruncating; } static bool classof(const SDNode *N) { return N->getOpcode() == ISD::MSCATTER; Index: include/llvm/CodeGen/SelectionDAGTargetInfo.h =================================================================== --- include/llvm/CodeGen/SelectionDAGTargetInfo.h +++ include/llvm/CodeGen/SelectionDAGTargetInfo.h @@ -150,7 +150,8 @@ // Return true when the decision to generate FMA's (or FMS, FMLA etc) rather // than FMUL and ADD is delegated to the machine combiner. - virtual bool generateFMAsInMachineCombiner(CodeGenOpt::Level OptLevel) const { + virtual bool generateFMAsInMachineCombiner(SelectionDAG &DAG, + CodeGenOpt::Level OptLevel) const { return false; } }; Index: include/llvm/CodeGen/TargetFrameLowering.h =================================================================== --- include/llvm/CodeGen/TargetFrameLowering.h +++ include/llvm/CodeGen/TargetFrameLowering.h @@ -332,6 +332,11 @@ return true; } + /// Returns the StackID to which type T should be allocated. + /// This always defaults to StackID 0, but targets that want to support + /// multiple StackIDs may want to override this method. + virtual unsigned getStackIDForType(const Type *T) const { return 0; } + /// Check if given function is safe for not having callee saved registers. /// This is used when interprocedural register allocation is enabled. static bool isSafeForNoCSROpt(const Function &F) { Index: include/llvm/CodeGen/TargetInstrInfo.h =================================================================== --- include/llvm/CodeGen/TargetInstrInfo.h +++ include/llvm/CodeGen/TargetInstrInfo.h @@ -246,14 +246,14 @@ } /// If the specified machine instruction has a load from a stack slot, - /// return true along with the FrameIndex of the loaded stack slot and the - /// machine mem operand containing the reference. + /// return true along with the FrameIndices of the loaded stack slot and the + /// machine mem operands containing the reference. /// If not, return false. Unlike isLoadFromStackSlot, this returns true for /// any instructions that loads from the stack. This is just a hint, as some /// cases may be missed. - virtual bool hasLoadFromStackSlot(const MachineInstr &MI, - const MachineMemOperand *&MMO, - int &FrameIndex) const; + virtual bool hasLoadFromStackSlot( + const MachineInstr &MI, + SmallVectorImpl &Accesses) const; /// If the specified machine instruction is a direct /// store to a stack slot, return the virtual or physical register number of @@ -284,14 +284,14 @@ } /// If the specified machine instruction has a store to a stack slot, - /// return true along with the FrameIndex of the loaded stack slot and the - /// machine mem operand containing the reference. + /// return true along with the FrameIndices of the loaded stack slot and the + /// machine mem operands containing the reference. /// If not, return false. Unlike isStoreToStackSlot, /// this returns true for any instructions that stores to the /// stack. This is just a hint, as some cases may be missed. - virtual bool hasStoreToStackSlot(const MachineInstr &MI, - const MachineMemOperand *&MMO, - int &FrameIndex) const; + virtual bool hasStoreToStackSlot( + const MachineInstr &MI, + SmallVectorImpl &Accesses) const; /// Return true if the specified machine instruction /// is a copy of one stack slot to another and has no other effect. Index: include/llvm/CodeGen/TargetLowering.h =================================================================== --- include/llvm/CodeGen/TargetLowering.h +++ include/llvm/CodeGen/TargetLowering.h @@ -88,6 +88,7 @@ class TargetRegisterClass; class TargetLibraryInfo; class TargetRegisterInfo; +class TargetTransformInfo; class Value; namespace Sched { @@ -1099,7 +1100,7 @@ } return EVT::getVectorVT(Ty->getContext(), EVT::getEVT(Elm, false), - VTy->getNumElements()); + VTy->getNumElements(), VTy->isScalable()); } return EVT::getEVT(Ty, AllowUnknown); } @@ -2211,6 +2212,38 @@ return false; } + /// Lower a set of compatible gather load intrinsics to an interleaved + /// load. Return true on success. + /// + /// \p Gathers is the list of gather instructions sorted by offset + /// \p FirstGather is the first gather in the IR. + /// \p OffsetFirstGather is a byte-offset to adjust the + /// address of the first gather instruction. + /// \p Factor is the interleave factor + virtual bool lowerGathersToInterleavedLoad(ArrayRef Gathers, + IntrinsicInst *FirstGather, + int OffsetFirstGather, + unsigned Factor, + TargetTransformInfo *TTI) const { + return false; + } + + /// Lower a set of compatible scatter store intrinsics to an + /// interleaved store. Return true on success. + /// + /// \p ValuesToStore is the list of values stored, in memory order + /// \p FirstScatterAddress is the address ptr from the first scatter + /// \p ReplaceNode points to where the replacement nodes should be built + /// (this should be the last scatter in instruction order) + /// \p Factor is the interleave factor + virtual bool lowerScattersToInterleavedStore(ArrayRef ValuesToStore, + Value *FirstScatterAddress, + IntrinsicInst *ReplaceNode, + unsigned Factor, + TargetTransformInfo *TTI) const { + return false; + } + /// Return true if zero-extending the specific node Val to type VT2 is free /// (either because it's implicitly zero-extended such as ARM ldrb / ldrh or /// because it's folded such as X86 zero-extending loads). @@ -3627,6 +3660,11 @@ SDValue getVectorElementPointer(SelectionDAG &DAG, SDValue VecPtr, EVT VecVT, SDValue Index) const; + + /// Expand a VECREDUCE_* into an explicit calculation. If Count is specified, + /// only the first Count elements of the vector are used. + SDValue expandVecReduce(SDNode *Node, SelectionDAG &DAG) const; + //===--------------------------------------------------------------------===// // Instruction Emitting Hooks // Index: include/llvm/CodeGen/TargetPassConfig.h =================================================================== --- include/llvm/CodeGen/TargetPassConfig.h +++ include/llvm/CodeGen/TargetPassConfig.h @@ -356,6 +356,13 @@ /// are required for fast register allocation. virtual void addFastRegAlloc(FunctionPass *RegAllocPass); + /// addPostCoalesce - Add passes to the optimized register allocation pipeline + /// after coalescing is complete, but before further scheduling or register + /// allocation. + virtual bool addPostCoalesce() { + return false; + } + /// addOptimizedRegAlloc - Add passes related to register allocation. /// LLVMTargetMachine provides standard regalloc passes for most targets. virtual void addOptimizedRegAlloc(FunctionPass *RegAllocPass); Index: include/llvm/CodeGen/ValueTypes.h =================================================================== --- include/llvm/CodeGen/ValueTypes.h +++ include/llvm/CodeGen/ValueTypes.h @@ -75,9 +75,7 @@ MVT M = MVT::getVectorVT(VT.V, NumElements, IsScalable); if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE) return M; - - assert(!IsScalable && "We don't support extended scalable types yet"); - return getExtendedVectorVT(Context, VT, NumElements); + return getExtendedVectorVT(Context, VT, NumElements, IsScalable); } /// Returns the EVT that represents a vector EC.Min elements in length, @@ -86,19 +84,15 @@ MVT M = MVT::getVectorVT(VT.V, EC); if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE) return M; - assert (!EC.Scalable && "We don't support extended scalable types yet"); - return getExtendedVectorVT(Context, VT, EC.Min); + return getExtendedVectorVT(Context, VT, EC); } /// Return a vector with the same number of elements as this vector, but /// with the element type converted to an integer type with the same /// bitwidth. EVT changeVectorElementTypeToInteger() const { - if (!isSimple()) { - assert (!isScalableVector() && - "We don't support extended scalable types yet"); + if (!isSimple()) return changeExtendedVectorElementTypeToInteger(); - } MVT EltTy = getSimpleVT().getVectorElementType(); unsigned BitWidth = EltTy.getSizeInBits(); MVT IntTy = MVT::getIntegerVT(BitWidth); @@ -109,6 +103,18 @@ return VecTy; } + /// Return a VT for a vector type whose attributes match ourselves + /// with the exception of the element type that is chosen by the caller. + EVT changeVectorElementType(EVT EltVT) const { + if (!isSimple()) + return changeExtendedVectorElementType(EltVT); + MVT VecTy = MVT::getVectorVT(EltVT.V, getVectorNumElements(), + isScalableVector()); + assert(VecTy.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE && + "Simple vector VT not representable by simple integer vector VT!"); + return VecTy; + } + /// Return the type converted to an equivalently sized integer or vector /// with integer element type. Similar to changeVectorElementTypeToInteger, /// but also handles scalars. @@ -155,12 +161,7 @@ /// Return true if this is a vector type where the runtime /// length is machine dependent bool isScalableVector() const { - // FIXME: We don't support extended scalable types yet, because the - // matching IR type doesn't exist. Once it has been added, this can - // be changed to call isExtendedScalableVector. - if (!isSimple()) - return false; - return V.isScalableVector(); + return isSimple() ? V.isScalableVector() : isExtendedScalableVector(); } /// Return true if this is a 16-bit vector type. @@ -263,7 +264,7 @@ /// Given a vector type, return the type of each element. EVT getVectorElementType() const { - assert(isVector() && "Invalid vector type!"); + assert((isVector() || isScalableVector()) && "Invalid vector type!"); if (isSimple()) return V.getVectorElementType(); return getExtendedVectorElementType(); @@ -271,7 +272,7 @@ /// Given a vector type, return the number of elements it contains. unsigned getVectorNumElements() const { - assert(isVector() && "Invalid vector type!"); + assert((isVector() || isScalableVector()) && "Invalid vector type!"); if (isSimple()) return V.getVectorNumElements(); return getExtendedVectorNumElements(); @@ -283,9 +284,7 @@ if (isSimple()) return V.getVectorElementCount(); - assert(!isScalableVector() && - "We don't support extended scalable types yet"); - return {getExtendedVectorNumElements(), false}; + return {getExtendedVectorNumElements(), isExtendedScalableVector()}; } /// Return the size of the specified value type in bits. @@ -411,10 +410,13 @@ // These are all out-of-line to prevent users of this header file // from having a dependency on Type.h. EVT changeExtendedTypeToInteger() const; + EVT changeExtendedVectorElementType(EVT EltVT) const; EVT changeExtendedVectorElementTypeToInteger() const; static EVT getExtendedIntegerVT(LLVMContext &C, unsigned BitWidth); static EVT getExtendedVectorVT(LLVMContext &C, EVT VT, - unsigned NumElements); + unsigned NumElements, bool IsScalable); + static EVT getExtendedVectorVT(LLVMContext &Context, EVT VT, + MVT::ElementCount EC); bool isExtendedFloatingPoint() const LLVM_READONLY; bool isExtendedInteger() const LLVM_READONLY; bool isExtendedScalarInteger() const LLVM_READONLY; @@ -427,8 +429,10 @@ bool isExtended512BitVector() const LLVM_READONLY; bool isExtended1024BitVector() const LLVM_READONLY; bool isExtended2048BitVector() const LLVM_READONLY; + bool isExtendedScalableVector() const LLVM_READONLY; EVT getExtendedVectorElementType() const; unsigned getExtendedVectorNumElements() const LLVM_READONLY; + MVT::ElementCount getExtendedVectorElementCount() const LLVM_READONLY; unsigned getExtendedSizeInBits() const LLVM_READONLY; }; Index: include/llvm/CodeGen/ValueTypes.td =================================================================== --- include/llvm/CodeGen/ValueTypes.td +++ include/llvm/CodeGen/ValueTypes.td @@ -118,34 +118,36 @@ def v2f16 : ValueType<32 , 85>; // 2 x f16 vector value def v4f16 : ValueType<64 , 86>; // 4 x f16 vector value def v8f16 : ValueType<128, 87>; // 8 x f16 vector value -def v1f32 : ValueType<32 , 88>; // 1 x f32 vector value -def v2f32 : ValueType<64 , 89>; // 2 x f32 vector value -def v4f32 : ValueType<128, 90>; // 4 x f32 vector value -def v8f32 : ValueType<256, 91>; // 8 x f32 vector value -def v16f32 : ValueType<512, 92>; // 16 x f32 vector value -def v1f64 : ValueType<64, 93>; // 1 x f64 vector value -def v2f64 : ValueType<128, 94>; // 2 x f64 vector value -def v4f64 : ValueType<256, 95>; // 4 x f64 vector value -def v8f64 : ValueType<512, 96>; // 8 x f64 vector value +def v16f16 : ValueType<256, 88>; // 16 x f16 vector value +def v32f16 : ValueType<512, 89>; // 32 x f16 vector value +def v1f32 : ValueType<32 , 90>; // 1 x f32 vector value +def v2f32 : ValueType<64 , 91>; // 2 x f32 vector value +def v4f32 : ValueType<128, 92>; // 4 x f32 vector value +def v8f32 : ValueType<256, 93>; // 8 x f32 vector value +def v16f32 : ValueType<512, 94>; // 16 x f32 vector value +def v1f64 : ValueType<64, 95>; // 1 x f64 vector value +def v2f64 : ValueType<128, 96>; // 2 x f64 vector value +def v4f64 : ValueType<256, 97>; // 4 x f64 vector value +def v8f64 : ValueType<512, 98>; // 8 x f64 vector value -def nxv2f16 : ValueType<32 , 97>; // n x 2 x f16 vector value -def nxv4f16 : ValueType<64 , 98>; // n x 4 x f16 vector value -def nxv8f16 : ValueType<128, 99>; // n x 8 x f16 vector value -def nxv1f32 : ValueType<32 , 100>; // n x 1 x f32 vector value -def nxv2f32 : ValueType<64 , 101>; // n x 2 x f32 vector value -def nxv4f32 : ValueType<128, 102>; // n x 4 x f32 vector value -def nxv8f32 : ValueType<256, 103>; // n x 8 x f32 vector value -def nxv16f32 : ValueType<512, 104>; // n x 16 x f32 vector value -def nxv1f64 : ValueType<64, 105>; // n x 1 x f64 vector value -def nxv2f64 : ValueType<128, 106>; // n x 2 x f64 vector value -def nxv4f64 : ValueType<256, 107>; // n x 4 x f64 vector value -def nxv8f64 : ValueType<512, 108>; // n x 8 x f64 vector value +def nxv2f16 : ValueType<32 , 99>; // n x 2 x f16 vector value +def nxv4f16 : ValueType<64 , 100>; // n x 4 x f16 vector value +def nxv8f16 : ValueType<128, 101>; // n x 8 x f16 vector value +def nxv1f32 : ValueType<32 , 102>; // n x 1 x f32 vector value +def nxv2f32 : ValueType<64 , 103>; // n x 2 x f32 vector value +def nxv4f32 : ValueType<128, 104>; // n x 4 x f32 vector value +def nxv8f32 : ValueType<256, 105>; // n x 8 x f32 vector value +def nxv16f32 : ValueType<512, 106>; // n x 16 x f32 vector value +def nxv1f64 : ValueType<64, 107>; // n x 1 x f64 vector value +def nxv2f64 : ValueType<128, 108>; // n x 2 x f64 vector value +def nxv4f64 : ValueType<256, 109>; // n x 4 x f64 vector value +def nxv8f64 : ValueType<512, 110>; // n x 8 x f64 vector value -def x86mmx : ValueType<64 , 109>; // X86 MMX value -def FlagVT : ValueType<0 , 110>; // Pre-RA sched glue -def isVoid : ValueType<0 , 111>; // Produces no value -def untyped: ValueType<8 , 112>; // Produces an untyped value -def ExceptRef: ValueType<0, 113>; // WebAssembly's except_ref type +def x86mmx : ValueType<64 , 111>; // X86 MMX value +def FlagVT : ValueType<0 , 112>; // Pre-RA sched glue +def isVoid : ValueType<0 , 113>; // Produces no value +def untyped: ValueType<8 , 114>; // Produces an untyped value +def ExceptRef: ValueType<0, 115>; // WebAssembly's except_ref type def token : ValueType<0 , 248>; // TokenTy def MetadataVT: ValueType<0, 249>; // Metadata Index: include/llvm/Config/llvm-config.h.cmake =================================================================== --- include/llvm/Config/llvm-config.h.cmake +++ include/llvm/Config/llvm-config.h.cmake @@ -17,6 +17,9 @@ /* Define if LLVM_ENABLE_DUMP is enabled */ #cmakedefine LLVM_ENABLE_DUMP +/* Define if we link AArch64 into opt tools */ +#cmakedefine LINK_AARCH64_INTO_TOOLS + /* Define if we link Polly to the tools */ #cmakedefine LINK_POLLY_INTO_TOOLS Index: include/llvm/IR/CallingConv.h =================================================================== --- include/llvm/IR/CallingConv.h +++ include/llvm/IR/CallingConv.h @@ -220,6 +220,12 @@ /// shader if tessellation is in use, or otherwise the vertex shader. AMDGPU_ES = 96, + // Calling convention between AArch64 Advanced SIMD functions + AArch64_VectorCall = 97, + + /// Calling convention between AArch64 SVE functions + AArch64_SVE_VectorCall = 98, + /// The highest possible calling convention ID. Must be some 2^k - 1. MaxID = 1023 }; Index: include/llvm/IR/Constants.h =================================================================== --- include/llvm/IR/Constants.h +++ include/llvm/IR/Constants.h @@ -511,7 +511,7 @@ public: /// Return a ConstantVector with the specified constant in each element. - static Constant *getSplat(unsigned NumElts, Constant *Elt); + static Constant *getSplat(VectorType::ElementCount EC, Constant *Elt); /// Specialize the getType() method to always return a VectorType, /// which reduces the amount of casting needed in parts of the compiler. @@ -1189,6 +1189,11 @@ Type *OnlyIfReducedTy = nullptr); static Constant *getShuffleVector(Constant *V1, Constant *V2, Constant *Mask, Type *OnlyIfReducedTy = nullptr); + static Constant *getRuntimeNumElements(Type *Ty, Type *SrcTy); + static Constant *getSeriesVector(VectorType::ElementCount EC, + Constant *Start, Constant* Step, + bool HasNUW = false, bool HasNSW = false, + Type *OnlyIfReducedTy = nullptr); static Constant *getExtractValue(Constant *Agg, ArrayRef Idxs, Type *OnlyIfReducedTy = nullptr); static Constant *getInsertValue(Constant *Agg, Constant *Val, @@ -1263,6 +1268,29 @@ DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantExpr, Constant) //===----------------------------------------------------------------------===// +/// A constant vector representing the numeric sequence "0, 1, 2, 3, 4...". +/// NOTE: Element values that result from wrapping are considered poison. +/// +class StepVector final : public ConstantData { + StepVector(const StepVector &) = delete; + + friend class Constant; + void destroyConstantImpl(); + + explicit StepVector(Type *T) : ConstantData(T, StepVectorVal) {} + +public: + /// Static factory methods - Return a 'stepvector' object of the specified + /// type or a ConstantVector if the result's vector length is known. + static Constant *get(Type *T); + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static bool classof(const Value *V) { + return V->getValueID() == StepVectorVal; + } +}; + +//===----------------------------------------------------------------------===// /// 'undef' values are things that do not have specified contents. /// These are used for a variety of purposes, including global variable /// initializers and operands to instructions. 'undef' values can occur with @@ -1309,6 +1337,28 @@ } }; +//===----------------------------------------------------------------------===// +/// A constant representing the scaling factor 'n' of a scalable vector of the +/// form ''. +/// +class VScale final : public ConstantData { + VScale(const VScale &) = delete; + + friend class Constant; + void destroyConstantImpl(); + + explicit VScale(Type *T) : ConstantData(T, VScaleVal) {} + +public: + /// Static factory methods - Return a 'vscale' object of the specified type. + static Constant *get(Type *T); + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static bool classof(const Value *V) { + return V->getValueID() == VScaleVal; + } +}; + } // end namespace llvm #endif // LLVM_IR_CONSTANTS_H Index: include/llvm/IR/DIBuilder.h =================================================================== --- include/llvm/IR/DIBuilder.h +++ include/llvm/IR/DIBuilder.h @@ -192,6 +192,12 @@ DIBasicType *createBasicType(StringRef Name, uint64_t SizeInBits, unsigned Encoding); + /// Create debugging information entry for a string + /// type. + /// \param Name Type name. + /// \param SizeInBits Size of the type. + DIStringType *createStringType(StringRef Name, uint64_t SizeInBits); + /// Create debugging information entry for a qualified /// type, e.g. 'const int'. /// \param Tag Tag identifing type, e.g. dwarf::TAG_volatile_type @@ -479,6 +485,14 @@ DICompositeType *createArrayType(uint64_t Size, uint32_t AlignInBits, DIType *Ty, DINodeArray Subscripts); + /// Create debugging information entry for a Fortran array. + /// \param Size Array size. + /// \param AlignInBits Alignment. + /// \param Ty Element type. + /// \param Subscripts Subscripts. + DIFortranArrayType *createFortranArrayType( + uint64_t Size, uint32_t AlignInBits, DIType *Ty, DINodeArray Subs); + /// Create debugging information entry for a vector type. /// \param Size Array size. /// \param AlignInBits Alignment. @@ -562,6 +576,12 @@ DISubrange *getOrCreateSubrange(int64_t Lo, int64_t Count); DISubrange *getOrCreateSubrange(int64_t Lo, Metadata *CountNode); + /// Create a descriptor for a value range. This + /// implicitly uniques the values returned. + DIFortranSubrange *getOrCreateFortranSubrange( + int64_t CLBound, int64_t CUBound, bool NoUBound, Metadata *Lbound, + Metadata * Lbndexp, Metadata *Ubound, Metadata * Ubndexp); + /// Create a new descriptor for the specified variable. /// \param Context Variable scope. /// \param Name Name of the variable. @@ -580,14 +600,14 @@ DIScope *Context, StringRef Name, StringRef LinkageName, DIFile *File, unsigned LineNo, DIType *Ty, bool isLocalToUnit, DIExpression *Expr = nullptr, MDNode *Decl = nullptr, - uint32_t AlignInBits = 0); + DINode::DIFlags Flags = DINode::FlagZero, uint32_t AlignInBits = 0); /// Identical to createGlobalVariable /// except that the resulting DbgNode is temporary and meant to be RAUWed. DIGlobalVariable *createTempGlobalVariableFwdDecl( DIScope *Context, StringRef Name, StringRef LinkageName, DIFile *File, unsigned LineNo, DIType *Ty, bool isLocalToUnit, MDNode *Decl = nullptr, - uint32_t AlignInBits = 0); + DINode::DIFlags Flags = DINode::FlagZero, uint32_t AlignInBits = 0); /// Create a new descriptor for an auto variable. This is a local variable /// that is not a subprogram parameter. @@ -707,6 +727,17 @@ DITemplateParameterArray TParams = nullptr, DITypeArray ThrownTypes = nullptr); + /// Create common block entry for a Fortran common block + /// \param Scope Scope of this common block + /// \param Name The name of this common block + /// \param File The file this common block is defined + /// \param LineNo Line number + /// \param VarList List of variables that a located in common block + /// \param AlignInBits Common block alignment + DICommonBlock *createCommonBlock(DIScope *Scope, DIGlobalVariable *decl, + StringRef Name, DIFile *File, + unsigned LineNo, uint32_t AlignInBits = 0); + /// This creates new descriptor for a namespace with the specified /// parent scope. /// \param Scope Namespace scope @@ -727,7 +758,9 @@ DIModule *createModule(DIScope *Scope, StringRef Name, StringRef ConfigurationMacros, StringRef IncludePath, - StringRef ISysRoot); + StringRef ISysRoot, + DIFile *F = nullptr, + unsigned LineNo = 0); /// This creates a descriptor for a lexical block with a new file /// attached. This merely extends the existing Index: include/llvm/IR/DataLayout.h =================================================================== --- include/llvm/IR/DataLayout.h +++ include/llvm/IR/DataLayout.h @@ -420,6 +420,19 @@ return 8 * getTypeStoreSize(Ty); } + /// Returns whether the type's store size is known at compile time. + /// + /// Scalable vectors have a fixed size but is only known at runtime. + /// + /// TODO: Use this in cases when the absolute value of getTypeStoreSizeInBits + /// is required. At some point we will need to change this class so that we + /// can better compare the sizes of scalable vectors without needing to + /// know its absolute size. + bool isTypeStoreSizeKnown(Type *Ty) const { + auto* VTy = dyn_cast(Ty); + return !VTy || !VTy->isScalable(); + } + /// Returns the offset in bytes between successive objects of the /// specified type, including alignment padding. /// Index: include/llvm/IR/DebugInfoFlags.def =================================================================== --- include/llvm/IR/DebugInfoFlags.def +++ include/llvm/IR/DebugInfoFlags.def @@ -48,6 +48,9 @@ HANDLE_DI_FLAG((1 << 24), FixedEnum) HANDLE_DI_FLAG((1 << 25), Thunk) HANDLE_DI_FLAG((1 << 26), Trivial) +HANDLE_DI_FLAG((1 << 27), Pure) +HANDLE_DI_FLAG((1 << 28), Elemental) +HANDLE_DI_FLAG((1 << 29), Recursive) // To avoid needing a dedicated value for IndirectVirtualBase, we use // the bitwise or of Virtual and FwdDecl, which does not otherwise @@ -57,7 +60,7 @@ #ifdef DI_FLAG_LARGEST_NEEDED // intended to be used with ADT/BitmaskEnum.h // NOTE: always must be equal to largest flag, check this when adding new flag -HANDLE_DI_FLAG((1 << 26), Largest) +HANDLE_DI_FLAG((1 << 29), Largest) #undef DI_FLAG_LARGEST_NEEDED #endif Index: include/llvm/IR/DebugInfoMetadata.h =================================================================== --- include/llvm/IR/DebugInfoMetadata.h +++ include/llvm/IR/DebugInfoMetadata.h @@ -217,10 +217,13 @@ return false; case GenericDINodeKind: case DISubrangeKind: + case DIFortranSubrangeKind: case DIEnumeratorKind: case DIBasicTypeKind: + case DIStringTypeKind: case DIDerivedTypeKind: case DICompositeTypeKind: + case DIFortranArrayTypeKind: case DISubroutineTypeKind: case DIFileKind: case DICompileUnitKind: @@ -228,6 +231,7 @@ case DILexicalBlockKind: case DILexicalBlockFileKind: case DINamespaceKind: + case DICommonBlockKind: case DITemplateTypeParameterKind: case DITemplateValueParameterKind: case DIGlobalVariableKind: @@ -371,7 +375,7 @@ return getOperand(0).get(); } - typedef PointerUnion CountType; + typedef PointerUnion3 CountType; CountType getCount() const { if (auto *MD = dyn_cast(getRawCountNode())) @@ -380,6 +384,9 @@ if (auto *DV = dyn_cast(getRawCountNode())) return CountType(DV); + if (auto *DE = dyn_cast(getRawCountNode())) + return CountType(DE); + return CountType(); } @@ -388,6 +395,71 @@ } }; +/// Fortran array subrange +class DIFortranSubrange : public DINode { + friend class LLVMContextImpl; + friend class MDNode; + + int64_t CLowerBound; + int64_t CUpperBound; + bool NoUpperBound; + + DIFortranSubrange(LLVMContext &C, StorageType Storage, int64_t CLowerBound, + int64_t CUpperBound, bool NoUpperBound, + ArrayRef Ops) + : DINode(C, DIFortranSubrangeKind, Storage, + dwarf::DW_TAG_subrange_type, Ops), CLowerBound(CLowerBound), + CUpperBound(CUpperBound), NoUpperBound(NoUpperBound) {} + ~DIFortranSubrange() = default; + + static DIFortranSubrange *getImpl(LLVMContext &Context, int64_t CLBound, + int64_t CUBound, bool NoUpperBound, + Metadata *Lbound, Metadata *Lbndexp, + Metadata *Ubound, Metadata *Ubndexp, + StorageType Storage, + bool ShouldCreate = true); + + TempDIFortranSubrange cloneImpl() const { + return getTemporary(getContext(), getCLowerBound(), getCUpperBound(), + noUpperBound(), getRawLowerBound(), + getRawLowerBoundExpression(), getRawUpperBound(), + getRawUpperBoundExpression()); + } + +public: + DEFINE_MDNODE_GET(DIFortranSubrange, (int64_t CLB, int64_t CUB, bool NUB, + Metadata *LBound, Metadata *LBndExp, + Metadata *UBound, Metadata *UBndExp), + (CLB, CUB, NUB, LBound, LBndExp, UBound, UBndExp)) + + TempDIFortranSubrange clone() const { return cloneImpl(); } + + DIVariable *getLowerBound() const { + return cast_or_null(getRawLowerBound()); + } + DIExpression *getLowerBoundExp() const { + return cast_or_null(getRawLowerBoundExpression()); + } + DIVariable *getUpperBound() const { + return cast_or_null(getRawUpperBound()); + } + DIExpression *getUpperBoundExp() const { + return cast_or_null(getRawUpperBoundExpression()); + } + + int64_t getCLowerBound() const { return CLowerBound; } + int64_t getCUpperBound() const { return CUpperBound; } + Metadata *getRawLowerBound() const { return getOperand(0); } + Metadata *getRawLowerBoundExpression() const { return getOperand(1); } + Metadata *getRawUpperBound() const { return getOperand(2); } + Metadata *getRawUpperBoundExpression() const { return getOperand(3); } + bool noUpperBound() const { return NoUpperBound; } + + static bool classof(const Metadata *MD) { + return MD->getMetadataID() == DIFortranSubrangeKind; + } +}; + /// Enumeration value. /// /// TODO: Add a pointer to the context (DW_TAG_enumeration_type) once that no @@ -477,8 +549,10 @@ default: return false; case DIBasicTypeKind: + case DIStringTypeKind: case DIDerivedTypeKind: case DICompositeTypeKind: + case DIFortranArrayTypeKind: case DISubroutineTypeKind: case DIFileKind: case DICompileUnitKind: @@ -486,6 +560,7 @@ case DILexicalBlockKind: case DILexicalBlockFileKind: case DINamespaceKind: + case DICommonBlockKind: case DIModuleKind: return true; } @@ -719,8 +794,10 @@ default: return false; case DIBasicTypeKind: + case DIStringTypeKind: case DIDerivedTypeKind: case DICompositeTypeKind: + case DIFortranArrayTypeKind: case DISubroutineTypeKind: return true; } @@ -766,6 +843,12 @@ DEFINE_MDNODE_GET(DIBasicType, (unsigned Tag, StringRef Name), (Tag, Name, 0, 0, 0)) DEFINE_MDNODE_GET(DIBasicType, + (unsigned Tag, StringRef Name, uint64_t SizeInBits), + (Tag, Name, SizeInBits, 0, 0)) + DEFINE_MDNODE_GET(DIBasicType, + (unsigned Tag, MDString *Name, uint64_t SizeInBits), + (Tag, Name, SizeInBits, 0, 0)) + DEFINE_MDNODE_GET(DIBasicType, (unsigned Tag, StringRef Name, uint64_t SizeInBits, uint32_t AlignInBits, unsigned Encoding), (Tag, Name, SizeInBits, AlignInBits, Encoding)) @@ -789,6 +872,99 @@ } }; +/// String type, Fortran CHARACTER(n) +class DIStringType : public DIType { + friend class LLVMContextImpl; + friend class MDNode; + + unsigned Encoding; + + DIStringType(LLVMContext &C, StorageType Storage, unsigned Tag, + uint64_t SizeInBits, uint32_t AlignInBits, unsigned Encoding, + ArrayRef Ops) + : DIType(C, DIStringTypeKind, Storage, Tag, 0, SizeInBits, AlignInBits, 0, + FlagZero, Ops), + Encoding(Encoding) {} + ~DIStringType() = default; + + static DIStringType *getImpl(LLVMContext &Context, unsigned Tag, + StringRef Name, Metadata *StringLength, + Metadata *StrLenExp, uint64_t SizeInBits, + uint32_t AlignInBits, unsigned Encoding, + StorageType Storage, bool ShouldCreate = true) { + return getImpl(Context, Tag, getCanonicalMDString(Context, Name), + StringLength, StrLenExp, SizeInBits, AlignInBits, Encoding, + Storage, ShouldCreate); + } + static DIStringType *getImpl(LLVMContext &Context, unsigned Tag, + MDString *Name, Metadata *StringLength, + Metadata *StrLenExp, uint64_t SizeInBits, + uint32_t AlignInBits, unsigned Encoding, + StorageType Storage, bool ShouldCreate = true); + + TempDIStringType cloneImpl() const { + return getTemporary(getContext(), getTag(), getName(), getRawStringLength(), + getRawStringLengthExp(), getSizeInBits(), + getAlignInBits(), getEncoding()); + } + +public: + DEFINE_MDNODE_GET(DIStringType, (unsigned Tag, StringRef Name), + (Tag, Name, nullptr, nullptr, 0, 0, 0)) + DEFINE_MDNODE_GET(DIStringType, + (unsigned Tag, StringRef Name, uint64_t SizeInBits, + uint32_t AlignInBits), + (Tag, Name, nullptr, nullptr, SizeInBits, AlignInBits, 0)) + DEFINE_MDNODE_GET(DIStringType, + (unsigned Tag, MDString *Name, uint64_t SizeInBits, + uint32_t AlignInBits), + (Tag, Name, nullptr, nullptr, SizeInBits, AlignInBits, 0)) + DEFINE_MDNODE_GET(DIStringType, + (unsigned Tag, StringRef Name, Metadata *StringLength, + Metadata *StringLengthExp, uint64_t SizeInBits, + uint32_t AlignInBits), + (Tag, Name, StringLength, StringLengthExp, SizeInBits, + AlignInBits, 0)) + DEFINE_MDNODE_GET(DIStringType, + (unsigned Tag, MDString *Name, Metadata *StringLength, + Metadata *StringLengthExp, uint64_t SizeInBits, + uint32_t AlignInBits), + (Tag, Name, StringLength, StringLengthExp, SizeInBits, + AlignInBits, 0)) + DEFINE_MDNODE_GET(DIStringType, + (unsigned Tag, StringRef Name, Metadata *StringLength, + Metadata *StringLengthExp, uint64_t SizeInBits, + uint32_t AlignInBits, unsigned Encoding), + (Tag, Name, StringLength, StringLengthExp, SizeInBits, + AlignInBits, Encoding)) + DEFINE_MDNODE_GET(DIStringType, + (unsigned Tag, MDString *Name, Metadata *StringLength, + Metadata *StringLengthExp, uint64_t SizeInBits, + uint32_t AlignInBits, unsigned Encoding), + (Tag, Name, StringLength, StringLengthExp, SizeInBits, + AlignInBits, Encoding)) + + TempDIStringType clone() const { return cloneImpl(); } + + static bool classof(const Metadata *MD) { + return MD->getMetadataID() == DIStringTypeKind; + } + + DIVariable *getStringLength() const { + return cast_or_null(getRawStringLength()); + } + + DIExpression *getStringLengthExp() const { + return cast_or_null(getRawStringLengthExp()); + } + + unsigned getEncoding() const { return Encoding; } + + Metadata *getRawStringLength() const { return getOperand(3); } + + Metadata *getRawStringLengthExp() const { return getOperand(4); } +}; + /// Derived types. /// /// This includes qualified types, pointers, references, friends, typedefs, and @@ -1067,6 +1243,16 @@ Metadata *getRawDiscriminator() const { return getOperand(8); } DIDerivedType *getDiscriminator() const { return getOperandAs(8); } + bool isScalableVector() const { + if (!isVector()) + return false; + const DINodeArray Elts = getElements(); + assert( + Elts.size() == 1 && Elts[0]->getTag() == dwarf::DW_TAG_subrange_type && + "Invalid vector element array, expected one element of type subrange"); + return !cast(Elts[0])->getCount().is(); + }; + /// Replace operands. /// /// If this \a isUniqued() and not \a isResolved(), on a uniquing collision @@ -1096,6 +1282,90 @@ } }; +/// Fortran array types. +class DIFortranArrayType : public DIType { + friend class LLVMContextImpl; + friend class MDNode; + + DIFortranArrayType(LLVMContext &C, StorageType Storage, unsigned Tag, + unsigned Line, uint64_t SizeInBits, uint32_t AlignInBits, + uint64_t OffsetInBits, DIFlags Flags, + ArrayRef Ops) + : DIType(C, DIFortranArrayTypeKind, Storage, Tag, Line, SizeInBits, + AlignInBits, OffsetInBits, Flags, Ops) {} + ~DIFortranArrayType() = default; + + static DIFortranArrayType * + getImpl(LLVMContext &Context, unsigned Tag, StringRef Name, Metadata *File, + unsigned Line, DIScopeRef Scope, DITypeRef BaseType, + uint64_t SizeInBits, uint32_t AlignInBits, uint64_t OffsetInBits, + DIFlags Flags, DINodeArray Elements, StorageType Storage, + bool ShouldCreate = true) { + return getImpl( + Context, Tag, getCanonicalMDString(Context, Name), File, Line, Scope, + BaseType, SizeInBits, AlignInBits, OffsetInBits, Flags, Elements.get(), + Storage, ShouldCreate); + } + static DIFortranArrayType * + getImpl(LLVMContext &Context, unsigned Tag, MDString *Name, Metadata *File, + unsigned Line, Metadata *Scope, Metadata *BaseType, + uint64_t SizeInBits, uint32_t AlignInBits, uint64_t OffsetInBits, + DIFlags Flags, Metadata *Elements, StorageType Storage, + bool ShouldCreate = true); + + TempDIFortranArrayType cloneImpl() const { + return getTemporary(getContext(), getTag(), getName(), getFile(), getLine(), + getScope(), getBaseType(), getSizeInBits(), + getAlignInBits(), getOffsetInBits(), getFlags(), + getElements()); + } + +public: + DEFINE_MDNODE_GET(DIFortranArrayType, + (unsigned Tag, StringRef Name, DIFile *File, unsigned Line, + DIScopeRef Scope, DITypeRef BaseType, uint64_t SizeInBits, + uint32_t AlignInBits, uint64_t OffsetInBits, + DIFlags Flags, DINodeArray Elements), + (Tag, Name, File, Line, Scope, BaseType, SizeInBits, + AlignInBits, OffsetInBits, Flags, Elements)) + DEFINE_MDNODE_GET(DIFortranArrayType, + (unsigned Tag, MDString *Name, Metadata *File, + unsigned Line, Metadata *Scope, Metadata *BaseType, + uint64_t SizeInBits, uint32_t AlignInBits, + uint64_t OffsetInBits, DIFlags Flags, Metadata *Elements), + (Tag, Name, File, Line, Scope, BaseType, SizeInBits, + AlignInBits, OffsetInBits, Flags, Elements)) + + TempDIFortranArrayType clone() const { return cloneImpl(); } + + DITypeRef getBaseType() const { return DITypeRef(getRawBaseType()); } + DINodeArray getElements() const { + return cast_or_null(getRawElements()); + } + + Metadata *getRawBaseType() const { return getOperand(3); } + Metadata *getRawElements() const { return getOperand(4); } + + /// Replace operands. + /// + /// If this \a isUniqued() and not \a isResolved(), on a uniquing collision + /// this will be RAUW'ed and deleted. Use a \a TrackingMDRef to keep track + /// of its movement if necessary. + /// @{ + void replaceElements(DINodeArray Elements) { +#ifndef NDEBUG + for (DINode *Op : getElements()) + assert(is_contained(Elements->operands(), Op) && + "Lost a member during member list replacement"); +#endif + replaceOperandWith(4, Elements.get()); + } + + static bool classof(const Metadata *MD) { + return MD->getMetadataID() == DIFortranArrayTypeKind; + } +}; + /// Type array for a subprogram. /// /// TODO: Fold the array of types in directly as operands. @@ -1718,6 +1988,9 @@ bool isExplicit() const { return getFlags() & FlagExplicit; } bool isPrototyped() const { return getFlags() & FlagPrototyped; } bool isMainSubprogram() const { return getFlags() & FlagMainSubprogram; } + bool isPure() const { return getFlags() & FlagPure; } + bool isElemental() const { return getFlags() & FlagElemental; } + bool isRecursive() const { return getFlags() & FlagRecursive; } /// Check if this is reference-qualified. /// @@ -2032,41 +2305,51 @@ class DIModule : public DIScope { friend class LLVMContextImpl; friend class MDNode; + unsigned Line; - DIModule(LLVMContext &Context, StorageType Storage, ArrayRef Ops) - : DIScope(Context, DIModuleKind, Storage, dwarf::DW_TAG_module, Ops) {} + DIModule(LLVMContext &Context, StorageType Storage, unsigned Line, + ArrayRef Ops) + : DIScope(Context, DIModuleKind, Storage, dwarf::DW_TAG_module, Ops), + Line(Line) {} ~DIModule() = default; static DIModule *getImpl(LLVMContext &Context, DIScope *Scope, StringRef Name, StringRef ConfigurationMacros, StringRef IncludePath, StringRef ISysRoot, + DIFile *File, unsigned Line, StorageType Storage, bool ShouldCreate = true) { return getImpl(Context, Scope, getCanonicalMDString(Context, Name), getCanonicalMDString(Context, ConfigurationMacros), getCanonicalMDString(Context, IncludePath), getCanonicalMDString(Context, ISysRoot), + static_cast(File), Line, Storage, ShouldCreate); } static DIModule *getImpl(LLVMContext &Context, Metadata *Scope, MDString *Name, MDString *ConfigurationMacros, MDString *IncludePath, MDString *ISysRoot, + Metadata *File, unsigned Line, StorageType Storage, bool ShouldCreate = true); TempDIModule cloneImpl() const { return getTemporary(getContext(), getScope(), getName(), getConfigurationMacros(), getIncludePath(), - getISysRoot()); + getISysRoot(), getFile(), getLine()); } public: DEFINE_MDNODE_GET(DIModule, (DIScope *Scope, StringRef Name, - StringRef ConfigurationMacros, StringRef IncludePath, - StringRef ISysRoot), - (Scope, Name, ConfigurationMacros, IncludePath, ISysRoot)) + StringRef ConfigurationMacros, + StringRef IncludePath, StringRef ISysRoot, + DIFile *File, unsigned Line), + (Scope, Name, ConfigurationMacros, IncludePath, ISysRoot, + File, Line)) DEFINE_MDNODE_GET(DIModule, (Metadata *Scope, MDString *Name, MDString *ConfigurationMacros, - MDString *IncludePath, MDString *ISysRoot), - (Scope, Name, ConfigurationMacros, IncludePath, ISysRoot)) + MDString *IncludePath, MDString *ISysRoot, + Metadata *File, unsigned Line), + (Scope, Name, ConfigurationMacros, IncludePath, ISysRoot, + File, Line)) TempDIModule clone() const { return cloneImpl(); } @@ -2075,12 +2358,15 @@ StringRef getConfigurationMacros() const { return getStringOperand(2); } StringRef getIncludePath() const { return getStringOperand(3); } StringRef getISysRoot() const { return getStringOperand(4); } + unsigned getLine() const { return Line; } + DIFile *getFile() const { return cast_or_null(getRawFile()); } Metadata *getRawScope() const { return getOperand(0); } MDString *getRawName() const { return getOperandAs(1); } MDString *getRawConfigurationMacros() const { return getOperandAs(2); } MDString *getRawIncludePath() const { return getOperandAs(3); } MDString *getRawISysRoot() const { return getOperandAs(4); } + Metadata *getRawFile() const { return getOperand(5); } static bool classof(const Metadata *MD) { return MD->getMetadataID() == DIModuleKind; @@ -2507,12 +2793,14 @@ bool IsLocalToUnit; bool IsDefinition; + DIFlags Flags; DIGlobalVariable(LLVMContext &C, StorageType Storage, unsigned Line, - bool IsLocalToUnit, bool IsDefinition, uint32_t AlignInBits, - ArrayRef Ops) + bool IsLocalToUnit, bool IsDefinition, DIFlags Flags, + uint32_t AlignInBits, ArrayRef Ops) : DIVariable(C, DIGlobalVariableKind, Storage, Line, Ops, AlignInBits), - IsLocalToUnit(IsLocalToUnit), IsDefinition(IsDefinition) {} + IsLocalToUnit(IsLocalToUnit), IsDefinition(IsDefinition), + Flags(Flags) {} ~DIGlobalVariable() = default; static DIGlobalVariable *getImpl(LLVMContext &Context, DIScope *Scope, @@ -2520,25 +2808,26 @@ DIFile *File, unsigned Line, DITypeRef Type, bool IsLocalToUnit, bool IsDefinition, DIDerivedType *StaticDataMemberDeclaration, - uint32_t AlignInBits, StorageType Storage, + DIFlags Flags, uint32_t AlignInBits, + StorageType Storage, bool ShouldCreate = true) { return getImpl(Context, Scope, getCanonicalMDString(Context, Name), getCanonicalMDString(Context, LinkageName), File, Line, Type, IsLocalToUnit, IsDefinition, StaticDataMemberDeclaration, - AlignInBits, Storage, ShouldCreate); + Flags, AlignInBits, Storage, ShouldCreate); } static DIGlobalVariable * getImpl(LLVMContext &Context, Metadata *Scope, MDString *Name, MDString *LinkageName, Metadata *File, unsigned Line, Metadata *Type, bool IsLocalToUnit, bool IsDefinition, - Metadata *StaticDataMemberDeclaration, uint32_t AlignInBits, + Metadata *StaticDataMemberDeclaration, DIFlags Flags, uint32_t AlignInBits, StorageType Storage, bool ShouldCreate = true); TempDIGlobalVariable cloneImpl() const { return getTemporary(getContext(), getScope(), getName(), getLinkageName(), getFile(), getLine(), getType(), isLocalToUnit(), isDefinition(), getStaticDataMemberDeclaration(), - getAlignInBits()); + getFlags(), getAlignInBits()); } public: @@ -2547,22 +2836,24 @@ DIFile *File, unsigned Line, DITypeRef Type, bool IsLocalToUnit, bool IsDefinition, DIDerivedType *StaticDataMemberDeclaration, - uint32_t AlignInBits), + DIFlags Flags, uint32_t AlignInBits), (Scope, Name, LinkageName, File, Line, Type, IsLocalToUnit, - IsDefinition, StaticDataMemberDeclaration, AlignInBits)) + IsDefinition, StaticDataMemberDeclaration, Flags, AlignInBits)) DEFINE_MDNODE_GET(DIGlobalVariable, (Metadata * Scope, MDString *Name, MDString *LinkageName, Metadata *File, unsigned Line, Metadata *Type, bool IsLocalToUnit, bool IsDefinition, Metadata *StaticDataMemberDeclaration, - uint32_t AlignInBits), + DIFlags Flags, uint32_t AlignInBits), (Scope, Name, LinkageName, File, Line, Type, IsLocalToUnit, - IsDefinition, StaticDataMemberDeclaration, AlignInBits)) + IsDefinition, StaticDataMemberDeclaration, Flags, AlignInBits)) TempDIGlobalVariable clone() const { return cloneImpl(); } bool isLocalToUnit() const { return IsLocalToUnit; } bool isDefinition() const { return IsDefinition; } + DIFlags getFlags() const { return Flags; } + bool isArtificial() const { return getFlags() & FlagArtificial; } StringRef getDisplayName() const { return getStringOperand(4); } StringRef getLinkageName() const { return getStringOperand(5); } DIDerivedType *getStaticDataMemberDeclaration() const { @@ -2577,6 +2868,68 @@ } }; +class DICommonBlock : public DIScope { + unsigned LineNo; + uint32_t AlignInBits; + + friend class LLVMContextImpl; + friend class MDNode; + + DICommonBlock(LLVMContext &Context, StorageType Storage, unsigned LineNo, + uint32_t AlignInBits, ArrayRef Ops) + : DIScope(Context, DICommonBlockKind, Storage, dwarf::DW_TAG_common_block, + Ops), LineNo(LineNo), AlignInBits(AlignInBits) {} + ~DICommonBlock() = default; + + static DICommonBlock *getImpl(LLVMContext &Context, DIScope *Scope, + DIGlobalVariable *Decl, StringRef Name, + DIFile *File, unsigned LineNo, + uint32_t AlignInBits, StorageType Storage, + bool ShouldCreate = true) { + return getImpl(Context, Scope, Decl, getCanonicalMDString(Context, Name), + File, LineNo, AlignInBits, Storage, ShouldCreate); + } + static DICommonBlock *getImpl(LLVMContext &Context, Metadata *Scope, + Metadata *Decl, MDString *Name, Metadata *File, + unsigned LineNo, uint32_t AlignInBits, + StorageType Storage, bool ShouldCreate = true); + + TempDICommonBlock cloneImpl() const { + return getTemporary(getContext(), getScope(), getDecl(), getName(), + getFile(), getLineNo(), getAlignInBits()); + } + +public: + DEFINE_MDNODE_GET(DICommonBlock, + (DIScope *Scope, DIGlobalVariable *Decl, StringRef Name, + DIFile *File, unsigned LineNo, uint32_t AlignInBits), + (Scope, Decl, Name, File, LineNo, AlignInBits)) + DEFINE_MDNODE_GET(DICommonBlock, + (Metadata *Scope, Metadata *Decl, MDString *Name, + Metadata *File, unsigned LineNo, uint32_t AlignInBits), + (Scope, Decl, Name, File, LineNo, AlignInBits)) + + TempDICommonBlock clone() const { return cloneImpl(); } + + DIScope *getScope() const { return cast_or_null(getRawScope()); } + DIGlobalVariable *getDecl() const { + return cast_or_null(getRawDecl()); + } + StringRef getName() const { return getStringOperand(2); } + DIFile *getFile() const { return cast_or_null(getRawFile()); } + unsigned getLineNo() const { return LineNo; } + uint32_t getAlignInBits() const { return AlignInBits; } + + Metadata *getRawScope() const { return getOperand(0); } + Metadata *getRawDecl() const { return getOperand(1); } + MDString *getRawName() const { return getOperandAs(2); } + Metadata *getRawFile() const { return getOperand(3); } + + static bool classof(const Metadata *MD) { + return MD->getMetadataID() == DICommonBlockKind; + } +}; + /// Local variable. /// /// TODO: Split up flags. Index: include/llvm/IR/DerivedTypes.h =================================================================== --- include/llvm/IR/DerivedTypes.h +++ include/llvm/IR/DerivedTypes.h @@ -391,14 +391,66 @@ /// Class to represent vector types. class VectorType : public SequentialType { - VectorType(Type *ElType, unsigned NumEl); +public: + /// A fully specified VectorType is of the form . M is the + /// minimum number of elements of type Ty contained within the vector and + /// the actual element count is the result of N * M. However, for all targets + /// N is expectated to be either statically unknown or guaranteed to be one. + /// For the latter the extra complication is discarded leading to: + /// + /// <4 x i32> - a vector containing 4 i32s + /// - a vector containing an unknown integer multiple of 4 i32s + class ElementCount { + public: + unsigned Min; // Minimum number of vector elements. + bool Scalable; // NumElements != MinNumElements + + constexpr ElementCount(unsigned MinElts, bool IsScalable) + : Min(MinElts), Scalable(IsScalable) {} + + ElementCount operator*(unsigned RHS) const { + return { Min * RHS, Scalable }; + } + ElementCount operator/(unsigned RHS) const { + return { Min / RHS, Scalable }; + } + + bool operator==(const ElementCount& RHS) const { + return Min == RHS.Min && Scalable == RHS.Scalable; + } + + bool operator!=(const ElementCount& RHS) const { + return !(*this == RHS); + } + }; + + constexpr static ElementCount SingleElement() { + return ElementCount(1, false); + } + +private: + bool Scalable; + + VectorType(Type *ElType, unsigned NumEl, bool Scalable=false); public: VectorType(const VectorType &) = delete; VectorType &operator=(const VectorType &) = delete; - /// This static method is the primary way to construct an VectorType. - static VectorType *get(Type *ElementType, unsigned NumElements); + /// VectorType::get - This static method is the primary way to construct an + /// VectorType. + /// + static VectorType *get(Type *ELType, ElementCount EC); + static VectorType *get(Type *ElType, unsigned NumEl, bool Scalable=false) { + return VectorType::get(ElType, { NumEl, Scalable }); + } + + /// This static method gets a VectorType with the same number of elements as + /// the input type, and the element type is an i1. + static VectorType *getBool(VectorType *VTy) { + Type *EltTy = IntegerType::get(VTy->getContext(), 1); + return VectorType::get(EltTy, VTy->getElementCount()); + } /// This static method gets a VectorType with the same number of elements as /// the input type, and the element type is an integer type of the same width @@ -407,7 +459,7 @@ unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits(); assert(EltBits && "Element size must be of a non-zero size"); Type *EltTy = IntegerType::get(VTy->getContext(), EltBits); - return VectorType::get(EltTy, VTy->getNumElements()); + return VectorType::get(EltTy, VTy->getElementCount()); } /// This static method is like getInteger except that the element types are @@ -415,7 +467,7 @@ static VectorType *getExtendedElementVectorType(VectorType *VTy) { unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits(); Type *EltTy = IntegerType::get(VTy->getContext(), EltBits * 2); - return VectorType::get(EltTy, VTy->getNumElements()); + return VectorType::get(EltTy, VTy->getElementCount()); } /// This static method is like getInteger except that the element types are @@ -424,29 +476,58 @@ unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits(); assert((EltBits & 1) == 0 && "Cannot truncate vector element with odd bit-width"); + assert(VTy->getElementType()->isIntegerTy() && + "Cannot truncate vector with non integer elements"); Type *EltTy = IntegerType::get(VTy->getContext(), EltBits / 2); - return VectorType::get(EltTy, VTy->getNumElements()); + return VectorType::get(EltTy, VTy->getElementCount()); + } + + /// This static method is like getInteger except that the element types are + /// half as wide as the elements in the input type. + static VectorType *getNarrowerFpElementVectorType(VectorType *VTy) { + Type *EltTy; + switch(VTy->getElementType()->getTypeID()) { + case DoubleTyID: + EltTy = Type::getFloatTy(VTy->getContext()); + break; + case FloatTyID: + EltTy = Type::getHalfTy(VTy->getContext()); + break; + default: + assert(0 && "Cannot create narrower fp vector element type"); + break; + } + return VectorType::get(EltTy, VTy->getElementCount()); } /// This static method returns a VectorType with half as many elements as the /// input type and the same element type. static VectorType *getHalfElementsVectorType(VectorType *VTy) { - unsigned NumElts = VTy->getNumElements(); - assert ((NumElts & 1) == 0 && + auto NumElts = VTy->getElementCount(); + assert ((NumElts.Min & 1) == 0 && "Cannot halve vector with odd number of elements."); - return VectorType::get(VTy->getElementType(), NumElts/2); + return VectorType::get(VTy->getElementType(), NumElts / 2); } /// This static method returns a VectorType with twice as many elements as the /// input type and the same element type. static VectorType *getDoubleElementsVectorType(VectorType *VTy) { - unsigned NumElts = VTy->getNumElements(); - return VectorType::get(VTy->getElementType(), NumElts*2); + auto NumElts = VTy->getElementCount(); + return VectorType::get(VTy->getElementType(), NumElts * 2); } /// Return true if the specified type is valid as a element type. static bool isValidElementType(Type *ElemTy); + /// Return an ElementCount struct for this VectorType + ElementCount getElementCount() const { + return {(unsigned)getNumElements(), Scalable}; + } + + /// Return true when the number of elements is only known at runtime. + bool isScalable() const { return Scalable; } + + /// TODO: decide whether my callers actually want me /// Return the number of bits in the Vector type. /// Returns zero when the vector is a vector of pointers. unsigned getBitWidth() const { @@ -463,6 +544,10 @@ return cast(this)->getNumElements(); } +unsigned Type::getVectorIsScalable() const { + return cast(this)->isScalable(); +} + /// Class to represent pointers. class PointerType : public Type { explicit PointerType(Type *ElType, unsigned AddrSpace); Index: include/llvm/IR/DiagnosticInfo.h =================================================================== --- include/llvm/IR/DiagnosticInfo.h +++ include/llvm/IR/DiagnosticInfo.h @@ -480,9 +480,19 @@ bool isAnalysis() const { return (getKind() == DK_OptimizationRemarkAnalysis || + getKind() == DK_OptimizationRemarkAnalysisFPCommute || + getKind() == DK_OptimizationRemarkAnalysisAliasing || getKind() == DK_MachineOptimizationRemarkAnalysis); } + typedef SmallVectorImpl::const_iterator iterator; + iterator arg_begin() const { return Args.begin(); }; + iterator arg_end() const { return Args.end(); }; + + const Argument& getArgument(unsigned I) const { + return Args[I]; + } + protected: /// Name of the pass that triggers this report. If this matches the /// regular expression given in -Rpass=regexp, then the remark will @@ -623,6 +633,8 @@ Orig.RemarkName, Orig.getFunction(), Orig.getLocation()), CodeRegion(Orig.getCodeRegion()) { *this << Prepend; + if (Orig.FirstExtraArgIndex != -1) + FirstExtraArgIndex = Orig.FirstExtraArgIndex + 1; std::copy(Orig.Args.begin(), Orig.Args.end(), std::back_inserter(Args)); } Index: include/llvm/IR/IRBuilder.h =================================================================== --- include/llvm/IR/IRBuilder.h +++ include/llvm/IR/IRBuilder.h @@ -33,6 +33,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" @@ -395,25 +396,6 @@ // Intrinsic creation methods //===--------------------------------------------------------------------===// - /// Create and insert a memset to the specified pointer and the - /// specified value. - /// - /// If the pointer isn't an i8*, it will be converted. If a TBAA tag is - /// specified, it will be added to the instruction. Likewise with alias.scope - /// and noalias tags. - CallInst *CreateMemSet(Value *Ptr, Value *Val, uint64_t Size, unsigned Align, - bool isVolatile = false, MDNode *TBAATag = nullptr, - MDNode *ScopeTag = nullptr, - MDNode *NoAliasTag = nullptr) { - return CreateMemSet(Ptr, Val, getInt64(Size), Align, isVolatile, - TBAATag, ScopeTag, NoAliasTag); - } - - CallInst *CreateMemSet(Value *Ptr, Value *Val, Value *Size, unsigned Align, - bool isVolatile = false, MDNode *TBAATag = nullptr, - MDNode *ScopeTag = nullptr, - MDNode *NoAliasTag = nullptr); - /// Create and insert an element unordered-atomic memset of the region of /// memory starting at the given pointer to the given value. /// @@ -438,29 +420,6 @@ MDNode *ScopeTag = nullptr, MDNode *NoAliasTag = nullptr); - /// Create and insert a memcpy between the specified pointers. - /// - /// If the pointers aren't i8*, they will be converted. If a TBAA tag is - /// specified, it will be added to the instruction. Likewise with alias.scope - /// and noalias tags. - CallInst *CreateMemCpy(Value *Dst, unsigned DstAlign, Value *Src, - unsigned SrcAlign, uint64_t Size, - bool isVolatile = false, MDNode *TBAATag = nullptr, - MDNode *TBAAStructTag = nullptr, - MDNode *ScopeTag = nullptr, - MDNode *NoAliasTag = nullptr) { - return CreateMemCpy(Dst, DstAlign, Src, SrcAlign, getInt64(Size), - isVolatile, TBAATag, TBAAStructTag, ScopeTag, - NoAliasTag); - } - - CallInst *CreateMemCpy(Value *Dst, unsigned DstAlign, Value *Src, - unsigned SrcAlign, Value *Size, - bool isVolatile = false, MDNode *TBAATag = nullptr, - MDNode *TBAAStructTag = nullptr, - MDNode *ScopeTag = nullptr, - MDNode *NoAliasTag = nullptr); - /// Create and insert an element unordered-atomic memcpy between the /// specified pointers. /// @@ -587,6 +546,10 @@ CallInst *CreateMaskedLoad(Value *Ptr, unsigned Align, Value *Mask, Value *PassThru = nullptr, const Twine &Name = ""); + /// \brief Create a call to Masked Speculative Load intrinsic + CallInst *CreateMaskedSpecLoad(Value *Ptr, unsigned Align, Value *Mask, + Value *PassThru = 0, const Twine &Name = ""); + /// Create a call to Masked Store intrinsic CallInst *CreateMaskedStore(Value *Val, Value *Ptr, unsigned Align, Value *Mask); @@ -675,6 +638,8 @@ Type *ResultType, const Twine &Name = ""); + CallInst *CreateCntVPop(Value *PredVec, const Twine &Name); + /// Create a call to intrinsic \p ID with 2 operands which is mangled on the /// first type. CallInst *CreateBinaryIntrinsic(Intrinsic::ID ID, @@ -703,7 +668,7 @@ return CreateBinaryIntrinsic(Intrinsic::maxnum, LHS, RHS, Name); } -private: +protected: /// Create a call to a masked intrinsic with given Id. CallInst *CreateMaskedIntrinsic(Intrinsic::ID Id, ArrayRef Ops, ArrayRef OverloadedTypes, @@ -1960,6 +1925,15 @@ return CreateShuffleVector(V1, V2, Mask, Name); } + Value *CreateSeriesVector(VectorType::ElementCount EC, Value *Start, + Value* Step, const Twine &Name = "", + bool HasNUW = false, bool HasNSW = false) { + auto Ty = VectorType::get(Step->getType(), EC); + auto StartV = CreateVectorSplat(EC, Start); + auto StepV = CreateVectorSplat(EC, Step); + return CreateAdd(StartV, CreateMul(StepV, StepVector::get(Ty)), Name); + } + Value *CreateExtractValue(Value *Agg, ArrayRef Idxs, const Twine &Name = "") { @@ -2071,20 +2045,20 @@ return Fn; } - /// Return a vector value that contains \arg V broadcasted to \p - /// NumElts elements. - Value *CreateVectorSplat(unsigned NumElts, Value *V, const Twine &Name = "") { - assert(NumElts > 0 && "Cannot splat to an empty vector!"); + /// Return a vector value that contains \arg V broadcasted onto \p + /// a vector of the same size as \arg VT + Value *CreateVectorSplat(VectorType::ElementCount EC, Value *V, + const Twine &Name = "") { + assert(EC.Min > 0 && "Cannot splat to an empty vector!"); - // First insert it into an undef vector so we can shuffle it. Type *I32Ty = getInt32Ty(); - Value *Undef = UndefValue::get(VectorType::get(V->getType(), NumElts)); - V = CreateInsertElement(Undef, V, ConstantInt::get(I32Ty, 0), + Value *UndefV = UndefValue::get(VectorType::get(V->getType(), EC)); + V = CreateInsertElement(UndefV, V, ConstantInt::get(I32Ty, 0), Name + ".splatinsert"); // Shuffle the value across the desired number of elements. - Value *Zeros = ConstantAggregateZero::get(VectorType::get(I32Ty, NumElts)); - return CreateShuffleVector(V, Undef, Zeros, Name + ".splat"); + Value *Zeros = ConstantAggregateZero::get(VectorType::get(I32Ty, EC)); + return CreateShuffleVector(V, UndefV, Zeros, Name + ".splat"); } /// Return a value that has been extracted from a larger integer type. @@ -2191,6 +2165,105 @@ return CreateAlignmentAssumptionHelper(DL, PtrValue, Mask, IntPtrTy, OffsetValue); } + + /// \brief Create and insert a memcpy between the specified pointers. + /// + /// If the pointers aren't i8*, they will be converted. If a TBAA tag is + /// specified, it will be added to the instruction. Likewise with alias.scope + /// and noalias tags. + CallInst *CreateMemCpy(Value *Dst, unsigned DstAlign, Value *Src, unsigned SrcAlign, + uint64_t Size, + bool isVolatile = false, MDNode *TBAATag = nullptr, + MDNode *TBAAStructTag = nullptr, + MDNode *ScopeTag = nullptr, + MDNode *NoAliasTag = nullptr) { + return CreateMemCpy(Dst, DstAlign, Src, SrcAlign, getInt64(Size), isVolatile, TBAATag, + TBAAStructTag, ScopeTag, NoAliasTag); + } + + CallInst *CreateMemCpy(Value *Dst, unsigned DstAlign, Value *Src, unsigned SrcAlign, + Value *Size, + bool isVolatile = false, MDNode *TBAATag = nullptr, + MDNode *TBAAStructTag = nullptr, + MDNode *ScopeTag = nullptr, + MDNode *NoAliasTag = nullptr) { + assert((DstAlign == 0 || isPowerOf2_32(DstAlign)) && "Must be 0 or a power of 2"); + assert((SrcAlign == 0 || isPowerOf2_32(SrcAlign)) && "Must be 0 or a power of 2"); + Dst = getCastedInt8PtrValue(Dst); + Src = getCastedInt8PtrValue(Src); + + Value *Ops[] = { Dst, Src, Size, getInt1(isVolatile) }; + Type *Tys[] = { Dst->getType(), Src->getType(), Size->getType() }; + Module *M = BB->getParent()->getParent(); + Value *TheFn = Intrinsic::getDeclaration(M, Intrinsic::memcpy, Tys); + + CallInst *CI = Insert(CallInst::Create(TheFn, Ops)); + + auto* MCI = cast(CI); + if (DstAlign > 0) + MCI->setDestAlignment(DstAlign); + + if (SrcAlign > 0) + MCI->setSourceAlignment(SrcAlign); + + // Set the TBAA info if present. + if (TBAATag) + CI->setMetadata(LLVMContext::MD_tbaa, TBAATag); + + // Set the TBAA Struct info if present. + if (TBAAStructTag) + CI->setMetadata(LLVMContext::MD_tbaa_struct, TBAAStructTag); + + if (ScopeTag) + CI->setMetadata(LLVMContext::MD_alias_scope, ScopeTag); + + if (NoAliasTag) + CI->setMetadata(LLVMContext::MD_noalias, NoAliasTag); + + return CI; + } + + /// \brief Create and insert a memset to the specified pointer and the + /// specified value. + /// + /// If the pointer isn't an i8*, it will be converted. If a TBAA tag is + /// specified, it will be added to the instruction. Likewise with alias.scope + /// and noalias tags. + CallInst *CreateMemSet(Value *Ptr, Value *Val, uint64_t Size, unsigned Align, + bool isVolatile = false, MDNode *TBAATag = nullptr, + MDNode *ScopeTag = nullptr, + MDNode *NoAliasTag = nullptr) { + return CreateMemSet(Ptr, Val, getInt64(Size), Align, isVolatile, + TBAATag, ScopeTag, NoAliasTag); + } + + CallInst *CreateMemSet(Value *Ptr, Value *Val, Value *Size, unsigned Align, + bool isVolatile = false, MDNode *TBAATag = nullptr, + MDNode *ScopeTag = nullptr, + MDNode *NoAliasTag = nullptr) { + Ptr = getCastedInt8PtrValue(Ptr); + Value *Ops[] = { Ptr, Val, Size, getInt1(isVolatile) }; + Type *Tys[] = { Ptr->getType(), Size->getType() }; + Module *M = BB->getParent()->getParent(); + Value *TheFn = Intrinsic::getDeclaration(M, Intrinsic::memset, Tys); + + CallInst *CI = Insert(CallInst::Create(TheFn, Ops)); + + if (Align > 0) + cast(CI)->setDestAlignment(Align); + + // Set the TBAA info if present. + if (TBAATag) + CI->setMetadata(LLVMContext::MD_tbaa, TBAATag); + + if (ScopeTag) + CI->setMetadata(LLVMContext::MD_alias_scope, ScopeTag); + + if (NoAliasTag) + CI->setMetadata(LLVMContext::MD_noalias, NoAliasTag); + + return CI; + } }; // Create wrappers for C Binding types (see CBindingWrapping.h). Index: include/llvm/IR/InlineAsm.h =================================================================== --- include/llvm/IR/InlineAsm.h +++ include/llvm/IR/InlineAsm.h @@ -337,6 +337,12 @@ return (Flag & 0xffff) >> 3; } + /// Decrement the number of registers field of the inline asm operand flag. + static unsigned decrementNumOperandRegisters(unsigned Flag) { + assert(getNumOperandRegisters(Flag) > 0 && "Flag has no operands!"); + return Flag - (1 << 3); + } + /// isUseOperandTiedToDef - Return true if the flag of the inline asm /// operand indicates it is an use operand that's matched to a def operand. static bool isUseOperandTiedToDef(unsigned Flag, unsigned &Idx) { Index: include/llvm/IR/InstrTypes.h =================================================================== --- include/llvm/IR/InstrTypes.h +++ include/llvm/IR/InstrTypes.h @@ -1127,7 +1127,7 @@ static Type* makeCmpResultType(Type* opnd_type) { if (VectorType* vt = dyn_cast(opnd_type)) { return VectorType::get(Type::getInt1Ty(opnd_type->getContext()), - vt->getNumElements()); + vt->getElementCount()); } return Type::getInt1Ty(opnd_type->getContext()); } Index: include/llvm/IR/Instructions.h =================================================================== --- include/llvm/IR/Instructions.h +++ include/llvm/IR/Instructions.h @@ -28,6 +28,7 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallingConv.h" #include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -1011,12 +1012,14 @@ // Vector GEP if (Ptr->getType()->isVectorTy()) { unsigned NumElem = Ptr->getType()->getVectorNumElements(); - return VectorType::get(PtrTy, NumElem); + bool Scalable = Ptr->getType()->getVectorIsScalable(); + return VectorType::get(PtrTy, NumElem, Scalable); } for (Value *Index : IdxList) if (Index->getType()->isVectorTy()) { unsigned NumElem = Index->getType()->getVectorNumElements(); - return VectorType::get(PtrTy, NumElem); + bool Scalable = Index->getType()->getVectorIsScalable(); + return VectorType::get(PtrTy, NumElem, Scalable); } // Scalar GEP return PtrTy; @@ -2401,7 +2404,7 @@ public: ShuffleVectorInst(Value *V1, Value *V2, Value *Mask, const Twine &NameStr = "", - Instruction *InsertBefor = nullptr); + Instruction *InsertBefore = nullptr); ShuffleVectorInst(Value *V1, Value *V2, Value *Mask, const Twine &NameStr, BasicBlock *InsertAtEnd); @@ -2424,28 +2427,24 @@ /// Transparently provide more efficient getOperand methods. DECLARE_TRANSPARENT_OPERAND_ACCESSORS(Value); - Constant *getMask() const { - return cast(getOperand(2)); + Value *getMask() const { + return getOperand(2); } - /// Return the shuffle mask value for the specified element of the mask. - /// Return -1 if the element is undef. - static int getMaskValue(const Constant *Mask, unsigned Elt); + /// getMaskValue - Return the index from the shuffle mask for the specified + /// output result. This is either -1 if the element is undef or a number less + /// than 2*numelements. + static bool getMaskValue(const Value *Mask, unsigned Elt, int &Result); - /// Return the shuffle mask value of this instruction for the given element - /// index. Return -1 if the element is undef. - int getMaskValue(unsigned Elt) const { - return getMaskValue(getMask(), Elt); + bool getMaskValue(unsigned i, int &Result) const { + return getMaskValue(getMask(), i, Result); } - /// Convert the input shuffle mask operand to a vector of integers. Undefined - /// elements of the mask are returned as -1. - static void getShuffleMask(const Constant *Mask, - SmallVectorImpl &Result); + /// getShuffleMask - Return the full mask for this instruction, where each + /// element is the element number and undef's are returned as -1. + static bool getShuffleMask(const Value *Mask, SmallVectorImpl &Result); - /// Return the mask for this instruction as a vector of integers. Undefined - /// elements of the mask are returned as -1. - void getShuffleMask(SmallVectorImpl &Result) const { + bool getShuffleMask(SmallVectorImpl &Result) const { return getShuffleMask(getMask(), Result); } @@ -2455,12 +2454,18 @@ return Mask; } + /// findBroadcastElement - If all defined elements of the mask specify + /// the same element, return that element, otherwise return -1. + static int findBroadcastElement(const Value *Mask); + + int findBroadcastElement() const { return findBroadcastElement(getMask()); } + /// Return true if this shuffle returns a vector with a different number of - /// elements than its source elements. + /// elements than its source vectors. /// Example: shufflevector <4 x n> A, <4 x n> B, <1,2> bool changesLength() const { - unsigned NumSourceElts = Op<0>()->getType()->getVectorNumElements(); - unsigned NumMaskElts = getMask()->getType()->getVectorNumElements(); + auto NumSourceElts = cast(Op<0>()->getType())->getElementCount(); + auto NumMaskElts = cast(getMask()->getType())->getElementCount(); return NumSourceElts != NumMaskElts; } @@ -2481,9 +2486,14 @@ /// Example: shufflevector <4 x n> A, <4 x n> B, <3,0,undef,3> /// TODO: Optionally allow length-changing shuffles. bool isSingleSource() const { - return !changesLength() && isSingleSourceMask(getMask()); + return !getType()->getVectorIsScalable() && !changesLength() && + isSingleSourceMask(getShuffleMask()); } + // Returns true if this shuffle is concatenating elements from sources + // Example: <2 x n> A, <2 x n> B, <0, 1, 2, 3> + bool isConcat() const; + /// Return true if this shuffle mask chooses elements from exactly one source /// vector without lane crossings. A shuffle using this mask is not /// necessarily a no-op because it may change the number of elements from its @@ -2497,15 +2507,22 @@ return isIdentityMask(MaskAsInts); } - /// Return true if this shuffle mask chooses elements from exactly one source + /// Return true if this shuffle chooses elements from exactly one source /// vector without lane crossings and does not change the number of elements /// from its input vectors. /// Example: shufflevector <4 x n> A, <4 x n> B, <4,undef,6,undef> - /// TODO: Optionally allow length-changing shuffles. bool isIdentity() const { return !changesLength() && isIdentityMask(getShuffleMask()); } + /// Return true if this shuffle lengthens exactly one source vector with + /// undefs in the high elements. + bool isIdentityWithPadding() const; + + /// Return true if this shuffle extracts the first N elements of exactly one + /// source vector. + bool isIdentityWithExtract() const; + /// Return true if this shuffle mask chooses elements from its source vectors /// without lane crossings. A shuffle using this mask would be /// equivalent to a vector select with a constant condition operand. @@ -2531,7 +2548,8 @@ /// In that case, the shuffle is better classified as an identity shuffle. /// TODO: Optionally allow length-changing shuffles. bool isSelect() const { - return !changesLength() && isSelectMask(getMask()); + return !getType()->getVectorIsScalable() && !changesLength() && + isSelectMask(getShuffleMask()); } /// Return true if this shuffle mask swaps the order of elements from exactly @@ -2551,7 +2569,8 @@ /// Example: shufflevector <4 x n> A, <4 x n> B, <3,undef,1,undef> /// TODO: Optionally allow length-changing shuffles. bool isReverse() const { - return !changesLength() && isReverseMask(getMask()); + return !getType()->getVectorIsScalable() && !changesLength() && + isReverseMask(getShuffleMask()); } /// Return true if this shuffle mask chooses all elements with the same value @@ -2573,7 +2592,8 @@ /// TODO: Optionally allow length-changing shuffles. /// TODO: Optionally allow splats from other elements. bool isZeroEltSplat() const { - return !changesLength() && isZeroEltSplatMask(getMask()); + return !getType()->getVectorIsScalable() && !changesLength() && + isZeroEltSplatMask(getShuffleMask()); } /// Return true if this shuffle mask is a transpose mask. @@ -2622,7 +2642,8 @@ /// exact specification. /// Example: shufflevector <4 x n> A, <4 x n> B, <0,4,2,6> bool isTranspose() const { - return !changesLength() && isTransposeMask(getMask()); + return !getType()->getVectorIsScalable() && !changesLength() && + isTransposeMask(getShuffleMask()); } /// Change values in a shuffle permute mask assuming the two vector operands Index: include/llvm/IR/Intrinsics.h =================================================================== --- include/llvm/IR/Intrinsics.h +++ include/llvm/IR/Intrinsics.h @@ -100,7 +100,9 @@ Void, VarArg, MMX, Token, Metadata, Half, Float, Double, Quad, Integer, Vector, Pointer, Struct, Argument, ExtendArgument, TruncArgument, HalfVecArgument, - SameVecWidthArgument, PtrToArgument, PtrToElt, VecOfAnyPtrsToElt + SameVecWidthArgument, PtrToArgument, PtrToElt, VecOfAnyPtrsToElt, + ScalableVecArgument, VecElementArgument, DoubleVecArgument, + VecOfBitcastsToInt, Subdivide2Argument, Subdivide4Argument } Kind; union { @@ -117,20 +119,28 @@ AK_AnyInteger, AK_AnyFloat, AK_AnyVector, - AK_AnyPointer + AK_AnyPointer, + AK_MatchType = 7 }; unsigned getArgumentNumber() const { assert(Kind == Argument || Kind == ExtendArgument || Kind == TruncArgument || Kind == HalfVecArgument || Kind == SameVecWidthArgument || Kind == PtrToArgument || - Kind == PtrToElt); + Kind == PtrToElt || Kind == VecOfAnyPtrsToElt || + Kind == VecElementArgument || Kind == DoubleVecArgument || + Kind == VecOfBitcastsToInt || Kind == Subdivide2Argument || + Kind == Subdivide4Argument); return Argument_Info >> 3; } ArgKind getArgumentKind() const { assert(Kind == Argument || Kind == ExtendArgument || Kind == TruncArgument || Kind == HalfVecArgument || - Kind == SameVecWidthArgument || Kind == PtrToArgument); + Kind == SameVecWidthArgument || Kind == PtrToArgument || + Kind == PtrToElt || Kind == VecOfAnyPtrsToElt || + Kind == VecElementArgument || Kind == DoubleVecArgument || + Kind == VecOfBitcastsToInt || Kind == Subdivide2Argument || + Kind == Subdivide4Argument); return (ArgKind)(Argument_Info & 7); } @@ -162,14 +172,21 @@ /// of IITDescriptors. void getIntrinsicInfoTableEntries(ID id, SmallVectorImpl &T); - /// Match the specified type (which comes from an intrinsic argument or return - /// value) with the type constraints specified by the .td file. If the given - /// type is an overloaded type it is pushed to the ArgTys vector. + enum MatchIntrinsicTypesResult { + MatchIntrinsicTypes_Match = 0, + MatchIntrinsicTypes_NoMatchRet = 1, + MatchIntrinsicTypes_NoMatchArg = 2, + }; + + /// Match the specified function type with the type constraints specified by + /// the .td file. If the given type is an overloaded type it is pushed to the + /// ArgTys vector. /// /// Returns false if the given type matches with the constraints, true /// otherwise. - bool matchIntrinsicType(Type *Ty, ArrayRef &Infos, - SmallVectorImpl &ArgTys); + MatchIntrinsicTypesResult + matchIntrinsicSignature(FunctionType *FTy, ArrayRef &Infos, + SmallVectorImpl &ArgTys); /// Verify if the intrinsic has variable arguments. This method is intended to /// be called after all the fixed arguments have been matched first. Index: include/llvm/IR/Intrinsics.td =================================================================== --- include/llvm/IR/Intrinsics.td +++ include/llvm/IR/Intrinsics.td @@ -160,11 +160,26 @@ class LLVMPointerTo : LLVMMatchType; class LLVMPointerToElt : LLVMMatchType; class LLVMVectorOfAnyPointersToElt : LLVMMatchType; +class LLVMVectorElementType : LLVMMatchType; // Match the type of another intrinsic parameter that is expected to be a -// vector type, but change the element count to be half as many +// vector type, but change the element count to be half as many. class LLVMHalfElementsVectorType : LLVMMatchType; +// Match the type of another intrinsic parameter that is expected to be a +// vector type, but change the element count to be twice as many. +class LLVMDoubleElementsVectorType : LLVMMatchType; + +// Match the type of another intrinsic parameter that is expected to be a +// vector type (i.e. ) but with each element subdivided to +// form a vector with more elements that are smaller than the original. +class LLVMSubdivide2VectorType : LLVMMatchType; +class LLVMSubdivide4VectorType : LLVMMatchType; + +// Match the element count and bit width of another intrinsic parameter, but +// change the element type to an integer. +class LLVMVectorOfBitcastsToInt : LLVMMatchType; + def llvm_void_ty : LLVMType; let isAny = 1 in { def llvm_any_ty : LLVMType; @@ -443,6 +458,50 @@ def int_round : Intrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>; def int_canonicalize : Intrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>; + + // Masked vector versions of the above functions. + def int_masked_sin : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>]>; + def int_masked_cos : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>]>; + def int_masked_powi : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, llvm_anyvector_ty]>; + def int_masked_pow : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, LLVMMatchType<0>]>; + def int_masked_maxnum : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, LLVMMatchType<0>]>; + def int_masked_minnum : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, LLVMMatchType<0>]>; + def int_masked_copysign : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, LLVMMatchType<0>]>; + def int_masked_rint : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>]>; + def int_masked_exp : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>]>; + def int_masked_exp2 : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>]>; + def int_masked_log : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>]>; + def int_masked_log2 : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>]>; + def int_masked_log10: Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>]>; + def int_masked_fmod : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, LLVMMatchType<0>, + LLVMVectorSameWidth<0, llvm_i1_ty>]>; } def int_minnum : Intrinsic<[llvm_anyfloat_ty], @@ -574,6 +633,7 @@ let IntrProperties = [IntrNoMem, IntrSpeculatable] in { def int_bswap: Intrinsic<[llvm_anyint_ty], [LLVMMatchType<0>]>; def int_ctpop: Intrinsic<[llvm_anyint_ty], [LLVMMatchType<0>]>; + def int_ctvpop : Intrinsic<[llvm_i64_ty], [llvm_anyvector_ty], [], "llvm.ctvpop">; def int_ctlz : Intrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, llvm_i1_ty]>; def int_cttz : Intrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, llvm_i1_ty]>; def int_bitreverse : Intrinsic<[llvm_anyint_ty], [LLVMMatchType<0>]>; @@ -863,16 +923,24 @@ LLVMVectorSameWidth<0, llvm_i1_ty>, LLVMMatchType<0>], [IntrReadMem, IntrArgMemOnly]>; +def int_masked_spec_load : Intrinsic<[llvm_anyvector_ty, + LLVMVectorSameWidth<0, llvm_i1_ty>], + [LLVMPointerTo<0>, llvm_i32_ty, + LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>], + [IntrReadMem, IntrArgMemOnly]>; + def int_masked_gather: Intrinsic<[llvm_anyvector_ty], [LLVMVectorOfAnyPointersToElt<0>, llvm_i32_ty, LLVMVectorSameWidth<0, llvm_i1_ty>, LLVMMatchType<0>], - [IntrReadMem]>; + [IntrReadMem, IntrArgMemOnly]>; def int_masked_scatter: Intrinsic<[], [llvm_anyvector_ty, LLVMVectorOfAnyPointersToElt<0>, llvm_i32_ty, - LLVMVectorSameWidth<0, llvm_i1_ty>]>; + LLVMVectorSameWidth<0, llvm_i1_ty>], + [IntrArgMemOnly]>; def int_masked_expandload: Intrinsic<[llvm_anyvector_ty], [LLVMPointerToElt<0>, Index: include/llvm/IR/IntrinsicsAArch64.td =================================================================== --- include/llvm/IR/IntrinsicsAArch64.td +++ include/llvm/IR/IntrinsicsAArch64.td @@ -667,3 +667,1316 @@ def int_aarch64_crc32cx : Intrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i64_ty], [IntrNoMem]>; } + +//===----------------------------------------------------------------------===// +// SVE + +def llvm_nxv2i1_ty : LLVMType; +def llvm_nxv4i1_ty : LLVMType; +def llvm_nxv8i1_ty : LLVMType; +def llvm_nxv16i1_ty : LLVMType; +def llvm_nxv16i8_ty : LLVMType; +def llvm_nxv4i32_ty : LLVMType; +def llvm_nxv2i64_ty : LLVMType; +def llvm_nxv8f16_ty : LLVMType; +def llvm_nxv4f32_ty : LLVMType; +def llvm_nxv2f64_ty : LLVMType; + +let TargetPrefix = "aarch64" in { // All intrinsics start with "llvm.aarch64.". + class AdvSIMD_1Vec_PredLoad_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, LLVMPointerTo<0>], + [IntrReadMem, IntrArgMemOnly]>; + + class AdvSIMD_2Vec_PredLoad_Intrinsic + : Intrinsic<[llvm_anyvector_ty, LLVMMatchType<0>], + [LLVMVectorSameWidth<0, llvm_i1_ty>, LLVMPointerTo<0>], + [IntrReadMem, IntrArgMemOnly]>; + + class AdvSIMD_3Vec_PredLoad_Intrinsic + : Intrinsic<[llvm_anyvector_ty, LLVMMatchType<0>, LLVMMatchType<0>], + [LLVMVectorSameWidth<0, llvm_i1_ty>, LLVMPointerTo<0>], + [IntrReadMem, IntrArgMemOnly]>; + + class AdvSIMD_4Vec_PredLoad_Intrinsic + : Intrinsic<[llvm_anyvector_ty, LLVMMatchType<0>, + LLVMMatchType<0>, LLVMMatchType<0>], + [LLVMVectorSameWidth<0, llvm_i1_ty>, LLVMPointerTo<0>], + [IntrReadMem, IntrArgMemOnly]>; + + class AdvSIMD_1Vec_PredStore_Intrinsic + : Intrinsic<[], + [llvm_anyvector_ty, + LLVMVectorSameWidth<0, llvm_i1_ty>, LLVMPointerTo<0>], + [IntrArgMemOnly, NoCapture<2>]>; + + class AdvSIMD_2Vec_PredStore_Intrinsic + : Intrinsic<[], + [llvm_anyvector_ty, LLVMMatchType<0>, + LLVMVectorSameWidth<0, llvm_i1_ty>, LLVMPointerTo<0>], + [IntrArgMemOnly, NoCapture<3>]>; + + class AdvSIMD_3Vec_PredStore_Intrinsic + : Intrinsic<[], + [llvm_anyvector_ty, LLVMMatchType<0>, LLVMMatchType<0>, + LLVMVectorSameWidth<0, llvm_i1_ty>, LLVMPointerTo<0>], + [IntrArgMemOnly, NoCapture<4>]>; + + class AdvSIMD_4Vec_PredStore_Intrinsic + : Intrinsic<[], + [llvm_anyvector_ty, LLVMMatchType<0>, LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMVectorSameWidth<0, llvm_i1_ty>, LLVMPointerTo<0>], + [IntrArgMemOnly, NoCapture<5>]>; + + class AdvSIMD_2VectorArgIndexed_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty], + [IntrNoMem]>; + + class AdvSIMD_3VectorArgIndexed_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty], + [IntrNoMem]>; + + class AdvSIMD_Merged1VectorArg_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>], + [IntrNoMem]>; + + class AdvSIMD_Pred1VectorArg_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>], + [IntrNoMem]>; + + class AdvSIMD_Pred2VectorArg_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, + LLVMMatchType<0>], + [IntrNoMem]>; + + class AdvSIMD_Pred3VectorArg_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMMatchType<0>], + [IntrNoMem]>; + + class AdvSIMD_SVE_Compare_Intrinsic + : Intrinsic<[LLVMVectorSameWidth<0, llvm_i1_ty>], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + llvm_anyvector_ty, + LLVMMatchType<0>], + [IntrNoMem]>; + + class AdvSIMD_SVE_CompareWide_Intrinsic + : Intrinsic<[LLVMVectorSameWidth<0, llvm_i1_ty>], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + llvm_anyvector_ty, + llvm_nxv2i64_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_Saturating_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMVectorSameWidth<0, llvm_i1_ty>], + [IntrNoMem]>; + + class AdvSIMD_SVE_SaturatingWithPattern_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + llvm_i32_ty, + llvm_i32_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_Saturating_N_Intrinsic + : Intrinsic<[T], + [T, llvm_anyvector_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic + : Intrinsic<[T], + [T, llvm_i32_ty, llvm_i32_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_ShiftByImm_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, + llvm_i32_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_ShiftWide_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, + llvm_nxv2i64_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_Reduce_Intrinsic + : Intrinsic<[LLVMVectorElementType<0>], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + llvm_anyvector_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_ReduceWithInit_Intrinsic + : Intrinsic<[LLVMVectorElementType<0>], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMVectorElementType<0>, + llvm_anyvector_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_Unpack_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMSubdivide2VectorType<0>], + [IntrNoMem]>; + + class AdvSIMD_SVE_CADD_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_CMLA_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_CMLA_LANE_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty, + llvm_i32_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_CNT_Intrinsic + : Intrinsic<[LLVMVectorOfBitcastsToInt<0>], + [LLVMVectorOfBitcastsToInt<0>, + LLVMVectorSameWidth<0, llvm_i1_ty>, + llvm_anyvector_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_CNTB_Intrinsic + : Intrinsic<[llvm_i64_ty], + [llvm_i32_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_CNTP_Intrinsic + : Intrinsic<[llvm_i64_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + llvm_anyvector_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_DOT_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMSubdivide4VectorType<0>, + LLVMSubdivide4VectorType<0>], + [IntrNoMem]>; + + class AdvSIMD_SVE_DOT_Indexed_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMSubdivide4VectorType<0>, + LLVMSubdivide4VectorType<0>, + llvm_i32_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_DUP_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMVectorElementType<0>], + [IntrNoMem]>; + + class AdvSIMD_SVE_DUPQ_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + llvm_i64_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_EXPA_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorOfBitcastsToInt<0>], + [IntrNoMem]>; + + class AdvSIMD_SVE_FCVT_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMVectorSameWidth<0, llvm_i1_ty>, + llvm_anyvector_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_FCVTZS_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMVectorSameWidth<0, llvm_i1_ty>, + llvm_anyvector_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_INSR_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMVectorElementType<0>], + [IntrNoMem]>; + + class AdvSIMD_SVE_PTRUE_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [llvm_i32_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_PUNPKHI_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMDoubleElementsVectorType<0>], + [IntrNoMem]>; + + class AdvSIMD_SVE_SCVTF_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMVectorSameWidth<0, llvm_i1_ty>, + llvm_anyvector_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_SCALE_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, + LLVMVectorOfBitcastsToInt<0>], + [IntrNoMem]>; + + class AdvSIMD_SVE_PTEST_Intrinsic + : Intrinsic<[llvm_i1_ty], + [llvm_anyvector_ty, + LLVMMatchType<0>], + [IntrNoMem]>; + + class AdvSIMD_SVE_TBL_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMVectorOfBitcastsToInt<0>], + [IntrNoMem]>; + + class AdvSIMD_SVE_WHILE_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [llvm_anyint_ty, LLVMMatchType<1>], + [IntrNoMem]>; + + class SVE2_2VectorArg_Wide_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMSubdivide2VectorType<0>], + [IntrNoMem]>; + + class SVE2_2VectorArg_Pred_Long_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, + LLVMSubdivide2VectorType<0>], + [IntrNoMem]>; + + class SVE2_1VectorArg_Long_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMSubdivide2VectorType<0>, + llvm_i32_ty], + [IntrNoMem]>; + + class SVE2_2VectorArg_Long_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMSubdivide2VectorType<0>, + LLVMSubdivide2VectorType<0>], + [IntrNoMem]>; + + class SVE2_2VectorArgIndexed_Long_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMSubdivide2VectorType<0>, + LLVMSubdivide2VectorType<0>, + llvm_i32_ty], + [IntrNoMem]>; + + class SVE2_3VectorArg_Long_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMSubdivide2VectorType<0>, + LLVMSubdivide2VectorType<0>], + [IntrNoMem]>; + + class SVE2_3VectorArgIndexed_Long_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMSubdivide2VectorType<0>, + LLVMSubdivide2VectorType<0>, + llvm_i32_ty], + [IntrNoMem]>; + + class SVE2_1VectorArg_Narrowing_Intrinsic + : Intrinsic<[LLVMSubdivide2VectorType<0>], + [llvm_anyvector_ty], + [IntrNoMem]>; + + class SVE2_2VectorArg_Narrowing_Intrinsic + : Intrinsic<[LLVMSubdivide2VectorType<0>], + [llvm_anyvector_ty, + LLVMMatchType<0>], + [IntrNoMem]>; + + class SVE2_Merged1VectorArg_Narrowing_Intrinsic + : Intrinsic<[LLVMSubdivide2VectorType<0>], + [LLVMSubdivide2VectorType<0>, + llvm_anyvector_ty], + [IntrNoMem]>; + + class SVE2_Merged2VectorArg_Narrowing_Intrinsic + : Intrinsic<[LLVMSubdivide2VectorType<0>], + [LLVMSubdivide2VectorType<0>, + llvm_anyvector_ty, + LLVMMatchType<0>], + [IntrNoMem]>; + + class SVE2_TBX_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMVectorOfBitcastsToInt<0>], + [IntrNoMem]>; + + class SVE2_CONFLICT_DETECT_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMAnyPointerType, + LLVMMatchType<1>]>; + + class SVE2_CADD_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty], + [IntrNoMem]>; + + class SVE2_1VectorArg_Imm_Narrowing_Intrinsic + : Intrinsic<[LLVMSubdivide2VectorType<0>], + [llvm_anyvector_ty, llvm_i32_ty], [IntrNoMem]>; + + class SVE2_2VectorArg_Imm_Narrowing_Intrinsic + : Intrinsic<[LLVMSubdivide2VectorType<0>], + [LLVMSubdivide2VectorType<0>, + llvm_anyvector_ty, llvm_i32_ty], [IntrNoMem]>; + + class AdvSIMD_SVE_CDOT_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMSubdivide4VectorType<0>, + LLVMSubdivide4VectorType<0>, + llvm_i32_ty], + [IntrNoMem]>; + + class AdvSIMD_SVE_CDOT_LANE_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + LLVMSubdivide4VectorType<0>, + LLVMSubdivide4VectorType<0>, + llvm_i32_ty, llvm_i32_ty], + [IntrNoMem]>; + + // NOTE: There is no relationship between these intrinsics beyond an attempt + // to reuse currently identical class definitions. + class AdvSIMD_SVE_LOGB_Intrinsic : AdvSIMD_SVE_CNT_Intrinsic; + class AdvSIMD_SVE_SADDV_Intrinsic : AdvSIMD_SVE_CNTP_Intrinsic; + class AdvSIMD_SVE_TMAD_Intrinsic : AdvSIMD_2VectorArgIndexed_Intrinsic; + class AdvSIMD_SVE_TSMUL_Intrinsic : AdvSIMD_SVE_TBL_Intrinsic; + class AdvSIMD_SVE_TSSEL_Intrinsic : AdvSIMD_SVE_TBL_Intrinsic; + + // This class of intrinsics are not intended to be useful within LLVM IR but + // are instead here to support some of the more regid parts of the ACLE. + class Builtin_SVCVT + : GCCBuiltin<"__builtin_sve_" # name>, + Intrinsic<[OUT], [OUT, llvm_nxv16i1_ty, IN], [IntrNoMem]>; +} + +let TargetPrefix = "aarch64" in { // All intrinsics start with "llvm.aarch64.". + +// +// Loads +// + +def int_aarch64_sve_ld2 : AdvSIMD_2Vec_PredLoad_Intrinsic; +def int_aarch64_sve_ld3 : AdvSIMD_3Vec_PredLoad_Intrinsic; +def int_aarch64_sve_ld4 : AdvSIMD_4Vec_PredLoad_Intrinsic; + +def int_aarch64_sve_ld2_legacy : AdvSIMD_2Vec_PredLoad_Intrinsic; +def int_aarch64_sve_ld3_legacy : AdvSIMD_3Vec_PredLoad_Intrinsic; +def int_aarch64_sve_ld4_legacy : AdvSIMD_4Vec_PredLoad_Intrinsic; + +def int_aarch64_sve_ldff1 : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMPointerToElt<0>], + [IntrReadMem, IntrArgMemOnly]>; + +def int_aarch64_sve_ldff1_gather : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMPointerToElt<0>, + LLVMVectorSameWidth<0, llvm_i64_ty>], + [IntrReadMem, IntrArgMemOnly]>; + +def int_aarch64_sve_ldnf1 : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMPointerToElt<0>], + [IntrReadMem, IntrArgMemOnly]>; + +def int_aarch64_sve_ldnt1 : AdvSIMD_1Vec_PredLoad_Intrinsic; + +def int_aarch64_sve_ldnt1_gather : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMPointerToElt<0>, + LLVMVectorSameWidth<0, llvm_i64_ty>], + [IntrReadMem, IntrArgMemOnly]>; + +def int_aarch64_sve_ld1rq : Intrinsic<[llvm_anyvector_ty], + [LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMPointerToElt<0>], + [IntrReadMem, IntrArgMemOnly]>; + +// +// Stores +// + +def int_aarch64_sve_st2 : AdvSIMD_2Vec_PredStore_Intrinsic; +def int_aarch64_sve_st3 : AdvSIMD_3Vec_PredStore_Intrinsic; +def int_aarch64_sve_st4 : AdvSIMD_4Vec_PredStore_Intrinsic; + +def int_aarch64_sve_stnt1 : AdvSIMD_1Vec_PredStore_Intrinsic; +def int_aarch64_sve_stnt1_scatter : Intrinsic<[], + [llvm_anyvector_ty, + LLVMVectorSameWidth<0, llvm_i1_ty>, + LLVMPointerToElt<0>, + LLVMVectorSameWidth<0, llvm_i64_ty>], + [IntrArgMemOnly, NoCapture<2>]>; + +// +// Prefetches +// Unfortunately, codegen has no distinction between pointer types, and so there +// needs to be some duplication here to encode the data type being prefetched + +def int_aarch64_sve_prf : Intrinsic<[], [llvm_anyvector_ty, + llvm_ptr_ty, + llvm_i32_ty], [IntrArgMemOnly]>; +def int_aarch64_sve_prfb_gather : Intrinsic<[], + [llvm_anyvector_ty, + LLVMPointerType, + LLVMVectorSameWidth<0,llvm_i64_ty>, + llvm_i32_ty], + [IntrArgMemOnly]>; +def int_aarch64_sve_prfh_gather : Intrinsic<[], + [llvm_anyvector_ty, + LLVMPointerType, + LLVMVectorSameWidth<0,llvm_i64_ty>, + llvm_i32_ty], + [IntrArgMemOnly]>; +def int_aarch64_sve_prfw_gather : Intrinsic<[], + [llvm_anyvector_ty, + LLVMPointerType, + LLVMVectorSameWidth<0,llvm_i64_ty>, + llvm_i32_ty], + [IntrArgMemOnly]>; +def int_aarch64_sve_prfd_gather : Intrinsic<[], + [llvm_anyvector_ty, + LLVMPointerType, + LLVMVectorSameWidth<0,llvm_i64_ty>, + llvm_i32_ty], + [IntrArgMemOnly]>; + +// +// Scalar to vector operations +// + +def int_aarch64_sve_dup : AdvSIMD_SVE_DUP_Intrinsic; + +// +// Integer arithmetic +// + +def int_aarch64_sve_abs : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_add : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_mad : AdvSIMD_Pred3VectorArg_Intrinsic; +def int_aarch64_sve_mla : AdvSIMD_Pred3VectorArg_Intrinsic; +def int_aarch64_sve_mla_lane : AdvSIMD_3VectorArgIndexed_Intrinsic; +def int_aarch64_sve_mls_lane : AdvSIMD_3VectorArgIndexed_Intrinsic; +def int_aarch64_sve_mls : AdvSIMD_Pred3VectorArg_Intrinsic; +def int_aarch64_sve_msb : AdvSIMD_Pred3VectorArg_Intrinsic; +def int_aarch64_sve_mul : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_mul_lane : AdvSIMD_2VectorArgIndexed_Intrinsic; +def int_aarch64_sve_neg : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_sub : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_subr : AdvSIMD_Pred2VectorArg_Intrinsic; + +def int_aarch64_sve_sabd : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_sdiv : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_sdivr : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_sdot : AdvSIMD_SVE_DOT_Intrinsic; +def int_aarch64_sve_sdot_lane : AdvSIMD_SVE_DOT_Indexed_Intrinsic; +def int_aarch64_sve_smax : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_smin : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_smulh : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_sqadd_x : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_sqsub_x : AdvSIMD_2VectorArg_Intrinsic; + +def int_aarch64_sve_uabd : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_udiv : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_udivr : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_udot : AdvSIMD_SVE_DOT_Intrinsic; +def int_aarch64_sve_udot_lane : AdvSIMD_SVE_DOT_Indexed_Intrinsic; +def int_aarch64_sve_umax : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_umin : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_umulh : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_uqadd_x : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_uqsub_x : AdvSIMD_2VectorArg_Intrinsic; + +// +// Logical operations +// + +def int_aarch64_sve_and : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_bic : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_cnot : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_eor : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_not : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_orr : AdvSIMD_Pred2VectorArg_Intrinsic; + +// +// Shifts +// + +def int_aarch64_sve_asr : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_asr_wide : AdvSIMD_SVE_ShiftWide_Intrinsic; +def int_aarch64_sve_asrd : AdvSIMD_SVE_ShiftByImm_Intrinsic; +def int_aarch64_sve_insr : AdvSIMD_SVE_INSR_Intrinsic; +def int_aarch64_sve_lsl : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_lsl_wide : AdvSIMD_SVE_ShiftWide_Intrinsic; +def int_aarch64_sve_lsr : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_lsr_wide : AdvSIMD_SVE_ShiftWide_Intrinsic; + +// +// Integer reductions +// + +def int_aarch64_sve_andv : AdvSIMD_SVE_Reduce_Intrinsic; +def int_aarch64_sve_eorv : AdvSIMD_SVE_Reduce_Intrinsic; +def int_aarch64_sve_orv : AdvSIMD_SVE_Reduce_Intrinsic; +def int_aarch64_sve_saddv : AdvSIMD_SVE_SADDV_Intrinsic; +def int_aarch64_sve_smaxv : AdvSIMD_SVE_Reduce_Intrinsic; +def int_aarch64_sve_sminv : AdvSIMD_SVE_Reduce_Intrinsic; +def int_aarch64_sve_uaddv : AdvSIMD_SVE_SADDV_Intrinsic; +def int_aarch64_sve_umaxv : AdvSIMD_SVE_Reduce_Intrinsic; +def int_aarch64_sve_uminv : AdvSIMD_SVE_Reduce_Intrinsic; + +// +// Integer comparisons +// + +def int_aarch64_sve_cmpeq : AdvSIMD_SVE_Compare_Intrinsic; +def int_aarch64_sve_cmpge : AdvSIMD_SVE_Compare_Intrinsic; +def int_aarch64_sve_cmpgt : AdvSIMD_SVE_Compare_Intrinsic; +def int_aarch64_sve_cmphi : AdvSIMD_SVE_Compare_Intrinsic; +def int_aarch64_sve_cmphs : AdvSIMD_SVE_Compare_Intrinsic; +def int_aarch64_sve_cmpne : AdvSIMD_SVE_Compare_Intrinsic; + +def int_aarch64_sve_cmpeq_wide : AdvSIMD_SVE_CompareWide_Intrinsic; +def int_aarch64_sve_cmpge_wide : AdvSIMD_SVE_CompareWide_Intrinsic; +def int_aarch64_sve_cmpgt_wide : AdvSIMD_SVE_CompareWide_Intrinsic; +def int_aarch64_sve_cmphi_wide : AdvSIMD_SVE_CompareWide_Intrinsic; +def int_aarch64_sve_cmphs_wide : AdvSIMD_SVE_CompareWide_Intrinsic; +def int_aarch64_sve_cmple_wide : AdvSIMD_SVE_CompareWide_Intrinsic; +def int_aarch64_sve_cmplo_wide : AdvSIMD_SVE_CompareWide_Intrinsic; +def int_aarch64_sve_cmpls_wide : AdvSIMD_SVE_CompareWide_Intrinsic; +def int_aarch64_sve_cmplt_wide : AdvSIMD_SVE_CompareWide_Intrinsic; +def int_aarch64_sve_cmpne_wide : AdvSIMD_SVE_CompareWide_Intrinsic; + +// +// While comparisons +// + +def int_aarch64_sve_whilele : AdvSIMD_SVE_WHILE_Intrinsic; +def int_aarch64_sve_whilelo : AdvSIMD_SVE_WHILE_Intrinsic; +def int_aarch64_sve_whilels : AdvSIMD_SVE_WHILE_Intrinsic; +def int_aarch64_sve_whilelt : AdvSIMD_SVE_WHILE_Intrinsic; +def int_aarch64_sve_whilege : AdvSIMD_SVE_WHILE_Intrinsic; +def int_aarch64_sve_whilegt : AdvSIMD_SVE_WHILE_Intrinsic; +def int_aarch64_sve_whilehi : AdvSIMD_SVE_WHILE_Intrinsic; +def int_aarch64_sve_whilehs : AdvSIMD_SVE_WHILE_Intrinsic; + +// +// Counting bits +// + +def int_aarch64_sve_cls : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_clz : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_cnt : AdvSIMD_SVE_CNT_Intrinsic; + +// +// Conversion +// + +def int_aarch64_sve_sxtb : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_sxth : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_sxtw : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_uxtb : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_uxth : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_uxtw : AdvSIMD_Merged1VectorArg_Intrinsic; + +// +// Reversal +// + +def int_aarch64_sve_rbit : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_revb : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_revh : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_revw : AdvSIMD_Merged1VectorArg_Intrinsic; + +// +// Floating-point arithmetic +// + +def int_aarch64_sve_fabd : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fabs : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_fadd : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fcadd : AdvSIMD_SVE_CADD_Intrinsic; +def int_aarch64_sve_fcmla : AdvSIMD_SVE_CMLA_Intrinsic; +def int_aarch64_sve_fcmla_lane : AdvSIMD_SVE_CMLA_LANE_Intrinsic; +def int_aarch64_sve_fdiv : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fdivr : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fexpa_x : AdvSIMD_SVE_EXPA_Intrinsic; +def int_aarch64_sve_fmad : AdvSIMD_Pred3VectorArg_Intrinsic; +def int_aarch64_sve_fmax : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fmaxnm : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fmin : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fminnm : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fmla : AdvSIMD_Pred3VectorArg_Intrinsic; +def int_aarch64_sve_fmls : AdvSIMD_Pred3VectorArg_Intrinsic; +def int_aarch64_sve_fmla_lane : AdvSIMD_3VectorArgIndexed_Intrinsic; +def int_aarch64_sve_fmls_lane : AdvSIMD_3VectorArgIndexed_Intrinsic; +def int_aarch64_sve_fmsb : AdvSIMD_Pred3VectorArg_Intrinsic; +def int_aarch64_sve_fmul : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fmulx : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fmul_lane : AdvSIMD_2VectorArgIndexed_Intrinsic; +def int_aarch64_sve_fneg : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_fnmad : AdvSIMD_Pred3VectorArg_Intrinsic; +def int_aarch64_sve_fnmla : AdvSIMD_Pred3VectorArg_Intrinsic; +def int_aarch64_sve_fnmls : AdvSIMD_Pred3VectorArg_Intrinsic; +def int_aarch64_sve_fnmsb : AdvSIMD_Pred3VectorArg_Intrinsic; +def int_aarch64_sve_frecpe_x : AdvSIMD_1VectorArg_Intrinsic; +def int_aarch64_sve_frecps_x : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_frecpx : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_frinta : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_frinti : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_frintm : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_frintn : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_frintp : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_frintx : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_frintz : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_frsqrte_x : AdvSIMD_1VectorArg_Intrinsic; +def int_aarch64_sve_frsqrts_x : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_fscale : AdvSIMD_SVE_SCALE_Intrinsic; +def int_aarch64_sve_fsqrt : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_ftmad_x : AdvSIMD_SVE_TMAD_Intrinsic; +def int_aarch64_sve_ftsmul_x : AdvSIMD_SVE_TSMUL_Intrinsic; +def int_aarch64_sve_ftssel_x : AdvSIMD_SVE_TSSEL_Intrinsic; +def int_aarch64_sve_fsub : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fsubr : AdvSIMD_Pred2VectorArg_Intrinsic; + +// +// Floating-point reductions +// + +def int_aarch64_sve_fadda : AdvSIMD_SVE_ReduceWithInit_Intrinsic; +def int_aarch64_sve_faddv : AdvSIMD_SVE_Reduce_Intrinsic; +def int_aarch64_sve_fmaxv : AdvSIMD_SVE_Reduce_Intrinsic; +def int_aarch64_sve_fmaxnmv : AdvSIMD_SVE_Reduce_Intrinsic; +def int_aarch64_sve_fminv : AdvSIMD_SVE_Reduce_Intrinsic; +def int_aarch64_sve_fminnmv : AdvSIMD_SVE_Reduce_Intrinsic; + +// +// Floating-point conversions +// + +def int_aarch64_sve_fcvt : AdvSIMD_SVE_FCVT_Intrinsic; +def int_aarch64_sve_fcvtzs : AdvSIMD_SVE_FCVTZS_Intrinsic; +def int_aarch64_sve_fcvtzu : AdvSIMD_SVE_FCVTZS_Intrinsic; +def int_aarch64_sve_scvtf : AdvSIMD_SVE_SCVTF_Intrinsic; +def int_aarch64_sve_ucvtf : AdvSIMD_SVE_SCVTF_Intrinsic; + +// +// Floating-point comparisons +// + +def int_aarch64_sve_facge : AdvSIMD_SVE_Compare_Intrinsic; +def int_aarch64_sve_facgt : AdvSIMD_SVE_Compare_Intrinsic; + +def int_aarch64_sve_fcmpeq : AdvSIMD_SVE_Compare_Intrinsic; +def int_aarch64_sve_fcmpge : AdvSIMD_SVE_Compare_Intrinsic; +def int_aarch64_sve_fcmpgt : AdvSIMD_SVE_Compare_Intrinsic; +def int_aarch64_sve_fcmpne : AdvSIMD_SVE_Compare_Intrinsic; +def int_aarch64_sve_fcmpuo : AdvSIMD_SVE_Compare_Intrinsic; + +// Variants of the above that are closely tied to the ACLE and its regid +// definition of "unused" bits when working with unpacked data types. +// For example, svcvt_f32_s64 expects the top half of active lanes to be zero'd. + +def int_aarch64_sve_fcvtzs_i32f16 : Builtin_SVCVT<"svcvt_s32_f16_m", llvm_nxv4i32_ty, llvm_nxv8f16_ty>; +def int_aarch64_sve_fcvtzs_i32f64 : Builtin_SVCVT<"svcvt_s32_f64_m", llvm_nxv4i32_ty, llvm_nxv2f64_ty>; +def int_aarch64_sve_fcvtzs_i64f16 : Builtin_SVCVT<"svcvt_s64_f16_m", llvm_nxv2i64_ty, llvm_nxv8f16_ty>; +def int_aarch64_sve_fcvtzs_i64f32 : Builtin_SVCVT<"svcvt_s64_f32_m", llvm_nxv2i64_ty, llvm_nxv4f32_ty>; + +def int_aarch64_sve_fcvtzu_i32f16 : Builtin_SVCVT<"svcvt_u32_f16_m", llvm_nxv4i32_ty, llvm_nxv8f16_ty>; +def int_aarch64_sve_fcvtzu_i32f64 : Builtin_SVCVT<"svcvt_u32_f64_m", llvm_nxv4i32_ty, llvm_nxv2f64_ty>; +def int_aarch64_sve_fcvtzu_i64f16 : Builtin_SVCVT<"svcvt_u64_f16_m", llvm_nxv2i64_ty, llvm_nxv8f16_ty>; +def int_aarch64_sve_fcvtzu_i64f32 : Builtin_SVCVT<"svcvt_u64_f32_m", llvm_nxv2i64_ty, llvm_nxv4f32_ty>; + +def int_aarch64_sve_fcvt_f16f32 : Builtin_SVCVT<"svcvt_f16_f32_m", llvm_nxv8f16_ty, llvm_nxv4f32_ty>; +def int_aarch64_sve_fcvt_f16f64 : Builtin_SVCVT<"svcvt_f16_f64_m", llvm_nxv8f16_ty, llvm_nxv2f64_ty>; +def int_aarch64_sve_fcvt_f32f64 : Builtin_SVCVT<"svcvt_f32_f64_m", llvm_nxv4f32_ty, llvm_nxv2f64_ty>; + +def int_aarch64_sve_fcvt_f32f16 : Builtin_SVCVT<"svcvt_f32_f16_m", llvm_nxv4f32_ty, llvm_nxv8f16_ty>; +def int_aarch64_sve_fcvt_f64f16 : Builtin_SVCVT<"svcvt_f64_f16_m", llvm_nxv2f64_ty, llvm_nxv8f16_ty>; +def int_aarch64_sve_fcvt_f64f32 : Builtin_SVCVT<"svcvt_f64_f32_m", llvm_nxv2f64_ty, llvm_nxv4f32_ty>; + +def int_aarch64_sve_fcvtlt_f32f16 : Builtin_SVCVT<"svcvtlt_f32_f16_m", llvm_nxv4f32_ty, llvm_nxv8f16_ty>; +def int_aarch64_sve_fcvtlt_f64f32 : Builtin_SVCVT<"svcvtlt_f64_f32_m", llvm_nxv2f64_ty, llvm_nxv4f32_ty>; +def int_aarch64_sve_fcvtnt_f16f32 : Builtin_SVCVT<"svcvtnt_f16_f32_m", llvm_nxv8f16_ty, llvm_nxv4f32_ty>; +def int_aarch64_sve_fcvtnt_f32f64 : Builtin_SVCVT<"svcvtnt_f32_f64_m", llvm_nxv4f32_ty, llvm_nxv2f64_ty>; + +def int_aarch64_sve_fcvtx_f32f64 : Builtin_SVCVT<"svcvtx_f32_f64_m", llvm_nxv4f32_ty, llvm_nxv2f64_ty>; +def int_aarch64_sve_fcvtxnt_f32f64 : Builtin_SVCVT<"svcvtxnt_f32_f64_m", llvm_nxv4f32_ty, llvm_nxv2f64_ty>; + +def int_aarch64_sve_scvtf_f16i32 : Builtin_SVCVT<"svcvt_f16_s32_m", llvm_nxv8f16_ty, llvm_nxv4i32_ty>; +def int_aarch64_sve_scvtf_f16i64 : Builtin_SVCVT<"svcvt_f16_s64_m", llvm_nxv8f16_ty, llvm_nxv2i64_ty>; +def int_aarch64_sve_scvtf_f32i64 : Builtin_SVCVT<"svcvt_f32_s64_m", llvm_nxv4f32_ty, llvm_nxv2i64_ty>; +def int_aarch64_sve_scvtf_f64i32 : Builtin_SVCVT<"svcvt_f64_s32_m", llvm_nxv2f64_ty, llvm_nxv4i32_ty>; + +def int_aarch64_sve_ucvtf_f16i32 : Builtin_SVCVT<"svcvt_f16_u32_m", llvm_nxv8f16_ty, llvm_nxv4i32_ty>; +def int_aarch64_sve_ucvtf_f16i64 : Builtin_SVCVT<"svcvt_f16_u64_m", llvm_nxv8f16_ty, llvm_nxv2i64_ty>; +def int_aarch64_sve_ucvtf_f32i64 : Builtin_SVCVT<"svcvt_f32_u64_m", llvm_nxv4f32_ty, llvm_nxv2i64_ty>; +def int_aarch64_sve_ucvtf_f64i32 : Builtin_SVCVT<"svcvt_f64_u32_m", llvm_nxv2f64_ty, llvm_nxv4i32_ty>; + +// +// Permutations and selection +// + +// TODO: Consider removing int_aarch64_sve_clasta & int_aarch64_sve_clastb. +def int_aarch64_sve_clasta : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_clasta_n : AdvSIMD_SVE_ReduceWithInit_Intrinsic; +def int_aarch64_sve_clastb : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_clastb_n : AdvSIMD_SVE_ReduceWithInit_Intrinsic; +def int_aarch64_sve_compact : AdvSIMD_Pred1VectorArg_Intrinsic; +def int_aarch64_sve_dupq_lane : AdvSIMD_SVE_DUPQ_Intrinsic; +def int_aarch64_sve_ext : AdvSIMD_2VectorArgIndexed_Intrinsic; +def int_aarch64_sve_lasta : AdvSIMD_SVE_Reduce_Intrinsic; +def int_aarch64_sve_lastb : AdvSIMD_SVE_Reduce_Intrinsic; +def int_aarch64_sve_rev : AdvSIMD_1VectorArg_Intrinsic; +def int_aarch64_sve_splice : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_sunpkhi : AdvSIMD_SVE_Unpack_Intrinsic; +def int_aarch64_sve_sunpklo : AdvSIMD_SVE_Unpack_Intrinsic; +def int_aarch64_sve_tbl : AdvSIMD_SVE_TBL_Intrinsic; +def int_aarch64_sve_trn1 : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_trn2 : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_uunpkhi : AdvSIMD_SVE_Unpack_Intrinsic; +def int_aarch64_sve_uunpklo : AdvSIMD_SVE_Unpack_Intrinsic; +def int_aarch64_sve_uzp1 : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_uzp2 : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_zip1 : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_zip2 : AdvSIMD_2VectorArg_Intrinsic; + +// +// Predicate creation +// + +def int_aarch64_sve_ptrue : AdvSIMD_SVE_PTRUE_Intrinsic; + +// +// Predicate operations +// + +def int_aarch64_sve_and_z : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_bic_z : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_brka : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_brka_z : AdvSIMD_Pred1VectorArg_Intrinsic; +def int_aarch64_sve_brkb : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_brkb_z : AdvSIMD_Pred1VectorArg_Intrinsic; +def int_aarch64_sve_brkn_z : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_brkpa_z : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_brkpb_z : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_eor_z : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_nand_z : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_nor_z : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_orn_z : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_orr_z : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_pfirst : AdvSIMD_Pred1VectorArg_Intrinsic; +def int_aarch64_sve_pnext : AdvSIMD_Pred1VectorArg_Intrinsic; +def int_aarch64_sve_punpkhi : AdvSIMD_SVE_PUNPKHI_Intrinsic; +def int_aarch64_sve_punpklo : AdvSIMD_SVE_PUNPKHI_Intrinsic; + +// +// Testing predicates +// + +def int_aarch64_sve_ptest_any : AdvSIMD_SVE_PTEST_Intrinsic; +def int_aarch64_sve_ptest_first : AdvSIMD_SVE_PTEST_Intrinsic; +def int_aarch64_sve_ptest_last : AdvSIMD_SVE_PTEST_Intrinsic; + +// +// FFR manipulation +// + +def int_aarch64_sve_rdffr : GCCBuiltin<"__builtin_sve_svrdffr">, Intrinsic<[llvm_nxv16i1_ty], []>; +def int_aarch64_sve_rdffr_z : GCCBuiltin<"__builtin_sve_svrdffr_z">, Intrinsic<[llvm_nxv16i1_ty], [llvm_nxv16i1_ty]>; +def int_aarch64_sve_setffr : GCCBuiltin<"__builtin_sve_svsetffr">, Intrinsic<[], []>; +def int_aarch64_sve_wrffr : GCCBuiltin<"__builtin_sve_svwrffr">, Intrinsic<[], [llvm_nxv16i1_ty]>; + +// +// Counting elements +// + +def int_aarch64_sve_cntb : AdvSIMD_SVE_CNTB_Intrinsic; +def int_aarch64_sve_cnth : AdvSIMD_SVE_CNTB_Intrinsic; +def int_aarch64_sve_cntw : AdvSIMD_SVE_CNTB_Intrinsic; +def int_aarch64_sve_cntd : AdvSIMD_SVE_CNTB_Intrinsic; + +def int_aarch64_sve_cntp : AdvSIMD_SVE_CNTP_Intrinsic; + +// +// Saturating scalara arithmetic +// + +def int_aarch64_sve_sqdech : AdvSIMD_SVE_SaturatingWithPattern_Intrinsic; +def int_aarch64_sve_sqdecw : AdvSIMD_SVE_SaturatingWithPattern_Intrinsic; +def int_aarch64_sve_sqdecd : AdvSIMD_SVE_SaturatingWithPattern_Intrinsic; +def int_aarch64_sve_sqdecp : AdvSIMD_SVE_Saturating_Intrinsic; + +def int_aarch64_sve_sqdecb_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqdecb_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqdech_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqdech_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqdecw_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqdecw_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqdecd_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqdecd_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqdecp_n32 : AdvSIMD_SVE_Saturating_N_Intrinsic; +def int_aarch64_sve_sqdecp_n64 : AdvSIMD_SVE_Saturating_N_Intrinsic; + +def int_aarch64_sve_sqinch : AdvSIMD_SVE_SaturatingWithPattern_Intrinsic; +def int_aarch64_sve_sqincw : AdvSIMD_SVE_SaturatingWithPattern_Intrinsic; +def int_aarch64_sve_sqincd : AdvSIMD_SVE_SaturatingWithPattern_Intrinsic; +def int_aarch64_sve_sqincp : AdvSIMD_SVE_Saturating_Intrinsic; + +def int_aarch64_sve_sqincb_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqincb_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqinch_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqinch_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqincw_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqincw_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqincd_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqincd_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_sqincp_n32 : AdvSIMD_SVE_Saturating_N_Intrinsic; +def int_aarch64_sve_sqincp_n64 : AdvSIMD_SVE_Saturating_N_Intrinsic; + +def int_aarch64_sve_uqdech : AdvSIMD_SVE_SaturatingWithPattern_Intrinsic; +def int_aarch64_sve_uqdecw : AdvSIMD_SVE_SaturatingWithPattern_Intrinsic; +def int_aarch64_sve_uqdecd : AdvSIMD_SVE_SaturatingWithPattern_Intrinsic; +def int_aarch64_sve_uqdecp : AdvSIMD_SVE_Saturating_Intrinsic; + +def int_aarch64_sve_uqdecb_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqdecb_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqdech_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqdech_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqdecw_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqdecw_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqdecd_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqdecd_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqdecp_n32 : AdvSIMD_SVE_Saturating_N_Intrinsic; +def int_aarch64_sve_uqdecp_n64 : AdvSIMD_SVE_Saturating_N_Intrinsic; + +def int_aarch64_sve_uqinch : AdvSIMD_SVE_SaturatingWithPattern_Intrinsic; +def int_aarch64_sve_uqincw : AdvSIMD_SVE_SaturatingWithPattern_Intrinsic; +def int_aarch64_sve_uqincd : AdvSIMD_SVE_SaturatingWithPattern_Intrinsic; +def int_aarch64_sve_uqincp : AdvSIMD_SVE_Saturating_Intrinsic; + +def int_aarch64_sve_uqincb_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqincb_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqinch_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqinch_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqincw_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqincw_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqincd_n32 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqincd_n64 : AdvSIMD_SVE_SaturatingWithPattern_N_Intrinsic; +def int_aarch64_sve_uqincp_n32 : AdvSIMD_SVE_Saturating_N_Intrinsic; +def int_aarch64_sve_uqincp_n64 : AdvSIMD_SVE_Saturating_N_Intrinsic; + +// +// Reinterpreting data +// + +def int_aarch64_sve_reinterpret_bool_b : Intrinsic<[llvm_nxv16i1_ty], + [llvm_anyvector_ty], + [IntrNoMem]>; + +def int_aarch64_sve_reinterpret_bool_h : Intrinsic<[llvm_nxv8i1_ty], + [llvm_anyvector_ty], + [IntrNoMem]>; + +def int_aarch64_sve_reinterpret_bool_w : Intrinsic<[llvm_nxv4i1_ty], + [llvm_anyvector_ty], + [IntrNoMem]>; + +def int_aarch64_sve_reinterpret_bool_d : Intrinsic<[llvm_nxv2i1_ty], + [llvm_anyvector_ty], + [IntrNoMem]>; + +// +// SVE2 - Uniform DSP operations +// + +def int_aarch64_sve_saba : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_shadd : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_shsub : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_shsubr : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_sli : AdvSIMD_2VectorArgIndexed_Intrinsic; +def int_aarch64_sve_sqabs : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_sqadd : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_sqdmulh : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_sqdmulh_lane : AdvSIMD_2VectorArgIndexed_Intrinsic; +def int_aarch64_sve_sqneg : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_sqrdmlah : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_sqrdmlah_lane : AdvSIMD_3VectorArgIndexed_Intrinsic; +def int_aarch64_sve_sqrdmlsh : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_sqrdmlsh_lane : AdvSIMD_3VectorArgIndexed_Intrinsic; +def int_aarch64_sve_sqrdmulh : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_sqrdmulh_lane : AdvSIMD_2VectorArgIndexed_Intrinsic; +def int_aarch64_sve_sqrshl : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_sqshl : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_sqshlu : AdvSIMD_SVE_ShiftByImm_Intrinsic; +def int_aarch64_sve_sqsub : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_sqsubr : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_srhadd : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_sri : AdvSIMD_2VectorArgIndexed_Intrinsic; +def int_aarch64_sve_srshl : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_srshr : AdvSIMD_SVE_ShiftByImm_Intrinsic; +def int_aarch64_sve_srsra : AdvSIMD_2VectorArgIndexed_Intrinsic; +def int_aarch64_sve_ssra : AdvSIMD_2VectorArgIndexed_Intrinsic; +def int_aarch64_sve_suqadd : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_uaba : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_uhadd : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_uhsub : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_uhsubr : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_uqadd : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_uqrshl : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_uqshl : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_uqsub : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_uqsubr : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_urecpe : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_urhadd : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_urshl : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_urshr : AdvSIMD_SVE_ShiftByImm_Intrinsic; +def int_aarch64_sve_ursqrte : AdvSIMD_Merged1VectorArg_Intrinsic; +def int_aarch64_sve_ursra : AdvSIMD_2VectorArgIndexed_Intrinsic; +def int_aarch64_sve_usqadd : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_usra : AdvSIMD_2VectorArgIndexed_Intrinsic; + +// +// SVE2 - Widening DSP operations +// + +def int_aarch64_sve_sabalb : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_sabalt : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_sabdlb : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_sabdlt : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_saddlb : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_saddlt : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_saddwb : SVE2_2VectorArg_Wide_Intrinsic; +def int_aarch64_sve_saddwt : SVE2_2VectorArg_Wide_Intrinsic; +def int_aarch64_sve_smlalb : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_smlalb_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_smlalt : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_smlalt_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_smlslb : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_smlslb_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_smlslt : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_smlslt_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_smullb : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_smullb_lane : SVE2_2VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_smullt : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_smullt_lane : SVE2_2VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_sqdmlalb : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_sqdmlalb_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_sqdmlalt : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_sqdmlalt_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_sqdmlslb : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_sqdmlslb_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_sqdmlslt : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_sqdmlslt_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_sqdmullb : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_sqdmullb_lane : SVE2_2VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_sqdmullt : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_sqdmullt_lane : SVE2_2VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_sshllb : SVE2_1VectorArg_Long_Intrinsic; +def int_aarch64_sve_sshllt : SVE2_1VectorArg_Long_Intrinsic; +def int_aarch64_sve_ssublb : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_ssublt : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_ssubwb : SVE2_2VectorArg_Wide_Intrinsic; +def int_aarch64_sve_ssubwt : SVE2_2VectorArg_Wide_Intrinsic; +def int_aarch64_sve_uabalb : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_uabalt : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_uabdlb : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_uabdlt : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_uaddlb : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_uaddlt : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_uaddwb : SVE2_2VectorArg_Wide_Intrinsic; +def int_aarch64_sve_uaddwt : SVE2_2VectorArg_Wide_Intrinsic; +def int_aarch64_sve_umlalb : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_umlalb_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_umlalt : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_umlalt_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_umlslb : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_umlslb_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_umlslt : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_umlslt_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_umullb : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_umullb_lane : SVE2_2VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_umullt : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_umullt_lane : SVE2_2VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_ushllb : SVE2_1VectorArg_Long_Intrinsic; +def int_aarch64_sve_ushllt : SVE2_1VectorArg_Long_Intrinsic; +def int_aarch64_sve_usublb : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_usublt : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_usubwb : SVE2_2VectorArg_Wide_Intrinsic; +def int_aarch64_sve_usubwt : SVE2_2VectorArg_Wide_Intrinsic; + +// +// SVE2 - Narrowing DSP operations +// + +def int_aarch64_sve_addhnb : SVE2_2VectorArg_Narrowing_Intrinsic; +def int_aarch64_sve_addhnt : SVE2_Merged2VectorArg_Narrowing_Intrinsic; +def int_aarch64_sve_raddhnb : SVE2_2VectorArg_Narrowing_Intrinsic; +def int_aarch64_sve_raddhnt : SVE2_Merged2VectorArg_Narrowing_Intrinsic; +def int_aarch64_sve_rshrnb : SVE2_1VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_rshrnt : SVE2_2VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_rsubhnb : SVE2_2VectorArg_Narrowing_Intrinsic; +def int_aarch64_sve_rsubhnt : SVE2_Merged2VectorArg_Narrowing_Intrinsic; +def int_aarch64_sve_shrnb : SVE2_1VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_shrnt : SVE2_2VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_sqrshrnb : SVE2_1VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_sqrshrnt : SVE2_2VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_sqrshrunb : SVE2_1VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_sqrshrunt : SVE2_2VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_sqshrnb : SVE2_1VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_sqshrnt : SVE2_2VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_sqshrunb : SVE2_1VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_sqshrunt : SVE2_2VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_subhnb : SVE2_2VectorArg_Narrowing_Intrinsic; +def int_aarch64_sve_subhnt : SVE2_Merged2VectorArg_Narrowing_Intrinsic; +def int_aarch64_sve_uqrshrnb : SVE2_1VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_uqrshrnt : SVE2_2VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_uqshrnb : SVE2_1VectorArg_Imm_Narrowing_Intrinsic; +def int_aarch64_sve_uqshrnt : SVE2_2VectorArg_Imm_Narrowing_Intrinsic; + +// +// SVE2 - Unary narrowing operations +// + +def int_aarch64_sve_sqxtnb : SVE2_1VectorArg_Narrowing_Intrinsic; +def int_aarch64_sve_sqxtnt : SVE2_Merged1VectorArg_Narrowing_Intrinsic; +def int_aarch64_sve_sqxtunb : SVE2_1VectorArg_Narrowing_Intrinsic; +def int_aarch64_sve_sqxtunt : SVE2_Merged1VectorArg_Narrowing_Intrinsic; +def int_aarch64_sve_uqxtnb : SVE2_1VectorArg_Narrowing_Intrinsic; +def int_aarch64_sve_uqxtnt : SVE2_Merged1VectorArg_Narrowing_Intrinsic; + +// +// SVE2 - Non-widening pairwise arithmetic +// + +def int_aarch64_sve_addp : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_faddp : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fmaxp : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fmaxnmp : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fminp : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_fminnmp : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_smaxp : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_sminp : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_umaxp : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_uminp : AdvSIMD_Pred2VectorArg_Intrinsic; + +// +// SVE2 - Widening pairwise arithmetic +// + +def int_aarch64_sve_sadalp : SVE2_2VectorArg_Pred_Long_Intrinsic; +def int_aarch64_sve_uadalp : SVE2_2VectorArg_Pred_Long_Intrinsic; + +// +// SVE2 - Bitwise ternary logical instructions +// + +def int_aarch64_sve_bcax : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_bsl : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_bsl1n : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_bsl2n : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_eor3 : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_nbsl : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_xar : AdvSIMD_2VectorArgIndexed_Intrinsic; + +// +// SVE2 - Large integer arithmetic +// + +def int_aarch64_sve_adclb : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_adclt : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_sbclb : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_sbclt : AdvSIMD_3VectorArg_Intrinsic; + +// +// SVE2 - Uniform complex integer arithmetic +// + +def int_aarch64_sve_cadd_x : SVE2_CADD_Intrinsic; +def int_aarch64_sve_sqcadd_x : SVE2_CADD_Intrinsic; +def int_aarch64_sve_cmla_x : AdvSIMD_3VectorArgIndexed_Intrinsic; +def int_aarch64_sve_cmla_lane_x : AdvSIMD_SVE_CMLA_LANE_Intrinsic; +def int_aarch64_sve_sqrdcmlah_x : AdvSIMD_3VectorArgIndexed_Intrinsic; +def int_aarch64_sve_sqrdcmlah_lane_x : AdvSIMD_SVE_CMLA_LANE_Intrinsic; + +// +// SVE2 - Widening complex integer arithmetic +// + +def int_aarch64_sve_saddlbt : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_sqdmlalbt : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_sqdmlslbt : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_ssublbt : SVE2_2VectorArg_Long_Intrinsic; +def int_aarch64_sve_ssubltb : SVE2_2VectorArg_Long_Intrinsic; + +// +// SVE2 - Widening complex integer dot product +// + +def int_aarch64_sve_cdot : AdvSIMD_SVE_CDOT_Intrinsic; +def int_aarch64_sve_cdot_lane : AdvSIMD_SVE_CDOT_LANE_Intrinsic; + +// +// SVE2 - Floating-point widening multiply-accumulate +// + +def int_aarch64_sve_fmlalb : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_fmlalb_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_fmlalt : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_fmlalt_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_fmlslb : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_fmlslb_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; +def int_aarch64_sve_fmlslt : SVE2_3VectorArg_Long_Intrinsic; +def int_aarch64_sve_fmlslt_lane : SVE2_3VectorArgIndexed_Long_Intrinsic; + +// +// SVE2 - Floating-point integer binary logarithm +// + +def int_aarch64_sve_flogb : AdvSIMD_SVE_LOGB_Intrinsic; + +// +// SVE2 - Vector histogram count +// + +def int_aarch64_sve_histcnt : AdvSIMD_Pred2VectorArg_Intrinsic; +def int_aarch64_sve_histseg : AdvSIMD_2VectorArg_Intrinsic; + +// +// SVE2 - Character match +// + +def int_aarch64_sve_match : AdvSIMD_SVE_Compare_Intrinsic; +def int_aarch64_sve_nmatch : AdvSIMD_SVE_Compare_Intrinsic; + + +// +// SVE2 - Contiguous conflict detection +// + +def int_aarch64_sve_whilerw_b : SVE2_CONFLICT_DETECT_Intrinsic; +def int_aarch64_sve_whilerw_h : SVE2_CONFLICT_DETECT_Intrinsic; +def int_aarch64_sve_whilerw_s : SVE2_CONFLICT_DETECT_Intrinsic; +def int_aarch64_sve_whilerw_d : SVE2_CONFLICT_DETECT_Intrinsic; +def int_aarch64_sve_whilewr_b : SVE2_CONFLICT_DETECT_Intrinsic; +def int_aarch64_sve_whilewr_h : SVE2_CONFLICT_DETECT_Intrinsic; +def int_aarch64_sve_whilewr_s : SVE2_CONFLICT_DETECT_Intrinsic; +def int_aarch64_sve_whilewr_d : SVE2_CONFLICT_DETECT_Intrinsic; + +// +// SVE2 - Polynomial arithmetic +// + +def int_aarch64_sve_eorbt : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_eortb : AdvSIMD_3VectorArg_Intrinsic; +def int_aarch64_sve_pmul : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_pmullb_pair : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_pmullt_pair : AdvSIMD_2VectorArg_Intrinsic; + +// +// SVE2 - Extended table lookup/permute +// + +def int_aarch64_sve_tbl2 : SVE2_TBX_Intrinsic; +def int_aarch64_sve_tbx : SVE2_TBX_Intrinsic; + +// +// SVE2 - Optional bit permutation +// + +def int_aarch64_sve_bdep_x : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_bext_x : AdvSIMD_2VectorArg_Intrinsic; +def int_aarch64_sve_bgrp_x : AdvSIMD_2VectorArg_Intrinsic; + +// +// SVE2 - Optional AES, SHA-3 and SM4 +// + +def int_aarch64_sve_aesd : GCCBuiltin<"__builtin_sve_svaesd_u8">, + Intrinsic<[llvm_nxv16i8_ty], + [llvm_nxv16i8_ty, llvm_nxv16i8_ty], + [IntrNoMem]>; +def int_aarch64_sve_aesimc : GCCBuiltin<"__builtin_sve_svaesimc_u8">, + Intrinsic<[llvm_nxv16i8_ty], + [llvm_nxv16i8_ty], + [IntrNoMem]>; +def int_aarch64_sve_aese : GCCBuiltin<"__builtin_sve_svaese_u8">, + Intrinsic<[llvm_nxv16i8_ty], + [llvm_nxv16i8_ty, llvm_nxv16i8_ty], + [IntrNoMem]>; +def int_aarch64_sve_aesmc : GCCBuiltin<"__builtin_sve_svaesmc_u8">, + Intrinsic<[llvm_nxv16i8_ty], + [llvm_nxv16i8_ty], + [IntrNoMem]>; +def int_aarch64_sve_rax1 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, LLVMMatchType<0>], + [IntrNoMem]>; +def int_aarch64_sve_sm4e : GCCBuiltin<"__builtin_sve_svsm4e_u32">, + Intrinsic<[llvm_nxv4i32_ty], + [llvm_nxv4i32_ty, llvm_nxv4i32_ty], + [IntrNoMem]>; +def int_aarch64_sve_sm4ekey : GCCBuiltin<"__builtin_sve_svsm4ekey_u32">, + Intrinsic<[llvm_nxv4i32_ty], + [llvm_nxv4i32_ty, llvm_nxv4i32_ty], + [IntrNoMem]>; +} Index: include/llvm/IR/Metadata.def =================================================================== --- include/llvm/IR/Metadata.def +++ include/llvm/IR/Metadata.def @@ -114,6 +114,10 @@ HANDLE_SPECIALIZED_MDNODE_BRANCH(DIMacroNode) HANDLE_SPECIALIZED_MDNODE_LEAF_UNIQUABLE(DIMacro) HANDLE_SPECIALIZED_MDNODE_LEAF_UNIQUABLE(DIMacroFile) +HANDLE_SPECIALIZED_MDNODE_LEAF_UNIQUABLE(DIStringType) +HANDLE_SPECIALIZED_MDNODE_LEAF_UNIQUABLE(DIFortranArrayType) +HANDLE_SPECIALIZED_MDNODE_LEAF_UNIQUABLE(DIFortranSubrange) +HANDLE_SPECIALIZED_MDNODE_LEAF_UNIQUABLE(DICommonBlock) #undef HANDLE_METADATA #undef HANDLE_METADATA_LEAF Index: include/llvm/IR/Operator.h =================================================================== --- include/llvm/IR/Operator.h +++ include/llvm/IR/Operator.h @@ -233,6 +233,9 @@ void operator&=(const FastMathFlags &OtherFlags) { Flags &= OtherFlags.Flags; } + bool operator!=(const FastMathFlags &OtherFlags) const { + return OtherFlags.Flags != Flags; + } }; /// Utility class for floating point operations which can have Index: include/llvm/IR/PatternMatch.h =================================================================== --- include/llvm/IR/PatternMatch.h +++ include/llvm/IR/PatternMatch.h @@ -37,6 +37,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Operator.h" #include "llvm/IR/Value.h" @@ -90,6 +91,10 @@ /// Match an arbitrary Constant and ignore it. inline class_match m_Constant() { return class_match(); } +inline class_match m_ConstantExpr() { + return class_match(); +} + /// Matching combinators template struct match_combine_or { LTy L; @@ -483,9 +488,15 @@ /// Match a Constant, capturing the value if we match. inline bind_ty m_Constant(Constant *&C) { return C; } +/// Match a ConstantExpr, capturing the value if we match. +inline bind_ty m_ConstantExpr(ConstantExpr *&C) { return C; } + /// Match a ConstantFP, capturing the value if we match. inline bind_ty m_ConstantFP(ConstantFP *&C) { return C; } +/// Match a PHINode, capturing it if we match. +inline bind_ty m_PHI(PHINode *&PHI) { return PHI; } + /// Match a specified Value*. struct specificval_ty { const Value *Val; @@ -972,6 +983,33 @@ } }; +template +struct CmpOp_match { + PredicateTy &Predicate; + LHS_t L; + RHS_t R; + + CmpOp_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) + : Predicate(Pred), L(LHS), R(RHS) {} + + template bool match(OpTy *V) { + if (V->getValueID() == Value::InstructionVal + Opcode) + if (auto *I = dyn_cast(V)) + if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { + Predicate = I->getPredicate(); + return true; + } + if (auto *CE = dyn_cast(V)) + if ((CE->getOpcode() == Opcode) && + L.match(CE->getOperand(0)) && + R.match(CE->getOperand(1))) { + Predicate = PredicateTy(CE->getPredicate()); + return true; + } + return false; + } +}; + template inline CmpClass_match m_Cmp(CmpInst::Predicate &Pred, const LHS &L, const RHS &R) { @@ -979,15 +1017,15 @@ } template -inline CmpClass_match +inline CmpOp_match m_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) { - return CmpClass_match(Pred, L, R); + return CmpOp_match(Pred, L, R); } template -inline CmpClass_match +inline CmpOp_match m_FCmp(FCmpInst::Predicate &Pred, const LHS &L, const RHS &R) { - return CmpClass_match(Pred, L, R); + return CmpOp_match(Pred, L, R); } //===----------------------------------------------------------------------===// @@ -1026,33 +1064,6 @@ } //===----------------------------------------------------------------------===// -// Matchers for InsertElementInst classes -// - -template -struct InsertElementClass_match { - Val_t V; - Elt_t E; - Idx_t I; - - InsertElementClass_match(const Val_t &Val, const Elt_t &Elt, const Idx_t &Idx) - : V(Val), E(Elt), I(Idx) {} - - template bool match(OpTy *VV) { - if (auto *II = dyn_cast(VV)) - return V.match(II->getOperand(0)) && E.match(II->getOperand(1)) && - I.match(II->getOperand(2)); - return false; - } -}; - -template -inline InsertElementClass_match -m_InsertElement(const Val_t &Val, const Elt_t &Elt, const Idx_t &Idx) { - return InsertElementClass_match(Val, Elt, Idx); -} - -//===----------------------------------------------------------------------===// // Matchers for ExtractElementInst classes // @@ -1077,33 +1088,6 @@ } //===----------------------------------------------------------------------===// -// Matchers for ShuffleVectorInst classes -// - -template -struct ShuffleVectorClass_match { - V1_t V1; - V2_t V2; - Mask_t M; - - ShuffleVectorClass_match(const V1_t &v1, const V2_t &v2, const Mask_t &m) - : V1(v1), V2(v2), M(m) {} - - template bool match(OpTy *V) { - if (auto *SI = dyn_cast(V)) - return V1.match(SI->getOperand(0)) && V2.match(SI->getOperand(1)) && - M.match(SI->getOperand(2)); - return false; - } -}; - -template -inline ShuffleVectorClass_match -m_ShuffleVector(const V1_t &v1, const V2_t &v2, const Mask_t &m) { - return ShuffleVectorClass_match(v1, v2, m); -} - -//===----------------------------------------------------------------------===// // Matchers for CastInst classes // @@ -1125,6 +1109,12 @@ return CastClass_match(Op); } +/// Matches IntToPtr. +template +inline CastClass_match m_IntToPtr(const OpTy &Op) { + return CastClass_match(Op); +} + /// Matches PtrToInt. template inline CastClass_match m_PtrToInt(const OpTy &Op) { @@ -1168,6 +1158,27 @@ return CastClass_match(Op); } +/// Matching combinators +template struct match_any_ext_or_none { + OpTy Op; + + match_any_ext_or_none(const OpTy &_Op) : Op(_Op) {} + + template bool match(ITy *V) { + if (CastClass_match(Op).match(V)) + return true; + if (CastClass_match(Op).match(V)) + return true; + return Op.match(V); + } +}; + +/// Match (V) +template +inline match_any_ext_or_none m_AnyExtOrNone(const OpTy &Op) { + return match_any_ext_or_none(Op); +} + /// Matches FPTrunc template inline CastClass_match m_FPTrunc(const OpTy &Op) { @@ -1184,6 +1195,8 @@ // Matcher for LoadInst classes // +/* + * MERGE we use our own impl template struct LoadClass_match { Op_t Op; @@ -1227,6 +1240,7 @@ m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp) { return StoreClass_match(ValueOp, PointerOp); } +*/ //===----------------------------------------------------------------------===// // Matchers for control flow. @@ -1624,6 +1638,230 @@ return m_Intrinsic(Op0, Op1); } +template +struct AnyOneOp_match { + T0 Op0; + + AnyOneOp_match(const T0 &Op0) + : Op0(Op0) {} + + template bool match(OpTy *V) { + if (V->getValueID() == Value::InstructionVal + Opcode) { + auto *I = cast(V); + return Op0.match(I->getOperand(0)); + } + if (auto *CE = dyn_cast(V)) + return (CE->getOpcode() == Opcode) && Op0.match(CE->getOperand(0)); + return false; + } +}; + +template +struct AnyTwoOp_match { + T0 Op0; + T1 Op1; + + AnyTwoOp_match(const T0 &Op0, const T1 &Op1) + : Op0(Op0), Op1(Op1) {} + + template bool match(OpTy *V) { + if (V->getValueID() == Value::InstructionVal + Opcode) { + auto *I = dyn_cast(V); + return Op0.match(I->getOperand(0)) && Op1.match(I->getOperand(1)); + } + if (auto *CE = dyn_cast(V)) + return (CE->getOpcode() == Opcode) && + Op0.match(CE->getOperand(0)) && Op1.match(CE->getOperand(1)); + return false; + } +}; + +template +struct AnyThreeOp_match { + T0 Op1; + T1 Op2; + T2 Op3; + + AnyThreeOp_match(const T0 &Op1, const T1 &Op2, const T2 &Op3) + : Op1(Op1), Op2(Op2), Op3(Op3) {} + + template bool match(OpTy *V) { + if (V->getValueID() == Value::InstructionVal + Opcode) { + auto *I = dyn_cast(V); + return Op1.match(I->getOperand(0)) && + Op2.match(I->getOperand(1)) && + Op3.match(I->getOperand(2)); + } + if (auto *CE = dyn_cast(V)) + return (CE->getOpcode() == Opcode) && + Op1.match(CE->getOperand(0)) && + Op2.match(CE->getOperand(1)) && + Op3.match(CE->getOperand(2)); + return false; + } +}; + +template +inline AnyThreeOp_match +m_ShuffleVector(const T0 &Op0, const T1 &Op1, const T2 &Op2) { + return AnyThreeOp_match(Op0,Op1,Op2); +} + +/// \brief Match an arbitrary stepvector constant. +inline class_match m_StepVector() { + return class_match(); +} + +template +inline AnyThreeOp_match +m_InsertElement(const T0 &Op0, const T1 &Op1, const T2 &Op2) { + return AnyThreeOp_match(Op0,Op1,Op2); +} + +#define m_SplatVector(X) \ + m_ShuffleVector( \ + m_InsertElement(m_Undef(), X, m_Zero()), \ + m_Value(), \ + m_Zero()) \ + + +/// \brief Match an arbitrary vscale constant. +inline class_match m_VScale() { + return class_match(); +} + +/// \brief Match the expected idioms used to model the original seriesvector +/// instruction. +template struct SeriesVector_match { + T0 Start; + T1 Step; + SeriesVector_match(const T0 &Op1, const T1 &Op2) : Start(Op1), Step(Op2) {} + + template bool match(OpTy *V) { + auto Ty = dyn_cast(V->getType()); + if (!Ty || !Ty->getElementType()->isIntegerTy()) + return false; + + Value *X, *Y; + auto One = ConstantInt::get(Ty->getElementType(), 1); + auto Zero = ConstantInt::get(Ty->getElementType(), 0); + + // series_vector(0, 1) + if (m_StepVector().match(V)) + return Start.match(Zero) && Step.match(One); + + auto StartMod = m_SplatVector(m_Value(X)); + auto StepMod = m_SplatVector(m_Value(Y)); + auto ScaledStepModA = m_Mul(m_StepVector(), StepMod); + auto ScaledStepModB = m_Mul(StepMod, m_StepVector()); + + // series_vector(X, 1) + if (m_Add(m_StepVector(), StartMod).match(V) || + m_Add(StartMod, m_StepVector()).match(V)) + return Start.match(X) && Step.match(One); + + // series_vector(0, Y) + if (m_Mul(m_StepVector(), StepMod).match(V) || + m_Mul(StepMod, m_StepVector()).match(V)) + return Start.match(Zero) && Step.match(Y); + + // series_vector(X, Y) + if (m_Add(StartMod, ScaledStepModA).match(V) || + m_Add(StartMod, ScaledStepModB).match(V) || + m_Add(ScaledStepModA, StartMod).match(V) || + m_Add(ScaledStepModB, StartMod).match(V)) + return Start.match(X) && Step.match(Y); + + return false; + } +}; + +template +inline SeriesVector_match +m_SeriesVector(const T0 &Op0, const T1 &Op1) { + return SeriesVector_match(Op0, Op1); +} + +//===----------------------------------------------------------------------===// +// Matchers for loads +// + +template inline AnyOneOp_match +m_Load(const T0 &Op0) { return AnyOneOp_match(Op0); } + +template +inline AnyTwoOp_match m_Store(const T0 &Op0, + const T1 &Op1) { + return AnyTwoOp_match(Op0, Op1); +} + +template +struct AnyLoadOp_match { + T0 Op0; + T1 Op1; + T2 Op2; + T3 Op3; + + AnyLoadOp_match(const T0 &Op0, const T1 &Op1, const T2 &Op2, const T3 &Op3) + : Op0(Op0), Op1(Op1), Op2(Op2), Op3(Op3) {} + + template bool match(OpTy *V) { + if (m_Load(Op0).match(V)) { + if (V->getType()->isVectorTy()) + Op2.match(ConstantInt::getTrue( + VectorType::getBool(cast(V->getType())))); + else + Op2.match(ConstantInt::getTrue(V->getContext())); + Op3.match(UndefValue::get(V->getType())); + return true; + } + + if (m_Intrinsic(Op0, Op1, Op2, Op3).match(V)) + return true; + + return false; + } +}; + +template +struct AnyStoreOp_match { + T0 Op0; + T1 Op1; + T2 Op2; + T3 Op3; + + AnyStoreOp_match(const T0 &Op0, const T1 &Op1, const T2 &Op2, const T3 &Op3) + : Op0(Op0), Op1(Op1), Op2(Op2), Op3(Op3) {} + + template bool match(OpTy *V) { + if (m_Store(Op0, Op1).match(V)) { + auto *Val = V->getOperand(0); + if (Val->getType()->isVectorTy()) + Op3.match(ConstantInt::getTrue( + VectorType::getBool(cast(Val->getType())))); + else + Op3.match(ConstantInt::getTrue(V->getContext())); + return true; + } + + if (m_Intrinsic(Op0, Op1, Op2, Op3).match(V)) + return true; + + return false; + } +}; + +template +inline AnyStoreOp_match +m_AnyStore(const T0 &Op0, const T1 &Op1, const T2 &Op2, const T3 &Op3) { + return AnyStoreOp_match(Op0, Op1, Op2, Op3); +} + +template +inline AnyLoadOp_match m_AnyLoad(const T0 &Op0, const T1 &Op1, + const T2 &Op2, const T3 &Op3) { + return AnyLoadOp_match(Op0, Op1, Op2, Op3); +} //===----------------------------------------------------------------------===// // Matchers for two-operands operators with the operators in either order // Index: include/llvm/IR/Type.h =================================================================== --- include/llvm/IR/Type.h +++ include/llvm/IR/Type.h @@ -368,6 +368,7 @@ } inline unsigned getVectorNumElements() const; + inline unsigned getVectorIsScalable() const; Type *getVectorElementType() const { assert(getTypeID() == VectorTyID); return ContainedTys[0]; Index: include/llvm/IR/Value.def =================================================================== --- include/llvm/IR/Value.def +++ include/llvm/IR/Value.def @@ -73,7 +73,9 @@ HANDLE_CONSTANT(ConstantVector) // ConstantData. +HANDLE_CONSTANT(StepVector) HANDLE_CONSTANT(UndefValue) +HANDLE_CONSTANT(VScale) HANDLE_CONSTANT(ConstantAggregateZero) HANDLE_CONSTANT(ConstantDataArray) HANDLE_CONSTANT(ConstantDataVector) Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -80,6 +80,8 @@ void initializeBlockExtractorPass(PassRegistry &); void initializeBlockFrequencyInfoWrapperPassPass(PassRegistry&); void initializeBoundsCheckingLegacyPassPass(PassRegistry&); +void initializeBOSCCPass(PassRegistry&); +void initializeBranchCoalescingPass(PassRegistry&); void initializeBranchFolderPassPass(PassRegistry&); void initializeBranchProbabilityInfoWrapperPassPass(PassRegistry&); void initializeBranchRelaxationPass(PassRegistry&); @@ -103,6 +105,7 @@ void initializeConstantHoistingLegacyPassPass(PassRegistry&); void initializeConstantMergeLegacyPassPass(PassRegistry&); void initializeConstantPropagationPass(PassRegistry&); +void initializeContiguousLoadStorePass(PassRegistry&); void initializeCorrelatedValuePropagationPass(PassRegistry&); void initializeCostModelAnalysisPass(PassRegistry&); void initializeCrossDSOCFIPass(PassRegistry&); @@ -171,6 +174,8 @@ void initializeImplicitNullChecksPass(PassRegistry&); void initializeIndVarSimplifyLegacyPassPass(PassRegistry&); void initializeIndirectBrExpandPassPass(PassRegistry&); +// MERGE our void initializeInductiveRangeCheckEliminationPass(PassRegistry&); +void initializeIRCELegacyPassPass(PassRegistry&); void initializeInferAddressSpacesPass(PassRegistry&); void initializeInferFunctionAttrsLegacyPassPass(PassRegistry&); void initializeInlineCostAnalysisPass(PassRegistry&); @@ -181,6 +186,8 @@ void initializeInstructionCombiningPassPass(PassRegistry&); void initializeInstructionSelectPass(PassRegistry&); void initializeInterleavedAccessPass(PassRegistry&); +void initializeInterleavedGatherScatterPass(PassRegistry &); +void initializeInterleavedGatherScatterStoreSinkPass(PassRegistry &); void initializeInternalizeLegacyPassPass(PassRegistry&); void initializeIntervalPartitionPass(PassRegistry&); void initializeJumpThreadingPass(PassRegistry&); @@ -221,13 +228,17 @@ void initializeLoopPassPass(PassRegistry&); void initializeLoopPredicationLegacyPassPass(PassRegistry&); void initializeLoopRerollPass(PassRegistry&); +void initializeLoopRewriteGEPsPassPass(PassRegistry&); void initializeLoopRotateLegacyPassPass(PassRegistry&); +void initializeLoopExprTreeFactoringPassPass(PassRegistry&); void initializeLoopSimplifyCFGLegacyPassPass(PassRegistry&); +void initializeLoopSpeculativeBoundsCheckPass(PassRegistry&); void initializeLoopSimplifyPass(PassRegistry&); void initializeLoopStrengthReducePass(PassRegistry&); void initializeLoopUnrollAndJamPass(PassRegistry&); void initializeLoopUnrollPass(PassRegistry&); void initializeLoopUnswitchPass(PassRegistry&); +void initializeLoopVectorizationAnalysisPass(PassRegistry&); void initializeLoopVectorizePass(PassRegistry&); void initializeLoopVersioningLICMPass(PassRegistry&); void initializeLoopVersioningPassPass(PassRegistry&); @@ -355,7 +366,9 @@ void initializeScalarizerPass(PassRegistry&); void initializeScavengerTestPass(PassRegistry&); void initializeScopedNoAliasAAWrapperPassPass(PassRegistry&); +void initializeSearchLoopVectorizePass(PassRegistry&); void initializeSeparateConstOffsetFromGEPPass(PassRegistry&); +void initializeSeparateInvariantsFromGepOffsetPass(PassRegistry&); void initializeShadowStackGCLoweringPass(PassRegistry&); void initializeShrinkWrapPass(PassRegistry&); void initializeSimpleInlinerPass(PassRegistry&); @@ -379,6 +392,8 @@ void initializeStripNonLineTableDebugInfoPass(PassRegistry&); void initializeStripSymbolsPass(PassRegistry&); void initializeStructurizeCFGPass(PassRegistry&); +void initializeSVELoopVectorizePass(PassRegistry&); +void initializeHWAddressSanitizerPass(PassRegistry&); void initializeTailCallElimPass(PassRegistry&); void initializeTailDuplicatePass(PassRegistry&); void initializeTargetLibraryInfoWrapperPassPass(PassRegistry&); @@ -401,6 +416,8 @@ void initializeWriteThinLTOBitcodePass(PassRegistry&); void initializeXRayInstrumentationPass(PassRegistry&); +void initializeSVEExpandLibCallPass(PassRegistry &); +void initializeSVEPostVectorizePass(PassRegistry &); } // end namespace llvm #endif // LLVM_INITIALIZEPASSES_H Index: include/llvm/LinkAllPasses.h =================================================================== --- include/llvm/LinkAllPasses.h +++ include/llvm/LinkAllPasses.h @@ -129,6 +129,7 @@ (void) llvm::createLoopPredicationPass(); (void) llvm::createLoopSimplifyPass(); (void) llvm::createLoopSimplifyCFGPass(); + (void) llvm::createLoopSpeculativeBoundsCheckPass(); (void) llvm::createLoopStrengthReducePass(); (void) llvm::createLoopRerollPass(); (void) llvm::createLoopUnrollPass(); @@ -204,6 +205,7 @@ (void) llvm::createCorrelatedValuePropagationPass(); (void) llvm::createMemDepPrinter(); (void) llvm::createLoopVectorizePass(); + (void) llvm::createSVELoopVectorizePass(); (void) llvm::createSLPVectorizerPass(); (void) llvm::createLoadStoreVectorizerPass(); (void) llvm::createPartiallyInlineLibCallsPass(); Index: include/llvm/MC/MCDwarf.h =================================================================== --- include/llvm/MC/MCDwarf.h +++ include/llvm/MC/MCDwarf.h @@ -434,6 +434,7 @@ }; private: + std::string Comment; OpType Operation; MCSymbol *Label; unsigned Register; @@ -443,14 +444,15 @@ }; std::vector Values; - MCCFIInstruction(OpType Op, MCSymbol *L, unsigned R, int O, StringRef V) - : Operation(Op), Label(L), Register(R), Offset(O), + MCCFIInstruction(OpType Op, MCSymbol *L, unsigned R, int O, StringRef V, + StringRef Comment = "") + : Comment(Comment), Operation(Op), Label(L), Register(R), Offset(O), Values(V.begin(), V.end()) { assert(Op != OpRegister); } MCCFIInstruction(OpType Op, MCSymbol *L, unsigned R1, unsigned R2) - : Operation(Op), Label(L), Register(R1), Register2(R2) { + : Comment(""), Operation(Op), Label(L), Register(R1), Register2(R2) { assert(Op == OpRegister); } @@ -540,8 +542,9 @@ /// .cfi_escape Allows the user to add arbitrary bytes to the unwind /// info. - static MCCFIInstruction createEscape(MCSymbol *L, StringRef Vals) { - return MCCFIInstruction(OpEscape, L, 0, 0, Vals); + static MCCFIInstruction createEscape(MCSymbol *L, StringRef Vals, + StringRef Comment="") { + return MCCFIInstruction(OpEscape, L, 0, 0, Vals, Comment); } /// A special wrapper for .cfi_escape that indicates GNU_ARGS_SIZE @@ -549,6 +552,30 @@ return MCCFIInstruction(OpGnuArgsSize, L, 0, Size, ""); } + /// \brief A special wrapper for .cfi_escape that describes the location + /// where register 'Reg' from the previous frame is saved at, calculated as: + /// SaveLoc(Reg) = Basereg + Offset + Scalereg * Offset2. + static MCCFIInstruction createScaledOffset(MCSymbol *L, unsigned Reg, + unsigned Basereg, int Offset, + unsigned Scalereg, int Offset2, + StringRef Comment = ""); + + /// \brief A special wrapper for .cfi_escape that describes the location + /// where register 'Reg' from the previous frame is saved at, calculated as: + /// SaveLoc(Reg) = CFA + Offset + Scalereg * Offset2. + static MCCFIInstruction createScaledCfaOffset(MCSymbol *L, unsigned Reg, + int Offset, unsigned Scalereg, + int Offset2, + StringRef Comment = ""); + + /// \brief A special wrapper for .cfi_escape that describes the CFA + /// as an expression using a 'scale register'. The CFA is defined as: + /// CFA = Basereg + Offset + Scalereg * Offset2. + static MCCFIInstruction createScaledDefCfa(MCSymbol *L, unsigned Basereg, + int Offset, unsigned Scalereg, + int Offset2, + StringRef Comment = ""); + OpType getOperation() const { return Operation; } MCSymbol *getLabel() const { return Label; } @@ -576,6 +603,10 @@ assert(Operation == OpEscape); return StringRef(&Values[0], Values.size()); } + + StringRef getComment() const { + return Comment; + } }; struct MCDwarfFrameInfo { Index: include/llvm/MC/MCParser/MCParsedAsmOperand.h =================================================================== --- include/llvm/MC/MCParser/MCParsedAsmOperand.h +++ include/llvm/MC/MCParser/MCParsedAsmOperand.h @@ -58,6 +58,8 @@ virtual bool isImm() const = 0; /// isReg - Is this a register operand? virtual bool isReg() const = 0; + // isAnyReg() - Is this any kind of register operand? + virtual bool isAnyReg() const { return isReg(); } virtual unsigned getReg() const = 0; /// isMem - Is this a memory operand? Index: include/llvm/Support/AArch64TargetParser.def =================================================================== --- include/llvm/Support/AArch64TargetParser.def +++ include/llvm/Support/AArch64TargetParser.def @@ -13,6 +13,15 @@ // NOTE: NO INCLUDE GUARD DESIRED! +#ifndef AARCH64_FEATURE +#define AARCH64_FEATURE(VAL, PRIORITY, NAME) +#endif + +AARCH64_FEATURE( 1, 1, "neon") +AARCH64_FEATURE( 8, 1, "lse") +AARCH64_FEATURE(22, 2, "sve") + + #ifndef AARCH64_ARCH #define AARCH64_ARCH(NAME, ID, CPU_ATTR, SUB_ARCH, ARCH_ATTR, ARCH_FPU, ARCH_BASE_EXT) #endif @@ -46,24 +55,29 @@ #define AARCH64_ARCH_EXT_NAME(NAME, ID, FEATURE, NEGFEATURE) #endif // FIXME: This would be nicer were it tablegen -AARCH64_ARCH_EXT_NAME("invalid", AArch64::AEK_INVALID, nullptr, nullptr) -AARCH64_ARCH_EXT_NAME("none", AArch64::AEK_NONE, nullptr, nullptr) -AARCH64_ARCH_EXT_NAME("crc", AArch64::AEK_CRC, "+crc", "-crc") -AARCH64_ARCH_EXT_NAME("lse", AArch64::AEK_LSE, "+lse", "-lse") -AARCH64_ARCH_EXT_NAME("rdm", AArch64::AEK_RDM, "+rdm", "-rdm") -AARCH64_ARCH_EXT_NAME("crypto", AArch64::AEK_CRYPTO, "+crypto","-crypto") -AARCH64_ARCH_EXT_NAME("sm4", AArch64::AEK_SM4, "+sm4", "-sm4") -AARCH64_ARCH_EXT_NAME("sha3", AArch64::AEK_SHA3, "+sha3", "-sha3") -AARCH64_ARCH_EXT_NAME("sha2", AArch64::AEK_SHA2, "+sha2", "-sha2") -AARCH64_ARCH_EXT_NAME("aes", AArch64::AEK_AES, "+aes", "-aes") -AARCH64_ARCH_EXT_NAME("dotprod", AArch64::AEK_DOTPROD, "+dotprod","-dotprod") -AARCH64_ARCH_EXT_NAME("fp", AArch64::AEK_FP, "+fp-armv8", "-fp-armv8") -AARCH64_ARCH_EXT_NAME("simd", AArch64::AEK_SIMD, "+neon", "-neon") -AARCH64_ARCH_EXT_NAME("fp16", AArch64::AEK_FP16, "+fullfp16", "-fullfp16") -AARCH64_ARCH_EXT_NAME("profile", AArch64::AEK_PROFILE, "+spe", "-spe") -AARCH64_ARCH_EXT_NAME("ras", AArch64::AEK_RAS, "+ras", "-ras") -AARCH64_ARCH_EXT_NAME("sve", AArch64::AEK_SVE, "+sve", "-sve") -AARCH64_ARCH_EXT_NAME("rcpc", AArch64::AEK_RCPC, "+rcpc", "-rcpc") +AARCH64_ARCH_EXT_NAME("invalid", AArch64::AEK_INVALID, nullptr, nullptr) +AARCH64_ARCH_EXT_NAME("none", AArch64::AEK_NONE, nullptr, nullptr) +AARCH64_ARCH_EXT_NAME("crc", AArch64::AEK_CRC, "+crc", "-crc") +AARCH64_ARCH_EXT_NAME("lse", AArch64::AEK_LSE, "+lse", "-lse") +AARCH64_ARCH_EXT_NAME("rdm", AArch64::AEK_RDM, "+rdm", "-rdm") +AARCH64_ARCH_EXT_NAME("crypto", AArch64::AEK_CRYPTO, "+crypto","-crypto") +AARCH64_ARCH_EXT_NAME("sm4", AArch64::AEK_SM4, "+sm4", "-sm4") +AARCH64_ARCH_EXT_NAME("sha3", AArch64::AEK_SHA3, "+sha3", "-sha3") +AARCH64_ARCH_EXT_NAME("sha2", AArch64::AEK_SHA2, "+sha2", "-sha2") +AARCH64_ARCH_EXT_NAME("aes", AArch64::AEK_AES, "+aes", "-aes") +AARCH64_ARCH_EXT_NAME("dotprod", AArch64::AEK_DOTPROD, "+dotprod","-dotprod") +AARCH64_ARCH_EXT_NAME("fp", AArch64::AEK_FP, "+fp-armv8", "-fp-armv8") +AARCH64_ARCH_EXT_NAME("simd", AArch64::AEK_SIMD, "+neon", "-neon") +AARCH64_ARCH_EXT_NAME("fp16", AArch64::AEK_FP16, "+fullfp16", "-fullfp16") +AARCH64_ARCH_EXT_NAME("profile", AArch64::AEK_PROFILE, "+spe", "-spe") +AARCH64_ARCH_EXT_NAME("ras", AArch64::AEK_RAS, "+ras", "-ras") +AARCH64_ARCH_EXT_NAME("sve", AArch64::AEK_SVE, "+sve", "-sve") +AARCH64_ARCH_EXT_NAME("sve2", AArch64::AEK_SVE2, "+sve2", "-sve2") +AARCH64_ARCH_EXT_NAME("sve2-aes", AArch64::AEK_SVE2AES, "+sve2-aes", "-sve2-aes") +AARCH64_ARCH_EXT_NAME("sve2-sm4", AArch64::AEK_SVE2SM4, "+sve2-sm4", "-sve2-sm4") +AARCH64_ARCH_EXT_NAME("sve2-sha3", AArch64::AEK_SVE2SHA3, "+sve2-sha3", "-sve2-sha3") +AARCH64_ARCH_EXT_NAME("sve2-bitperm", AArch64::AEK_SVE2BITPERM, "+sve2-bitperm", "-sve2-bitperm") +AARCH64_ARCH_EXT_NAME("rcpc", AArch64::AEK_RCPC, "+rcpc", "-rcpc") #undef AARCH64_ARCH_EXT_NAME #ifndef AARCH64_CPU_NAME @@ -112,3 +126,4 @@ // Invalid CPU AARCH64_CPU_NAME("invalid", INVALID, FK_INVALID, true, AArch64::AEK_INVALID) #undef AARCH64_CPU_NAME +#undef AARCH64_FEATURE Index: include/llvm/Support/LockFileManager.h =================================================================== --- include/llvm/Support/LockFileManager.h +++ include/llvm/Support/LockFileManager.h @@ -78,7 +78,7 @@ operator LockFileState() const { return getState(); } /// For a shared lock, wait until the owner releases the lock. - WaitForUnlockResult waitForUnlock(); + WaitForUnlockResult waitForUnlock(unsigned MaxSeconds = 5 * 60); /// Remove the lock file. This may delete a different lock file than /// the one previously read if there is a race. Index: include/llvm/Support/MachineValueType.h =================================================================== --- include/llvm/Support/MachineValueType.h +++ include/llvm/Support/MachineValueType.h @@ -149,28 +149,30 @@ v2f16 = 85, // 2 x f16 v4f16 = 86, // 4 x f16 v8f16 = 87, // 8 x f16 - v1f32 = 88, // 1 x f32 - v2f32 = 89, // 2 x f32 - v4f32 = 90, // 4 x f32 - v8f32 = 91, // 8 x f32 - v16f32 = 92, // 16 x f32 - v1f64 = 93, // 1 x f64 - v2f64 = 94, // 2 x f64 - v4f64 = 95, // 4 x f64 - v8f64 = 96, // 8 x f64 + v16f16 = 88, // 16 x f16 + v32f16 = 89, // 32 x f16 + v1f32 = 90, // 1 x f32 + v2f32 = 91, // 2 x f32 + v4f32 = 92, // 4 x f32 + v8f32 = 93, // 8 x f32 + v16f32 = 94, // 16 x f32 + v1f64 = 95, // 1 x f64 + v2f64 = 96, // 2 x f64 + v4f64 = 97, // 4 x f64 + v8f64 = 98, // 8 x f64 - nxv2f16 = 97, // n x 2 x f16 - nxv4f16 = 98, // n x 4 x f16 - nxv8f16 = 99, // n x 8 x f16 - nxv1f32 = 100, // n x 1 x f32 - nxv2f32 = 101, // n x 2 x f32 - nxv4f32 = 102, // n x 4 x f32 - nxv8f32 = 103, // n x 8 x f32 - nxv16f32 = 104, // n x 16 x f32 - nxv1f64 = 105, // n x 1 x f64 - nxv2f64 = 106, // n x 2 x f64 - nxv4f64 = 107, // n x 4 x f64 - nxv8f64 = 108, // n x 8 x f64 + nxv2f16 = 99, // n x 2 x f16 + nxv4f16 = 100, // n x 4 x f16 + nxv8f16 = 101, // n x 8 x f16 + nxv1f32 = 102, // n x 1 x f32 + nxv2f32 = 103, // n x 2 x f32 + nxv4f32 = 104, // n x 4 x f32 + nxv8f32 = 105, // n x 8 x f32 + nxv16f32 = 106, // n x 16 x f32 + nxv1f64 = 107, // n x 1 x f64 + nxv2f64 = 108, // n x 2 x f64 + nxv4f64 = 109, // n x 4 x f64 + nxv8f64 = 110, // n x 8 x f64 FIRST_FP_VECTOR_VALUETYPE = v2f16, LAST_FP_VECTOR_VALUETYPE = nxv8f64, @@ -181,20 +183,20 @@ FIRST_VECTOR_VALUETYPE = v1i1, LAST_VECTOR_VALUETYPE = nxv8f64, - x86mmx = 109, // This is an X86 MMX value + x86mmx = 111, // This is an X86 MMX value - Glue = 110, // This glues nodes together during pre-RA sched + Glue = 112, // This glues nodes together during pre-RA sched - isVoid = 111, // This has no value + isVoid = 113, // This has no value - Untyped = 112, // This value takes a register, but has + Untyped = 114, // This value takes a register, but has // unspecified type. The register class // will be determined by the opcode. - ExceptRef = 113, // WebAssembly's except_ref type + ExceptRef = 115, // WebAssembly's except_ref type FIRST_VALUETYPE = 1, // This is always the beginning of the list. - LAST_VALUETYPE = 114, // This always remains at the end of the list. + LAST_VALUETYPE = 116, // This always remains at the end of the list. // This is the current maximum for LAST_VALUETYPE. // MVT::MAX_ALLOWED_VALUETYPE is used for asserts and to size bit vectors @@ -358,17 +360,18 @@ /// Return true if this is a 256-bit vector type. bool is256BitVector() const { - return (SimpleTy == MVT::v8f32 || SimpleTy == MVT::v4f64 || - SimpleTy == MVT::v32i8 || SimpleTy == MVT::v16i16 || - SimpleTy == MVT::v8i32 || SimpleTy == MVT::v4i64); + return (SimpleTy == MVT::v16f16 || SimpleTy == MVT::v8f32 || + SimpleTy == MVT::v4f64 || SimpleTy == MVT::v32i8 || + SimpleTy == MVT::v16i16 || SimpleTy == MVT::v8i32 || + SimpleTy == MVT::v4i64); } /// Return true if this is a 512-bit vector type. bool is512BitVector() const { - return (SimpleTy == MVT::v16f32 || SimpleTy == MVT::v8f64 || - SimpleTy == MVT::v512i1 || SimpleTy == MVT::v64i8 || - SimpleTy == MVT::v32i16 || SimpleTy == MVT::v16i32 || - SimpleTy == MVT::v8i64); + return (SimpleTy == MVT::v32f16 || SimpleTy == MVT::v16f32 || + SimpleTy == MVT::v8f64 || SimpleTy == MVT::v512i1 || + SimpleTy == MVT::v64i8 || SimpleTy == MVT::v32i16 || + SimpleTy == MVT::v16i32 || SimpleTy == MVT::v8i64); } /// Return true if this is a 1024-bit vector type. @@ -491,6 +494,8 @@ case v2f16: case v4f16: case v8f16: + case v16f16: + case v32f16: case nxv2f16: case nxv4f16: case nxv8f16: return f16; @@ -534,6 +539,7 @@ case v32i16: case v32i32: case v32i64: + case v32f16: case nxv32i1: case nxv32i8: case nxv32i16: @@ -544,6 +550,7 @@ case v16i16: case v16i32: case v16i64: + case v16f16: case v16f32: case nxv16i1: case nxv16i8: @@ -716,6 +723,7 @@ case v16i16: case v8i32: case v4i64: + case v16f16: case v8f32: case v4f64: case nxv32i8: @@ -729,6 +737,7 @@ case v32i16: case v16i32: case v8i64: + case v32f16: case v16f32: case v8f64: case nxv32i16: @@ -885,6 +894,8 @@ if (NumElements == 2) return MVT::v2f16; if (NumElements == 4) return MVT::v4f16; if (NumElements == 8) return MVT::v8f16; + if (NumElements == 16) return MVT::v16f16; + if (NumElements == 32) return MVT::v32f16; break; case MVT::f32: if (NumElements == 1) return MVT::v1f32; @@ -1041,6 +1052,16 @@ (MVT::SimpleValueType)(MVT::LAST_FP_VECTOR_VALUETYPE + 1)); } + static mvt_range integer_fixed_vector_valuetypes() { + return mvt_range(MVT::FIRST_INTEGER_VALUETYPE, + MVT::FIRST_INTEGER_SCALABLE_VALUETYPE); + } + + static mvt_range fp_fixed_vector_valuetypes() { + return mvt_range(MVT::FIRST_FP_VECTOR_VALUETYPE, + MVT::FIRST_FP_SCALABLE_VALUETYPE); + } + static mvt_range integer_scalable_vector_valuetypes() { return mvt_range(MVT::FIRST_INTEGER_SCALABLE_VALUETYPE, (MVT::SimpleValueType)(MVT::LAST_INTEGER_SCALABLE_VALUETYPE + 1)); Index: include/llvm/Support/Process.h =================================================================== --- include/llvm/Support/Process.h +++ include/llvm/Support/Process.h @@ -183,6 +183,11 @@ /// Get the result of a process wide random number generator. The /// generator will be automatically seeded in non-deterministic fashion. static unsigned GetRandomNumber(); + + /// States whether our parent process is using the same executable image as + /// the current process. This typically happens when clang spawns itself as + /// "clang -cc1" or "clang -cc1as" + static bool CheckMyParentUsesSameExeImage(); }; } Index: include/llvm/Support/TargetParser.h =================================================================== --- include/llvm/Support/TargetParser.h +++ include/llvm/Support/TargetParser.h @@ -88,6 +88,11 @@ AEK_DOTPROD = 1 << 14, AEK_SHA2 = 1 << 15, AEK_AES = 1 << 16, + AEK_SVE2 = 1 << 17, + AEK_SVE2AES = 1 << 18, + AEK_SVE2SM4 = 1 << 19, + AEK_SVE2SHA3 = 1 << 20, + AEK_SVE2BITPERM = 1 << 21, // Unsupported extensions. AEK_OS = 0x8000000, AEK_IWMMXT = 0x10000000, @@ -178,6 +183,11 @@ AEK_SHA3 = 1 << 14, AEK_SHA2 = 1 << 15, AEK_AES = 1 << 16, + AEK_SVE2 = 1 << 17, + AEK_SVE2AES = 1 << 18, + AEK_SVE2SM4 = 1 << 19, + AEK_SVE2SHA3 = 1 << 20, + AEK_SVE2BITPERM = 1 << 21, }; StringRef getCanonicalArchName(StringRef Arch); Index: include/llvm/Target/Target.td =================================================================== --- include/llvm/Target/Target.td +++ include/llvm/Target/Target.td @@ -664,6 +664,17 @@ class unknown_class; def unknown : unknown_class; + +class DiagnosticPredicateClass { + code Predicate; + let Predicate = c; +} + +def DP_None : DiagnosticPredicateClass<[{}]>; +def DP_IsImm : DiagnosticPredicateClass<[{Operand.isImm()}]>; +def DP_IsReg : DiagnosticPredicateClass<[{Operand.isReg()}]>; +def DP_IsTok : DiagnosticPredicateClass<[{Operand.isToken()}]>; + /// AsmOperandClass - Representation for the kinds of operands which the target /// specific parser can create and the assembly matcher may need to distinguish. /// Index: include/llvm/Target/TargetMachine.h =================================================================== --- include/llvm/Target/TargetMachine.h +++ include/llvm/Target/TargetMachine.h @@ -199,7 +199,7 @@ void setOptLevel(CodeGenOpt::Level Level); void setFastISel(bool Enable) { Options.EnableFastISel = Enable; } - bool getO0WantsFastISel() { return O0WantsFastISel; } + virtual bool getO0WantsFastISel() { return O0WantsFastISel; } void setO0WantsFastISel(bool Enable) { O0WantsFastISel = Enable; } void setGlobalISel(bool Enable) { Options.EnableGlobalISel = Enable; } void setMachineOutliner(bool Enable) { Index: include/llvm/Target/TargetSelectionDAG.td =================================================================== --- include/llvm/Target/TargetSelectionDAG.td +++ include/llvm/Target/TargetSelectionDAG.td @@ -238,6 +238,15 @@ def SDTVecShuffle : SDTypeProfile<1, 2, [ SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2> ]>; +def SDTVecShuffleVar : SDTypeProfile<1, 3, [ + SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2>, SDTCisInt<3> +]>; +def SDTSplatVector : SDTypeProfile<1, 1, [ + SDTCisVec<0> +]>; +def SDTSeriesVector : SDTypeProfile<1, 2, [ + SDTCisVec<0>, SDTCisSameAs<1, 2>, SDTCisInt<2> +]>; def SDTVecExtract : SDTypeProfile<1, 2, [ // vector extract SDTCisEltOfVec<0, 1>, SDTCisPtrTy<2> ]>; @@ -251,6 +260,9 @@ def SDTSubVecInsert : SDTypeProfile<1, 3, [ // subvector insert SDTCisSubVecOfVec<2, 1>, SDTCisSameAs<0,1>, SDTCisInt<3> ]>; +def SDTVecElementCount : SDTypeProfile<1, 1, [ // vector element count + SDTCisInt<0>, SDTCisVT<1, OtherVT> +]>; def SDTPrefetch : SDTypeProfile<0, 4, [ // prefetch SDTCisPtrTy<0>, SDTCisSameAs<1, 2>, SDTCisSameAs<1, 3>, SDTCisInt<1> @@ -310,6 +322,7 @@ def bb : SDNode<"ISD::BasicBlock", SDTOther , [], "BasicBlockSDNode">; def cond : SDNode<"ISD::CONDCODE" , SDTOther , [], "CondCodeSDNode">; def undef : SDNode<"ISD::UNDEF" , SDTUNDEF , []>; +def vscale : SDNode<"ISD::VSCALE" , SDTIntUnaryOp, []>; def globaladdr : SDNode<"ISD::GlobalAddress", SDTPtrLeaf, [], "GlobalAddressSDNode">; def tglobaladdr : SDNode<"ISD::TargetGlobalAddress", SDTPtrLeaf, [], @@ -525,9 +538,13 @@ [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; def vector_shuffle : SDNode<"ISD::VECTOR_SHUFFLE", SDTVecShuffle, []>; +def vector_shuffle_var : SDNode<"ISD::VECTOR_SHUFFLE_VAR", + SDTVecShuffleVar, []>; +def series_vector : SDNode<"ISD::SERIES_VECTOR", SDTSeriesVector, []>; def build_vector : SDNode<"ISD::BUILD_VECTOR", SDTypeProfile<1, -1, []>, []>; def scalar_to_vector : SDNode<"ISD::SCALAR_TO_VECTOR", SDTypeProfile<1, 1, []>, []>; +def splat_vector : SDNode<"ISD::SPLAT_VECTOR", SDTSplatVector, []>; // vector_extract/vector_insert are deprecated. extractelt/insertelt // are preferred. @@ -564,7 +581,6 @@ def assertsext : SDNode<"ISD::AssertSext", SDT_assertext>; def assertzext : SDNode<"ISD::AssertZext", SDT_assertext>; - //===----------------------------------------------------------------------===// // Selection DAG Condition Codes @@ -788,18 +804,29 @@ } // extending load fragments. -def extload : PatFrag<(ops node:$ptr), (unindexedload node:$ptr)> { +def extload : PatFrag<(ops node:$ptr), (unindexedload node:$ptr), [{ + return cast(N)->getExtensionType() == ISD::EXTLOAD; +}]> { let IsLoad = 1; let IsAnyExtLoad = 1; } -def sextload : PatFrag<(ops node:$ptr), (unindexedload node:$ptr)> { + +def sextload : PatFrag<(ops node:$ptr), (unindexedload node:$ptr), [{ + return cast(N)->getExtensionType() == ISD::SEXTLOAD; +}]> { let IsLoad = 1; let IsSignExtLoad = 1; } -def zextload : PatFrag<(ops node:$ptr), (unindexedload node:$ptr)> { +def zextload : PatFrag<(ops node:$ptr), (unindexedload node:$ptr), [{ + return cast(N)->getExtensionType() == ISD::ZEXTLOAD; +}]> { let IsLoad = 1; let IsZeroExtLoad = 1; } +def azextload : PatFrag<(ops node:$ptr), (unindexedload node:$ptr), [{ + auto Type = cast(N)->getExtensionType(); + return Type == ISD::ZEXTLOAD || Type == ISD::EXTLOAD; +}]>; def extloadi1 : PatFrag<(ops node:$ptr), (extload node:$ptr)> { let IsLoad = 1; @@ -860,6 +887,20 @@ let MemoryVT = i32; } +def azextloadi1 : PatFrag<(ops node:$ptr), (azextload node:$ptr), [{ + return cast(N)->getMemoryVT() == MVT::i1; +}]>; +def azextloadi8 : PatFrag<(ops node:$ptr), (azextload node:$ptr), [{ + return cast(N)->getMemoryVT() == MVT::i8; +}]>; +def azextloadi16 : PatFrag<(ops node:$ptr), (azextload node:$ptr), [{ + return cast(N)->getMemoryVT() == MVT::i16; +}]>; +def azextloadi32 : PatFrag<(ops node:$ptr), (azextload node:$ptr), [{ + return cast(N)->getMemoryVT() == MVT::i32; +}]>; + + def extloadvi1 : PatFrag<(ops node:$ptr), (extload node:$ptr)> { let IsLoad = 1; let ScalarMemoryVT = i1; @@ -919,6 +960,19 @@ let ScalarMemoryVT = i32; } +def azextloadvi1 : PatFrag<(ops node:$ptr), (azextload node:$ptr), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i1; +}]>; +def azextloadvi8 : PatFrag<(ops node:$ptr), (azextload node:$ptr), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def azextloadvi16 : PatFrag<(ops node:$ptr), (azextload node:$ptr), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def azextloadvi32 : PatFrag<(ops node:$ptr), (azextload node:$ptr), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i32; +}]>; + // store fragments. def unindexedstore : PatFrag<(ops node:$val, node:$ptr), (st node:$val, node:$ptr)> { @@ -963,14 +1017,22 @@ let MemoryVT = f64; } +def truncstorevi1 : PatFrag<(ops node:$val, node:$ptr), + (truncstore node:$val, node:$ptr), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i1; +}]>; def truncstorevi8 : PatFrag<(ops node:$val, node:$ptr), - (truncstore node:$val, node:$ptr)> { + (truncstore node:$val, node:$ptr), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i8; +}]> { let IsStore = 1; let ScalarMemoryVT = i8; } def truncstorevi16 : PatFrag<(ops node:$val, node:$ptr), - (truncstore node:$val, node:$ptr)> { + (truncstore node:$val, node:$ptr), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i16; +}]> { let IsStore = 1; let ScalarMemoryVT = i16; } Index: include/llvm/Transforms/IPO/PassManagerBuilder.h =================================================================== --- include/llvm/Transforms/IPO/PassManagerBuilder.h +++ include/llvm/Transforms/IPO/PassManagerBuilder.h @@ -147,6 +147,7 @@ bool DisableUnrollLoops; bool SLPVectorize; bool LoopVectorize; + bool SearchLoopVectorize; bool RerollLoops; bool NewGVN; bool DisableGVNLoadPRE; Index: include/llvm/Transforms/LVCommon.h =================================================================== --- /dev/null +++ include/llvm/Transforms/LVCommon.h @@ -0,0 +1,33 @@ +#ifndef LLVM_TRANSFORMS_VECTORIZE_LVCOMMON_H +#define LLVM_TRANSFORMS_VECTORIZE_LVCOMMON_H + +#include "llvm/Support/CommandLine.h" + +namespace llvm { + + +extern cl::opt EnableIfConversion; +extern cl::opt TinyTripCountVectorThreshold; +extern cl::opt MaximizeBandwidth; +extern cl::opt EnableMemAccessVersioning; +extern cl::opt EnableInterleavedMemAccesses; +extern cl::opt MaxInterleaveGroupFactor; +extern cl::opt ForceTargetNumScalarRegs; +extern cl::opt ForceTargetNumVectorRegs; +extern cl::opt ForceTargetMaxScalarInterleaveFactor; +extern cl::opt ForceTargetMaxVectorInterleaveFactor; +extern cl::opt ForceTargetInstructionCost; +extern cl::opt SmallLoopCost; +extern cl::opt LoopVectorizeWithBlockFrequency; +extern cl::opt EnableLoadStoreRuntimeInterleave; +extern cl::opt NumberOfStoresToPredicate; +extern cl::opt EnableIndVarRegisterHeur; +extern cl::opt EnableCondStoresVectorization; +extern cl::opt MaxNestedScalarReductionIC; +extern cl::opt PragmaVectorizeMemoryCheckThreshold; +extern cl::opt VectorizeSCEVCheckThreshold; +extern cl::opt PragmaVectorizeSCEVCheckThreshold; + +} + +#endif Index: include/llvm/Transforms/Scalar.h =================================================================== --- include/llvm/Transforms/Scalar.h +++ include/llvm/Transforms/Scalar.h @@ -204,6 +204,12 @@ //===----------------------------------------------------------------------===// // +// LoopExprTreeFactoring - +// +Pass *createLoopExprTreeFactoringPass(); + +//===----------------------------------------------------------------------===// +// // LoopRotate - This pass is a simple loop rotating pass. // Pass *createLoopRotatePass(int MaxHeaderSize = -1); @@ -216,12 +222,41 @@ //===----------------------------------------------------------------------===// // +// LoopVectorizationAnalysis - Determines whether a loop can be vectorized +// TODO: Is this the right place? +// +Pass *createLVAPass(); + +//===----------------------------------------------------------------------===// +// // LoopVersioningLICM - This pass is a loop versioning pass for LICM. // Pass *createLoopVersioningLICMPass(); //===----------------------------------------------------------------------===// // +// LoopSpeculativeBoundsCheck - Determines whether speculatively checking loop +// bounds can be used to enable later optimization if the initial value of a +// load can guarantee no aliasing will occur. +// +Pass *createLoopSpeculativeBoundsCheckPass(); + +//===----------------------------------------------------------------------===// +// +// PromoteMemoryToRegister - This pass is used to promote memory references to +// be register references. A simple example of the transformation performed by +// this pass is: +// +// FROM CODE TO CODE +// %X = alloca i32, i32 1 ret i32 42 +// store i32 42, i32 *%X +// %Y = load i32* %X +// ret i32 %Y +// +FunctionPass *createPromoteMemoryToRegisterPass(); + +//===----------------------------------------------------------------------===// +// // DemoteRegisterToMemoryPass - This pass is used to demote registers to memory // references. In basically undoes the PromoteMemoryToRegister pass to make cfg // hacking easier. @@ -254,7 +289,8 @@ FunctionPass *createCFGSimplificationPass( unsigned Threshold = 1, bool ForwardSwitchCond = false, bool ConvertSwitch = false, bool KeepLoops = true, bool SinkCommon = false, - std::function Ftor = nullptr); + std::function Ftor = nullptr, + unsigned SwitchRemovalThreshold = 0); //===----------------------------------------------------------------------===// // @@ -406,6 +442,13 @@ //===----------------------------------------------------------------------===// // +// SeparateInvariantFromGepOffset - Split Invariants from complex GEP offsets +// +FunctionPass * +createSeparateInvariantsFromGepOffsetPass(); + +//===----------------------------------------------------------------------===// +// // SpeculativeExecution - Aggressively hoist instructions to enable // speculative execution on targets where branches are expensive. // @@ -487,6 +530,10 @@ //===----------------------------------------------------------------------===// // +// LoopRewriteGEPs - +// +Pass *createLoopRewriteGEPsPass(); + // LoopSimplifyCFG - This pass performs basic CFG simplification on loops, // primarily to help other loop passes. // Index: include/llvm/Transforms/Scalar/GVN.h =================================================================== --- include/llvm/Transforms/Scalar/GVN.h +++ include/llvm/Transforms/Scalar/GVN.h @@ -68,6 +68,7 @@ class GVN : public PassInfoMixin { public: struct Expression; + bool DisablePREsWhichIntroduceArtificialLoopDep = false; /// Run the pass over the function. PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); @@ -152,6 +153,7 @@ friend class gvn::GVNLegacyPass; friend struct DenseMapInfo; + LoopInfo *Loops; MemoryDependenceResults *MD; DominatorTree *DT; const TargetLibraryInfo *TLI; @@ -289,7 +291,9 @@ /// Create a legacy GVN pass. This also allows parameterizing whether or not /// loads are eliminated by the pass. -FunctionPass *createGVNPass(bool NoLoads = false); +FunctionPass * +createGVNPass(bool NoLoads = false, + bool DisablePREsWhichIntroduceArtificialLoopDep = false); /// A simple and fast domtree-based GVN pass to hoist common expressions /// from sibling branches. Index: include/llvm/Transforms/Scalar/MemCpyOptimizer.h =================================================================== --- include/llvm/Transforms/Scalar/MemCpyOptimizer.h +++ include/llvm/Transforms/Scalar/MemCpyOptimizer.h @@ -71,6 +71,9 @@ Instruction *tryMergingIntoMemset(Instruction *I, Value *StartPtr, Value *ByteVal); + bool isFortranMZero(CallSite CS); + bool processFortranMZero(CallSite CS, BasicBlock::iterator &BBI); + bool iterateOnFunction(Function &F); }; Index: include/llvm/Transforms/Utils/Local.h =================================================================== --- include/llvm/Transforms/Utils/Local.h +++ include/llvm/Transforms/Utils/Local.h @@ -66,6 +66,7 @@ bool ConvertSwitchToLookupTable; bool NeedCanonicalLoop; bool SinkCommonInsts; + unsigned SwitchRemovalThreshold; AssumptionCache *AC; SimplifyCFGOptions(unsigned BonusThreshold = 1, @@ -78,6 +79,7 @@ ConvertSwitchToLookupTable(SwitchToLookup), NeedCanonicalLoop(CanonicalLoops), SinkCommonInsts(SinkCommon), + SwitchRemovalThreshold(0), AC(AssumpCache) {} // Support 'builder' pattern to set members by name at construction time. Index: include/llvm/Transforms/Utils/LoopUtils.h =================================================================== --- include/llvm/Transforms/Utils/LoopUtils.h +++ include/llvm/Transforms/Utils/LoopUtils.h @@ -74,7 +74,11 @@ RK_IntegerMinMax, ///< Min/max implemented in terms of select(cmp()). RK_FloatAdd, ///< Sum of floats. RK_FloatMult, ///< Product of floats. - RK_FloatMinMax ///< Min/max implemented in terms of select(cmp()). + RK_FloatMinMax, ///< Min/max implemented in terms of select(cmp()). + RK_ConstSelectICmp, ///< select(icmp(), X, Y) where one of (X,Y) is a + ///< constant integer. + RK_ConstSelectFCmp ///< select(fcmp(), X, Y) where one of (X,Y) is a + ///< constant integer. }; // This enum represents the kind of minmax recurrence. @@ -88,14 +92,24 @@ MRK_FloatMax }; + typedef SmallVector ExitInstrList; + RecurrenceDescriptor() = default; - RecurrenceDescriptor(Value *Start, Instruction *Exit, RecurrenceKind K, + // TODO: Is there a nice way to initialize this without the manual loop? + // Maybe, does the insert thing work with SmallVectorImpl? (CastInsts + // was added with the recent merge) + RecurrenceDescriptor(Value *Start, SmallVectorImpl &Exits, + StoreInst* Store, RecurrenceKind K, MinMaxRecurrenceKind MK, Instruction *UAI, Type *RT, - bool Signed, SmallPtrSetImpl &CI) - : StartValue(Start), LoopExitInstr(Exit), Kind(K), MinMaxKind(MK), - UnsafeAlgebraInst(UAI), RecurrenceType(RT), IsSigned(Signed) { + bool Signed, SmallPtrSetImpl &CI, + bool IsOrdered) + : IntermediateStore(Store), StartValue(Start), Kind(K), MinMaxKind(MK), + UnsafeAlgebraInst(UAI), RecurrenceType(RT), IsSigned(Signed), + IsOrdered(IsOrdered) { CastInsts.insert(CI.begin(), CI.end()); + for (Instruction *Exit : Exits) + LoopExitInstrs.push_back(Exit); } /// This POD struct holds information about a potential recurrence operation. @@ -151,6 +165,12 @@ /// or max(X, Y). static InstDesc isMinMaxSelectCmpPattern(Instruction *I, InstDesc &Prev); + /// Returns a struct describing whether the instruction is either a + /// Select(ICmp(A, B), X, Y), or + /// Select(FCmp(A, B), X, Y) + /// where one of (X, Y) is a constant integer + static InstDesc isConstSelectCmpPattern(Instruction *I, InstDesc &Prev); + /// Returns identity corresponding to the RecurrenceKind. static Constant *getRecurrenceIdentity(RecurrenceKind K, Type *Tp); @@ -168,21 +188,27 @@ /// computed. static bool AddReductionVar(PHINode *Phi, RecurrenceKind Kind, Loop *TheLoop, bool HasFunNoNaNAttr, + ScalarEvolution *SE, RecurrenceDescriptor &RedDes, + bool AllowMultipleExits, DemandedBits *DB = nullptr, AssumptionCache *AC = nullptr, DominatorTree *DT = nullptr); - /// Returns true if Phi is a reduction in TheLoop. The RecurrenceDescriptor - /// is returned in RedDes. If either \p DB is non-null or \p AC and \p DT are - /// non-null, the minimal bit width needed to compute the reduction will be - /// computed. + /// Returns true if Phi is a reduction in TheLoop. The RecurrenceDescriptor is + /// returned in RedDes. static bool isReductionPHI(PHINode *Phi, Loop *TheLoop, + ScalarEvolution *SE, RecurrenceDescriptor &RedDes, + bool AllowMultipleExits = false, DemandedBits *DB = nullptr, AssumptionCache *AC = nullptr, DominatorTree *DT = nullptr); + RecurrenceKind getRecurrenceKind() const { return Kind; } + + MinMaxRecurrenceKind getMinMaxRecurrenceKind() const { return MinMaxKind; } + /// Returns true if Phi is a first-order recurrence. A first-order recurrence /// is a non-reduction recurrence relation in which the value of the /// recurrence in the current loop iteration equals a value defined in the @@ -195,13 +221,36 @@ DenseMap &SinkAfter, DominatorTree *DT); - RecurrenceKind getRecurrenceKind() { return Kind; } + TrackingVH getRecurrenceStartValue() const { return StartValue; } - MinMaxRecurrenceKind getMinMaxRecurrenceKind() { return MinMaxKind; } + // For the 'normal' loop vectorizer + Instruction *getLoopExitInstr() const { +#ifndef NDEBUG + if (LoopExitInstrs.size() != 1) { + auto NumExits = LoopExitInstrs.size(); + errs() << "getLoopExitInstr: NumExits: " << NumExits << "\n"; + for (auto *E : LoopExitInstrs) + E->dump(); + // Previously asserted here; keeping the print above but allowing + // for progression. + // + // When using debug to examine all problems with loops from + // the vectorizer, we just return a null here to indicate that + // there isn't a *single* instruction here, in much the same + // way the LoopInfo functions to get exiting/exit blocks work. + return nullptr; + } +#endif + return LoopExitInstrs.back(); + } - TrackingVH getRecurrenceStartValue() { return StartValue; } + ExitInstrList *getLoopExitInstrs() { return &LoopExitInstrs; } - Instruction *getLoopExitInstr() { return LoopExitInstr; } + // TODO: Something more than Instruction ptrs ? + // Would make for a nicer iterator than Inst** + iterator_range exitInstrs() { + return make_range(LoopExitInstrs.begin(), LoopExitInstrs.end()); + } /// Returns true if the recurrence has unsafe algebra which requires a relaxed /// floating-point model. @@ -221,7 +270,7 @@ /// Returns the type of the recurrence. This type can be narrower than the /// actual type of the Phi if the recurrence has been type-promoted. - Type *getRecurrenceType() { return RecurrenceType; } + Type *getRecurrenceType() const { return RecurrenceType; } /// Returns a reference to the instructions used for type-promoting the /// recurrence. @@ -230,12 +279,20 @@ /// Returns true if all source operands of the recurrence are SExtInsts. bool isSigned() { return IsSigned; } + /// The list of intermediate stores of reductions + StoreInst * IntermediateStore = nullptr; + + /// Expose an ordered FP reduction to the instance users. + bool isOrdered() const { return IsOrdered; } + private: // The starting value of the recurrence. // It does not have to be zero! TrackingVH StartValue; // The instruction who's value is used outside the loop. - Instruction *LoopExitInstr = nullptr; + // Instruction *LoopExitInstr; + // TODO: Do we need to match the exit block to the instruction? + ExitInstrList LoopExitInstrs; // The kind of the recurrence. RecurrenceKind Kind = RK_NoRecurrence; // If this a min/max recurrence the kind of recurrence. @@ -248,6 +305,8 @@ bool IsSigned = false; // Instructions used for type-promoting the recurrence. SmallPtrSet CastInsts; + // If this is an ordered reduction + bool IsOrdered = false; }; /// A struct for saving information about induction variables. @@ -541,6 +600,12 @@ RecurrenceDescriptor &Desc, Value *Src, bool NoNaN = false); + +/// Create an ordered reduction intrinsic using the given recurrence +/// descriptor\p Desc. +Value *createOrderedReduction(IRBuilder<> &Builder, RecurrenceDescriptor &Desc, + Value *Src, Value *Start, Value *Predicate); + /// Get the intersection (logical and) of all of the potential IR flags /// of each scalar operation (VL) that will be converted into a vector (I). /// If OpValue is non-null, we only consider operations similar to OpValue @@ -548,6 +613,8 @@ /// Flag set: NSW, NUW, exact, and all of fast-math. void propagateIRFlags(Value *I, ArrayRef VL, Value *OpValue = nullptr); + +bool storeToSameAddress(ScalarEvolution *SE, StoreInst *A, StoreInst *B); } // end namespace llvm #endif // LLVM_TRANSFORMS_UTILS_LOOPUTILS_H Index: include/llvm/Transforms/Utils/SimplifyLibCalls.h =================================================================== --- include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -133,6 +133,7 @@ Value *optimizeCAbs(CallInst *CI, IRBuilder<> &B); Value *optimizeCos(CallInst *CI, IRBuilder<> &B); Value *optimizePow(CallInst *CI, IRBuilder<> &B); + Value *replacePowWithExp(CallInst *Pow, IRBuilder<> &B); Value *replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B); Value *optimizeExp2(CallInst *CI, IRBuilder<> &B); Value *optimizeFMinFMax(CallInst *CI, IRBuilder<> &B); Index: include/llvm/Transforms/Vectorize.h =================================================================== --- include/llvm/Transforms/Vectorize.h +++ include/llvm/Transforms/Vectorize.h @@ -115,11 +115,31 @@ //===----------------------------------------------------------------------===// // +// LoopVectorize - Create an SVE loop vectorization pass. +// +Pass *createSVELoopVectorizePass(bool NoUnrolling = false, + bool AlwaysVectorize = true); + +//===----------------------------------------------------------------------===// +// +// SearchLoopVectorize - Create a search loop vectorization pass. +// +Pass *createSearchLoopVectorizePass(bool NoUnrolling = false, + bool AlwaysVectorize = true); + +//===----------------------------------------------------------------------===// +// // SLPVectorizer - Create a bottom-up SLP vectorizer pass. // Pass *createSLPVectorizerPass(); //===----------------------------------------------------------------------===// +// +// BOSCC - Create a BOSCC post vectorizer pass. +// +Pass *createBOSCCPass(); + +//===----------------------------------------------------------------------===// /// Vectorize the BasicBlock. /// /// @param BB The BasicBlock to be vectorized Index: lib/Analysis/AliasAnalysis.cpp =================================================================== --- lib/Analysis/AliasAnalysis.cpp +++ lib/Analysis/AliasAnalysis.cpp @@ -184,7 +184,7 @@ if (doesAccessArgPointees(MRB)) { for (auto AI = CS.arg_begin(), AE = CS.arg_end(); AI != AE; ++AI) { const Value *Arg = *AI; - if (!Arg->getType()->isPointerTy()) + if (!Arg->getType()->isPtrOrPtrVectorTy()) continue; unsigned ArgIdx = std::distance(CS.arg_begin(), AI); MemoryLocation ArgLoc = MemoryLocation::getForArgument(CS, ArgIdx, TLI); @@ -260,7 +260,7 @@ bool IsMustAlias = true; for (auto I = CS2.arg_begin(), E = CS2.arg_end(); I != E; ++I) { const Value *Arg = *I; - if (!Arg->getType()->isPointerTy()) + if (!Arg->getType()->isPtrOrPtrVectorTy()) continue; unsigned CS2ArgIdx = std::distance(CS2.arg_begin(), I); auto CS2ArgLoc = MemoryLocation::getForArgument(CS2, CS2ArgIdx, TLI); @@ -309,7 +309,7 @@ bool IsMustAlias = true; for (auto I = CS1.arg_begin(), E = CS1.arg_end(); I != E; ++I) { const Value *Arg = *I; - if (!Arg->getType()->isPointerTy()) + if (!Arg->getType()->isPtrOrPtrVectorTy()) continue; unsigned CS1ArgIdx = std::distance(CS1.arg_begin(), I); auto CS1ArgLoc = MemoryLocation::getForArgument(CS1, CS1ArgIdx, TLI); @@ -579,7 +579,7 @@ // Only look at the no-capture or byval pointer arguments. If this // pointer were passed to arguments that were neither of these, then it // couldn't be no-capture. - if (!(*CI)->getType()->isPointerTy() || + if (!(*CI)->getType()->isPtrOrPtrVectorTy() || (!CS.doesNotCapture(ArgNo) && ArgNo < CS.getNumArgOperands() && !CS.isByValArgument(ArgNo))) continue; Index: lib/Analysis/AliasSetTracker.cpp =================================================================== --- lib/Analysis/AliasSetTracker.cpp +++ lib/Analysis/AliasSetTracker.cpp @@ -459,6 +459,36 @@ AS->addUnknownInst(Inst, AA); } +static bool isVectorMemIntrinsic(Instruction *I, bool &IsWrite) { + if (const IntrinsicInst *II = dyn_cast(I)) { + switch (II->getIntrinsicID()) { + case Intrinsic::masked_load: + case Intrinsic::masked_spec_load: + case Intrinsic::masked_gather: + IsWrite = false; + return true; + case Intrinsic::masked_store: + case Intrinsic::masked_scatter: + IsWrite = true; + return true; + default: + return false; + } + } + return false; +} + +void AliasSetTracker::add(IntrinsicInst *I, bool IsWrite) { + AAMDNodes AAInfo; + I->getAAMetadata(AAInfo); + Value *Ptr = I->getArgOperand(IsWrite ? 1 : 0); + AliasSet::AccessLattice Access = + IsWrite ? AliasSet::ModAccess : AliasSet::RefAccess; + // TODO: For fixed-width unmasked contiguous accesses we can find + // the access size, for now just assume unknown. + addPointer(Ptr, MemoryLocation::UnknownSize, AAInfo, Access); +} + void AliasSetTracker::add(Instruction *I) { // Dispatch to one of the other add methods. if (LoadInst *LI = dyn_cast(I)) @@ -467,6 +497,11 @@ return add(SI); if (VAArgInst *VAAI = dyn_cast(I)) return add(VAAI); + // Handle vector masked intrinsics by examining their pointer arguments + // like ordinary load/stores. + bool IsWrite; + if (isVectorMemIntrinsic(I, IsWrite)) + return add(cast(I), IsWrite); if (AnyMemSetInst *MSI = dyn_cast(I)) return add(MSI); if (AnyMemTransferInst *MTI = dyn_cast(I)) Index: lib/Analysis/BasicAliasAnalysis.cpp =================================================================== --- lib/Analysis/BasicAliasAnalysis.cpp +++ lib/Analysis/BasicAliasAnalysis.cpp @@ -475,6 +475,14 @@ return false; } + // Don't attempt to analyze GEPs with gather/scatter semantics + // TODO: For fixed-width with constant offsets, or for a constant-like + // seriesvector, we may be able to do better than this. + if (GEPOp->getType()->isVectorTy()) { + Decomposed.Base = V; + return false; + } + unsigned AS = GEPOp->getPointerAddressSpace(); // Walk the indices of the GEP, accumulating them into BaseOff/VarIndices. gep_type_iterator GTI = gep_type_begin(GEPOp); @@ -829,9 +837,10 @@ // Only look at the no-capture or byval pointer arguments. If this // pointer were passed to arguments that were neither of these, then it // couldn't be no-capture. - if (!(*CI)->getType()->isPointerTy() || + if (!(*CI)->getType()->isPtrOrPtrVectorTy() || (!CS.doesNotCapture(OperandNo) && - OperandNo < CS.getNumArgOperands() && !CS.isByValArgument(OperandNo))) + OperandNo < CS.getNumArgOperands() && + !CS.isByValArgument(OperandNo))) continue; // Call doesn't access memory through this operand, so we don't care @@ -1227,9 +1236,14 @@ int64_t GEP1BaseOffset = DecompGEP1.StructOffset + DecompGEP1.OtherOffset; int64_t GEP2BaseOffset = DecompGEP2.StructOffset + DecompGEP2.OtherOffset; - assert(DecompGEP1.Base == UnderlyingV1 && DecompGEP2.Base == UnderlyingV2 && - "DecomposeGEPExpression returned a result different from " - "GetUnderlyingObject"); + // DecomposeGEPExpression and GetUnderlyingObject should return the + // same result except when DecomposeGEPExpression has no DataLayout. + // FIXME: They always have a DataLayout, so this should become an + // assert. + // Another possibility is that a vector GEP prevents decomposition. + if (DecompGEP1.Base != UnderlyingV1 || DecompGEP2.Base != UnderlyingV2) { + return MayAlias; + } // If the GEP's offset relative to its base is such that the base would // fall below the start of the object underlying V2, then the GEP and V2 @@ -1653,7 +1667,8 @@ if (isValueEqualInPotentialCycles(V1, V2)) return MustAlias; - if (!V1->getType()->isPointerTy() || !V2->getType()->isPointerTy()) + if (!V1->getType()->isPtrOrPtrVectorTy() || + !V2->getType()->isPtrOrPtrVectorTy()) return NoAlias; // Scalars cannot alias each other // Figure out what objects these things are pointing to if we can. Index: lib/Analysis/CallGraphSCCPass.cpp =================================================================== --- lib/Analysis/CallGraphSCCPass.cpp +++ lib/Analysis/CallGraphSCCPass.cpp @@ -643,7 +643,6 @@ StringRef getPassName() const override { return "Print CallGraph IR"; } }; - } // end anonymous namespace. char PrintCallGraphPass::ID = 0; Index: lib/Analysis/CaptureTracking.cpp =================================================================== --- lib/Analysis/CaptureTracking.cpp +++ lib/Analysis/CaptureTracking.cpp @@ -28,6 +28,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" using namespace llvm; @@ -213,7 +214,7 @@ static int const Threshold = 20; void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker) { - assert(V->getType()->isPointerTy() && "Capture is for pointers only!"); + assert(V->getType()->isPtrOrPtrVectorTy() && "Capture is for pointers only!"); SmallVector Worklist; SmallSet Visited; @@ -264,6 +265,31 @@ if (MI->isVolatile()) if (Tracker->captured(U)) return; + if (const auto *F = CS.getCalledFunction()) { + if (const auto IID = F->getIntrinsicID()) { + // Permit masked loads so long as the use isn't the merge value. + if (IID == Intrinsic::masked_load || + IID == Intrinsic::masked_spec_load || + IID == Intrinsic::masked_gather) { + // Initial conservative check that the merge value is undef. + // TODO: Also allow all-zeroes... + if (!isa(I->getOperand(3))) + if (Tracker->captured(U)) + return; + break; + } + + // Disallow masked stores if the data is the value (not really + // possible at this point; splats get discarded by way of hitting + // the default case for vector inserts and shuffles. + if (IID == Intrinsic::masked_store || + IID == Intrinsic::masked_scatter) { + if (V == I->getOperand(0)) + if (Tracker->captured(U)) + return; + } + } + } // Not captured if only passed via 'nocapture' arguments. Note that // calling a function pointer does not in itself cause the pointer to Index: lib/Analysis/ConstantFolding.cpp =================================================================== --- lib/Analysis/ConstantFolding.cpp +++ lib/Analysis/ConstantFolding.cpp @@ -503,7 +503,7 @@ MapTy = Type::getInt32Ty(C->getContext()); else if (LoadTy->isDoubleTy()) MapTy = Type::getInt64Ty(C->getContext()); - else if (LoadTy->isVectorTy()) { + else if (LoadTy->isVectorTy() && !LoadTy->getVectorIsScalable()) { MapTy = PointerType::getIntNTy(C->getContext(), DL.getTypeAllocSizeInBits(LoadTy)); } else @@ -819,9 +819,8 @@ // "inttoptr (sub (ptrtoint Ptr), V)" if (Ops.size() == 2 && ResElemTy->isIntegerTy(8)) { auto *CE = dyn_cast(Ops[1]); - assert((!CE || CE->getType() == IntPtrTy) && - "CastGEPIndices didn't canonicalize index types!"); if (CE && CE->getOpcode() == Instruction::Sub && + CE->getType() == IntPtrTy && CE->getOperand(0)->isNullValue()) { Constant *Res = ConstantExpr::getPtrToInt(Ptr, CE->getType()); Res = ConstantExpr::getSub(Res, CE->getOperand(1)); @@ -997,8 +996,15 @@ Type *DestTy = InstOrCE->getType(); // Handle easy binops first. - if (Instruction::isBinaryOp(Opcode)) - return ConstantFoldBinaryOpOperands(Opcode, Ops[0], Ops[1], DL); + if (Instruction::isBinaryOp(Opcode)) { + bool NUW = false, NSW = false; + if (auto OBO = dyn_cast(InstOrCE)) { + NUW = OBO->hasNoUnsignedWrap(); + NSW = OBO->hasNoSignedWrap(); + } + + return ConstantFoldBinaryOpOperands(Opcode, Ops[0], Ops[1], DL, NUW, NSW); + } if (Instruction::isCast(Opcode)) return ConstantFoldCastOperand(Opcode, Ops[0], DestTy, DL); @@ -1261,13 +1267,16 @@ Constant *llvm::ConstantFoldBinaryOpOperands(unsigned Opcode, Constant *LHS, Constant *RHS, - const DataLayout &DL) { + const DataLayout &DL, + bool HasNUW, bool HasNSW) { assert(Instruction::isBinaryOp(Opcode)); if (isa(LHS) || isa(RHS)) if (Constant *C = SymbolicallyEvaluateBinop(Opcode, LHS, RHS, DL)) return C; - return ConstantExpr::get(Opcode, LHS, RHS); + unsigned Flags = (HasNUW ? OverflowingBinaryOperator::NoUnsignedWrap : 0) | + (HasNSW ? OverflowingBinaryOperator::NoSignedWrap : 0); + return ConstantExpr::get(Opcode, LHS, RHS, Flags); } Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C, @@ -2022,6 +2031,10 @@ SmallVector Lane(Operands.size()); Type *Ty = VTy->getElementType(); + // This function currently only supports non-scalable vectors + if (VTy->isScalable()) + return nullptr; + if (IntrinsicID == Intrinsic::masked_load) { auto *SrcPtr = Operands[0]; auto *Mask = Operands[2]; Index: lib/Analysis/InlineCost.cpp =================================================================== --- lib/Analysis/InlineCost.cpp +++ lib/Analysis/InlineCost.cpp @@ -406,21 +406,24 @@ /// /// Respects any simplified values known during the analysis of this callsite. bool CallAnalyzer::isGEPFree(GetElementPtrInst &GEP) { - SmallVector Operands; - Operands.push_back(GEP.getOperand(0)); + SmallVector Indices; for (User::op_iterator I = GEP.idx_begin(), E = GEP.idx_end(); I != E; ++I) if (Constant *SimpleOp = SimplifiedValues.lookup(*I)) - Operands.push_back(SimpleOp); + Indices.push_back(SimpleOp); else - Operands.push_back(*I); - return TargetTransformInfo::TCC_Free == TTI.getUserCost(&GEP, Operands); + Indices.push_back(*I); + return TargetTransformInfo::TCC_Free == + TTI.getGEPCost(GEP.getSourceElementType(), GEP.getPointerOperand(), + Indices); } bool CallAnalyzer::visitAlloca(AllocaInst &I) { // Check whether inlining will turn a dynamic alloca into a static // alloca and handle that case. if (I.isArrayAllocation()) { - Constant *Size = SimplifiedValues.lookup(I.getArraySize()); + auto *Size = I.getArraySize(); + if (!isa(Size)) + Size = SimplifiedValues.lookup(Size); if (auto *AllocSize = dyn_cast_or_null(Size)) { Type *Ty = I.getAllocatedType(); AllocatedSize = SaturatingMultiplyAdd( Index: lib/Analysis/InstructionSimplify.cpp =================================================================== --- lib/Analysis/InstructionSimplify.cpp +++ lib/Analysis/InstructionSimplify.cpp @@ -683,9 +683,8 @@ } while (Visited.insert(V).second); Constant *OffsetIntPtr = ConstantInt::get(IntPtrTy, Offset); - if (V->getType()->isVectorTy()) - return ConstantVector::getSplat(V->getType()->getVectorNumElements(), - OffsetIntPtr); + if (auto *VTy = dyn_cast(V->getType())) + return ConstantVector::getSplat(VTy->getElementCount(), OffsetIntPtr); return OffsetIntPtr; } @@ -1789,6 +1788,10 @@ if (match(Op1, m_AllOnes())) return Op0; + // TODO: m_AllOnes needs to support scalable vectors + if (match(Op1, m_SplatVector(m_AllOnes()))) + return Op0; + // A & ~A = ~A & A = 0 if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) @@ -1922,6 +1925,14 @@ if (Op0 == Op1 || match(Op1, m_Zero())) return Op0; + // X | -1 = -1 + if (match(Op1, m_AllOnes())) + return Op1; + + // TODO: m_AllOnes needs to support scalable vectors + if (match(Op1, m_SplatVector(m_AllOnes()))) + return Op1; + // A | ~A = ~A | A = -1 if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) @@ -3830,6 +3841,10 @@ // select false, X, Y -> Y if (CondC->isNullValue()) return FalseVal; + + // TODO: m_AllOnes needs to support scalable vectors + if (match(CondC, m_SplatVector(m_AllOnes()))) + return TrueVal; } // select ?, X, X -> X @@ -3872,9 +3887,9 @@ Type *LastType = GetElementPtrInst::getIndexedType(SrcTy, Ops.slice(1)); Type *GEPTy = PointerType::get(LastType, AS); if (VectorType *VT = dyn_cast(Ops[0]->getType())) - GEPTy = VectorType::get(GEPTy, VT->getNumElements()); + GEPTy = VectorType::get(GEPTy, VT->getElementCount()); else if (VectorType *VT = dyn_cast(Ops[1]->getType())) - GEPTy = VectorType::get(GEPTy, VT->getNumElements()); + GEPTy = VectorType::get(GEPTy, VT->getElementCount()); if (isa(Ops[0])) return UndefValue::get(GEPTy); @@ -4022,7 +4037,7 @@ // Fold into undef if index is out of bounds. if (auto *CI = dyn_cast(Idx)) { uint64_t NumElements = cast(Vec->getType())->getNumElements(); - if (CI->uge(NumElements)) + if (!cast(Vec->getType())->isScalable() && CI->uge(NumElements)) return UndefValue::get(Vec->getType()); } @@ -4082,7 +4097,8 @@ // If extracting a specified index from the vector, see if we can recursively // find a previously computed scalar that was inserted into the vector. if (auto *IdxC = dyn_cast(Idx)) { - if (IdxC->getValue().uge(Vec->getType()->getVectorNumElements())) + if (!cast(Vec->getType())->isScalable() && + IdxC->getValue().uge(Vec->getType()->getVectorNumElements())) // definitely out of bounds, thus undefined result return UndefValue::get(Vec->getType()->getVectorElementType()); if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue())) @@ -4180,6 +4196,9 @@ static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1, int MaskVal, Value *RootVec, unsigned MaxRecurse) { + if (Op0->getType()->getVectorIsScalable()) + return nullptr; + if (!MaxRecurse--) return nullptr; @@ -4200,9 +4219,11 @@ // If the source operand is a shuffle itself, look through it to find the // matching root vector. if (auto *SourceShuf = dyn_cast(SourceOp)) { - return foldIdentityShuffles( - DestElt, SourceShuf->getOperand(0), SourceShuf->getOperand(1), - SourceShuf->getMaskValue(RootElt), RootVec, MaxRecurse); + int Res; + SourceShuf->getMaskValue(RootElt, Res); + return foldIdentityShuffles(DestElt, SourceShuf->getOperand(0), + SourceShuf->getOperand(1), Res, RootVec, + MaxRecurse); } // TODO: Look through bitcasts? What if the bitcast changes the vector element @@ -4225,7 +4246,7 @@ return RootVec; } -static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, +static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Value *Mask, Type *RetTy, const SimplifyQuery &Q, unsigned MaxRecurse) { if (isa(Mask)) @@ -4235,6 +4256,9 @@ unsigned MaskNumElts = Mask->getType()->getVectorNumElements(); unsigned InVecNumElts = InVecTy->getVectorNumElements(); + if (Mask->getType()->getVectorIsScalable()) + return nullptr; + SmallVector Indices; ShuffleVectorInst::getShuffleMask(Mask, Indices); assert(MaskNumElts == Indices.size() && @@ -4260,8 +4284,9 @@ auto *Op1Const = dyn_cast(Op1); // If all operands are constant, constant fold the shuffle. - if (Op0Const && Op1Const) - return ConstantFoldShuffleVectorInstruction(Op0Const, Op1Const, Mask); + if (Op0Const && Op1Const && isa(Mask)) + return ConstantFoldShuffleVectorInstruction(Op0Const, Op1Const, + cast(Mask)); // Canonicalization: if only one input vector is constant, it shall be the // second one. @@ -4273,8 +4298,8 @@ // A shuffle of a splat is always the splat itself. Legal if the shuffle's // value type is same as the input vectors' type. if (auto *OpShuf = dyn_cast(Op0)) - if (isa(Op1) && RetTy == InVecTy && - OpShuf->getMask()->getSplatValue()) + if (!InVecTy->getVectorIsScalable() && isa(Op1) && + RetTy == InVecTy && cast(OpShuf->getMask())->getSplatValue()) return Op0; // Don't fold a shuffle with undef mask elements. This may get folded in a @@ -4302,7 +4327,7 @@ } /// Given operands for a ShuffleVectorInst, fold the result or return null. -Value *llvm::SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, +Value *llvm::SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Value *Mask, Type *RetTy, const SimplifyQuery &Q) { return ::SimplifyShuffleVectorInst(Op0, Op1, Mask, RetTy, Q, RecursionLimit); } @@ -5031,6 +5056,10 @@ } case Instruction::ShuffleVector: { auto *SVI = cast(I); + if (cast(SVI->getType())->isScalable()) { + Result = ConstantFoldInstruction(I, Q.DL, Q.TLI); + break; + } Result = SimplifyShuffleVectorInst(SVI->getOperand(0), SVI->getOperand(1), SVI->getMask(), SVI->getType(), Q); break; Index: lib/Analysis/Loads.cpp =================================================================== --- lib/Analysis/Loads.cpp +++ lib/Analysis/Loads.cpp @@ -18,12 +18,15 @@ #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Statepoint.h" using namespace llvm; +using namespace llvm::PatternMatch; static bool isAligned(const Value *Base, const APInt &Offset, unsigned Align, const DataLayout &DL) { @@ -339,6 +342,19 @@ unsigned MaxInstsToScan, AliasAnalysis *AA, bool *IsLoadCSE, unsigned *NumScanedInst) { + auto AllTrueMask = AccessTy->isVectorTy() + ? ConstantInt::getTrue( + VectorType::getBool(cast(AccessTy))) + : ConstantInt::getTrue(Ptr->getContext()); + return FindAvailablePtrMaskedLoadStore( + Ptr, AllTrueMask, UndefValue::get(AccessTy), AccessTy, AtLeastAtomic, ScanBB, + ScanFrom, MaxInstsToScan, AA, IsLoadCSE, NumScanedInst); +} + +Value *llvm::FindAvailablePtrMaskedLoadStore( + Value *Ptr, Value *Mask, Value *Passthru, Type *AccessTy, bool AtLeastAtomic, + BasicBlock *ScanBB, BasicBlock::iterator &ScanFrom, unsigned MaxInstsToScan, + AliasAnalysis *AA, bool *IsLoadCSE, unsigned *NumScanedInst) { if (MaxInstsToScan == 0) MaxInstsToScan = ~0U; @@ -367,54 +383,70 @@ return nullptr; --ScanFrom; + // If this is a load of Ptr, the loaded value is available. // (This is true even if the load is volatile or atomic, although // those cases are unlikely.) - if (LoadInst *LI = dyn_cast(Inst)) - if (AreEquivalentAddressValues( - LI->getPointerOperand()->stripPointerCasts(), StrippedPtr) && - CastInst::isBitOrNoopPointerCastable(LI->getType(), AccessTy, DL)) { + Value *InstPtr = nullptr, *InstPassthru = nullptr, *InstMask = nullptr; + if (match(Inst, m_AnyLoad(m_Value(InstPtr), m_Value(), m_Value(InstMask), + m_Value(InstPassthru)))) { + auto *ConstInstMask = dyn_cast_or_null(InstMask); + bool maskAllOnes = ConstInstMask && ConstInstMask->isAllOnesValue(); + + if (AreEquivalentAddressValues(InstPtr->stripPointerCasts(), + StrippedPtr) && + CastInst::isBitOrNoopPointerCastable(Inst->getType(), AccessTy, DL) && + ((maskAllOnes && isa(Passthru)) || + (InstMask == Mask && + (Passthru == InstPassthru || (isa(Passthru)))))) { // We can value forward from an atomic to a non-atomic, but not the // other way around. - if (LI->isAtomic() < AtLeastAtomic) + if (Inst->isAtomic() < AtLeastAtomic) return nullptr; if (IsLoadCSE) - *IsLoadCSE = true; - return LI; - } + *IsLoadCSE = true; - if (StoreInst *SI = dyn_cast(Inst)) { - Value *StorePtr = SI->getPointerOperand()->stripPointerCasts(); + return Inst; + } + } + + Value *InstVal = nullptr; + if (match(Inst, m_AnyStore(m_Value(InstVal), m_Value(InstPtr), m_Value(), + m_Value(InstMask)))) { // If this is a store through Ptr, the value is available! // (This is true even if the store is volatile or atomic, although // those cases are unlikely.) - if (AreEquivalentAddressValues(StorePtr, StrippedPtr) && - CastInst::isBitOrNoopPointerCastable(SI->getValueOperand()->getType(), - AccessTy, DL)) { + auto *ConstInstMask = dyn_cast_or_null(InstMask); + bool maskAllOnes = ConstInstMask && ConstInstMask->isAllOnesValue(); + if (AreEquivalentAddressValues(InstPtr->stripPointerCasts(), + StrippedPtr) && + CastInst::isBitOrNoopPointerCastable(InstVal->getType(), AccessTy, + DL) && + ((InstMask == Mask || maskAllOnes) && isa(Passthru))) { // We can value forward from an atomic to a non-atomic, but not the // other way around. - if (SI->isAtomic() < AtLeastAtomic) + if (Inst->isAtomic() < AtLeastAtomic) return nullptr; if (IsLoadCSE) *IsLoadCSE = false; - return SI->getOperand(0); + return Inst->getOperand(0); } // If both StrippedPtr and StorePtr reach all the way to an alloca or // global and they are different, ignore the store. This is a trivial form // of alias analysis that is important for reg2mem'd code. if ((isa(StrippedPtr) || isa(StrippedPtr)) && - (isa(StorePtr) || isa(StorePtr)) && - StrippedPtr != StorePtr) + (isa(InstPtr) || isa(InstPtr)) && + StrippedPtr != InstPtr) continue; // If we have alias analysis and it says the store won't modify the loaded // value, ignore the store. - if (AA && !isModSet(AA->getModRefInfo(SI, StrippedPtr, AccessSize))) + if (AA && !isModSet(AA->getModRefInfo(Inst, StrippedPtr, AccessSize))) continue; // Otherwise the store that may or may not alias the pointer, bail out. Index: lib/Analysis/LoopAccessAnalysis.cpp =================================================================== --- lib/Analysis/LoopAccessAnalysis.cpp +++ lib/Analysis/LoopAccessAnalysis.cpp @@ -34,6 +34,7 @@ #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/BasicBlock.h" @@ -131,6 +132,10 @@ cl::desc("Enable conflict detection in loop-access analysis"), cl::init(true)); +static cl::opt EnableUncountedLAA( + "enable-laa-uncounted-loops", cl::init(false), cl::Hidden, + cl::desc("Enable access analysis of loops without a defined trip count")); + bool VectorizerParams::isInterleaveForced() { return ::VectorizationInterleave.getNumOccurrences() > 0; } @@ -198,12 +203,33 @@ const SCEV *ScStart; const SCEV *ScEnd; - if (SE->isLoopInvariant(Sc, Lp)) - ScStart = ScEnd = Sc; - else { + if (SE->isLoopInvariant(Sc, Lp)) { + auto &DL = Lp->getHeader()->getModule()->getDataLayout(); + auto *Ty = Ptr->getType()->getPointerElementType(); + auto *Size = SE->getConstant(Sc->getType(), DL.getTypeAllocSize(Ty)); + + ScStart = Sc; + ScEnd = SE->getAddExpr(Sc, Size); + } else { const SCEVAddRecExpr *AR = dyn_cast(Sc); - assert(AR && "Invalid addrec expression"); const SCEV *Ex = PSE.getBackedgeTakenCount(); + assert(AR && "Invalid addrec expression"); + + // TODO: Ensure we either have only one counted exit, + // or that we sort based on minimum exit count. + SmallVector ExitingBlocks; + Lp->getExitingBlocks(ExitingBlocks); + + if (ExitingBlocks.size() > 1) { + Ex = nullptr; + for (auto *EB : ExitingBlocks) { + Ex = SE->getExitCount(Lp, EB); + if (Ex != SE->getCouldNotCompute()) + break; + } + assert(Ex != nullptr && Ex != SE->getCouldNotCompute() && + "Lost counted loop exit"); + } ScStart = AR->getStart(); ScEnd = AR->evaluateAtIteration(Ex, *SE); @@ -350,7 +376,11 @@ // checking pointer group for each pointer. This is also required // for correctness, because in this case we can have checking between // pointers to the same underlying object. - if (!UseDependencies) { + // + // If we have strided accesses we want to always group checks, + // otherwise we can have two strided pointers with constant offset + // being compared against each other for the same underlying object. + if (!Strided && !UseDependencies) { for (unsigned I = 0; I < Pointers.size(); ++I) CheckingGroups.push_back(CheckingPtrGroup(I, *this)); return; @@ -500,10 +530,11 @@ typedef PointerIntPair MemAccessInfo; typedef SmallVector MemAccessInfoList; - AccessAnalysis(const DataLayout &Dl, Loop *TheLoop, AliasAnalysis *AA, + AccessAnalysis(const DataLayout &Dl, const TargetTransformInfo *TTI, + Loop *TheLoop, AliasAnalysis *AA, LoopInfo *LI, MemoryDepChecker::DepCandidates &DA, PredicatedScalarEvolution &PSE) - : DL(Dl), TheLoop(TheLoop), AST(*AA), LI(LI), DepCands(DA), + : DL(Dl), TTI(TTI), TheLoop(TheLoop), AST(*AA), LI(LI), DepCands(DA), IsRTCheckAnalysisNeeded(false), PSE(PSE) {} /// Register a load and whether it is only read from. @@ -567,6 +598,11 @@ MemAccessInfoList &getDependenciesToCheck() { return CheckDeps; } + /// Set of uncomputable pointers. + // + // Used when emitting no_vec_unknown_array_bounds insight. + SmallPtrSet UncomputablePtrs; + private: typedef SetVector PtrAccessSet; @@ -579,6 +615,8 @@ const DataLayout &DL; + const TargetTransformInfo *TTI; + /// The loop being checked. const Loop *TheLoop; @@ -592,6 +630,10 @@ //intrinsic property (such as TBAA metadata). AliasSetTracker AST; + /// Similar to above, but instead contains only the alias sets we care about + /// when doing runtime checks. + SmallVector ReducedAST; + LoopInfo *LI; /// Sets of potentially dependent accesses - members of one set share an @@ -620,6 +662,10 @@ static bool hasComputableBounds(PredicatedScalarEvolution &PSE, const ValueToValueMap &Strides, Value *Ptr, Loop *L, bool Assume) { + ScalarEvolution *SE = PSE.getSE(); + if (SE->isLoopInvariant(SE->getSCEV(Ptr), L)) + return true; + const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); // The bounds for loop-invariant pointer is trivial. @@ -637,9 +683,31 @@ return AR->isAffine(); } +/// \brief Check whether a loop has a defined trip count +static bool isCountedLoop(ScalarEvolution *SE, Loop *Lp) { + bool IsCounted = false; + // TODO: Cache exitingblocks. + // Alternatively, add a new method to SE. + SmallVector ExitingBlocks; + Lp->getExitingBlocks(ExitingBlocks); + // Lp->dump(); + for (auto *EB : ExitingBlocks) + if (SE->getExitCount(Lp, EB) != SE->getCouldNotCompute()) + IsCounted = true; + + + LLVM_DEBUG(dbgs() << "LAA: Loop has a counted exit? " << IsCounted << "\n"); + return IsCounted; +} + +// Forward reference - defined later in the file. +static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, + PredicatedScalarEvolution &PSE, const Loop *L); + /// Check whether a pointer address cannot wrap. static bool isNoWrap(PredicatedScalarEvolution &PSE, - const ValueToValueMap &Strides, Value *Ptr, Loop *L) { + const ValueToValueMap &Strides, Value *Ptr, + const Loop *L) { const SCEV *PtrScev = PSE.getSCEV(Ptr); if (PSE.getSE()->isLoopInvariant(PtrScev, L)) return true; @@ -648,6 +716,13 @@ if (Stride == 1 || PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW)) return true; + // This is a specific case of 'hasNoOverflow(Ptr)' but uses the more + // amenable 'isNoWrapAddRec' method that inspects the pointer for + // no wrapping when this cannot be determined from the SCEV alone. + if (auto PtrAddRec = dyn_cast(PtrScev)) + if (isNoWrapAddRec(Ptr, PtrAddRec, PSE, L)) + return true; + return false; } @@ -708,7 +783,7 @@ // We assign a consecutive id to access from different alias sets. // Accesses between different groups doesn't need to be checked. unsigned ASId = 1; - for (auto &AS : AST) { + for (auto &AS : ReducedAST) { int NumReadPtrChecks = 0; int NumWritePtrChecks = 0; bool CanDoAliasSetRT = true; @@ -720,7 +795,7 @@ SmallVector Retries; - for (auto A : AS) { + for (auto A : *AS) { Value *Ptr = A.getValue(); bool IsWrite = Accesses.count(MemAccessInfo(Ptr, true)); MemAccessInfo Access(Ptr, IsWrite); @@ -730,8 +805,53 @@ else ++NumReadPtrChecks; - if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId, TheLoop, - RunningDepId, ASId, ShouldCheckWrap, false)) { + // TODO: Investigate commit a87b055656b2e8a8f28ca5fc1c272a250366053c + // Community code previously just checked for a unit stride, but + // has now bundled that with invariant checking in a new function. + // It's possible this non-unit-stride check should also migrate to + // that new function, but it isn't really determining whether the + // pointer will wrap or not, just whether the stride is allowed by + // the backend. + // + // We do keep the check for the loop being counted, though. + bool SupportedStride = false; + if (ShouldCheckWrap) { + int Stride = getPtrStride(PSE, Ptr, TheLoop, StridesMap); + bool IsLoopInv = + PSE.getSE()->isLoopInvariant(PSE.getSE()->getSCEV(Ptr), TheLoop); + if (std::abs(Stride) > 1) { + SupportedStride = TTI->canVectorizeNonUnitStrides(); + RtCheck.Strided = true; + } else + SupportedStride = (IsLoopInv || Stride == 1); + } + + if (hasComputableBounds(PSE, StridesMap, Ptr, TheLoop, /*Assume=*/ false) && + // When we run after a failing dependency check we have to make sure + // we don't have wrapping pointers. + (!ShouldCheckWrap || isNoWrap(PSE, StridesMap, Ptr, TheLoop) || + SupportedStride) && + isCountedLoop(PSE.getSE(), TheLoop)) { + // The id of the dependence set. + unsigned DepId; + + if (IsDepCheckNeeded) { + Value *Leader = DepCands.getLeaderValue(Access).getPointer(); + unsigned &LeaderId = DepSetId[Leader]; + if (!LeaderId) + LeaderId = RunningDepId++; + DepId = LeaderId; + } else + // Each access has its own dependence set. + DepId = RunningDepId++; + + RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, PSE); + + LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); + } else if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId, + TheLoop, RunningDepId, ASId, + ShouldCheckWrap, false)) { + UncomputablePtrs.insert(Ptr); LLVM_DEBUG(dbgs() << "LAA: Can't find bounds for ptr:" << *Ptr << '\n'); Retries.push_back(Access); CanDoAliasSetRT = false; @@ -767,11 +887,13 @@ } } - CanDoRT &= CanDoAliasSetRT; + // If we don't actually need RT checks for this alias set, then this + // shouldn't hold up vectorisation. + if (NeedsAliasSetRTCheck) + CanDoRT &= CanDoAliasSetRT; NeedRTCheck |= NeedsAliasSetRTCheck; ++ASId; } - // If the pointers that we would use for the bounds comparison have different // address spaces, assume the values aren't directly comparable, so we can't // use them for the runtime check. We also have to assume they could @@ -830,6 +952,7 @@ (A.getInt() ? "write" : (ReadOnlyPtr.count(A.getPointer()) ? "read-only" : "read")) << ")\n"; }); + ReducedAST.clear(); // The AliasSetTracker has nicely partitioned our pointers by metadata // compatibility and potential for underlying-object overlap. As a result, we @@ -841,6 +964,7 @@ // (matching the original instruction order within each set). bool SetHasWrite = false; + bool AnalyzeAS = false; // Map of pointers to last access encountered. typedef DenseMap UnderlyingObjToAccessMap; @@ -897,6 +1021,7 @@ if ((IsWrite || IsReadOnlyPtr) && SetHasWrite) { CheckDeps.push_back(Access); IsRTCheckAnalysisNeeded = true; + AnalyzeAS = true; } if (IsWrite) @@ -930,6 +1055,8 @@ } } } + if (AnalyzeAS) + ReducedAST.push_back(&AS); } } @@ -1221,7 +1348,8 @@ return X == PtrSCEVB; } -bool MemoryDepChecker::Dependence::isSafeForVectorization(DepType Type) { +bool MemoryDepChecker::Dependence::isSafeForVectorization(DepType Type, + const TargetTransformInfo *TTI) { switch (Type) { case NoDep: case Forward: @@ -1229,10 +1357,11 @@ return true; case Unknown: - case ForwardButPreventsForwarding: case Backward: - case BackwardVectorizableButPreventsForwarding: return false; + case ForwardButPreventsForwarding: + case BackwardVectorizableButPreventsForwarding: + return TTI->vectorizePreventedSLForwarding(); } llvm_unreachable("unexpected DepType!"); } @@ -1313,7 +1442,7 @@ if (MaxVFWithoutSLForwardIssues < MaxSafeDepDistBytes && MaxVFWithoutSLForwardIssues != VectorizerParams::MaxVectorWidth * TypeByteSize) - MaxSafeDepDistBytes = MaxVFWithoutSLForwardIssues; + MaxDepDistWithSLF = MaxVFWithoutSLForwardIssues; return false; } @@ -1467,6 +1596,26 @@ // "A[B[i]] += ..." and similar code or pointer arithmetic that could wrap in // the address space. if (!StrideAPtr || !StrideBPtr || StrideAPtr != StrideBPtr){ + bool SrcInvariant = PSE.getSE()->isLoopInvariant(Src, InnermostLoop); + bool SinkInvariant = PSE.getSE()->isLoopInvariant(Sink, InnermostLoop); + + assert(!(StrideAPtr && SrcInvariant) && "Cannot be strided and invariant"); + assert(!(StrideBPtr && SinkInvariant) && "Cannot be strided and invariant"); + + // For cases where the stride is invariant but not constant, the compiler + // can still vectorize cases using gathers/scatters. + // This code checks only for the 'affine' property and not for 'no-wrapping' + // since PSE will add no-wrap checks where needed when creating pointer + // checks in 'createCheckForAccess()'. + bool SrcAffine = + StrideAPtr || (!SrcInvariant && isa(Src) && + cast(Src)->isAffine()); + bool SinkAffine = + StrideBPtr || (!SinkInvariant && isa(Sink) && + cast(Sink)->isAffine()); + if (SrcAffine || SinkAffine) + ShouldRetryWithRuntimeCheck = true; + LLVM_DEBUG(dbgs() << "Pointer access with non-constant stride\n"); return Dependence::Unknown; } @@ -1474,14 +1623,15 @@ Type *ATy = APtr->getType()->getPointerElementType(); Type *BTy = BPtr->getType()->getPointerElementType(); auto &DL = InnermostLoop->getHeader()->getModule()->getDataLayout(); - uint64_t TypeByteSize = DL.getTypeAllocSize(ATy); + uint64_t TypeAByteSize = DL.getTypeAllocSize(ATy); + uint64_t TypeBByteSize = DL.getTypeAllocSize(BTy); uint64_t Stride = std::abs(StrideAPtr); const SCEVConstant *C = dyn_cast(Dist); if (!C) { - if (TypeByteSize == DL.getTypeAllocSize(BTy) && + if (TypeAByteSize == TypeBByteSize && isSafeDependenceDistance(DL, *(PSE.getSE()), *(PSE.getBackedgeTakenCount()), *Dist, Stride, - TypeByteSize)) + TypeAByteSize)) return Dependence::NoDep; LLVM_DEBUG(dbgs() << "LAA: Dependence because of non-constant distance\n"); @@ -1493,8 +1643,8 @@ int64_t Distance = Val.getSExtValue(); // Attempt to prove strided accesses independent. - if (std::abs(Distance) > 0 && Stride > 1 && ATy == BTy && - areStridedAccessesIndependent(std::abs(Distance), Stride, TypeByteSize)) { + if (std::abs(Distance) > 0 && Stride > 1 && TypeAByteSize == TypeBByteSize && + areStridedAccessesIndependent(std::abs(Distance), Stride, TypeAByteSize)) { LLVM_DEBUG(dbgs() << "LAA: Strided accesses are independent\n"); return Dependence::NoDep; } @@ -1503,7 +1653,7 @@ if (Val.isNegative()) { bool IsTrueDataDependence = (AIsWrite && !BIsWrite); if (IsTrueDataDependence && EnableForwardingConflictDetection && - (couldPreventStoreLoadForward(Val.abs().getZExtValue(), TypeByteSize) || + (couldPreventStoreLoadForward(Val.abs().getZExtValue(), TypeAByteSize) || ATy != BTy)) { LLVM_DEBUG(dbgs() << "LAA: Forward but may prevent st->ld forwarding\n"); return Dependence::ForwardButPreventsForwarding; @@ -1513,10 +1663,11 @@ return Dependence::Forward; } + bool SizesAreSame = (TypeAByteSize == TypeBByteSize); + // Write to the same location with the same size. - // Could be improved to assert type sizes are the same (i32 == float, etc). if (Val == 0) { - if (ATy == BTy) + if (SizesAreSame) return Dependence::Forward; LLVM_DEBUG( dbgs() << "LAA: Zero dependence difference but different types\n"); @@ -1525,7 +1676,7 @@ assert(Val.isStrictlyPositive() && "Expect a positive value"); - if (ATy != BTy) { + if (!SizesAreSame) { LLVM_DEBUG( dbgs() << "LAA: ReadWrite-Write positive dependency with different types\n"); @@ -1567,7 +1718,7 @@ // the minimum distance needed is 28, which is greater than distance. It is // not safe to do vectorization. uint64_t MinDistanceNeeded = - TypeByteSize * Stride * (MinNumIter - 1) + TypeByteSize; + TypeAByteSize * Stride * (MinNumIter - 1) + TypeAByteSize; if (MinDistanceNeeded > static_cast(Distance)) { LLVM_DEBUG(dbgs() << "LAA: Failure because of positive distance " << Distance << '\n'); @@ -1602,13 +1753,13 @@ bool IsTrueDataDependence = (!AIsWrite && BIsWrite); if (IsTrueDataDependence && EnableForwardingConflictDetection && - couldPreventStoreLoadForward(Distance, TypeByteSize)) + couldPreventStoreLoadForward(Distance, TypeAByteSize)) return Dependence::BackwardVectorizableButPreventsForwarding; - uint64_t MaxVF = MaxSafeDepDistBytes / (TypeByteSize * Stride); + uint64_t MaxVF = MaxSafeDepDistBytes / (TypeAByteSize * Stride); LLVM_DEBUG(dbgs() << "LAA: Positive distance " << Val.getSExtValue() << " with max VF = " << MaxVF << '\n'); - uint64_t MaxVFInBits = MaxVF * TypeByteSize * 8; + uint64_t MaxVFInBits = MaxVF * TypeAByteSize * 8; MaxSafeRegisterWidth = std::min(MaxSafeRegisterWidth, MaxVFInBits); return Dependence::BackwardVectorizable; } @@ -1616,8 +1767,10 @@ bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets, MemAccessInfoList &CheckDeps, const ValueToValueMap &Strides) { - + // Runtime checks are only feasible if only unknown dependences prevent + // vectorization. MaxSafeDepDistBytes = -1; + MaxDepDistWithSLF = -1U; SmallPtrSet Visited; for (MemAccessInfo CurAccess : CheckDeps) { if (Visited.count(CurAccess)) @@ -1652,7 +1805,18 @@ Dependence::DepType Type = isDependent(*A.first, A.second, *B.first, B.second, Strides); - SafeForVectorization &= Dependence::isSafeForVectorization(Type); + bool DepSafe = Dependence::isSafeForVectorization(Type, TTI); + SafeForVectorization &= DepSafe; + // Runtime checks are only feasible, if all unsafe dependencies are + // unknown. For other unsafe deps, we already know they will fail + // the runtime checks at compile time. + RuntimeChecksFeasible &= (Type == Dependence::Unknown) || DepSafe; + + if (!SafeForVectorization) { + // TODO: Add minDistanc, actual distance, minIter and type size + // for unsafe dependences to generate better insight + UnsafeDependences.push_back(Dependence(A.second, B.second, Type)); + } // Gather dependences unless we accumulated MaxDependences // dependences. In that case return as soon as we find the first @@ -1669,8 +1833,13 @@ << "Too many dependences, stopped recording\n"); } } - if (!RecordDependences && !SafeForVectorization) + // We do not generate runtime checks for accesses with constant + // strides, so without investigating all dependences, we cannot + // be sure runtime checks are safe. + if (!RecordDependences && !SafeForVectorization) { + RuntimeChecksFeasible = false; return false; + } } ++OI; } @@ -1729,7 +1898,7 @@ } // We must have a single exiting block. - if (!TheLoop->getExitingBlock()) { + if (!TheLoop->getExitingBlock() && !AllowUncountedLoops) { LLVM_DEBUG( dbgs() << "LAA: loop control flow is not understood by analyzer\n"); recordAnalysis("CFGNotUnderstood") @@ -1740,9 +1909,15 @@ // We only handle bottom-tested loops, i.e. loop in which the condition is // checked at the end of each iteration. With that we can assume that all // instructions in the loop are executed the same number of times. - if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) { - LLVM_DEBUG( - dbgs() << "LAA: loop control flow is not understood by analyzer\n"); + SmallVector ExitingBlocks; + TheLoop->getExitingBlocks(ExitingBlocks); + // TODO: Remove this limitation -- there's at least a few loops with the + // (single) exit in the middle; just requires better predication and + // we can still plant the vectorized exit at the end. + if (ExitingBlocks.empty() || + ExitingBlocks.back() != TheLoop->getLoopLatch()) { + LLVM_DEBUG(dbgs() + << "LAA: loop control flow is not understood by analyzer\n"); recordAnalysis("CFGNotUnderstood") << "loop control flow is not understood by analyzer"; return false; @@ -1750,7 +1925,7 @@ // ScalarEvolution needs to be able to find the exit count. const SCEV *ExitCount = PSE->getBackedgeTakenCount(); - if (ExitCount == PSE->getSE()->getCouldNotCompute()) { + if (ExitCount == PSE->getSE()->getCouldNotCompute() && !AllowUncountedLoops) { recordAnalysis("CantComputeNumberOfIterations") << "could not determine number of loop iterations"; LLVM_DEBUG(dbgs() << "LAA: SCEV could not compute the loop exit count.\n"); @@ -1768,6 +1943,7 @@ // Holds the Load and Store instructions. SmallVector Loads; SmallVector Stores; + SmallVector MemSets; // Holds all the different accesses in the loop. unsigned NumReads = 0; @@ -1800,10 +1976,19 @@ continue; auto *Ld = dyn_cast(&I); - if (!Ld || (!Ld->isSimple() && !IsAnnotatedParallel)) { + if (Ld && !Ld->isSimple() && !IsAnnotatedParallel) { recordAnalysis("NonSimpleLoad", Ld) << "read with atomic ordering or volatile read"; - LLVM_DEBUG(dbgs() << "LAA: Found a non-simple load.\n"); + LLVM_DEBUG(dbgs() << "LAA: Found a non-simple load " << *Ld << "\n"); + CanVecMem = false; + return; + } + if (!Ld) { + recordAnalysis("CantVectorizeInstruction", &I) + << "instruction cannot be vectorized"; + LLVM_DEBUG(dbgs() << + "LAA: Found memory reading op that isn't a simple load " << + I << "\n"); CanVecMem = false; return; } @@ -1815,19 +2000,29 @@ continue; } + if (auto *MSI = dyn_cast(&I)) { + MemSets.push_back(MSI); + NumStores++; + DepChecker->addAccess(MSI); + continue; + } + // Save 'store' instructions. Abort if other instructions write to memory. if (I.mayWriteToMemory()) { auto *St = dyn_cast(&I); if (!St) { - recordAnalysis("CantVectorizeInstruction", St) + recordAnalysis("CantVectorizeInstruction", &I) << "instruction cannot be vectorized"; + LLVM_DEBUG(dbgs() << + "LAA: Found memory writing op that isn't a simple store: " << + I << "\n"); CanVecMem = false; return; } if (!St->isSimple() && !IsAnnotatedParallel) { recordAnalysis("NonSimpleStore", St) << "write with atomic ordering or volatile write"; - LLVM_DEBUG(dbgs() << "LAA: Found a non-simple store.\n"); + LLVM_DEBUG(dbgs() << "LAA: Found a non-simple store " << *St << "\n"); CanVecMem = false; return; } @@ -1845,7 +2040,7 @@ // Check if we see any stores. If there are no stores, then we don't // care if the pointers are *restrict*. - if (!Stores.size()) { + if (!Stores.size() && MemSets.empty()) { LLVM_DEBUG(dbgs() << "LAA: Found a read-only loop!\n"); CanVecMem = true; return; @@ -1853,7 +2048,7 @@ MemoryDepChecker::DepCandidates DependentAccesses; AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(), - TheLoop, AA, LI, DependentAccesses, *PSE); + TTI, TheLoop, AA, LI, DependentAccesses, *PSE); // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects // multiple times on the same object. If the ptr is accessed twice, once @@ -1865,7 +2060,8 @@ for (StoreInst *ST : Stores) { Value *Ptr = ST->getPointerOperand(); // Check for store to loop invariant address. - StoreToLoopInvariantAddress |= isUniform(Ptr); + if (isUniform(Ptr)) + InvariantStores.push_back(ST); // If we did *not* see this pointer before, insert it to the read-write // list. At this phase it is only a 'write' list. if (Seen.insert(Ptr).second) { @@ -1882,6 +2078,29 @@ } } + // Treat the memset calls as stores. + for (auto I : MemSets) { + auto MSI = cast(I); + Value *Ptr = MSI->getRawDest(); + // Check for store to loop invariant address. + if (isUniform(Ptr)) + InvariantMemSets.push_back(MSI); + // If we did *not* see this pointer before, insert it to the read-write + // list. At this phase it is only a 'write' list. + if (Seen.insert(Ptr).second) { + ++NumReadWrites; + + MemoryLocation Loc = MemoryLocation::get(MSI); + // The TBAA metadata could have a control dependency on the predication + // condition, so we cannot rely on it when determining whether or not we + // need runtime pointer checks. + if (blockNeedsPredication(MSI->getParent(), TheLoop, DT)) + Loc.AATags.TBAA = nullptr; + + Accesses.addStore(Loc); + } + } + if (IsAnnotatedParallel) { LLVM_DEBUG( dbgs() << "LAA: A loop annotated parallel, ignore memory dependency " @@ -1901,10 +2120,22 @@ // read a few words, modify, and write a few words, and some of the // words may be written to the same address. bool IsReadOnlyPtr = false; - if (Seen.insert(Ptr).second || - !getPtrStride(*PSE, Ptr, TheLoop, SymbolicStrides)) { + if (Seen.insert(Ptr).second) { ++NumReads; IsReadOnlyPtr = true; + } else if (!getPtrStride(*PSE, Ptr, TheLoop, SymbolicStrides)) { + bool LoadAffine = false; + const SCEV *Src = PSE->getSE()->getSCEV(Ptr); + if (auto *SrcAddRec = dyn_cast(Src)) { + bool PtrInvariant = PSE->getSE()->isLoopInvariant(Src, TheLoop); + LoadAffine = !PtrInvariant && SrcAddRec->isAffine() && + isNoWrapAddRec(Ptr, SrcAddRec, *PSE, TheLoop); + } + + if (!LoadAffine) { + ++NumReads; + IsReadOnlyPtr = true; + } } MemoryLocation Loc = MemoryLocation::get(LD); @@ -1938,6 +2169,13 @@ LLVM_DEBUG(dbgs() << "LAA: We can't vectorize because we can't find " << "the array bounds.\n"); CanVecMem = false; + FailReason = FailureReason::UnknownArrayBounds; + UncomputablePtrs = std::move(Accesses.UncomputablePtrs); + // FIXME: + // This does not always hold, because canCheckPtrAtRT also + // returns 'false' if the address spaces between pointers don't match + // + //assert(UncomputablePtrs.size() > 0 && "no uncomputable pointers"); return; } @@ -1970,6 +2208,7 @@ << "cannot check memory dependencies at runtime"; LLVM_DEBUG(dbgs() << "LAA: Can't vectorize with memory checks\n"); CanVecMem = false; + FailReason = FailureReason::UnsafeDataDependenceTriedRT; return; } @@ -1989,6 +2228,7 @@ "to attempt to isolate the offending operations into a separate " "loop"; LLVM_DEBUG(dbgs() << "LAA: unsafe dependent memory operations in loop\n"); + FailReason = FailureReason::UnsafeDataDependence; } } @@ -2032,6 +2272,12 @@ return (SE->isLoopInvariant(SE->getSCEV(V), TheLoop)); } +uint64_t LoopAccessInfo::getMaxSafeDepDistBytes() const { + if (TTI->vectorizePreventedSLForwarding()) + return MaxSafeDepDistBytes; + return std::min(MaxSafeDepDistBytes, MaxDepDistBytesWithSLF); +} + // FIXME: this function is currently a duplicate of the one in // LoopVectorize.cpp. static Instruction *getFirstInst(Instruction *FirstInst, Value *V, @@ -2258,14 +2504,20 @@ StrideSet.insert(Stride); } +// TODO: Lots of shuffling here, we removed StoreToLoopInvariantAddress +// Need to see whether adding TTI causes issues. LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, - const TargetLibraryInfo *TLI, AliasAnalysis *AA, + const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI, + AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI) : PSE(llvm::make_unique(*SE, *L)), PtrRtChecking(llvm::make_unique(SE)), - DepChecker(llvm::make_unique(*PSE, L)), TheLoop(L), - NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1), CanVecMem(false), - StoreToLoopInvariantAddress(false) { + DepChecker(llvm::make_unique(*PSE, L, TTI)), TheLoop(L), + TTI(TTI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1), + MaxDepDistBytesWithSLF(-1U), CanVecMem(false), + AllowUncountedLoops(EnableUncountedLAA), + FailReason(FailureReason::Unknown) { if (canAnalyzeLoop()) analyzeLoop(AA, LI, TLI, DT); } @@ -2298,7 +2550,7 @@ OS << "\n"; OS.indent(Depth) << "Store to invariant address was " - << (StoreToLoopInvariantAddress ? "" : "not ") + << (hasStoreToLoopInvariantAddress() ? "" : "not ") << "found in loop.\n"; OS.indent(Depth) << "SCEV assumptions:\n"; @@ -2314,7 +2566,7 @@ auto &LAI = LoopAccessInfoMap[L]; if (!LAI) - LAI = llvm::make_unique(L, SE, TLI, AA, DT, LI); + LAI = llvm::make_unique(L, SE, TLI, TTI, AA, DT, LI); return *LAI.get(); } @@ -2335,6 +2587,7 @@ auto *TLIP = getAnalysisIfAvailable(); TLI = TLIP ? &TLIP->getTLI() : nullptr; AA = &getAnalysis().getAAResults(); + TTI = &getAnalysis().getTTI(F); DT = &getAnalysis().getDomTree(); LI = &getAnalysis().getLoopInfo(); @@ -2346,6 +2599,7 @@ AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addRequired(); AU.setPreservesAll(); } @@ -2359,13 +2613,14 @@ INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(LoopAccessLegacyAnalysis, LAA_NAME, laa_name, false, true) AnalysisKey LoopAccessAnalysis::Key; LoopAccessInfo LoopAccessAnalysis::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR) { - return LoopAccessInfo(&L, &AR.SE, &AR.TLI, &AR.AA, &AR.DT, &AR.LI); + return LoopAccessInfo(&L, &AR.SE, &AR.TLI, &AR.TTI, &AR.AA, &AR.DT, &AR.LI); } namespace llvm { Index: lib/Analysis/LoopInfo.cpp =================================================================== --- lib/Analysis/LoopInfo.cpp +++ lib/Analysis/LoopInfo.cpp @@ -365,17 +365,55 @@ return LocRange(Start); } + DebugLoc DL; // Try the pre-header first. if (BasicBlock *PHeadBB = getLoopPreheader()) - if (DebugLoc DL = PHeadBB->getTerminator()->getDebugLoc()) - return LocRange(DL); + DL = PHeadBB->getTerminator()->getDebugLoc(); // If we have no pre-header or there are no instructions with debug - // info in it, try the header. + // info in it or the line number in debug info is 0, try the header. if (BasicBlock *HeadBB = getHeader()) - return LocRange(HeadBB->getTerminator()->getDebugLoc()); + if (!DL || (DL.getLine() == 0)) + DL = HeadBB->getTerminator()->getDebugLoc(); - return LocRange(); + return DL ? LocRange(DL) : LocRange(); +} + +void Loop::getAttachedDebugLocations(std::vector &Locs) const { + // If we have a debug location in the loop ID, then use it. + if (MDNode *LoopID = getLoopID()) { + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + if (DILocation *L = dyn_cast(LoopID->getOperand(i))) { + Locs.push_back({L}); + } + } + + if (!Locs.empty()) + return; + } + + DebugLoc DL; + // Try the pre-header first. + if (BasicBlock *PHeadBB = getLoopPreheader()) + DL = PHeadBB->getTerminator()->getDebugLoc(); + + // If we have no pre-header or there are no instructions with debug + // info in it or the line number in debug info is 0, try the header. + if (BasicBlock *HeadBB = getHeader()) + if (!DL || (DL.getLine() == 0)) + DL = HeadBB->getTerminator()->getDebugLoc(); + + Locs.push_back(DL); + + return; +} + +void Loop::getEarlyExitLocations(std::vector &ExitLocs) const { + std::vector Locs; + getAttachedDebugLocations(Locs); + + for (unsigned i = 2; i + 1 < Locs.size(); i += 2) + ExitLocs.emplace_back(Locs[i], Locs[i+1]); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) Index: lib/Analysis/MemoryLocation.cpp =================================================================== --- lib/Analysis/MemoryLocation.cpp +++ lib/Analysis/MemoryLocation.cpp @@ -37,6 +37,18 @@ AATags); } +MemoryLocation MemoryLocation::get(const MemSetInst *MSI) { + AAMDNodes AATags; + MSI->getAAMetadata(AATags); + const auto &DL = MSI->getModule()->getDataLayout(); + auto Length = MSI->getLength(); + uint64_t Size = UnknownSize; + if (auto CI = dyn_cast(Length)) + Size = CI->getZExtValue() * DL.getTypeStoreSize(MSI->getValue()->getType()); + + return MemoryLocation(MSI->getRawDest(), Size, AATags); +} + MemoryLocation MemoryLocation::get(const VAArgInst *VI) { AAMDNodes AATags; VI->getAAMetadata(AATags); @@ -120,6 +132,17 @@ switch (II->getIntrinsicID()) { default: break; + // TODO: Improve for fixed width, max reg size, fixed stride, etc. + // Safe to fall through to UnknownSize for now. + case Intrinsic::masked_load: + case Intrinsic::masked_spec_load: + case Intrinsic::masked_gather: + break; + // TODO: Improve for fixed width, max reg size, fixed stride, etc. + // Safe to fall through to UnknownSize for now. + case Intrinsic::masked_store: + case Intrinsic::masked_scatter: + break; case Intrinsic::memset: case Intrinsic::memcpy: case Intrinsic::memmove: Index: lib/Analysis/PHITransAddr.cpp =================================================================== --- lib/Analysis/PHITransAddr.cpp +++ lib/Analysis/PHITransAddr.cpp @@ -413,7 +413,7 @@ return Result; } -#if 0 +#if 1 // FIXME: This code works, but it is unclear that we actually want to insert // a big chain of computation in order to make a value available in a block. // This needs to be evaluated carefully to consider its cost trade offs. Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -82,6 +82,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Config/llvm-config.h" #include "llvm/IR/Argument.h" @@ -3727,6 +3728,12 @@ } const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { + auto *VTy = dyn_cast(AllocTy); + if (VTy && VTy->isScalable()) { + auto NumElts = getSCEV(ConstantExpr::getRuntimeNumElements(IntTy, VTy)); + return getMulExpr(NumElts, getSizeOfExpr(IntTy, VTy->getElementType())); + } + // We can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. @@ -5678,6 +5685,16 @@ } if (const SCEVUnknown *U = dyn_cast(S)) { + if (isa(U->getValue())) { + auto MinBits = TTI.getRegisterBitWidth(true); + if (MinBits > 0) { + auto MaxBits = TTI.getRegisterBitWidthUpperBound(true); + auto LowerBound = APInt(BitWidth, 1); + auto UpperBound = APInt(BitWidth, (MaxBits / MinBits) + 1); + return setRange(U, SignHint, ConstantRange(LowerBound, UpperBound)); + } + } + // Check if the IR explicitly contains !range metadata. Optional MDRange = GetRangeFromMetadata(U->getValue()); if (MDRange.hasValue()) @@ -11153,8 +11170,8 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, - LoopInfo &LI) - : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), + LoopInfo &LI, TargetTransformInfo &TTI) + : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), TTI(TTI), CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64) { // To use guards for proving predicates, we need to scan every instruction in @@ -11174,7 +11191,8 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), - LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), + LI(Arg.LI), TTI(Arg.TTI), + CouldNotCompute(std::move(Arg.CouldNotCompute)), ValueExprMap(std::move(Arg.ValueExprMap)), PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), PendingPhiRanges(std::move(Arg.PendingPhiRanges)), @@ -11663,7 +11681,7 @@ void ScalarEvolution::verify() const { ScalarEvolution &SE = *const_cast(this); - ScalarEvolution SE2(F, TLI, AC, DT, LI); + ScalarEvolution SE2(F, TLI, AC, DT, LI, TTI); SmallVector LoopStack(LI.begin(), LI.end()); @@ -11753,7 +11771,8 @@ return ScalarEvolution(F, AM.getResult(F), AM.getResult(F), AM.getResult(F), - AM.getResult(F)); + AM.getResult(F), + AM.getResult(F)); } PreservedAnalyses @@ -11768,6 +11787,7 @@ INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution", "Scalar Evolution Analysis", false, true) @@ -11782,7 +11802,8 @@ F, getAnalysis().getTLI(), getAnalysis().getAssumptionCache(F), getAnalysis().getDomTree(), - getAnalysis().getLoopInfo())); + getAnalysis().getLoopInfo(), + getAnalysis().getTTI(F))); return false; } @@ -11805,6 +11826,7 @@ AU.addRequiredTransitive(); AU.addRequiredTransitive(); AU.addRequiredTransitive(); + AU.addRequiredTransitive(); } const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS, Index: lib/Analysis/ScalarEvolutionExpander.cpp =================================================================== --- lib/Analysis/ScalarEvolutionExpander.cpp +++ lib/Analysis/ScalarEvolutionExpander.cpp @@ -276,14 +276,15 @@ if (const SCEVMulExpr *M = dyn_cast(S)) { // Size is known, check if there is a constant operand which is a multiple // of the given factor. If so, we can factor it. - const SCEVConstant *FC = cast(Factor); - if (const SCEVConstant *C = dyn_cast(M->getOperand(0))) - if (!C->getAPInt().srem(FC->getAPInt())) { - SmallVector NewMulOps(M->op_begin(), M->op_end()); - NewMulOps[0] = SE.getConstant(C->getAPInt().sdiv(FC->getAPInt())); - S = SE.getMulExpr(NewMulOps); - return true; - } + if (const SCEVConstant *FC = dyn_cast(Factor)) { + if (const SCEVConstant *C = dyn_cast(M->getOperand(0))) + if (!C->getAPInt().srem(FC->getAPInt())) { + SmallVector NewMulOps(M->op_begin(), M->op_end()); + NewMulOps[0] = SE.getConstant(C->getAPInt().sdiv(FC->getAPInt())); + S = SE.getMulExpr(NewMulOps); + return true; + } + } } // In an AddRec, check if both start and step are divisible. Index: lib/Analysis/TargetLibraryInfo.cpp =================================================================== --- lib/Analysis/TargetLibraryInfo.cpp +++ lib/Analysis/TargetLibraryInfo.cpp @@ -17,6 +17,8 @@ #include "llvm/Support/CommandLine.h" using namespace llvm; +#define DEBUG_TYPE "target-library-info" + static cl::opt ClVectorLibrary( "vector-library", cl::Hidden, cl::desc("Vector functions library"), cl::init(TargetLibraryInfoImpl::NoLibrary), @@ -24,6 +26,8 @@ "No vector functions library"), clEnumValN(TargetLibraryInfoImpl::Accelerate, "Accelerate", "Accelerate framework"), + clEnumValN(TargetLibraryInfoImpl::SLEEF, "SLEEF", + "SIMD Library for Evaluating Elementary Functions"), clEnumValN(TargetLibraryInfoImpl::SVML, "SVML", "Intel SVML library"))); @@ -32,6 +36,12 @@ #include "llvm/Analysis/TargetLibraryInfo.def" }; +#define LANES(x) VectorType::ElementCount(x, false) +#define SCALABLE_LANES(x) VectorType::ElementCount(x, true) + +#define NOMASK 0 +#define MASKED 1 + static bool hasSinCosPiStret(const Triple &T) { // Only Darwin variants have _stret versions of combined trig functions. if (!T.isOSDarwin()) @@ -514,7 +524,7 @@ TLI.setUnavailable(LibFunc_nvvm_reflect); } - TLI.addVectorizableFunctionsFromVecLib(ClVectorLibrary); + TLI.addVectorizableFunctionsFromVecLib(ClVectorLibrary, T); } TargetLibraryInfoImpl::TargetLibraryInfoImpl() { @@ -538,6 +548,11 @@ memcpy(AvailableArray, TLI.AvailableArray, sizeof(AvailableArray)); VectorDescs = TLI.VectorDescs; ScalarDescs = TLI.ScalarDescs; + // Copy the multimap with a loop instead of the assignment operator to prevent + // build failures on systems with an old C++ STL library. + assert(VectorFunctionInfo.empty() && "Not a valid multimap."); + for (const auto &VFI : TLI.VectorFunctionInfo) + VectorFunctionInfo.insert(VFI); } TargetLibraryInfoImpl::TargetLibraryInfoImpl(TargetLibraryInfoImpl &&TLI) @@ -549,6 +564,11 @@ AvailableArray); VectorDescs = TLI.VectorDescs; ScalarDescs = TLI.ScalarDescs; + // Copy the multimap with a loop instead of the assignment operator to prevent + // build failures on systems with an old C++ STL library. + assert(VectorFunctionInfo.empty() && "Not a valid multimap."); + for (const auto &VFI : TLI.VectorFunctionInfo) + VectorFunctionInfo.insert(VFI); } TargetLibraryInfoImpl &TargetLibraryInfoImpl::operator=(const TargetLibraryInfoImpl &TLI) { @@ -1382,11 +1402,15 @@ } static bool compareByScalarFnName(const VecDesc &LHS, const VecDesc &RHS) { - return LHS.ScalarFnName < RHS.ScalarFnName; + return (LHS.ScalarFnName < RHS.ScalarFnName) || + ((LHS.ScalarFnName == RHS.ScalarFnName) && + (LHS.Priority > RHS.Priority)); } static bool compareByVectorFnName(const VecDesc &LHS, const VecDesc &RHS) { - return LHS.VectorFnName < RHS.VectorFnName; + return (LHS.VectorFnName < RHS.VectorFnName) || + ((LHS.VectorFnName == RHS.VectorFnName) && + (LHS.Priority > RHS.Priority)); } static bool compareWithScalarFnName(const VecDesc &LHS, StringRef S) { @@ -1406,155 +1430,206 @@ } void TargetLibraryInfoImpl::addVectorizableFunctionsFromVecLib( - enum VectorLibrary VecLib) { + enum VectorLibrary VecLib, const Triple &T) { switch (VecLib) { case Accelerate: { const VecDesc VecFuncs[] = { // Floating-Point Arithmetic and Auxiliary Functions - {"ceilf", "vceilf", 4}, - {"fabsf", "vfabsf", 4}, - {"llvm.fabs.f32", "vfabsf", 4}, - {"floorf", "vfloorf", 4}, - {"sqrtf", "vsqrtf", 4}, - {"llvm.sqrt.f32", "vsqrtf", 4}, + {"ceilf", "vceilf", LANES(4), NOMASK}, + {"fabsf", "vfabsf", LANES(4), NOMASK}, + {"llvm.fabs.f32", "vfabsf", LANES(4), NOMASK}, + {"floorf", "vfloorf", LANES(4), NOMASK}, + {"sqrtf", "vsqrtf", LANES(4), NOMASK}, + {"llvm.sqrt.f32", "vsqrtf", LANES(4), NOMASK}, // Exponential and Logarithmic Functions - {"expf", "vexpf", 4}, - {"llvm.exp.f32", "vexpf", 4}, - {"expm1f", "vexpm1f", 4}, - {"logf", "vlogf", 4}, - {"llvm.log.f32", "vlogf", 4}, - {"log1pf", "vlog1pf", 4}, - {"log10f", "vlog10f", 4}, - {"llvm.log10.f32", "vlog10f", 4}, - {"logbf", "vlogbf", 4}, + {"expf", "vexpf", LANES(4), NOMASK}, + {"llvm.exp.f32", "vexpf", LANES(4), NOMASK}, + {"expm1f", "vexpm1f", LANES(4), NOMASK}, + {"logf", "vlogf", LANES(4), NOMASK}, + {"llvm.log.f32", "vlogf", LANES(4), NOMASK}, + {"log1pf", "vlog1pf", LANES(4), NOMASK}, + {"log10f", "vlog10f", LANES(4), NOMASK}, + {"llvm.log10.f32", "vlog10f", LANES(4), NOMASK}, + {"logbf", "vlogbf", LANES(4), NOMASK}, // Trigonometric Functions - {"sinf", "vsinf", 4}, - {"llvm.sin.f32", "vsinf", 4}, - {"cosf", "vcosf", 4}, - {"llvm.cos.f32", "vcosf", 4}, - {"tanf", "vtanf", 4}, - {"asinf", "vasinf", 4}, - {"acosf", "vacosf", 4}, - {"atanf", "vatanf", 4}, + {"sinf", "vsinf", LANES(4), NOMASK}, + {"llvm.sin.f32", "vsinf", LANES(4), NOMASK}, + {"cosf", "vcosf", LANES(4), NOMASK}, + {"llvm.cos.f32", "vcosf", LANES(4), NOMASK}, + {"tanf", "vtanf", LANES(4), NOMASK}, + {"asinf", "vasinf", LANES(4), NOMASK}, + {"acosf", "vacosf", LANES(4), NOMASK}, + {"atanf", "vatanf", LANES(4), NOMASK}, // Hyperbolic Functions - {"sinhf", "vsinhf", 4}, - {"coshf", "vcoshf", 4}, - {"tanhf", "vtanhf", 4}, - {"asinhf", "vasinhf", 4}, - {"acoshf", "vacoshf", 4}, - {"atanhf", "vatanhf", 4}, + {"sinhf", "vsinhf", LANES(4), NOMASK}, + {"coshf", "vcoshf", LANES(4), NOMASK}, + {"tanhf", "vtanhf", LANES(4), NOMASK}, + {"asinhf", "vasinhf", LANES(4), NOMASK}, + {"acoshf", "vacoshf", LANES(4), NOMASK}, + {"atanhf", "vatanhf", LANES(4), NOMASK}, }; addVectorizableFunctions(VecFuncs); break; } + case SLEEF: { + // SLEEF mappings are typically inserted by the driver, but this knows + // nothing about LLVM's intrinsics, which we specify manually here. + + // The available SLEEF routines are target specific. + if (T.getArch() == Triple::aarch64) { + const VecDesc VecFuncs[] = { + {"llvm.cos.f64", "_ZGVnN2v_cos", LANES(2), NOMASK}, + {"llvm.cos.f64", "_ZGVsMxv_cos", SCALABLE_LANES(2), MASKED}, + {"llvm.cos.f32", "_ZGVnN4v_cosf", LANES(4), NOMASK}, + {"llvm.cos.f32", "_ZGVsMxv_cosf", SCALABLE_LANES(4), MASKED}, + + {"llvm.exp.f64", "_ZGVnN2v_exp", LANES(2), NOMASK}, + {"llvm.exp.f64", "_ZGVsMxv_exp", SCALABLE_LANES(2), MASKED}, + {"llvm.exp.f32", "_ZGVnN4v_expf", LANES(4), NOMASK}, + {"llvm.exp.f32", "_ZGVsMxv_expf", SCALABLE_LANES(4), MASKED}, + + {"llvm.exp2.f64", "_ZGVnN2v_exp2", LANES(2), NOMASK}, + {"llvm.exp2.f64", "_ZGVsMxv_exp2", SCALABLE_LANES(2), MASKED}, + {"llvm.exp2.f32", "_ZGVnN4v_exp2f", LANES(4), NOMASK}, + {"llvm.exp2.f32", "_ZGVsMxv_exp2f", SCALABLE_LANES(4), MASKED}, + + {"fmod", "_ZGVnN2vv_fmod", LANES(2), NOMASK}, + {"fmod", "_ZGVsMxvv_fmod", SCALABLE_LANES(2), MASKED}, + {"fmodf", "_ZGVnN4vv_fmodf", LANES(4), NOMASK}, + {"fmodf", "_ZGVsMxvv_fmodf", SCALABLE_LANES(4), MASKED}, + + {"llvm.log.f64", "_ZGVnN2v_log", LANES(2), NOMASK}, + {"llvm.log.f64", "_ZGVsMxv_log", SCALABLE_LANES(2), MASKED}, + {"llvm.log.f32", "_ZGVnN4v_logf", LANES(4), NOMASK}, + {"llvm.log.f32", "_ZGVsMxv_logf", SCALABLE_LANES(4), MASKED}, + + {"llvm.log10.f64", "_ZGVnN2v_log10", LANES(2), NOMASK}, + {"llvm.log10.f64", "_ZGVsMxv_log10", SCALABLE_LANES(2), MASKED}, + {"llvm.log10.f32", "_ZGVnN4v_log10f", LANES(4), NOMASK}, + {"llvm.log10.f32", "_ZGVsMxv_log10f", SCALABLE_LANES(4), MASKED}, + + {"llvm.pow.f64", "_ZGVnN2vv_pow", LANES(2), NOMASK}, + {"llvm.pow.f64", "_ZGVsMxvv_pow", SCALABLE_LANES(2), MASKED}, + {"llvm.pow.f32", "_ZGVnN4vv_powf", LANES(4), NOMASK}, + {"llvm.pow.f32", "_ZGVsMxvv_powf", SCALABLE_LANES(4), MASKED}, + + {"llvm.sin.f64", "_ZGVnN2v_sin", LANES(2), NOMASK}, + {"llvm.sin.f64", "_ZGVsMxv_sin", SCALABLE_LANES(2), MASKED}, + {"llvm.sin.f32", "_ZGVnN4v_sinf", LANES(4), NOMASK}, + {"llvm.sin.f32", "_ZGVsMxv_sinf", SCALABLE_LANES(4), MASKED}, + }; + addVectorizableFunctions(VecFuncs); + } + break; + } case SVML: { const VecDesc VecFuncs[] = { - {"sin", "__svml_sin2", 2}, - {"sin", "__svml_sin4", 4}, - {"sin", "__svml_sin8", 8}, + {"sin", "__svml_sin2", LANES(2), NOMASK}, + {"sin", "__svml_sin4", LANES(4), NOMASK}, + {"sin", "__svml_sin8", LANES(8), NOMASK}, - {"sinf", "__svml_sinf4", 4}, - {"sinf", "__svml_sinf8", 8}, - {"sinf", "__svml_sinf16", 16}, + {"sinf", "__svml_sinf4", LANES(4), NOMASK}, + {"sinf", "__svml_sinf8", LANES(8), NOMASK}, + {"sinf", "__svml_sinf16", LANES(16), NOMASK}, - {"llvm.sin.f64", "__svml_sin2", 2}, - {"llvm.sin.f64", "__svml_sin4", 4}, - {"llvm.sin.f64", "__svml_sin8", 8}, + {"llvm.sin.f64", "__svml_sin2", LANES(2), NOMASK}, + {"llvm.sin.f64", "__svml_sin4", LANES(4), NOMASK}, + {"llvm.sin.f64", "__svml_sin8", LANES(8), NOMASK}, - {"llvm.sin.f32", "__svml_sinf4", 4}, - {"llvm.sin.f32", "__svml_sinf8", 8}, - {"llvm.sin.f32", "__svml_sinf16", 16}, + {"llvm.sin.f32", "__svml_sinf4", LANES(4), NOMASK}, + {"llvm.sin.f32", "__svml_sinf8", LANES(8), NOMASK}, + {"llvm.sin.f32", "__svml_sinf16", LANES(16), NOMASK}, - {"cos", "__svml_cos2", 2}, - {"cos", "__svml_cos4", 4}, - {"cos", "__svml_cos8", 8}, + {"cos", "__svml_cos2", LANES(2), NOMASK}, + {"cos", "__svml_cos4", LANES(4), NOMASK}, + {"cos", "__svml_cos8", LANES(8), NOMASK}, - {"cosf", "__svml_cosf4", 4}, - {"cosf", "__svml_cosf8", 8}, - {"cosf", "__svml_cosf16", 16}, + {"cosf", "__svml_cosf4", LANES(4), NOMASK}, + {"cosf", "__svml_cosf8", LANES(8), NOMASK}, + {"cosf", "__svml_cosf16", LANES(16), NOMASK}, - {"llvm.cos.f64", "__svml_cos2", 2}, - {"llvm.cos.f64", "__svml_cos4", 4}, - {"llvm.cos.f64", "__svml_cos8", 8}, + {"llvm.cos.f64", "__svml_cos2", LANES(2), NOMASK}, + {"llvm.cos.f64", "__svml_cos4", LANES(4), NOMASK}, + {"llvm.cos.f64", "__svml_cos8", LANES(8), NOMASK}, - {"llvm.cos.f32", "__svml_cosf4", 4}, - {"llvm.cos.f32", "__svml_cosf8", 8}, - {"llvm.cos.f32", "__svml_cosf16", 16}, + {"llvm.cos.f32", "__svml_cosf4", LANES(4), NOMASK}, + {"llvm.cos.f32", "__svml_cosf8", LANES(8), NOMASK}, + {"llvm.cos.f32", "__svml_cosf16", LANES(16), NOMASK}, - {"pow", "__svml_pow2", 2}, - {"pow", "__svml_pow4", 4}, - {"pow", "__svml_pow8", 8}, + {"pow", "__svml_pow2", LANES(2), NOMASK}, + {"pow", "__svml_pow4", LANES(4), NOMASK}, + {"pow", "__svml_pow8", LANES(8), NOMASK}, - {"powf", "__svml_powf4", 4}, - {"powf", "__svml_powf8", 8}, - {"powf", "__svml_powf16", 16}, + {"powf", "__svml_powf4", LANES(4), NOMASK}, + {"powf", "__svml_powf8", LANES(8), NOMASK}, + {"powf", "__svml_powf16", LANES(16), NOMASK}, - { "__pow_finite", "__svml_pow2", 2 }, - { "__pow_finite", "__svml_pow4", 4 }, - { "__pow_finite", "__svml_pow8", 8 }, + {"__pow_finite", "__svml_pow2", LANES(2), NOMASK}, + {"__pow_finite", "__svml_pow4", LANES(4), NOMASK}, + {"__pow_finite", "__svml_pow8", LANES(8), NOMASK}, - { "__powf_finite", "__svml_powf4", 4 }, - { "__powf_finite", "__svml_powf8", 8 }, - { "__powf_finite", "__svml_powf16", 16 }, + {"__powf_finite", "__svml_powf4", LANES(4), NOMASK}, + {"__powf_finite", "__svml_powf8", LANES(8), NOMASK}, + {"__powf_finite", "__svml_powf16", LANES(16), NOMASK}, - {"llvm.pow.f64", "__svml_pow2", 2}, - {"llvm.pow.f64", "__svml_pow4", 4}, - {"llvm.pow.f64", "__svml_pow8", 8}, + {"llvm.pow.f64", "__svml_pow2", LANES(2), NOMASK}, + {"llvm.pow.f64", "__svml_pow4", LANES(4), NOMASK}, + {"llvm.pow.f64", "__svml_pow8", LANES(8), NOMASK}, - {"llvm.pow.f32", "__svml_powf4", 4}, - {"llvm.pow.f32", "__svml_powf8", 8}, - {"llvm.pow.f32", "__svml_powf16", 16}, + {"llvm.pow.f32", "__svml_powf4", LANES(4), NOMASK}, + {"llvm.pow.f32", "__svml_powf8", LANES(8), NOMASK}, + {"llvm.pow.f32", "__svml_powf16", LANES(16), NOMASK}, - {"exp", "__svml_exp2", 2}, - {"exp", "__svml_exp4", 4}, - {"exp", "__svml_exp8", 8}, + {"exp", "__svml_exp2", LANES(2), NOMASK}, + {"exp", "__svml_exp4", LANES(4), NOMASK}, + {"exp", "__svml_exp8", LANES(8), NOMASK}, - {"expf", "__svml_expf4", 4}, - {"expf", "__svml_expf8", 8}, - {"expf", "__svml_expf16", 16}, + {"expf", "__svml_expf4", LANES(4), NOMASK}, + {"expf", "__svml_expf8", LANES(8), NOMASK}, + {"expf", "__svml_expf16", LANES(16), NOMASK}, - { "__exp_finite", "__svml_exp2", 2 }, - { "__exp_finite", "__svml_exp4", 4 }, - { "__exp_finite", "__svml_exp8", 8 }, + {"__exp_finite", "__svml_exp2", LANES(2), NOMASK}, + {"__exp_finite", "__svml_exp4", LANES(4), NOMASK}, + {"__exp_finite", "__svml_exp8", LANES(8), NOMASK}, - { "__expf_finite", "__svml_expf4", 4 }, - { "__expf_finite", "__svml_expf8", 8 }, - { "__expf_finite", "__svml_expf16", 16 }, + {"__expf_finite", "__svml_expf4", LANES(4), NOMASK}, + {"__expf_finite", "__svml_expf8", LANES(8), NOMASK}, + {"__expf_finite", "__svml_expf16", LANES(16), NOMASK}, - {"llvm.exp.f64", "__svml_exp2", 2}, - {"llvm.exp.f64", "__svml_exp4", 4}, - {"llvm.exp.f64", "__svml_exp8", 8}, + {"llvm.exp.f64", "__svml_exp2", LANES(2), NOMASK}, + {"llvm.exp.f64", "__svml_exp4", LANES(4), NOMASK}, + {"llvm.exp.f64", "__svml_exp8", LANES(8), NOMASK}, - {"llvm.exp.f32", "__svml_expf4", 4}, - {"llvm.exp.f32", "__svml_expf8", 8}, - {"llvm.exp.f32", "__svml_expf16", 16}, + {"llvm.exp.f32", "__svml_expf4", LANES(4), NOMASK}, + {"llvm.exp.f32", "__svml_expf8", LANES(8), NOMASK}, + {"llvm.exp.f32", "__svml_expf16", LANES(16), NOMASK}, - {"log", "__svml_log2", 2}, - {"log", "__svml_log4", 4}, - {"log", "__svml_log8", 8}, + {"log", "__svml_log2", LANES(2), NOMASK}, + {"log", "__svml_log4", LANES(4), NOMASK}, + {"log", "__svml_log8", LANES(8), NOMASK}, - {"logf", "__svml_logf4", 4}, - {"logf", "__svml_logf8", 8}, - {"logf", "__svml_logf16", 16}, + {"logf", "__svml_logf4", LANES(4), NOMASK}, + {"logf", "__svml_logf8", LANES(8), NOMASK}, + {"logf", "__svml_logf16", LANES(16), NOMASK}, - { "__log_finite", "__svml_log2", 2 }, - { "__log_finite", "__svml_log4", 4 }, - { "__log_finite", "__svml_log8", 8 }, + {"__log_finite", "__svml_log2", LANES(2), NOMASK}, + {"__log_finite", "__svml_log4", LANES(4), NOMASK}, + {"__log_finite", "__svml_log8", LANES(8), NOMASK}, - { "__logf_finite", "__svml_logf4", 4 }, - { "__logf_finite", "__svml_logf8", 8 }, - { "__logf_finite", "__svml_logf16", 16 }, + {"__logf_finite", "__svml_logf4", LANES(4), NOMASK}, + {"__logf_finite", "__svml_logf8", LANES(8), NOMASK}, + {"__logf_finite", "__svml_logf16", LANES(16), NOMASK}, - {"llvm.log.f64", "__svml_log2", 2}, - {"llvm.log.f64", "__svml_log4", 4}, - {"llvm.log.f64", "__svml_log8", 8}, + {"llvm.log.f64", "__svml_log2", LANES(2), NOMASK}, + {"llvm.log.f64", "__svml_log4", LANES(4), NOMASK}, + {"llvm.log.f64", "__svml_log8", LANES(8), NOMASK}, - {"llvm.log.f32", "__svml_logf4", 4}, - {"llvm.log.f32", "__svml_logf8", 8}, - {"llvm.log.f32", "__svml_logf16", 16}, + {"llvm.log.f32", "__svml_logf4", LANES(4), NOMASK}, + {"llvm.log.f32", "__svml_logf8", LANES(8), NOMASK}, + {"llvm.log.f32", "__svml_logf16", LANES(16), NOMASK}, }; addVectorizableFunctions(VecFuncs); break; @@ -1572,26 +1647,86 @@ std::vector::const_iterator I = std::lower_bound( VectorDescs.begin(), VectorDescs.end(), funcName, compareWithScalarFnName); - return I != VectorDescs.end() && StringRef(I->ScalarFnName) == funcName; + if (I != VectorDescs.end()) + if (I->ScalarFnName == funcName) + return true; + + return VectorFunctionInfo.find(funcName.str()) != VectorFunctionInfo.end(); } -StringRef TargetLibraryInfoImpl::getVectorizedFunction(StringRef F, - unsigned VF) const { +std::string +TargetLibraryInfoImpl::getVectorizedFunction(StringRef F, + VectorType::ElementCount VF, + bool Masked, + FunctionType *Sign) const { + LLVM_DEBUG(dbgs() << "TLI: getVectorizedFunction\n" + << "\tF: " << F << "\n" + << "\tVF: " << (VF.Scalable ? "n x " : "") << VF.Min << "\n" + << "\tMasked: " << Masked << "\n" + << "\tSign: " << *Sign << "\n"); + F = sanitizeFunctionName(F); - if (F.empty()) + if (F.empty() || VF == LANES(1)) return F; std::vector::const_iterator I = std::lower_bound( VectorDescs.begin(), VectorDescs.end(), F, compareWithScalarFnName); while (I != VectorDescs.end() && StringRef(I->ScalarFnName) == F) { - if (I->VectorizationFactor == VF) + if ((I->VectorizationFactor == VF) && (I->Masked == Masked)) { + LLVM_DEBUG(dbgs() << "* Found a vector function in the static list: " + << I->VectorFnName << "\n"); return I->VectorFnName; + } ++I; } - return StringRef(); + + auto Range = llvm::make_range(VectorFunctionInfo.equal_range(F.str())); + std::vector Candidates; + for (auto VecInfo : Range) { + if (VecInfo.second.Signature == Sign) { + assert(!VecInfo.second.Name.empty() && "Empty function name"); + LLVM_DEBUG(dbgs() << "* Found a vector function in the OpenMP dynamic list: " + << VecInfo.second.Name << "\n" + << " with signature: " + << *VecInfo.second.Signature << "\n"); + Candidates.emplace_back(VecInfo.second.Name); + } + } + + // We shouldn't get more than 2 candidate vector functions (one with + // the ABI mangled name, one with the user provided vector name). + assert(Candidates.size() <= 2 && + "We should not get more than 2 candidate vector functions."); + + switch (Candidates.size()) { + // No candidates found. + case 0: + return ""; + // If we find only one candidate in the list of vector functions, + // return it. + case 1: + return Candidates[0]; + // If we found two candidates, prefer the user-provided function + // over the OpenMP mangled one (which starts with "_ZGV)". Notice + // that based on OpenMP classification, the names can be at most + // two: one is the Vector Function ABI mangled name, one is the user + // provided custom name. + case 2: { + if (!StringRef(Candidates[0]).startswith("_ZGV")) + return Candidates[0]; + + if (!StringRef(Candidates[1]).startswith("_ZGV")) + return Candidates[1]; + + // If neither of them is a custom name, return the first match. + return Candidates[0]; + } + default: + llvm_unreachable("Invalid number of candidates."); + } } StringRef TargetLibraryInfoImpl::getScalarizedFunction(StringRef F, - unsigned &VF) const { + VectorType::ElementCount &VF) const { F = sanitizeFunctionName(F); if (F.empty()) return F; @@ -1600,6 +1735,8 @@ ScalarDescs.begin(), ScalarDescs.end(), F, compareWithVectorFnName); if (I == VectorDescs.end() || StringRef(I->VectorFnName) != F) return StringRef(); + if (I->Masked) + return StringRef(); VF = I->VectorizationFactor; return I->ScalarFnName; } @@ -1661,3 +1798,91 @@ char TargetLibraryInfoWrapperPass::ID = 0; void TargetLibraryInfoWrapperPass::anchor() {} + +void TargetLibraryInfoImpl::addOpenMPVectorFunctions(Module *M) { + LLVM_DEBUG(dbgs() << "TLI: List 'declare simd'-generated globals :\n"); + for (auto &GV : M->functions()) { + auto Name = GV.getName(); + // Skip invalid names + if (!isMangledName(Name)) + continue; + + const auto Ty = GV.getType(); + FunctionType *FTy; + // Skip invalid types + if (!isValidSignature(Ty, FTy)) + continue; + + LLVM_DEBUG(dbgs() << GV << "\n"); + + const auto Names = demangle(Name); + const VectorFnInfo VecInfo = {Names.first, FTy}; + VectorFunctionInfo.emplace(Names.second, VecInfo); + } + + LLVM_DEBUG(dbgs() << "TLI: List of functions added with 'declare simd':\n"); + for (auto &VFI : VectorFunctionInfo) { + LLVM_DEBUG(dbgs() << "Scalar name: " << VFI.first << "\n" + << "Vector name: " << VFI.second.Name << "\n" + << "Vector signature: " << *VFI.second.Signature << "\n"); + } +} + +namespace { +bool checkTys(ArrayRef Params) { + for (auto &Ty : Params) { + if (Ty->isVectorTy()) + return true; + } + return false; +} + +const std::string Prefix = "vec_prefix_"; +const std::string Postfix = "_vec_postfix"; +const std::string Midfix = "_vec_midfix_"; +} + +std::pair +TargetLibraryInfoImpl::demangle(const std::string In) { + StringRef Out = StringRef(In).drop_back(Postfix.size()); + StringRef Tmp = Out.drop_front(Prefix.size()); + auto Split = Tmp.split(Midfix); + return std::make_pair(Split.first, Split.second); +} + +bool TargetLibraryInfoImpl::isMangledName(const std::string Name) { + auto RefName = StringRef(Name); + const bool HasPrefix = RefName.startswith(Prefix); + const bool HasSuffix = RefName.endswith(Postfix); + const bool HasMidfix = RefName.contains(Midfix); + + if (HasPrefix && HasMidfix && HasSuffix) { + auto Split = demangle(Name); + return !Split.first.empty() && !Split.second.empty(); + } + + return false; +} + +bool TargetLibraryInfoImpl::isValidSignature(Type *Ty, FunctionType *&FTy) { + if (!Ty->isPointerTy()) + return false; + + FTy = dyn_cast(Ty->getPointerElementType()); + + if (!FTy) + return false; + + auto RetTy = FTy->getReturnType(); + + if (RetTy->isVectorTy()) + return true; + + return RetTy->isVoidTy() && checkTys(FTy->params()); +} + +std::string TargetLibraryInfoImpl::mangle(const std::string VecName, + const std::string ScalarName) { + const std::string Ret = Prefix + VecName + Midfix + ScalarName + Postfix; + return Ret; +} Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -247,7 +247,7 @@ unsigned TargetTransformInfo:: getOperandsScalarizationOverhead(ArrayRef Args, - unsigned VF) const { + VectorType::ElementCount VF) const { return TTIImpl->getOperandsScalarizationOverhead(Args, VF); } @@ -336,6 +336,10 @@ return TTIImpl->getRegisterBitWidth(Vector); } +unsigned TargetTransformInfo::getRegisterBitWidthUpperBound(bool Vector) const { + return TTIImpl->getRegisterBitWidthUpperBound(Vector); +} + unsigned TargetTransformInfo::getMinVectorRegisterBitWidth() const { return TTIImpl->getMinVectorRegisterBitWidth(); } @@ -452,6 +456,13 @@ return Cost; } +unsigned TargetTransformInfo::getVectorMemoryOpCost( + unsigned Opcode, Type *Src, Value *Ptr, unsigned Alignment, + unsigned AddressSpace, const MemAccessInfo &Info, Instruction *I) const { + return TTIImpl->getVectorMemoryOpCost(Opcode, Src, Ptr, Alignment, + AddressSpace, Info, I); +} + int TargetTransformInfo::getMaskedMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, unsigned AddressSpace) const { @@ -489,7 +500,8 @@ } int TargetTransformInfo::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, - ArrayRef Args, FastMathFlags FMF, unsigned VF) const { + ArrayRef Args, FastMathFlags FMF, + VectorType::ElementCount VF) const { int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Args, FMF, VF); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; @@ -549,6 +561,11 @@ return TTIImpl->getOrCreateResultFromMemIntrinsic(Inst, ExpectedType); } +bool TargetTransformInfo::hasVectorMemoryOp(unsigned Opcode, Type *Ty, + const MemAccessInfo &Info) const { + return TTIImpl->hasVectorMemoryOp(Opcode, Ty, Info); +} + Type *TargetTransformInfo::getMemcpyLoopLoweringType(LLVMContext &Context, Value *Length, unsigned SrcAlign, @@ -569,6 +586,15 @@ return TTIImpl->areInlineCompatible(Caller, Callee); } + +bool TargetTransformInfo::canVectorizeNonUnitStrides(bool forceFixedWidth) const { + return TTIImpl->canVectorizeNonUnitStrides(forceFixedWidth); +} + +bool TargetTransformInfo::vectorizePreventedSLForwarding() const { + return TTIImpl->vectorizePreventedSLForwarding(); +} + bool TargetTransformInfo::isIndexedLoadLegal(MemIndexedMode Mode, Type *Ty) const { return TTIImpl->isIndexedLoadLegal(Mode, Ty); @@ -626,6 +652,11 @@ return TTIImpl->shouldExpandReduction(II); } +bool TargetTransformInfo::canReduceInVector(unsigned Opcode, Type *ScalarTy, + ReductionFlags Flags) const { + return TTIImpl->canReduceInVector(Opcode, ScalarTy, Flags); +} + int TargetTransformInfo::getInstructionLatency(const Instruction *I) const { return TTIImpl->getInstructionLatency(I); } Index: lib/Analysis/ValueTracking.cpp =================================================================== --- lib/Analysis/ValueTracking.cpp +++ lib/Analysis/ValueTracking.cpp @@ -590,14 +590,14 @@ assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && "must be an assume intrinsic"); - Value *Arg = I->getArgOperand(0); + Value *ArgV = I->getArgOperand(0); - if (Arg == V && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + if (ArgV == V && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { assert(BitWidth == 1 && "assume operand is not i1?"); Known.setAllOnes(); return; } - if (match(Arg, m_Not(m_Specific(V))) && + if (match(ArgV, m_Not(m_Specific(V))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { assert(BitWidth == 1 && "assume operand is not i1?"); Known.setAllZero(); @@ -608,6 +608,10 @@ if (Depth == MaxDepth) continue; + ICmpInst *Arg = dyn_cast(I->getArgOperand(0)); + if (!Arg) + continue; + Value *A, *B; auto m_V = m_CombineOr(m_Specific(V), m_CombineOr(m_PtrToInt(m_Specific(V)), @@ -615,230 +619,243 @@ CmpInst::Predicate Pred; uint64_t C; - // assume(v = a) - if (match(Arg, m_c_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - Known.Zero |= RHSKnown.Zero; - Known.One |= RHSKnown.One; - // assume(v & b = a) - } else if (match(Arg, - m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits MaskKnown(BitWidth); - computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); + if (Arg->getPredicate() == ICmpInst::ICMP_EQ) { + // assume(v = a) + if (match(Arg, m_c_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + Known.Zero |= RHSKnown.Zero; + Known.One |= RHSKnown.One; + // assume(v & b = a) + } else if (match(Arg, + m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits MaskKnown(BitWidth); + computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); - // For those bits in the mask that are known to be one, we can propagate - // known bits from the RHS to V. - Known.Zero |= RHSKnown.Zero & MaskKnown.One; - Known.One |= RHSKnown.One & MaskKnown.One; - // assume(~(v & b) = a) - } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))), - m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits MaskKnown(BitWidth); - computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); + // For those bits in the mask that are known to be one, we can propagate + // known bits from the RHS to V. + Known.Zero |= RHSKnown.Zero & MaskKnown.One; + Known.One |= RHSKnown.One & MaskKnown.One; + // assume(~(v & b) = a) + } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits MaskKnown(BitWidth); + computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); - // For those bits in the mask that are known to be one, we can propagate - // inverted known bits from the RHS to V. - Known.Zero |= RHSKnown.One & MaskKnown.One; - Known.One |= RHSKnown.Zero & MaskKnown.One; - // assume(v | b = a) - } else if (match(Arg, - m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits BKnown(BitWidth); - computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + // For those bits in the mask that are known to be one, we can propagate + // inverted known bits from the RHS to V. + Known.Zero |= RHSKnown.One & MaskKnown.One; + Known.One |= RHSKnown.Zero & MaskKnown.One; + // assume(v | b = a) + } else if (match(Arg, + m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits BKnown(BitWidth); + computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); - // For those bits in B that are known to be zero, we can propagate known - // bits from the RHS to V. - Known.Zero |= RHSKnown.Zero & BKnown.Zero; - Known.One |= RHSKnown.One & BKnown.Zero; - // assume(~(v | b) = a) - } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))), - m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits BKnown(BitWidth); - computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + // For those bits in B that are known to be zero, we can propagate known + // bits from the RHS to V. + Known.Zero |= RHSKnown.Zero & BKnown.Zero; + Known.One |= RHSKnown.One & BKnown.Zero; + // assume(~(v | b) = a) + } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits BKnown(BitWidth); + computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); - // For those bits in B that are known to be zero, we can propagate - // inverted known bits from the RHS to V. - Known.Zero |= RHSKnown.One & BKnown.Zero; - Known.One |= RHSKnown.Zero & BKnown.Zero; - // assume(v ^ b = a) - } else if (match(Arg, - m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits BKnown(BitWidth); - computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + // For those bits in B that are known to be zero, we can propagate + // inverted known bits from the RHS to V. + Known.Zero |= RHSKnown.One & BKnown.Zero; + Known.One |= RHSKnown.Zero & BKnown.Zero; + // assume(v ^ b = a) + } else if (match(Arg, + m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits BKnown(BitWidth); + computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); - // For those bits in B that are known to be zero, we can propagate known - // bits from the RHS to V. For those bits in B that are known to be one, - // we can propagate inverted known bits from the RHS to V. - Known.Zero |= RHSKnown.Zero & BKnown.Zero; - Known.One |= RHSKnown.One & BKnown.Zero; - Known.Zero |= RHSKnown.One & BKnown.One; - Known.One |= RHSKnown.Zero & BKnown.One; - // assume(~(v ^ b) = a) - } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))), - m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits BKnown(BitWidth); - computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + // For those bits in B that are known to be zero, we can propagate known + // bits from the RHS to V. For those bits in B that are known to be one, + // we can propagate inverted known bits from the RHS to V. + Known.Zero |= RHSKnown.Zero & BKnown.Zero; + Known.One |= RHSKnown.One & BKnown.Zero; + Known.Zero |= RHSKnown.One & BKnown.One; + Known.One |= RHSKnown.Zero & BKnown.One; + // assume(~(v ^ b) = a) + } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits BKnown(BitWidth); + computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); - // For those bits in B that are known to be zero, we can propagate - // inverted known bits from the RHS to V. For those bits in B that are - // known to be one, we can propagate known bits from the RHS to V. - Known.Zero |= RHSKnown.One & BKnown.Zero; - Known.One |= RHSKnown.Zero & BKnown.Zero; - Known.Zero |= RHSKnown.Zero & BKnown.One; - Known.One |= RHSKnown.One & BKnown.One; - // assume(v << c = a) - } else if (match(Arg, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)), - m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT) && - C < BitWidth) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - // For those bits in RHS that are known, we can propagate them to known - // bits in V shifted to the right by C. - RHSKnown.Zero.lshrInPlace(C); - Known.Zero |= RHSKnown.Zero; - RHSKnown.One.lshrInPlace(C); - Known.One |= RHSKnown.One; - // assume(~(v << c) = a) - } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))), - m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT) && - C < BitWidth) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - // For those bits in RHS that are known, we can propagate them inverted - // to known bits in V shifted to the right by C. - RHSKnown.One.lshrInPlace(C); - Known.Zero |= RHSKnown.One; - RHSKnown.Zero.lshrInPlace(C); - Known.One |= RHSKnown.Zero; - // assume(v >> c = a) - } else if (match(Arg, - m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)), - m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT) && - C < BitWidth) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - // For those bits in RHS that are known, we can propagate them to known - // bits in V shifted to the right by C. - Known.Zero |= RHSKnown.Zero << C; - Known.One |= RHSKnown.One << C; - // assume(~(v >> c) = a) - } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_Shr(m_V, m_ConstantInt(C))), - m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT) && - C < BitWidth) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - // For those bits in RHS that are known, we can propagate them inverted - // to known bits in V shifted to the right by C. - Known.Zero |= RHSKnown.One << C; - Known.One |= RHSKnown.Zero << C; - // assume(v >=_s c) where c is non-negative - } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_SGE && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - - if (RHSKnown.isNonNegative()) { - // We know that the sign bit is zero. - Known.makeNonNegative(); + // For those bits in B that are known to be zero, we can propagate + // inverted known bits from the RHS to V. For those bits in B that are + // known to be one, we can propagate known bits from the RHS to V. + Known.Zero |= RHSKnown.One & BKnown.Zero; + Known.One |= RHSKnown.Zero & BKnown.Zero; + Known.Zero |= RHSKnown.Zero & BKnown.One; + Known.One |= RHSKnown.One & BKnown.One; + // assume(v << c = a) + } else if (match(Arg, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT) && + C < BitWidth) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them to known + // bits in V shifted to the right by C. + RHSKnown.Zero.lshrInPlace(C); + Known.Zero |= RHSKnown.Zero; + RHSKnown.One.lshrInPlace(C); + Known.One |= RHSKnown.One; + // assume(~(v << c) = a) + } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT) && + C < BitWidth) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them inverted + // to known bits in V shifted to the right by C. + RHSKnown.One.lshrInPlace(C); + Known.Zero |= RHSKnown.One; + RHSKnown.Zero.lshrInPlace(C); + Known.One |= RHSKnown.Zero; + // assume(v >> c = a) + } else if (match(Arg, + m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT) && + C < BitWidth) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them to known + // bits in V shifted to the right by C. + Known.Zero |= RHSKnown.Zero << C; + Known.One |= RHSKnown.One << C; + // assume(~(v >> c) = a) + } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_Shr(m_V, m_ConstantInt(C))), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT) && + C < BitWidth) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them inverted + // to known bits in V shifted to the right by C. + Known.Zero |= RHSKnown.One << C; + Known.One |= RHSKnown.Zero << C; } - // assume(v >_s c) where c is at least -1. - } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_SGT && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + } else if (Arg->getPredicate() == ICmpInst::ICMP_SGE) { + if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_SGE && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I)); - if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) { - // We know that the sign bit is zero. - Known.makeNonNegative(); + if (RHSKnown.isNonNegative()) { + // We know that the sign bit is zero. + Known.makeNonNegative(); + } } - // assume(v <=_s c) where c is negative - } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_SLE && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + } else if (Arg->getPredicate() == ICmpInst::ICMP_SGT) { + // assume(v >_s c) where c is at least -1. + if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_SGT && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I)); - if (RHSKnown.isNegative()) { - // We know that the sign bit is one. - Known.makeNegative(); + if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) { + // We know that the sign bit is zero. + Known.makeNonNegative(); + } } - // assume(v <_s c) where c is non-positive - } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_SLT && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + } else if (Arg->getPredicate() == ICmpInst::ICMP_SLE) { + // assume(v <=_s c) where c is negative + if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_SLE && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I)); - if (RHSKnown.isZero() || RHSKnown.isNegative()) { - // We know that the sign bit is one. - Known.makeNegative(); + if (RHSKnown.isNegative()) { + // We know that the sign bit is one. + Known.makeNegative(); + } + } + } else if (Arg->getPredicate() == ICmpInst::ICMP_SLT) { + // assume(v <_s c) where c is non-positive + if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_SLT && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + + if (RHSKnown.isZero() || RHSKnown.isNegative()) { + // We know that the sign bit is one. + Known.makeNegative(); + } } // assume(v <=_u c) - } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_ULE && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + } else if (Arg->getPredicate() == ICmpInst::ICMP_ULE) { + if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_ULE && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - // Whatever high bits in c are zero are known to be zero. - Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); - // assume(v <_u c) - } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_ULT && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - - // If the RHS is known zero, then this assumption must be wrong (nothing - // is unsigned less than zero). Signal a conflict and get out of here. - if (RHSKnown.isZero()) { - Known.Zero.setAllBits(); - Known.One.setAllBits(); - break; - } - - // Whatever high bits in c are zero are known to be zero (if c is a power - // of 2, then one more). - if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I))) - Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros() + 1); - else + // Whatever high bits in c are zero are known to be zero. Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); + } + // assume(v <_u c) + } else if (Arg->getPredicate() == ICmpInst::ICMP_ULT) { + if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_ULT && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + KnownBits RHSKnown(BitWidth); + computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + + // If the RHS is known zero, then this assumption must be wrong (nothing + // is unsigned less than zero). Signal a conflict and get out of here. + if (RHSKnown.isZero()) { + Known.Zero.setAllBits(); + Known.One.setAllBits(); + break; + } + + // Whatever high bits in c are zero are known to be zero (if c is a power + // of 2, then one more). + if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I))) + Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros() + 1); + else + Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); + } } } @@ -1626,7 +1643,13 @@ Known.resetAll(); // We can't imply anything about undefs. - if (isa(V)) + if (isa(V) || isa(V)) + return; + + if (isa(V)) + return; + + if (isa(V)) return; // There's no point in looking through other users of ConstantData for @@ -3446,7 +3469,7 @@ Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, unsigned MaxLookup) { - if (!V->getType()->isPointerTy()) + if (!V->getType()->isPtrOrPtrVectorTy()) return V; for (unsigned Count = 0; MaxLookup == 0 || Count < MaxLookup; ++Count) { if (GEPOperator *GEP = dyn_cast(V)) { @@ -3488,7 +3511,7 @@ return V; } - assert(V->getType()->isPointerTy() && "Unexpected operand type!"); + assert(V->getType()->isPtrOrPtrVectorTy() && "Unexpected operand type!"); } return V; } Index: lib/Analysis/VectorUtils.cpp =================================================================== --- lib/Analysis/VectorUtils.cpp +++ lib/Analysis/VectorUtils.cpp @@ -81,19 +81,83 @@ } } +/// If the given vector intrinsic ID has another version which has a +/// mask input, then return it. +static Intrinsic::ID getMaskedVectorIntrinsic(Intrinsic::ID ID) { + switch (ID) { + case Intrinsic::sin: + return Intrinsic::masked_sin; + case Intrinsic::cos: + return Intrinsic::masked_cos; + case Intrinsic::exp: + return Intrinsic::masked_exp; + case Intrinsic::exp2: + return Intrinsic::masked_exp2; + case Intrinsic::log: + return Intrinsic::masked_log; + case Intrinsic::log2: + return Intrinsic::masked_log2; + case Intrinsic::log10: + return Intrinsic::masked_log10; + case Intrinsic::powi: + return Intrinsic::masked_powi; + case Intrinsic::pow: + return Intrinsic::masked_pow; + case Intrinsic::copysign: + return Intrinsic::masked_copysign; + case Intrinsic::rint: + return Intrinsic::masked_rint; + case Intrinsic::maxnum: + return Intrinsic::masked_maxnum; + case Intrinsic::minnum: + return Intrinsic::masked_minnum; + default: + return Intrinsic::not_intrinsic; + } +} + +/// \brief Returns true if the given vector intrinsic is maskable. +std::pair llvm::isMaskedVectorIntrinsic(Intrinsic::ID ID) { + switch (ID) { + case Intrinsic::masked_sin: + case Intrinsic::masked_cos: + case Intrinsic::masked_exp: + case Intrinsic::masked_exp2: + case Intrinsic::masked_log: + case Intrinsic::masked_log2: + case Intrinsic::masked_log10: + case Intrinsic::masked_powi: + case Intrinsic::masked_pow: + case Intrinsic::masked_copysign: + case Intrinsic::masked_rint: + case Intrinsic::masked_maxnum: + case Intrinsic::masked_minnum: + return std::make_pair(true, 0); + case Intrinsic::masked_fmod: + return std::make_pair(true, 2); + default: + return std::make_pair(false, 0); + } +} /// Returns intrinsic ID for call. /// For the input call instruction it finds mapping intrinsic and returns /// its ID, in case it does not found it return not_intrinsic. +/// If UseMask is true, then find a masking vectorized function if available. Intrinsic::ID llvm::getVectorIntrinsicIDForCall(const CallInst *CI, - const TargetLibraryInfo *TLI) { + const TargetLibraryInfo *TLI, + bool UseMask) { Intrinsic::ID ID = getIntrinsicForCallSite(CI, TLI); if (ID == Intrinsic::not_intrinsic) return Intrinsic::not_intrinsic; if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || ID == Intrinsic::lifetime_end || ID == Intrinsic::assume || - ID == Intrinsic::sideeffect) + ID == Intrinsic::sideeffect) { + Intrinsic::ID MaskedIntr = getMaskedVectorIntrinsic(ID); + if (UseMask && MaskedIntr != Intrinsic::not_intrinsic) + return MaskedIntr; return ID; + } return Intrinsic::not_intrinsic; } @@ -237,7 +301,10 @@ assert(V->getType()->isVectorTy() && "Not looking at a vector?"); VectorType *VTy = cast(V->getType()); unsigned Width = VTy->getNumElements(); - if (EltNo >= Width) // Out of range access. + + // Out of range access for fixed-width vectors. Scalable vectors can accept + // any index. + if ((EltNo >= Width) && !VTy->isScalable()) return UndefValue::get(VTy->getElementType()); if (Constant *C = dyn_cast(V)) @@ -260,13 +327,15 @@ } if (ShuffleVectorInst *SVI = dyn_cast(V)) { - unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); - int InEl = SVI->getMaskValue(EltNo); - if (InEl < 0) - return UndefValue::get(VTy->getElementType()); - if (InEl < (int)LHSWidth) - return findScalarElement(SVI->getOperand(0), InEl); - return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); + int InEl; + if (SVI->getMaskValue(EltNo, InEl)) { + unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); + if (InEl < 0) + return UndefValue::get(VTy->getElementType()); + if (InEl < (int)LHSWidth) + return findScalarElement(SVI->getOperand(0), InEl); + return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); + } } // Extract a value from a vector add operation with a constant zero. @@ -295,7 +364,10 @@ if (!ShuffleInst) return nullptr; // All-zero (or undef) shuffle mask elements. - for (int MaskElt : ShuffleInst->getShuffleMask()) + SmallVector Mask; + if (!ShuffleInst->getShuffleMask(Mask)) + return nullptr; + for (int MaskElt : Mask) if (MaskElt != 0 && MaskElt != -1) return nullptr; // The first shuffle source is 'insertelement' with index 0. @@ -450,6 +522,50 @@ return MinBWs; } +static Instruction *getTestReduction(IRBuilder<> Builder, Value *Src, + Intrinsic::ID ID) { + auto *VTy = dyn_cast(Src->getType()); + if (!VTy || !VTy->isScalable() || !VTy->getElementType()->isIntegerTy(1)) + return nullptr; + + // Create an all lanes active predicate + Constant *Pred = ConstantInt::getTrue(VectorType::getBool(VTy)); + Module *M = Builder.GetInsertBlock()->getParent()->getParent(); + Function *Intrinsic = + Intrinsic::getDeclaration(M, ID, Src->getType()); + return Builder.CreateCall(Intrinsic, {Pred, Src}); +} + +Value *llvm::getAnyTrueReduction(IRBuilder<> &Builder, Value *Src, + const Twine &Name) { + auto Res = getTestReduction(Builder, Src, Intrinsic::aarch64_sve_orv); + Res->setName(Name); + return Res; +} + +Value *llvm::getAllTrueReduction(IRBuilder<> &Builder, Value *Src, + const Twine &Name) { + auto Res = getTestReduction(Builder, Src, Intrinsic::aarch64_sve_andv); + Res->setName(Name); + return Res; +} + +Value *llvm::getAllFalseReduction(IRBuilder<> &Builder, Value *Src, + const Twine &Name) { + Src = Builder.CreateNot(Src); + auto Res = getTestReduction(Builder, Src, Intrinsic::aarch64_sve_andv); + Res->setName(Name); + return Res; +} + +Value *llvm::getLastTrueVector(IRBuilder<> &Builder, Value *Src, + const Twine &Name) { + auto IdxTy = Builder.getInt64Ty(); + auto NumElts = ConstantExpr::getRuntimeNumElements(IdxTy, Src->getType()); + auto Idx = ConstantExpr::getSub(NumElts, ConstantInt::get(IdxTy, 1)); + return Builder.CreateExtractElement(Src, Idx, Name); +} + /// \returns \p I after propagating metadata from \p VL. Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef VL) { Instruction *I0 = cast(VL[0]); Index: lib/AsmParser/LLLexer.cpp =================================================================== --- lib/AsmParser/LLLexer.cpp +++ lib/AsmParser/LLLexer.cpp @@ -530,8 +530,11 @@ KEYWORD(localexec); KEYWORD(zeroinitializer); KEYWORD(undef); + KEYWORD(vscale); + KEYWORD(stepvector); KEYWORD(null); KEYWORD(none); + KEYWORD(as); KEYWORD(to); KEYWORD(caller); KEYWORD(within); @@ -592,6 +595,7 @@ KEYWORD(arm_apcscc); KEYWORD(arm_aapcscc); KEYWORD(arm_aapcs_vfpcc); + KEYWORD(aarch64_vector_pcs); KEYWORD(msp430_intrcc); KEYWORD(avr_intrcc); KEYWORD(avr_signalcc); @@ -704,6 +708,7 @@ KEYWORD(xchg); KEYWORD(nand); KEYWORD(max); KEYWORD(min); KEYWORD(umax); KEYWORD(umin); + KEYWORD(n); KEYWORD(x); KEYWORD(blockaddress); Index: lib/AsmParser/LLParser.h =================================================================== --- lib/AsmParser/LLParser.h +++ lib/AsmParser/LLParser.h @@ -55,7 +55,9 @@ t_Constant, // Value in ConstantVal. t_InlineAsm, // Value in FTy/StrVal/StrVal2/UIntVal. t_ConstantStruct, // Value in ConstantStructElts. - t_PackedConstantStruct // Value in ConstantStructElts. + t_PackedConstantStruct, // Value in ConstantStructElts. + t_StepVector, // No value. + t_VScale // No value. } Kind = t_LocalID; LLLexer::LocTy Loc; Index: lib/AsmParser/LLParser.cpp =================================================================== --- lib/AsmParser/LLParser.cpp +++ lib/AsmParser/LLParser.cpp @@ -1851,6 +1851,7 @@ /// ::= 'arm_apcscc' /// ::= 'arm_aapcscc' /// ::= 'arm_aapcs_vfpcc' +/// ::= 'aarch64_vector_pcs' /// ::= 'msp430_intrcc' /// ::= 'avr_intrcc' /// ::= 'avr_signalcc' @@ -1894,6 +1895,7 @@ case lltok::kw_arm_apcscc: CC = CallingConv::ARM_APCS; break; case lltok::kw_arm_aapcscc: CC = CallingConv::ARM_AAPCS; break; case lltok::kw_arm_aapcs_vfpcc:CC = CallingConv::ARM_AAPCS_VFP; break; + case lltok::kw_aarch64_vector_pcs:CC = CallingConv::AArch64_VectorCall; break; case lltok::kw_msp430_intrcc: CC = CallingConv::MSP430_INTR; break; case lltok::kw_avr_intrcc: CC = CallingConv::AVR_INTR; break; case lltok::kw_avr_signalcc: CC = CallingConv::AVR_SIGNAL; break; @@ -2659,10 +2661,22 @@ /// Type /// ::= '[' APSINTVAL 'x' Types ']' /// ::= '<' APSINTVAL 'x' Types '>' +/// ::= '<' 'n' 'x' APSINTVAL 'x' Types '>' bool LLParser::ParseArrayVectorType(Type *&Result, bool isVector) { + bool Scalable = false; /* assume fixed length vectors */ + + if (isVector && Lex.getKind() == lltok::kw_n) { + Lex.Lex(); // consume the 'n' + + if (ParseToken(lltok::kw_x, "expected 'x' after scalable vector specifier")) + return true; + + Scalable = true; /* scalable vector */ + } + if (Lex.getKind() != lltok::APSInt || Lex.getAPSIntVal().isSigned() || Lex.getAPSIntVal().getBitWidth() > 64) - return TokError("expected number in address space"); + return TokError("expected scalable vector or number in address space"); LocTy SizeLoc = Lex.getLoc(); uint64_t Size = Lex.getAPSIntVal().getZExtValue(); @@ -2686,7 +2700,7 @@ return Error(SizeLoc, "size too large for vector"); if (!VectorType::isValidElementType(EltTy)) return Error(TypeLoc, "invalid vector element type"); - Result = VectorType::get(EltTy, unsigned(Size)); + Result = VectorType::get(EltTy, unsigned(Size), Scalable); } else { if (!ArrayType::isValidElementType(EltTy)) return Error(TypeLoc, "invalid array element type"); @@ -2991,7 +3005,9 @@ ID.Kind = ValID::t_Constant; break; case lltok::kw_null: ID.Kind = ValID::t_Null; break; + case lltok::kw_stepvector: ID.Kind = ValID::t_StepVector; break; case lltok::kw_undef: ID.Kind = ValID::t_Undef; break; + case lltok::kw_vscale: ID.Kind = ValID::t_VScale; break; case lltok::kw_zeroinitializer: ID.Kind = ValID::t_Zero; break; case lltok::kw_none: ID.Kind = ValID::t_None; break; @@ -4258,6 +4274,28 @@ return false; } +/// ParseDIFortranSubrange: +/// ::= !DIFortranSubrange(lowerBound: 2) +bool LLParser::ParseDIFortranSubrange(MDNode *&Result, bool IsDistinct) { +#define VISIT_MD_FIELDS(OPTIONAL, REQUIRED) \ + OPTIONAL(constLowerBound, MDSignedField, (0, INT64_MIN, INT64_MAX)); \ + OPTIONAL(constUpperBound, MDSignedField, (0, INT64_MIN, INT64_MAX)); \ + OPTIONAL(lowerBound, MDField, ); \ + OPTIONAL(lowerBoundExpression, MDField, ); \ + OPTIONAL(upperBound, MDField, ); \ + OPTIONAL(upperBoundExpression, MDField, ); + PARSE_MD_FIELDS(); +#undef VISIT_MD_FIELDS + + Result = GET_OR_DISTINCT(DIFortranSubrange, + (Context, constLowerBound.Val, constUpperBound.Val, + (!constUpperBound.Seen) && (!upperBound.Seen), + lowerBound.Val, lowerBoundExpression.Val, + upperBound.Val, upperBoundExpression.Val)); + return false; +} + + /// ParseDIEnumerator: /// ::= !DIEnumerator(value: 30, isUnsigned: true, name: "SomeKind") bool LLParser::ParseDIEnumerator(MDNode *&Result, bool IsDistinct) { @@ -4297,6 +4335,26 @@ return false; } +/// ParseDIStringType: +/// ::= !DIStringType(name: "character(4)", size: 32, align: 32) +bool LLParser::ParseDIStringType(MDNode *&Result, bool IsDistinct) { +#define VISIT_MD_FIELDS(OPTIONAL, REQUIRED) \ + OPTIONAL(tag, DwarfTagField, (dwarf::DW_TAG_string_type)); \ + OPTIONAL(name, MDStringField, ); \ + OPTIONAL(stringLength, MDField, ); \ + OPTIONAL(stringLengthExpression, MDField, ); \ + OPTIONAL(size, MDUnsignedField, (0, UINT64_MAX)); \ + OPTIONAL(align, MDUnsignedField, (0, UINT32_MAX)); \ + OPTIONAL(encoding, DwarfAttEncodingField, ); + PARSE_MD_FIELDS(); +#undef VISIT_MD_FIELDS + + Result = GET_OR_DISTINCT(DIStringType, (Context, tag.Val, name.Val, + stringLength.Val, stringLengthExpression.Val, size.Val, align.Val, + encoding.Val)); + return false; +} + /// ParseDIDerivedType: /// ::= !DIDerivedType(tag: DW_TAG_pointer_type, name: "int", file: !0, /// line: 7, scope: !1, baseType: !2, size: 32, @@ -4374,6 +4432,31 @@ return false; } +bool LLParser::ParseDIFortranArrayType(MDNode *&Result, bool IsDistinct) { +#define VISIT_MD_FIELDS(OPTIONAL, REQUIRED) \ + OPTIONAL(tag, DwarfTagField, (dwarf::DW_TAG_array_type)); \ + OPTIONAL(name, MDStringField, ); \ + OPTIONAL(file, MDField, ); \ + OPTIONAL(line, LineField, ); \ + OPTIONAL(scope, MDField, ); \ + OPTIONAL(baseType, MDField, ); \ + OPTIONAL(size, MDUnsignedField, (0, UINT64_MAX)); \ + OPTIONAL(align, MDUnsignedField, (0, UINT32_MAX)); \ + OPTIONAL(offset, MDUnsignedField, (0, UINT64_MAX)); \ + OPTIONAL(flags, DIFlagField, ); \ + OPTIONAL(elements, MDField, ); + PARSE_MD_FIELDS(); +#undef VISIT_MD_FIELDS + + // Create a new node, and save it in the context if it belongs in the type + // map. + Result = GET_OR_DISTINCT( + DIFortranArrayType, + (Context, tag.Val, name.Val, file.Val, line.Val, scope.Val, baseType.Val, + size.Val, align.Val, offset.Val, flags.Val, elements.Val)); + return false; +} + bool LLParser::ParseDISubroutineType(MDNode *&Result, bool IsDistinct) { #define VISIT_MD_FIELDS(OPTIONAL, REQUIRED) \ OPTIONAL(flags, DIFlagField, ); \ @@ -4538,6 +4621,25 @@ return false; } +/// ParseDICommonBlock: +/// ::= !DICommonBlock(scope: !0, file: !2, name: "SomeNamespace", line: 9) +bool LLParser::ParseDICommonBlock(MDNode *&Result, bool IsDistinct) { +#define VISIT_MD_FIELDS(OPTIONAL, REQUIRED) \ + REQUIRED(scope, MDField, ); \ + OPTIONAL(declaration, MDField, ); \ + OPTIONAL(name, MDStringField, ); \ + OPTIONAL(file, MDField, ); \ + OPTIONAL(line, LineField, ); \ + OPTIONAL(align, MDUnsignedField, (0, UINT32_MAX)); + PARSE_MD_FIELDS(); +#undef VISIT_MD_FIELDS + + Result = GET_OR_DISTINCT(DICommonBlock, + (Context, scope.Val, declaration.Val, name.Val, + file.Val, line.Val, align.Val)); + return false; +} + /// ParseDINamespace: /// ::= !DINamespace(scope: !0, file: !2, name: "SomeNamespace", line: 9) bool LLParser::ParseDINamespace(MDNode *&Result, bool IsDistinct) { @@ -4594,12 +4696,15 @@ REQUIRED(name, MDStringField, ); \ OPTIONAL(configMacros, MDStringField, ); \ OPTIONAL(includePath, MDStringField, ); \ - OPTIONAL(isysroot, MDStringField, ); + OPTIONAL(isysroot, MDStringField, ); \ + OPTIONAL(file, MDField, ); \ + OPTIONAL(line, LineField, ); PARSE_MD_FIELDS(); #undef VISIT_MD_FIELDS Result = GET_OR_DISTINCT(DIModule, (Context, scope.Val, name.Val, - configMacros.Val, includePath.Val, isysroot.Val)); + configMacros.Val, includePath.Val, isysroot.Val, + file.Val, line.Val)); return false; } @@ -4649,6 +4754,7 @@ OPTIONAL(isLocal, MDBoolField, ); \ OPTIONAL(isDefinition, MDBoolField, (true)); \ OPTIONAL(declaration, MDField, ); \ + OPTIONAL(flags, DIFlagField, ); \ OPTIONAL(align, MDUnsignedField, (0, UINT32_MAX)); PARSE_MD_FIELDS(); #undef VISIT_MD_FIELDS @@ -4656,7 +4762,8 @@ Result = GET_OR_DISTINCT(DIGlobalVariable, (Context, scope.Val, name.Val, linkageName.Val, file.Val, line.Val, type.Val, isLocal.Val, - isDefinition.Val, declaration.Val, align.Val)); + isDefinition.Val, declaration.Val, flags.Val, + align.Val)); return false; } @@ -4951,12 +5058,22 @@ return Error(ID.Loc, "null must be a pointer type"); V = ConstantPointerNull::get(cast(Ty)); return false; + case ValID::t_StepVector: + if (!Ty->isVectorTy() || !Ty->getVectorElementType()->isIntegerTy()) + return Error(ID.Loc, "stepvector must be an integer vector type"); + V = StepVector::get(Ty); + return false; case ValID::t_Undef: // FIXME: LabelTy should not be a first-class type. if (!Ty->isFirstClassType() || Ty->isLabelTy()) return Error(ID.Loc, "invalid type for undef constant"); V = UndefValue::get(Ty); return false; + case ValID::t_VScale: + if (!Ty->isIntegerTy()) + return Error(ID.Loc, "vscale must be an integer type"); + V = VScale::get(Ty); + return false; case ValID::t_EmptyArray: if (!Ty->isArrayTy() || cast(Ty)->getNumElements() != 0) return Error(ID.Loc, "invalid empty array initializer"); @@ -5013,6 +5130,7 @@ case ValID::t_APSInt: case ValID::t_APFloat: case ValID::t_Undef: + case ValID::t_VScale: case ValID::t_Constant: case ValID::t_ConstantStruct: case ValID::t_PackedConstantStruct: { Index: lib/AsmParser/LLToken.h =================================================================== --- lib/AsmParser/LLToken.h +++ lib/AsmParser/LLToken.h @@ -38,6 +38,7 @@ bar, // | colon, // : + kw_n, kw_x, kw_true, kw_false, @@ -73,9 +74,12 @@ kw_initialexec, kw_localexec, kw_zeroinitializer, + kw_stepvector, kw_undef, + kw_vscale, kw_null, kw_none, + kw_as, kw_to, kw_caller, kw_within, @@ -139,6 +143,7 @@ kw_arm_apcscc, kw_arm_aapcscc, kw_arm_aapcs_vfpcc, + kw_aarch64_vector_pcs, kw_msp430_intrcc, kw_avr_intrcc, kw_avr_signalcc, Index: lib/Bitcode/Reader/BitcodeReader.cpp =================================================================== --- lib/Bitcode/Reader/BitcodeReader.cpp +++ lib/Bitcode/Reader/BitcodeReader.cpp @@ -1728,17 +1728,20 @@ return error("Invalid type"); ResultTy = ArrayType::get(ResultTy, Record[0]); break; - case bitc::TYPE_CODE_VECTOR: // VECTOR: [numelts, eltty] - if (Record.size() < 2) + case bitc::TYPE_CODE_VECTOR: { // VECTOR: [numblks, numelts, eltty] + unsigned Size = Record.size(); + if (Size < 2) return error("Invalid record"); - if (Record[0] == 0) + if (Record[Size - 2] == 0) return error("Invalid vector length"); - ResultTy = getTypeByID(Record[1]); + ResultTy = getTypeByID(Record[Size - 1]); if (!ResultTy || !StructType::isValidElementType(ResultTy)) return error("Invalid type"); - ResultTy = VectorType::get(ResultTy, Record[0]); + ResultTy = VectorType::get(ResultTy, Record[Size - 2], + Size > 2 ? Record[Size - 3] : false); break; } + } if (NumRecords >= TypeList.size()) return error("Invalid TYPE table"); @@ -2163,9 +2166,15 @@ unsigned BitCode = Stream.readRecord(Entry.ID, Record); switch (BitCode) { default: // Default behavior: unknown constant + case bitc::CST_CODE_STEPVEC: // STEPVEC + V = StepVector::get(CurTy); + break; case bitc::CST_CODE_UNDEF: // UNDEF V = UndefValue::get(CurTy); break; + case bitc::CST_CODE_VSCALE: // VSCALE + V = VScale::get(CurTy); + break; case bitc::CST_CODE_SETTYPE: // SETTYPE: [typeid] if (Record.empty()) return error("Invalid record"); @@ -2412,7 +2421,7 @@ if (VectorType *VTy = dyn_cast(CurTy)) if (Value *V = ValueList[Record[0]]) if (SelectorTy != V->getType()) - SelectorTy = VectorType::get(SelectorTy, VTy->getNumElements()); + SelectorTy = VectorType::get(SelectorTy, VTy->getElementCount()); V = ConstantExpr::getSelect(ValueList.getConstantFwdRef(Record[0], SelectorTy), @@ -2470,7 +2479,7 @@ Constant *Op0 = ValueList.getConstantFwdRef(Record[0], OpTy); Constant *Op1 = ValueList.getConstantFwdRef(Record[1], OpTy); Type *ShufTy = VectorType::get(Type::getInt32Ty(Context), - OpTy->getNumElements()); + OpTy->getElementCount()); Constant *Op2 = ValueList.getConstantFwdRef(Record[2], ShufTy); V = ConstantExpr::getShuffleVector(Op0, Op1, Op2); break; @@ -2484,7 +2493,7 @@ Constant *Op0 = ValueList.getConstantFwdRef(Record[1], OpTy); Constant *Op1 = ValueList.getConstantFwdRef(Record[2], OpTy); Type *ShufTy = VectorType::get(Type::getInt32Ty(Context), - RTy->getNumElements()); + RTy->getElementCount()); Constant *Op2 = ValueList.getConstantFwdRef(Record[3], ShufTy); V = ConstantExpr::getShuffleVector(Op0, Op1, Op2); break; Index: lib/Bitcode/Reader/MetadataLoader.cpp =================================================================== --- lib/Bitcode/Reader/MetadataLoader.cpp +++ lib/Bitcode/Reader/MetadataLoader.cpp @@ -804,10 +804,13 @@ case bitc::METADATA_LOCATION: case bitc::METADATA_GENERIC_DEBUG: case bitc::METADATA_SUBRANGE: + case bitc::METADATA_FORTRAN_SUBRANGE: case bitc::METADATA_ENUMERATOR: case bitc::METADATA_BASIC_TYPE: + case bitc::METADATA_STRING_TYPE: case bitc::METADATA_DERIVED_TYPE: case bitc::METADATA_COMPOSITE_TYPE: + case bitc::METADATA_FORTRAN_ARRAY_TYPE: case bitc::METADATA_SUBROUTINE_TYPE: case bitc::METADATA_MODULE: case bitc::METADATA_FILE: @@ -816,6 +819,7 @@ case bitc::METADATA_LEXICAL_BLOCK: case bitc::METADATA_LEXICAL_BLOCK_FILE: case bitc::METADATA_NAMESPACE: + case bitc::METADATA_COMMON_BLOCK: case bitc::METADATA_MACRO: case bitc::METADATA_MACRO_FILE: case bitc::METADATA_TEMPLATE_TYPE: @@ -1197,6 +1201,20 @@ NextMetadataNo++; break; } + case bitc::METADATA_FORTRAN_SUBRANGE: { + if (Record.size() != 8) + return error("Invalid record"); + + IsDistinct = Record[0]; + MetadataList.assignValue( + GET_OR_DISTINCT(DIFortranSubrange, + (Context, Record[1], Record[2], Record[3], + getMDOrNull(Record[4]), getMDOrNull(Record[5]), + getMDOrNull(Record[6]), getMDOrNull(Record[7]))), + NextMetadataNo); + NextMetadataNo++; + break; + } case bitc::METADATA_ENUMERATOR: { if (Record.size() != 3) return error("Invalid record"); @@ -1223,6 +1241,20 @@ NextMetadataNo++; break; } + case bitc::METADATA_STRING_TYPE: { + if (Record.size() != 6) + return error("Invalid record"); + + IsDistinct = Record[0]; + MetadataList.assignValue( + GET_OR_DISTINCT(DIStringType, + (Context, Record[1], getMDString(Record[2]), + getMDOrNull(Record[3]), getMDOrNull(Record[4]), + Record[5], Record[6], Record[7])), + NextMetadataNo); + NextMetadataNo++; + break; + } case bitc::METADATA_DERIVED_TYPE: { if (Record.size() < 12 || Record.size() > 13) return error("Invalid record"); @@ -1316,6 +1348,38 @@ NextMetadataNo++; break; } + case bitc::METADATA_FORTRAN_ARRAY_TYPE: { + if (Record.size() != 12) + return error("Invalid record"); + + // If we have a UUID and this is not a forward declaration, lookup the + // mapping. + IsDistinct = Record[0] & 0x1; + unsigned Tag = Record[1]; + MDString *Name = getMDString(Record[2]); + Metadata *File = getMDOrNull(Record[3]); + unsigned Line = Record[4]; + Metadata *Scope = getDITypeRefOrNull(Record[5]); + Metadata *BaseType = nullptr; + uint64_t SizeInBits = Record[7]; + if (Record[8] > (uint64_t)std::numeric_limits::max()) + return error("Alignment value is too large"); + uint32_t AlignInBits = Record[8]; + uint64_t OffsetInBits = 0; + DINode::DIFlags Flags = static_cast(Record[10]); + Metadata *Elements = nullptr; + BaseType = getDITypeRefOrNull(Record[6]); + OffsetInBits = Record[9]; + Elements = getMDOrNull(Record[11]); + DIFortranArrayType *CT = + GET_OR_DISTINCT(DIFortranArrayType, + (Context, Tag, Name, File, Line, Scope, BaseType, + SizeInBits, AlignInBits, OffsetInBits, Flags, + Elements)); + MetadataList.assignValue(CT, NextMetadataNo); + NextMetadataNo++; + break; + } case bitc::METADATA_SUBROUTINE_TYPE: { if (Record.size() < 3 || Record.size() > 4) return error("Invalid record"); @@ -1336,15 +1400,20 @@ } case bitc::METADATA_MODULE: { - if (Record.size() != 6) + if (Record.size() != 8) return error("Invalid record"); IsDistinct = Record[0]; MetadataList.assignValue( GET_OR_DISTINCT(DIModule, - (Context, getMDOrNull(Record[1]), - getMDString(Record[2]), getMDString(Record[3]), - getMDString(Record[4]), getMDString(Record[5]))), + (Context, + getMDOrNull(Record[1]), //Scope + getMDString(Record[2]), //Name + getMDString(Record[3]), //ConfigMacros + getMDString(Record[4]), //IncludePath + getMDString(Record[5]), //ISysRoot + getMDOrNull(Record[6]), //File + Record[7])), //Line NextMetadataNo); NextMetadataNo++; break; @@ -1485,6 +1554,17 @@ NextMetadataNo++; break; } + case bitc::METADATA_COMMON_BLOCK: { + IsDistinct = Record[0] & 1; + MetadataList.assignValue( + GET_OR_DISTINCT(DICommonBlock, + (Context, getMDOrNull(Record[1]), + getMDOrNull(Record[2]), getMDString(Record[3]), + getMDOrNull(Record[4]), Record[5], Record[6])), + NextMetadataNo); + NextMetadataNo++; + break; + } case bitc::METADATA_NAMESPACE: { // Newer versions of DINamespace dropped file and line. MDString *Name; @@ -1557,20 +1637,35 @@ break; } case bitc::METADATA_GLOBAL_VAR: { - if (Record.size() < 11 || Record.size() > 12) + if (Record.size() < 11 || Record.size() > 13) return error("Invalid record"); IsDistinct = Record[0] & 1; unsigned Version = Record[0] >> 1; - if (Version == 1) { + if (Version == 3) { + // Add support for DIFlags + MetadataList.assignValue( + GET_OR_DISTINCT( + DIGlobalVariable, + (Context, getMDOrNull(Record[1]), getMDString(Record[2]), + getMDString(Record[3]), getMDOrNull(Record[4]), Record[5], + getDITypeRefOrNull(Record[6]), Record[7], Record[8], + getMDOrNull(Record[10]), + static_cast(Record[11]), Record[12])), + NextMetadataNo); + + NextMetadataNo++; + } else if (Version == 1) { + // No upgrade necessary. A null field will be introduced to indicate + // that no parameter information is available. MetadataList.assignValue( GET_OR_DISTINCT(DIGlobalVariable, (Context, getMDOrNull(Record[1]), getMDString(Record[2]), getMDString(Record[3]), getMDOrNull(Record[4]), Record[5], getDITypeRefOrNull(Record[6]), Record[7], Record[8], - getMDOrNull(Record[10]), Record[11])), + getMDOrNull(Record[10]), DINode::FlagZero, Record[11])), NextMetadataNo); NextMetadataNo++; } else if (Version == 0) { @@ -1602,7 +1697,7 @@ (Context, getMDOrNull(Record[1]), getMDString(Record[2]), getMDString(Record[3]), getMDOrNull(Record[4]), Record[5], getDITypeRefOrNull(Record[6]), Record[7], Record[8], - getMDOrNull(Record[10]), AlignInBits)); + getMDOrNull(Record[10]), DINode::FlagZero, AlignInBits)); DIGlobalVariableExpression *DGVE = nullptr; if (Attach || Expr) Index: lib/Bitcode/Writer/BitcodeWriter.cpp =================================================================== --- lib/Bitcode/Writer/BitcodeWriter.cpp +++ lib/Bitcode/Writer/BitcodeWriter.cpp @@ -294,14 +294,22 @@ SmallVectorImpl &Record, unsigned &Abbrev); void writeDISubrange(const DISubrange *N, SmallVectorImpl &Record, unsigned Abbrev); + void writeDIFortranSubrange(const DIFortranSubrange *N, + SmallVectorImpl &Record, + unsigned Abbrev); void writeDIEnumerator(const DIEnumerator *N, SmallVectorImpl &Record, unsigned Abbrev); void writeDIBasicType(const DIBasicType *N, SmallVectorImpl &Record, unsigned Abbrev); + void writeDIStringType(const DIStringType *N, + SmallVectorImpl &Record, unsigned Abbrev); void writeDIDerivedType(const DIDerivedType *N, SmallVectorImpl &Record, unsigned Abbrev); void writeDICompositeType(const DICompositeType *N, SmallVectorImpl &Record, unsigned Abbrev); + void writeDIFortranArrayType(const DIFortranArrayType *N, + SmallVectorImpl &Record, + unsigned Abbrev); void writeDISubroutineType(const DISubroutineType *N, SmallVectorImpl &Record, unsigned Abbrev); @@ -316,6 +324,8 @@ void writeDILexicalBlockFile(const DILexicalBlockFile *N, SmallVectorImpl &Record, unsigned Abbrev); + void writeDICommonBlock(const DICommonBlock *N, + SmallVectorImpl &Record, unsigned Abbrev); void writeDINamespace(const DINamespace *N, SmallVectorImpl &Record, unsigned Abbrev); void writeDIMacro(const DIMacro *N, SmallVectorImpl &Record, @@ -917,6 +927,7 @@ VectorType *VT = cast(T); // VECTOR [numelts, eltty] Code = bitc::TYPE_CODE_VECTOR; + TypeVals.push_back(VT->isScalable()); TypeVals.push_back(VT->getNumElements()); TypeVals.push_back(VE.getTypeID(VT->getElementType())); break; @@ -1466,6 +1477,22 @@ Record.clear(); } +void ModuleBitcodeWriter::writeDIFortranSubrange( + const DIFortranSubrange *N, SmallVectorImpl &Record, + unsigned Abbrev) { + Record.push_back(N->isDistinct()); + Record.push_back(N->getCLowerBound()); + Record.push_back(N->getCUpperBound()); + Record.push_back(N->noUpperBound()); + Record.push_back(VE.getMetadataOrNullID(N->getLowerBound())); + Record.push_back(VE.getMetadataOrNullID(N->getLowerBoundExp())); + Record.push_back(VE.getMetadataOrNullID(N->getUpperBound())); + Record.push_back(VE.getMetadataOrNullID(N->getUpperBoundExp())); + + Stream.EmitRecord(bitc::METADATA_FORTRAN_SUBRANGE, Record, Abbrev); + Record.clear(); +} + void ModuleBitcodeWriter::writeDIEnumerator(const DIEnumerator *N, SmallVectorImpl &Record, unsigned Abbrev) { @@ -1491,6 +1518,22 @@ Record.clear(); } +void ModuleBitcodeWriter::writeDIStringType(const DIStringType *N, + SmallVectorImpl &Record, + unsigned Abbrev) { + Record.push_back(N->isDistinct()); + Record.push_back(N->getTag()); + Record.push_back(VE.getMetadataOrNullID(N->getRawName())); + Record.push_back(VE.getMetadataOrNullID(N->getStringLength())); + Record.push_back(VE.getMetadataOrNullID(N->getStringLengthExp())); + Record.push_back(N->getSizeInBits()); + Record.push_back(N->getAlignInBits()); + Record.push_back(N->getEncoding()); + + Stream.EmitRecord(bitc::METADATA_STRING_TYPE, Record, Abbrev); + Record.clear(); +} + void ModuleBitcodeWriter::writeDIDerivedType(const DIDerivedType *N, SmallVectorImpl &Record, unsigned Abbrev) { @@ -1544,6 +1587,26 @@ Record.clear(); } +void ModuleBitcodeWriter::writeDIFortranArrayType( + const DIFortranArrayType *N, SmallVectorImpl &Record, + unsigned Abbrev) { + Record.push_back(N->isDistinct()); + Record.push_back(N->getTag()); + Record.push_back(VE.getMetadataOrNullID(N->getRawName())); + Record.push_back(VE.getMetadataOrNullID(N->getFile())); + Record.push_back(N->getLine()); + Record.push_back(VE.getMetadataOrNullID(N->getScope())); + Record.push_back(VE.getMetadataOrNullID(N->getBaseType())); + Record.push_back(N->getSizeInBits()); + Record.push_back(N->getAlignInBits()); + Record.push_back(N->getOffsetInBits()); + Record.push_back(N->getFlags()); + Record.push_back(VE.getMetadataOrNullID(N->getElements().get())); + + Stream.EmitRecord(bitc::METADATA_FORTRAN_ARRAY_TYPE, Record, Abbrev); + Record.clear(); +} + void ModuleBitcodeWriter::writeDISubroutineType( const DISubroutineType *N, SmallVectorImpl &Record, unsigned Abbrev) { @@ -1663,6 +1726,21 @@ Record.clear(); } +void ModuleBitcodeWriter::writeDICommonBlock(const DICommonBlock *N, + SmallVectorImpl &Record, + unsigned Abbrev) { + Record.push_back(N->isDistinct()); + Record.push_back(VE.getMetadataOrNullID(N->getScope())); + Record.push_back(VE.getMetadataOrNullID(N->getDecl())); + Record.push_back(VE.getMetadataOrNullID(N->getRawName())); + Record.push_back(VE.getMetadataOrNullID(N->getFile())); + Record.push_back(N->getLineNo()); + Record.push_back(N->getAlignInBits()); + + Stream.EmitRecord(bitc::METADATA_COMMON_BLOCK, Record, Abbrev); + Record.clear(); +} + void ModuleBitcodeWriter::writeDINamespace(const DINamespace *N, SmallVectorImpl &Record, unsigned Abbrev) { @@ -1704,8 +1782,13 @@ SmallVectorImpl &Record, unsigned Abbrev) { Record.push_back(N->isDistinct()); - for (auto &I : N->operands()) - Record.push_back(VE.getMetadataOrNullID(I)); + Record.push_back(VE.getMetadataOrNullID(N->getRawScope())); + Record.push_back(VE.getMetadataOrNullID(N->getRawName())); + Record.push_back(VE.getMetadataOrNullID(N->getRawConfigurationMacros())); + Record.push_back(VE.getMetadataOrNullID(N->getRawIncludePath())); + Record.push_back(VE.getMetadataOrNullID(N->getRawISysRoot())); + Record.push_back(VE.getMetadataOrNullID(N->getRawFile())); + Record.push_back(N->getLine()); Stream.EmitRecord(bitc::METADATA_MODULE, Record, Abbrev); Record.clear(); @@ -1738,7 +1821,7 @@ void ModuleBitcodeWriter::writeDIGlobalVariable( const DIGlobalVariable *N, SmallVectorImpl &Record, unsigned Abbrev) { - const uint64_t Version = 1 << 1; + const uint64_t Version = 3 << 1; Record.push_back((uint64_t)N->isDistinct() | Version); Record.push_back(VE.getMetadataOrNullID(N->getScope())); Record.push_back(VE.getMetadataOrNullID(N->getRawName())); @@ -1750,6 +1833,7 @@ Record.push_back(N->isDefinition()); Record.push_back(/* expr */ 0); Record.push_back(VE.getMetadataOrNullID(N->getStaticDataMemberDeclaration())); + Record.push_back(N->getFlags()); Record.push_back(N->getAlignInBits()); Stream.EmitRecord(bitc::METADATA_GLOBAL_VAR, Record, Abbrev); @@ -2272,8 +2356,12 @@ unsigned AbbrevToUse = 0; if (C->isNullValue()) { Code = bitc::CST_CODE_NULL; + } else if (isa(C)) { + Code = bitc::CST_CODE_STEPVEC; } else if (isa(C)) { Code = bitc::CST_CODE_UNDEF; + } else if (isa(C)) { + Code = bitc::CST_CODE_VSCALE; } else if (const ConstantInt *IV = dyn_cast(C)) { if (IV->getBitWidth() <= 64) { uint64_t V = IV->getSExtValue(); Index: lib/CodeGen/Analysis.cpp =================================================================== --- lib/CodeGen/Analysis.cpp +++ lib/CodeGen/Analysis.cpp @@ -83,26 +83,55 @@ /// void llvm::ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL, Type *Ty, SmallVectorImpl &ValueVTs, - SmallVectorImpl *Offsets, - uint64_t StartingOffset) { + SmallVectorImpl *Offsets, + FieldOffsets StartingOffset) { // Given a struct type, recursively traverse the elements. if (StructType *STy = dyn_cast(Ty)) { + bool IsSizeLess = false; const StructLayout *SL = DL.getStructLayout(STy); for (StructType::element_iterator EB = STy->element_begin(), EI = EB, EE = STy->element_end(); - EI != EE; ++EI) - ComputeValueVTs(TLI, DL, *EI, ValueVTs, Offsets, - StartingOffset + SL->getElementOffset(EI - EB)); + EI != EE; ++EI) { + FieldOffsets ElementOffset = StartingOffset; + auto *VTy = dyn_cast(*EI); + + if (VTy && VTy->isScalable()) { + ElementOffset.ScaledBytes += SL->getElementOffset(EI - EB); + IsSizeLess = true; + } else { + ElementOffset.UnscaledBytes += SL->getElementOffset(EI - EB); + } + ComputeValueVTs(TLI, DL, *EI, ValueVTs, Offsets, ElementOffset); + } + + // We don't handle padding in sizeless structs yet if we're actually + // trying to store them; if the fields are just being passed around + // in registers we're fine. + // + // getStructLayout will need to return offset as separate values for + // previous elements combined size + padding bytes for this to work. + // + // Maybe we should force sizeless structs to be packed in clang? + if (IsSizeLess && Offsets) + assert(!(SL->hasPadding()) && "Padding in sizeless struct"); + return; } // Given an array type, recursively traverse the elements. if (ArrayType *ATy = dyn_cast(Ty)) { Type *EltTy = ATy->getElementType(); uint64_t EltSize = DL.getTypeAllocSize(EltTy); - for (unsigned i = 0, e = ATy->getNumElements(); i != e; ++i) - ComputeValueVTs(TLI, DL, EltTy, ValueVTs, Offsets, - StartingOffset + i * EltSize); + for (unsigned i = 0, e = ATy->getNumElements(); i != e; ++i) { + FieldOffsets ElementOffset = StartingOffset; + auto *VTy = dyn_cast(EltTy); + + // We don't yet handle arrays of scalable types + assert(!(VTy && VTy->isScalable()) && "Scalable array type"); + + ElementOffset.UnscaledBytes += i * EltSize; + ComputeValueVTs(TLI, DL, EltTy, ValueVTs, Offsets, ElementOffset); + } return; } // Interpret void as zero return values. Index: lib/CodeGen/AsmPrinter/AsmPrinter.cpp =================================================================== --- lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -749,18 +749,30 @@ const MachineFrameInfo &MFI = MF->getFrameInfo(); bool Commented = false; + auto getSize = + [&MFI](const SmallVectorImpl &Accesses) { + unsigned Size = 0; + for (auto A : Accesses) + if (MFI.isSpillSlotObjectIndex( + cast(A->getPseudoValue()) + ->getFrameIndex())) + Size += A->getSize(); + return Size; + }; + // We assume a single instruction only has a spill or reload, not // both. const MachineMemOperand *MMO; + SmallVector Accesses; if (TII->isLoadFromStackSlotPostFE(MI, FI)) { if (MFI.isSpillSlotObjectIndex(FI)) { MMO = *MI.memoperands_begin(); CommentOS << MMO->getSize() << "-byte Reload"; Commented = true; } - } else if (TII->hasLoadFromStackSlot(MI, MMO, FI)) { - if (MFI.isSpillSlotObjectIndex(FI)) { - CommentOS << MMO->getSize() << "-byte Folded Reload"; + } else if (TII->hasLoadFromStackSlot(MI, Accesses)) { + if (auto Size = getSize(Accesses)) { + CommentOS << Size << "-byte Folded Reload"; Commented = true; } } else if (TII->isStoreToStackSlotPostFE(MI, FI)) { @@ -769,9 +781,9 @@ CommentOS << MMO->getSize() << "-byte Spill"; Commented = true; } - } else if (TII->hasStoreToStackSlot(MI, MMO, FI)) { - if (MFI.isSpillSlotObjectIndex(FI)) { - CommentOS << MMO->getSize() << "-byte Folded Spill"; + } else if (TII->hasStoreToStackSlot(MI, Accesses)) { + if (auto Size = getSize(Accesses)) { + CommentOS << Size << "-byte Folded Spill"; Commented = true; } } Index: lib/CodeGen/AsmPrinter/AsmPrinterDwarf.cpp =================================================================== --- lib/CodeGen/AsmPrinter/AsmPrinterDwarf.cpp +++ lib/CodeGen/AsmPrinter/AsmPrinterDwarf.cpp @@ -219,6 +219,7 @@ OutStreamer->EmitCFIGnuArgsSize(Inst.getOffset()); break; case MCCFIInstruction::OpEscape: + OutStreamer->AddComment(Inst.getComment()); OutStreamer->EmitCFIEscape(Inst.getValues()); break; case MCCFIInstruction::OpRestore: Index: lib/CodeGen/AsmPrinter/DebugLocEntry.h =================================================================== --- lib/CodeGen/AsmPrinter/DebugLocEntry.h +++ lib/CodeGen/AsmPrinter/DebugLocEntry.h @@ -151,6 +151,9 @@ /// Lower this entry into a DWARF expression. void finalize(const AsmPrinter &AP, DebugLocStream::ListBuilder &List, const DIBasicType *BT); + + void finalize(const AsmPrinter &AP, DebugLocStream::ListBuilder &List, + const DIStringType *ST); }; /// Compare two Values for equality. Index: lib/CodeGen/AsmPrinter/DebugLocStream.h =================================================================== --- lib/CodeGen/AsmPrinter/DebugLocStream.h +++ lib/CodeGen/AsmPrinter/DebugLocStream.h @@ -157,17 +157,21 @@ DbgVariable &V; const MachineInstr &MI; size_t ListIndex; + bool Finalized; public: ListBuilder(DebugLocStream &Locs, DwarfCompileUnit &CU, AsmPrinter &Asm, DbgVariable &V, const MachineInstr &MI) - : Locs(Locs), Asm(Asm), V(V), MI(MI), ListIndex(Locs.startList(&CU)) {} + : Locs(Locs), Asm(Asm), V(V), MI(MI), ListIndex(Locs.startList(&CU)), + Finalized(false) {} + + void finalize(); /// Finalize the list. /// /// If the list is empty, delete it. Otherwise, finalize it by creating a /// temp symbol in \a Asm and setting up the \a DbgVariable. - ~ListBuilder(); + ~ListBuilder() { finalize(); } DebugLocStream &getLocs() { return Locs; } }; Index: lib/CodeGen/AsmPrinter/DebugLocStream.cpp =================================================================== --- lib/CodeGen/AsmPrinter/DebugLocStream.cpp +++ lib/CodeGen/AsmPrinter/DebugLocStream.cpp @@ -38,7 +38,10 @@ "Popped off more entries than are in the list"); } -DebugLocStream::ListBuilder::~ListBuilder() { +void DebugLocStream::ListBuilder::finalize() { + if (Finalized) + return; + Finalized = true; if (!Locs.finalizeList(Asm)) return; V.initializeDbgValue(&MI); Index: lib/CodeGen/AsmPrinter/DwarfCompileUnit.h =================================================================== --- lib/CodeGen/AsmPrinter/DwarfCompileUnit.h +++ lib/CodeGen/AsmPrinter/DwarfCompileUnit.h @@ -132,6 +132,12 @@ getOrCreateGlobalVariableDIE(const DIGlobalVariable *GV, ArrayRef GlobalExprs); + DIE *getOrCreateCommonBlock(const DICommonBlock *CB, + ArrayRef GlobalExprs); + + void addLocationAttribute(DIE *ToDIE, const DIGlobalVariable *GV, + ArrayRef GlobalExprs); + /// addLabelAddress - Add a dwarf label attribute data and value using /// either DW_FORM_addr or DW_FORM_GNU_addr_index. void addLabelAddress(DIE &Die, dwarf::Attribute Attribute, @@ -304,6 +310,12 @@ void setDWOId(uint64_t DwoId) { DWOId = DwoId; } bool hasDwarfPubSections() const; + + void constructDieLocation(DIE &Die, dwarf::Attribute Attribute, + const DbgVariable &DV); + void constructDieLocationAddExpr(DIE &Die, dwarf::Attribute Attribute, + const DbgVariable &DV, + DIExpression *SubExpr); }; } // end namespace llvm Index: lib/CodeGen/AsmPrinter/DwarfCompileUnit.cpp =================================================================== --- lib/CodeGen/AsmPrinter/DwarfCompileUnit.cpp +++ lib/CodeGen/AsmPrinter/DwarfCompileUnit.cpp @@ -108,59 +108,8 @@ File->getSource(), CUID); } -DIE *DwarfCompileUnit::getOrCreateGlobalVariableDIE( - const DIGlobalVariable *GV, ArrayRef GlobalExprs) { - // Check for pre-existence. - if (DIE *Die = getDIE(GV)) - return Die; - - assert(GV); - - auto *GVContext = GV->getScope(); - auto *GTy = DD->resolve(GV->getType()); - - // Construct the context before querying for the existence of the DIE in - // case such construction creates the DIE. - DIE *ContextDIE = getOrCreateContextDIE(GVContext); - - // Add to map. - DIE *VariableDIE = &createAndAddDIE(GV->getTag(), *ContextDIE, GV); - DIScope *DeclContext; - if (auto *SDMDecl = GV->getStaticDataMemberDeclaration()) { - DeclContext = resolve(SDMDecl->getScope()); - assert(SDMDecl->isStaticMember() && "Expected static member decl"); - assert(GV->isDefinition()); - // We need the declaration DIE that is in the static member's class. - DIE *VariableSpecDIE = getOrCreateStaticMemberDIE(SDMDecl); - addDIEEntry(*VariableDIE, dwarf::DW_AT_specification, *VariableSpecDIE); - // If the global variable's type is different from the one in the class - // member type, assume that it's more specific and also emit it. - if (GTy != DD->resolve(SDMDecl->getBaseType())) - addType(*VariableDIE, GTy); - } else { - DeclContext = GV->getScope(); - // Add name and type. - addString(*VariableDIE, dwarf::DW_AT_name, GV->getDisplayName()); - addType(*VariableDIE, GTy); - - // Add scoping info. - if (!GV->isLocalToUnit()) - addFlag(*VariableDIE, dwarf::DW_AT_external); - - // Add line number info. - addSourceLine(*VariableDIE, GV); - } - - if (!GV->isDefinition()) - addFlag(*VariableDIE, dwarf::DW_AT_declaration); - else - addGlobalName(GV->getName(), *VariableDIE, DeclContext); - - if (uint32_t AlignInBytes = GV->getAlignInBytes()) - addUInt(*VariableDIE, dwarf::DW_AT_alignment, dwarf::DW_FORM_udata, - AlignInBytes); - - // Add location. +void DwarfCompileUnit::addLocationAttribute( + DIE *ToDIE, const DIGlobalVariable *GV, ArrayRef GlobalExprs) { bool addToAccelTable = false; DIELoc *Loc = nullptr; std::unique_ptr DwarfExpr; @@ -173,7 +122,7 @@ // DW_AT_const_value(X). if (GlobalExprs.size() == 1 && Expr && Expr->isConstant()) { addToAccelTable = true; - addConstantValue(*VariableDIE, /*Unsigned=*/true, Expr->getElement(1)); + addConstantValue(*ToDIE, /*Unsigned=*/true, Expr->getElement(1)); break; } @@ -239,20 +188,105 @@ DwarfExpr->addExpression(Expr); } if (Loc) - addBlock(*VariableDIE, dwarf::DW_AT_location, DwarfExpr->finalize()); + addBlock(*ToDIE, dwarf::DW_AT_location, DwarfExpr->finalize()); if (DD->useAllLinkageNames()) - addLinkageName(*VariableDIE, GV->getLinkageName()); + addLinkageName(*ToDIE, GV->getLinkageName()); if (addToAccelTable) { - DD->addAccelName(GV->getName(), *VariableDIE); + DD->addAccelName(GV->getName(), *ToDIE); // If the linkage name is different than the name, go ahead and output // that as well into the name table. if (GV->getLinkageName() != "" && GV->getName() != GV->getLinkageName() && DD->useAllLinkageNames()) - DD->addAccelName(GV->getLinkageName(), *VariableDIE); + DD->addAccelName(GV->getLinkageName(), *ToDIE); } +} + +DIE *DwarfCompileUnit::getOrCreateCommonBlock( + const DICommonBlock *CB, ArrayRef GlobalExprs) { + // Construct the context before querying for the existence of the DIE in case + // such construction creates the DIE. + DIE *ContextDIE = getOrCreateContextDIE(CB->getScope()); + + if (DIE *NDie = getDIE(CB)) + return NDie; + DIE &NDie = createAndAddDIE(dwarf::DW_TAG_common_block, *ContextDIE, CB); + StringRef Name = CB->getName().empty() ? "_BLNK_" : CB->getName(); + addString(NDie, dwarf::DW_AT_name, Name); + addGlobalName(Name, NDie, CB->getScope()); + if (CB->getFile()) + addSourceLine(NDie, CB->getLineNo(), CB->getFile()); + if (DIGlobalVariable *V = CB->getDecl()) + getCU().addLocationAttribute(&NDie, V, GlobalExprs); + if (uint32_t AlignInBits = CB->getAlignInBits()) { + uint32_t AlignInBytes = AlignInBits >> 3; + addUInt(NDie, dwarf::DW_AT_alignment, dwarf::DW_FORM_udata, AlignInBytes); + } + return &NDie; +} + +DIE *DwarfCompileUnit::getOrCreateGlobalVariableDIE( + const DIGlobalVariable *GV, ArrayRef GlobalExprs) { + // Check for pre-existence. + if (DIE *Die = getDIE(GV)) + return Die; + + assert(GV); + + auto *GVContext = GV->getScope(); + auto *GTy = DD->resolve(GV->getType()); + + // Construct the context before querying for the existence of the DIE in + // case such construction creates the DIE. + auto *CB = GVContext ? dyn_cast(GVContext) : nullptr; + DIE *ContextDIE = CB ? getOrCreateCommonBlock(CB, GlobalExprs) + : getOrCreateContextDIE(GVContext); + + // Add to map. + DIE *VariableDIE = &createAndAddDIE(GV->getTag(), *ContextDIE, GV); + DIScope *DeclContext; + if (auto *SDMDecl = GV->getStaticDataMemberDeclaration()) { + DeclContext = resolve(SDMDecl->getScope()); + assert(SDMDecl->isStaticMember() && "Expected static member decl"); + assert(GV->isDefinition()); + // We need the declaration DIE that is in the static member's class. + DIE *VariableSpecDIE = getOrCreateStaticMemberDIE(SDMDecl); + addDIEEntry(*VariableDIE, dwarf::DW_AT_specification, *VariableSpecDIE); + // If the global variable's type is different from the one in the class + // member type, assume that it's more specific and also emit it. + if (GTy != DD->resolve(SDMDecl->getBaseType())) + addType(*VariableDIE, GTy); + } else { + DeclContext = GV->getScope(); + // Add name and type. + if (!GV->getDisplayName().empty()) + addString(*VariableDIE, dwarf::DW_AT_name, GV->getDisplayName()); + addType(*VariableDIE, GTy); + + // Add scoping info. + if (!GV->isLocalToUnit()) + addFlag(*VariableDIE, dwarf::DW_AT_external); + + // Add line number info. + addSourceLine(*VariableDIE, GV); + } + + if (!GV->isDefinition()) + addFlag(*VariableDIE, dwarf::DW_AT_declaration); + else + addGlobalName(GV->getName(), *VariableDIE, DeclContext); + + if (GV->isArtificial()) + addFlag(*VariableDIE, dwarf::DW_AT_artificial); + + if (uint32_t AlignInBytes = GV->getAlignInBytes()) + addUInt(*VariableDIE, dwarf::DW_AT_alignment, dwarf::DW_FORM_udata, + AlignInBytes); + + // Add location. + addLocationAttribute(VariableDIE, GV, GlobalExprs); return VariableDIE; } @@ -588,6 +622,43 @@ return VariableDie; } +void DwarfCompileUnit::constructDieLocationAddExpr( + DIE &Die, dwarf::Attribute Attribute, const DbgVariable &DV, + DIExpression *SubExpr) { + if (Attribute == dwarf::DW_AT_location) + return; // clients like gdb don't handle location lists correctly + if (DV.getMInsn()) + return; // temp should not have a DBG_VALUE instruction + if (!DV.hasFrameIndexExprs()) + return; // but it should have a frame index expression + + DIELoc *Loc = new (DIEValueAllocator) DIELoc; + DIEDwarfExpression DwarfExpr(*Asm, *this, *Loc); + for (auto &Fragment : DV.getFrameIndexExprs()) { + unsigned FrameReg = 0; + const DIExpression *Expr = Fragment.Expr; + const TargetFrameLowering *TFI = Asm->MF->getSubtarget().getFrameLowering(); + int Offset = TFI->getFrameIndexReference(*Asm->MF, Fragment.FI, FrameReg); + DwarfExpr.addFragmentOffset(Expr); + SmallVector Ops; + Ops.push_back(dwarf::DW_OP_plus_uconst); + Ops.push_back(Offset); + Ops.append(Expr->elements_begin(), Expr->elements_end()); + if (SubExpr) { + for (unsigned SEOp : SubExpr->getElements()) + Ops.push_back(SEOp); + } else { + Ops.push_back(dwarf::DW_OP_deref); + } + DIExpressionCursor Cursor(Ops); + DwarfExpr.setMemoryLocationKind(); + DwarfExpr.addMachineRegExpression( + *Asm->MF->getSubtarget().getRegisterInfo(), Cursor, FrameReg); + DwarfExpr.addExpression(std::move(Cursor)); + } + addBlock(Die, Attribute, DwarfExpr.finalize()); +} + DIE *DwarfCompileUnit::constructVariableDIE(DbgVariable &DV, const LexicalScope &Scope, DIE *&ObjectPointer) { Index: lib/CodeGen/AsmPrinter/DwarfDebug.h =================================================================== --- lib/CodeGen/AsmPrinter/DwarfDebug.h +++ lib/CodeGen/AsmPrinter/DwarfDebug.h @@ -314,6 +314,14 @@ bool SingleCU; bool IsDarwin; + /// Map for tracking Fortran deferred CHARACTER lengths + DenseMap StringTypeLocMap; + + /// Map for tracking Fortran assumed shape array descriptors + DenseMap SubrangeDieMap; + + DenseMap VariableInDependentType; + AddressPool AddrPool; /// Accelerator tables. @@ -481,6 +489,12 @@ void collectVariableInfoFromMFTable(DwarfCompileUnit &TheCU, DenseSet &P); + /// Populate dependent type variable map + void populateDependentTypeMap(); + + /// Clear dependent type tracking map + void clearDependentTracking() { VariableInDependentType.clear(); } + /// Emit the reference to the section. void emitSectionReference(const DwarfCompileUnit &CU); @@ -626,6 +640,24 @@ /// going to be null. bool isLexicalScopeDIENull(LexicalScope *Scope); + + bool hasDwarfPubSections(bool includeMinimalInlineScopes) const; + + unsigned getStringTypeLoc(const DIStringType *ST) const { + auto I = StringTypeLocMap.find(ST); + return I != StringTypeLocMap.end() ? I->second : 0; + } + + void addStringTypeLoc(const DIStringType *ST, unsigned Loc) { + assert(ST); + if (Loc) + StringTypeLocMap[ST] = Loc; + } + + DIE *getSubrangeDie(const DIFortranSubrange *SR) const; + void constructSubrangeDie(const DIFortranArrayType *AT, + DbgVariable &DV, DwarfCompileUnit &TheCU); + /// Find the matching DwarfCompileUnit for the given CU DIE. DwarfCompileUnit *lookupCU(const DIE *Die) { return CUDieMap.lookup(Die); } const DwarfCompileUnit *lookupCU(const DIE *Die) const { Index: lib/CodeGen/AsmPrinter/DwarfDebug.cpp =================================================================== --- lib/CodeGen/AsmPrinter/DwarfDebug.cpp +++ lib/CodeGen/AsmPrinter/DwarfDebug.cpp @@ -936,6 +936,55 @@ CU.createAbstractVariable(Cleansed, Scope); } +DIE *DwarfDebug::getSubrangeDie(const DIFortranSubrange *SR) const { + auto I = SubrangeDieMap.find(SR); + return (I == SubrangeDieMap.end()) ? nullptr : I->second; +} + +void DwarfDebug::constructSubrangeDie(const DIFortranArrayType *AT, + DbgVariable &DV, + DwarfCompileUnit &TheCU) { + dwarf::Attribute Attribute; + const DIFortranSubrange *WFS = nullptr; + DIExpression *WEx = nullptr; + const DIVariable *DI = DV.getVariable(); + DINodeArray Elements = AT->getElements(); + + for (unsigned i = 0, N = Elements.size(); i < N; ++i) { + DINode *Element = cast(Elements[i]); + if (const DIFortranSubrange *FS = dyn_cast(Element)) { + if (DIVariable *UBV = FS->getUpperBound()) + if (UBV == DI) { + Attribute = dwarf::DW_AT_upper_bound; + WFS = FS; + WEx = FS->getUpperBoundExp(); + break; + } + if (DIVariable *LBV = FS->getLowerBound()) + if (LBV == DI) { + Attribute = dwarf::DW_AT_lower_bound; + WFS = FS; + WEx = FS->getLowerBoundExp(); + break; + } + } + } + + if (!WFS) + return; + + DIE *Die; + auto I = SubrangeDieMap.find(WFS); + if (I == SubrangeDieMap.end()) { + Die = DIE::get(DIEValueAllocator, dwarf::DW_TAG_subrange_type); + SubrangeDieMap[WFS] = Die; + } else { + Die = I->second; + } + + TheCU.constructDieLocationAddExpr(*Die, Attribute, DV, WEx); +} + // Collect variable information from side table maintained by MF. void DwarfDebug::collectVariableInfoFromMFTable( DwarfCompileUnit &TheCU, DenseSet &Processed) { @@ -957,6 +1006,11 @@ ensureAbstractVariableIsCreatedIfScoped(TheCU, Var, Scope->getScopeNode()); auto RegVar = llvm::make_unique(Var.first, Var.second); RegVar->initializeMMI(VI.Expr, VI.Slot); + if (VariableInDependentType.count(VI.Var)) { + const DIType *DT = VariableInDependentType[VI.Var]; + if (const DIFortranArrayType *AT = dyn_cast(DT)) + constructSubrangeDie(AT, *RegVar.get(), TheCU); + } if (DbgVariable *DbgVar = MFVars.lookup(Var)) DbgVar->addMMIEntry(*RegVar); else if (InfoHolder.addScopeVariable(Scope, RegVar.get())) { @@ -1194,10 +1248,35 @@ return false; } +void DwarfDebug::populateDependentTypeMap() { + for (const auto &I : DbgValues) { + InlinedVariable IV = I.first; + if (I.second.empty()) + continue; + + if (const DIStringType *ST = dyn_cast( + static_cast(IV.first->getType()))) + if (const DIVariable *LV = ST->getStringLength()) + VariableInDependentType[LV] = ST; + + if (const DIFortranArrayType *AT = dyn_cast( + static_cast(IV.first->getType()))) + for (const DINode *S : AT->getElements()) + if (const DIFortranSubrange *FS = dyn_cast(S)) { + if (const DIVariable *LBV = FS->getLowerBound()) + VariableInDependentType[LBV] = AT; + if (const DIVariable *UBV = FS->getUpperBound()) + VariableInDependentType[UBV] = AT; + } + } +} + // Find variables for each lexical scope. void DwarfDebug::collectVariableInfo(DwarfCompileUnit &TheCU, const DISubprogram *SP, DenseSet &Processed) { + clearDependentTracking(); + populateDependentTypeMap(); // Grab the variable info that was squirreled away in the MMI side-table. collectVariableInfoFromMFTable(TheCU, Processed); @@ -1252,6 +1331,22 @@ // Finalize the entry by lowering it into a DWARF bytestream. for (auto &Entry : Entries) Entry.finalize(*Asm, List, BT); + List.finalize(); + + if (VariableInDependentType.count(IV.first)) { + const DIType *DT = VariableInDependentType[IV.first]; + if (const DIStringType *ST = dyn_cast(DT)) { + unsigned Offset; + DbgVariable TVar = {IV.first, IV.second}; + DebugLocStream::ListBuilder LB(DebugLocs, TheCU, *Asm, TVar, *MInsn); + for (auto &Entry : Entries) + Entry.finalize(*Asm, LB, ST); + LB.finalize(); + Offset = TVar.getDebugLocListIndex(); + if (Offset != ~0u) + addStringTypeLoc(ST, Offset); + } + } } // Collect info for variables that were optimized out. @@ -1781,6 +1876,20 @@ DwarfExpr.finalize(); } +void DebugLocEntry::finalize(const AsmPrinter &AP, + DebugLocStream::ListBuilder &List, + const DIStringType *ST) { + DebugLocStream::EntryBuilder Entry(List, Begin, End); + BufferByteStreamer Streamer = Entry.getStreamer(); + DebugLocDwarfExpression DwarfExpr(AP.getDwarfVersion(), Streamer); + DebugLocEntry::Value Value = Values[0]; + assert(!Value.isFragment()); + assert(Values.size() == 1 && "only fragments may have >1 value"); + Value.Expression = ST->getStringLengthExp(); + emitDebugLocValue(AP, nullptr, Value, DwarfExpr); + DwarfExpr.finalize(); +} + void DwarfDebug::emitDebugLocEntryLocation(const DebugLocStream::Entry &Entry) { // Emit the size. Asm->OutStreamer->AddComment("Loc expr size"); Index: lib/CodeGen/AsmPrinter/DwarfExpression.h =================================================================== --- lib/CodeGen/AsmPrinter/DwarfExpression.h +++ lib/CodeGen/AsmPrinter/DwarfExpression.h @@ -264,7 +264,7 @@ /// DwarfExpression implementation for singular DW_AT_location. class DIEDwarfExpression final : public DwarfExpression { -const AsmPrinter &AP; + const AsmPrinter &AP; DwarfUnit &DU; DIELoc &DIE; Index: lib/CodeGen/AsmPrinter/DwarfExpression.cpp =================================================================== --- lib/CodeGen/AsmPrinter/DwarfExpression.cpp +++ lib/CodeGen/AsmPrinter/DwarfExpression.cpp @@ -346,6 +346,11 @@ emitOp(dwarf::DW_OP_plus_uconst); emitUnsigned(Op->getArg(0)); break; + case dwarf::DW_OP_deref_size: + assert(LocationKind != Register); + emitOp(dwarf::DW_OP_deref_size); + emitUnsigned(Op->getArg(0)); + break; case dwarf::DW_OP_plus: case dwarf::DW_OP_minus: case dwarf::DW_OP_mul: @@ -376,6 +381,18 @@ emitOp(dwarf::DW_OP_constu); emitUnsigned(Op->getArg(0)); break; + case dwarf::DW_OP_bregx: { + unsigned DwarfReg = Op->getArg(0); + int Offset = Op->getArg(1); + if (DwarfReg <= 31) + emitOp(dwarf::DW_OP_breg0 + DwarfReg); + else { + emitOp(dwarf::DW_OP_bregx); + emitUnsigned(DwarfReg); + } + emitSigned(Offset); + break; + } case dwarf::DW_OP_stack_value: LocationKind = Implicit; break; Index: lib/CodeGen/AsmPrinter/DwarfUnit.h =================================================================== --- lib/CodeGen/AsmPrinter/DwarfUnit.h +++ lib/CodeGen/AsmPrinter/DwarfUnit.h @@ -341,10 +341,14 @@ private: void constructTypeDIE(DIE &Buffer, const DIBasicType *BTy); + void constructTypeDIE(DIE &Buffer, const DIStringType *BTy); + void constructTypeDIE(DIE &Buffer, const DIFortranArrayType *ATy); void constructTypeDIE(DIE &Buffer, const DIDerivedType *DTy); void constructTypeDIE(DIE &Buffer, const DISubroutineType *CTy); void constructSubrangeDIE(DIE &Buffer, const DISubrange *SR, DIE *IndexTy); + void constructFortranSubrangeDIE(DIE &Buffer, const DIFortranSubrange *SR); void constructArrayTypeDIE(DIE &Buffer, const DICompositeType *CTy); + void constructArrayTypeDIE(DIE &Buffer, const DIFortranArrayType *ATy); void constructEnumTypeDIE(DIE &Buffer, const DICompositeType *CTy); DIE &constructMemberDIE(DIE &Buffer, const DIDerivedType *DT); void constructTemplateTypeParameterDIE(DIE &Buffer, Index: lib/CodeGen/AsmPrinter/DwarfUnit.cpp =================================================================== --- lib/CodeGen/AsmPrinter/DwarfUnit.cpp +++ lib/CodeGen/AsmPrinter/DwarfUnit.cpp @@ -760,8 +760,12 @@ if (auto *BT = dyn_cast(Ty)) constructTypeDIE(TyDIE, BT); + else if (auto *ST = dyn_cast(Ty)) + constructTypeDIE(TyDIE, ST); else if (auto *STy = dyn_cast(Ty)) constructTypeDIE(TyDIE, STy); + else if (auto *ATy = dyn_cast(Ty)) + constructArrayTypeDIE(TyDIE, ATy); else if (auto *CTy = dyn_cast(Ty)) { if (DD->generateTypeUnits() && !Ty->isForwardDecl()) if (MDString *TypeId = CTy->getRawIdentifier()) { @@ -790,7 +794,7 @@ DD->addAccelType(Ty->getName(), TyDIE, Flags); if (!Context || isa(Context) || isa(Context) || - isa(Context)) + isa(Context) || isa(Context)) addGlobalType(Ty, TyDIE, Context); } } @@ -846,13 +850,39 @@ if (BTy->getTag() == dwarf::DW_TAG_unspecified_type) return; - addUInt(Buffer, dwarf::DW_AT_encoding, dwarf::DW_FORM_data1, - BTy->getEncoding()); + if (BTy->getTag() != dwarf::DW_TAG_string_type) + addUInt(Buffer, dwarf::DW_AT_encoding, dwarf::DW_FORM_data1, + BTy->getEncoding()); uint64_t Size = BTy->getSizeInBits() >> 3; addUInt(Buffer, dwarf::DW_AT_byte_size, None, Size); } +void DwarfUnit::constructTypeDIE(DIE &Buffer, const DIStringType *STy) { + // Get core information. + StringRef Name = STy->getName(); + // Add name if not anonymous or intermediate type. + if (!Name.empty()) + addString(Buffer, dwarf::DW_AT_name, Name); + + if (unsigned LLI = DD->getStringTypeLoc(STy)) { + // DW_TAG_string_type has a DW_AT_string_length location + dwarf::Form Form = (DD->getDwarfVersion() >= 4) + ? dwarf::DW_FORM_sec_offset : dwarf::DW_FORM_data4; + Buffer.addValue(DIEValueAllocator, dwarf::DW_AT_string_length, Form, + DIELocList(LLI)); + } + + uint64_t Size = STy->getSizeInBits() >> 3; + addUInt(Buffer, dwarf::DW_AT_byte_size, None, Size); + + if (STy->getEncoding()) { + // for eventual unicode support + addUInt(Buffer, dwarf::DW_AT_encoding, dwarf::DW_FORM_data1, + STy->getEncoding()); + } +} + void DwarfUnit::constructTypeDIE(DIE &Buffer, const DIDerivedType *DTy) { // Get core information. StringRef Name = DTy->getName(); @@ -1182,6 +1212,8 @@ addString(MDie, dwarf::DW_AT_LLVM_include_path, M->getIncludePath()); if (!M->getISysRoot().empty()) addString(MDie, dwarf::DW_AT_LLVM_isysroot, M->getISysRoot()); + if (M->getLine() && !M->getScope()->getFilename().empty()) + addSourceLine(MDie, M->getLine(), M->getScope()->getFile()); return &MDie; } @@ -1260,6 +1292,14 @@ return true; } +static bool isFortran(uint16_t Language) { + return (Language == dwarf::DW_LANG_Fortran77) || + (Language == dwarf::DW_LANG_Fortran90) || + (Language == dwarf::DW_LANG_Fortran95) || + (Language == dwarf::DW_LANG_Fortran03) || + (Language == dwarf::DW_LANG_Fortran08); +} + void DwarfUnit::applySubprogramAttributes(const DISubprogram *SP, DIE &SPDie, bool SkipSPAttributes) { // If -fdebug-info-for-profiling is enabled, need to emit the subprogram @@ -1294,6 +1334,10 @@ if (const DISubroutineType *SPTy = SP->getType()) { Args = SPTy->getTypeArray(); CC = SPTy->getCC(); + + // Standard recommends fortran MAIN program to use DW_CC_program. + if (isFortran(Language) && SP->isMainSubprogram()) + CC = dwarf::DW_CC_program; } // Add a DW_AT_calling_convention if this has an explicit convention. @@ -1367,6 +1411,12 @@ if (SP->isMainSubprogram()) addFlag(SPDie, dwarf::DW_AT_main_subprogram); + if (SP->isPure()) + addFlag(SPDie, dwarf::DW_AT_pure); + if (SP->isElemental()) + addFlag(SPDie, dwarf::DW_AT_elemental); + if (SP->isRecursive()) + addFlag(SPDie, dwarf::DW_AT_recursive); } void DwarfUnit::constructSubrangeDIE(DIE &Buffer, const DISubrange *SR, @@ -1387,13 +1437,41 @@ if (DefaultLowerBound == -1 || LowerBound != DefaultLowerBound) addUInt(DW_Subrange, dwarf::DW_AT_lower_bound, None, LowerBound); - if (auto *CV = SR->getCount().dyn_cast()) { + if (auto *Expr = SR->getCount().dyn_cast()) { + DIELoc *Loc = new (DIEValueAllocator) DIELoc; + DIEDwarfExpression DExpr(*Asm, *this, *Loc); + DExpr.addExpression(Expr); + DExpr.finalize(); + addBlock(DW_Subrange, dwarf::DW_AT_count, Loc); + } else if (auto *CV = SR->getCount().dyn_cast()) { if (auto *CountVarDIE = getDIE(CV)) addDIEEntry(DW_Subrange, dwarf::DW_AT_count, *CountVarDIE); } else if (Count != -1) + // FIXME: An unbounded array should reference the expression that defines + // the array. addUInt(DW_Subrange, dwarf::DW_AT_count, None, Count); } +void DwarfUnit::constructFortranSubrangeDIE(DIE &Buffer, + const DIFortranSubrange *SR) { + DIE *IndexTy = getIndexTyDie(); + DIE *Die = DD->getSubrangeDie(SR); + if ((!Die) || Die->getParent()) + Die = DIE::get(DIEValueAllocator, dwarf::DW_TAG_subrange_type); + DIE &DW_Subrange = Buffer.addChild(Die); + addDIEEntry(DW_Subrange, dwarf::DW_AT_type, *IndexTy); + + if (!SR->getLowerBound()) { + int64_t BVC = SR->getCLowerBound(); + addSInt(DW_Subrange, dwarf::DW_AT_lower_bound, dwarf::DW_FORM_sdata, BVC); + } + + if ((!SR->getUpperBound()) && (!SR->noUpperBound())) { + int64_t BVC = SR->getCUpperBound(); + addSInt(DW_Subrange, dwarf::DW_AT_upper_bound, dwarf::DW_FORM_sdata, BVC); + } +} + DIE *DwarfUnit::getIndexTyDie() { if (IndexTyDie) return IndexTyDie; @@ -1422,6 +1500,7 @@ // Locate the number of elements in the vector. const DINodeArray Elements = CTy->getElements(); + assert(!CTy->isScalableVector() && "Unexpected scalable vector"); assert(Elements.size() == 1 && Elements[0]->getTag() == dwarf::DW_TAG_subrange_type && "Invalid vector element array, expected one element of type subrange"); @@ -1438,11 +1517,14 @@ void DwarfUnit::constructArrayTypeDIE(DIE &Buffer, const DICompositeType *CTy) { if (CTy->isVector()) { addFlag(Buffer, dwarf::DW_AT_GNU_vector); - if (hasVectorBeenPadded(CTy)) + if (!CTy->isScalableVector() && hasVectorBeenPadded(CTy)) addUInt(Buffer, dwarf::DW_AT_byte_size, None, CTy->getSizeInBits() / CHAR_BIT); } + if (isFortran(getLanguage())) + addUInt(Buffer, dwarf::DW_AT_ordering, None, dwarf::DW_ORD_col_major); + // Emit the element type. addType(Buffer, resolve(CTy->getBaseType())); @@ -1461,6 +1543,20 @@ } } +void DwarfUnit::constructArrayTypeDIE(DIE &Buffer, + const DIFortranArrayType *ATy) { + // Emit the element type. + addType(Buffer, resolve(ATy->getBaseType())); + + // Add subranges to array type. + DINodeArray Elements = ATy->getElements(); + for (unsigned i = 0, N = Elements.size(); i < N; ++i) { + DINode *Element = cast(Elements[i]); + if (const DIFortranSubrange *FS = dyn_cast(Element)) + constructFortranSubrangeDIE(Buffer, FS); + } +} + void DwarfUnit::constructEnumTypeDIE(DIE &Buffer, const DICompositeType *CTy) { const DIType *DTy = resolve(CTy->getBaseType()); bool IsUnsigned = DTy && isUnsignedDIType(DD, DTy); Index: lib/CodeGen/BranchRelaxation.cpp =================================================================== --- lib/CodeGen/BranchRelaxation.cpp +++ lib/CodeGen/BranchRelaxation.cpp @@ -132,7 +132,7 @@ unsigned Align = MBB.getAlignment(); unsigned Num = MBB.getNumber(); assert(BlockInfo[Num].Offset % (1u << Align) == 0); - assert(!Num || BlockInfo[PrevNum].postOffset(MBB) <= BlockInfo[Num].Offset); + assert(!Num || BlockInfo[PrevNum].postOffset(MBB) == BlockInfo[Num].Offset); assert(BlockInfo[Num].Size == computeBlockSize(MBB)); PrevNum = Num; } @@ -199,7 +199,7 @@ unsigned PrevNum = Start.getNumber(); for (auto &MBB : make_range(MachineFunction::iterator(Start), MF->end())) { unsigned Num = MBB.getNumber(); - if (!Num) // block zero is never changed from offset zero. + if (Num == PrevNum) // First block's offset is never changed. continue; // Get the offset and known bits at the end of the layout predecessor. // Include the alignment of the current block. Index: lib/CodeGen/CMakeLists.txt =================================================================== --- lib/CodeGen/CMakeLists.txt +++ lib/CodeGen/CMakeLists.txt @@ -13,6 +13,7 @@ CFIInstrInserter.cpp CodeGen.cpp CodeGenPrepare.cpp + ContiguousLoadStorePass.cpp CriticalAntiDepBreaker.cpp DeadMachineInstructionElim.cpp DetectDeadLanes.cpp @@ -39,6 +40,9 @@ InlineSpiller.cpp InterferenceCache.cpp InterleavedAccessPass.cpp + InterleavedGatherScatterPass.cpp + InterleavedGatherScatterStoreSinkPass.cpp + InterleavedGatherScatterStrideDescUtils.cpp IntrinsicLowering.cpp LatencyPriorityQueue.cpp LazyMachineBlockFrequencyInfo.cpp Index: lib/CodeGen/CodeGen.cpp =================================================================== --- lib/CodeGen/CodeGen.cpp +++ lib/CodeGen/CodeGen.cpp @@ -25,6 +25,7 @@ initializeBranchRelaxationPass(Registry); initializeCFIInstrInserterPass(Registry); initializeCodeGenPreparePass(Registry); + initializeContiguousLoadStorePass(Registry); initializeDeadMachineInstructionElimPass(Registry); initializeDetectDeadLanesPass(Registry); initializeDwarfEHPreparePass(Registry); Index: lib/CodeGen/CodeGenPrepare.cpp =================================================================== --- lib/CodeGen/CodeGenPrepare.cpp +++ lib/CodeGen/CodeGenPrepare.cpp @@ -221,6 +221,11 @@ cl::init(true), cl::desc("Enable splitting large offset of GEP.")); +static cl::opt EnableCheapIndexVectorStride( + "enable-cheap-indexvector-stride", cl::Hidden, cl::init(false), + cl::desc("Assume an index vector with a stride larger than 1 is cheap to " + "generate")); + namespace { enum ExtType { @@ -334,6 +339,8 @@ Type *AccessTy, unsigned AddrSpace); bool optimizeInlineAsmInst(CallInst *CS); bool optimizeCallInst(CallInst *CI, bool &ModifiedDT); + bool optimizeIndexVector(Instruction *I); + bool optimizeCnt(Instruction *I); bool optimizeExt(Instruction *&I); bool optimizeExtUses(Instruction *I); bool optimizeLoadExt(LoadInst *Load); @@ -5742,18 +5749,6 @@ return true; } -static bool isBroadcastShuffle(ShuffleVectorInst *SVI) { - SmallVector Mask(SVI->getShuffleMask()); - int SplatElem = -1; - for (unsigned i = 0; i < Mask.size(); ++i) { - if (SplatElem != -1 && Mask[i] != -1 && Mask[i] != SplatElem) - return false; - SplatElem = Mask[i]; - } - - return true; -} - /// Some targets have expensive vector shifts if the lanes aren't all the same /// (e.g. x86 only introduced "vpsllvd" and friends with AVX2). In these cases /// it's often worth sinking a shufflevector splat down to its use so that @@ -5767,7 +5762,7 @@ // We only expect better codegen by sinking a shuffle if we can recognise a // constant splat. - if (!isBroadcastShuffle(SVI)) + if (SVI->findBroadcastElement() < 0) return false; // InsertedShuffles - Only insert a shuffle in each block once. @@ -6008,13 +6003,14 @@ UseSplat = true; } - unsigned End = getTransitionType()->getVectorNumElements(); + auto EC = cast(getTransitionType())->getElementCount(); if (UseSplat) - return ConstantVector::getSplat(End, Val); + return ConstantVector::getSplat(EC, Val); SmallVector ConstVec; UndefValue *UndefVal = UndefValue::get(Val->getType()); - for (unsigned Idx = 0; Idx != End; ++Idx) { + + for (unsigned Idx = 0; Idx != EC.Min; ++Idx) { if (Idx == ExtractIdx) ConstVec.push_back(Val); else @@ -6502,6 +6498,100 @@ return true; } +// Push down a splatted base that is added to an index vector closer to the +// construction of the index vector, so that the scalar value can be folded into +// the INDEX instruction. Same for the stride (which is multiplied, not added). +bool CodeGenPrepare::optimizeIndexVector(Instruction *I) { + // FIXME: for now only do this for scalable types for which we know + // there must be a cheap INDEX instruction, but preferrably do this + // through some target interface where we query the cost. + if (!I->getType()->isVectorTy() || !I->getType()->getVectorIsScalable()) + return false; + + // Only do this for non-zero base values and for strides > 1 if this is cheap + // to generate. + if (!(I->getOpcode() == Instruction::Add || + (I->getOpcode() == Instruction::Mul && EnableCheapIndexVectorStride))) + return false; + + Instruction *Base; + if (!match(I, m_c_BinOp(m_Instruction(Base), + m_SeriesVector(m_Zero(), m_Value())))) + return false; + + Value *SplatVal; + if (Base->getParent() == I->getParent() || !Base->hasOneUse() || + !match(Base, m_SplatVector(m_Value(SplatVal)))) + return false; + + IRBuilder<> Builder(I); + Value *NewSplat = Builder.CreateVectorSplat( + cast(I->getType())->getElementCount(), SplatVal); + I->replaceUsesOfWith(Base, NewSplat); + return true; +} + +bool IsCntIntrinsic(Instruction *I) { + if (isa(I) && isa(I->getOperand(0))) + I = cast(I->getOperand(0)); + + if (auto *II = dyn_cast(I)) { + switch (II->getIntrinsicID()) { + case Intrinsic::aarch64_sve_cntb: + case Intrinsic::aarch64_sve_cnth: + case Intrinsic::aarch64_sve_cntw: + case Intrinsic::aarch64_sve_cntd: + return true; + default: + return false; + } + } + return false; +} + +// Push down a cnt[bwhd] instruction closer to an add/sub so that they +// can be replaced with inc/dec[bwhd]. This should also apply when the +// cnt instruction is truncated before the add/sub. +bool CodeGenPrepare::optimizeCnt(Instruction *I) { + + if (!(I->getOpcode() == Instruction::Add || + (I->getOpcode() == Instruction::Sub))) + return false; + + Instruction *Operand; + + if (!(match(I, m_BinOp(m_Value(), m_Instruction(Operand))) && + IsCntIntrinsic(Operand)) && + !(match(I, m_BinOp(m_Instruction(Operand), m_Value())) && + IsCntIntrinsic(Operand))) + return false; + + if (isa(Operand)) { + if (Operand->getParent() != I->getParent()) { + Instruction *Trunc = Operand->clone(); + Trunc->insertBefore(I); + I->replaceUsesOfWith(Operand, Trunc); + if (Operand->user_empty()) + Operand->eraseFromParent(); + Operand = Trunc; + } + I = Operand; + Operand = cast(Operand->getOperand(0)); + } + + if (Operand->getParent() == I->getParent()) + return false; + + Instruction *NewCnt = Operand->clone(); + NewCnt->insertBefore(I); + I->replaceUsesOfWith(Operand, NewCnt); + + if (Operand->user_empty()) + Operand->eraseFromParent(); + + return true; +} + bool CodeGenPrepare::optimizeInst(Instruction *I, bool &ModifiedDT) { // Bail out if we inserted the instruction to prevent optimizations from // stepping on each other's toes. @@ -6604,6 +6694,12 @@ return false; } + if (optimizeIndexVector(I)) + return true; + + if (optimizeCnt(I)) + return true; + if (GetElementPtrInst *GEPI = dyn_cast(I)) { if (GEPI->hasAllZeroIndices()) { /// The GEP operand must be a pointer, so must its result -> BitCast Index: lib/CodeGen/ContiguousLoadStorePass.cpp =================================================================== --- /dev/null +++ lib/CodeGen/ContiguousLoadStorePass.cpp @@ -0,0 +1,546 @@ +//===-------------------- ContiguousLoadStorePass.cpp ---------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Contiguous Load Store pass, which identifies +// structured loads where each element has the same operation(s) +// applied to it. In this case the order of the elements within the +// vector is irrelevant and it is possible to use contiguous loads. +// +// E.g. +// %vec = load <8 x double>, <8 x double>* %a +// %b1 = shufflevector <8 x double> %vec, +// <8 x double> undef, <2 x i32> +// %b2 = shufflevector <8 x double> %vec, +// <8 x double> undef, <2 x i32> +// %add1 = fadd fast <2 x double> %b1, %broadcast.splat +// %add2 = fadd fast <2 x double> %b2, %broadcast.splat +// %1 = shufflevector <2 x double> %add1, <2 x double> %add2, +// <4 x i32> +// %interleaved.vec = shufflevector <8 x double> %3, <8 x double> undef, +// <8 x i32> +// store <8 x double> %interleaved.vec, <8 x double>* %b +// +// As the same operation is being applied to each of the deinterleaved shuffles, +// they can be replaced with the following: +// +// %b1 = shufflevector <8 x double> %vec, +// <8 x double> undef, <2 x i32> +// %b2 = shufflevector <8 x double> %vec, +// <8 x double> undef, <2 x i32> +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "contiguous-load-store" +#define PASS_SHORT_NAME "Contiguous Load Store Pass" + +namespace llvm { + void initializeContiguousLoadStorePass(PassRegistry &); +} + +namespace { +class ContiguousLoadStore : public FunctionPass { + +public: + static char ID; + ContiguousLoadStore() : FunctionPass(ID) { + initializeContiguousLoadStorePass(*PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { + return PASS_SHORT_NAME; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + } + + bool runOnFunction(Function &F) override; + +private: + bool isDeInterleaveMaskOfFactor(ArrayRef Mask, + unsigned Factor, unsigned &Index); + + bool isDeInterleaveMask(ArrayRef Mask, unsigned &Factor, + unsigned &Index, unsigned MaxFactor); + + bool isReInterleaveMask(ArrayRef Mask, unsigned &Factor, + unsigned MaxFactor, unsigned OpNumElts); + + bool checkInstUsers(const SmallVectorImpl &Shuffles, + const SmallVectorImpl &Endpoints); + + bool forwardPathFromLoad(LoadInst *LI, + unsigned Factor, + const SmallVectorImpl &Endpoints); + + void replaceShuffleMasks(SmallVectorImpl &Loads, + ShuffleVectorInst *SVI, + unsigned Factor); + + bool findMatchingInterleaveData(Value *V, + SmallVectorImpl &Loads, + unsigned Index); + + bool extractReinterleaveData(Value *V, + SmallVectorImpl &Endpoints, + unsigned &Index); + + bool structuredLoadStore(ShuffleVectorInst *SVI, + unsigned Factor); + + unsigned MaxFactor; +}; +} + +char ContiguousLoadStore::ID = 0; + +INITIALIZE_PASS_BEGIN(ContiguousLoadStore, + DEBUG_TYPE, + PASS_SHORT_NAME, false, false) +INITIALIZE_PASS_END(ContiguousLoadStore, + DEBUG_TYPE, + PASS_SHORT_NAME, false, false) + +FunctionPass *llvm::createContiguousLoadStorePass() { + return new ContiguousLoadStore(); +} + +// Stolen from lib/CodeGen/InterleavedAccessPass.cpp +bool ContiguousLoadStore::isDeInterleaveMaskOfFactor( + ArrayRef Mask, unsigned Factor, unsigned &Index) { + // Check all potential start indices from 0 to (Factor - 1). + for (Index = 0; Index < Factor; Index++) { + unsigned i = 0; + // Check that elements are in ascending order by Factor. Ignore undef + // elements. + for (; i < Mask.size(); i++) + if (Mask[i] >= 0 && static_cast(Mask[i]) != Index + i * Factor) + break; + if (i == Mask.size()) + return true; + } + return false; +} + +// Stolen from lib/CodeGen/InterleavedAccessPass.cpp +bool ContiguousLoadStore::isDeInterleaveMask( + ArrayRef Mask, unsigned &Factor, + unsigned &Index, unsigned MaxFactor) { + if (Mask.size() < 2) + return false; + // Check potential Factors. + for (Factor = 2; Factor <= MaxFactor; Factor++) + if (isDeInterleaveMaskOfFactor(Mask, Factor, Index)) + return true; + return false; +} + +// Stolen from lib/CodeGen/InterleavedAccessPass.cpp +bool ContiguousLoadStore::isReInterleaveMask(ArrayRef Mask, + unsigned &Factor, + unsigned MaxFactor, unsigned OpNumElts) { + unsigned NumElts = Mask.size(); + if (NumElts < 4) + return false; + // Check potential Factors. + for (Factor = 2; Factor <= MaxFactor; Factor++) { + if (NumElts % Factor) + continue; + unsigned LaneLen = NumElts / Factor; + if (!isPowerOf2_32(LaneLen)) + continue; + // Check whether each element matches the general interleaved rule. + // Ignore undef elements, as long as the defined elements match the rule. + // Outer loop processes all factors (x, y, z in the above example) + unsigned I = 0, J; + for (; I < Factor; I++) { + unsigned SavedLaneValue; + unsigned SavedNoUndefs = 0; + // Inner loop processes consecutive accesses (x, x+1... in the example) + for (J = 0; J < LaneLen - 1; J++) { + // Lane computes x's position in the Mask + unsigned Lane = J * Factor + I; + unsigned NextLane = Lane + Factor; + int LaneValue = Mask[Lane]; + int NextLaneValue = Mask[NextLane]; + // If both are defined, values must be sequential + if (LaneValue >= 0 && NextLaneValue >= 0 && + LaneValue + 1 != NextLaneValue) + break; + // If the next value is undef, save the current one as reference + if (LaneValue >= 0 && NextLaneValue < 0) { + SavedLaneValue = LaneValue; + SavedNoUndefs = 1; + } + // Undefs are allowed, but defined elements must still be consecutive: + // i.e.: x,..., undef,..., x + 2,..., undef,..., undef,..., x + 5, .... + // Verify this by storing the last non-undef followed by an undef + // Check that following non-undef masks are incremented with the + // corresponding distance. + if (SavedNoUndefs > 0 && LaneValue < 0) { + SavedNoUndefs++; + if (NextLaneValue >= 0 && + SavedLaneValue + SavedNoUndefs != (unsigned)NextLaneValue) + break; + } + } + if (J < LaneLen - 1) + break; + int StartMask = 0; + if (Mask[I] >= 0) { + // Check that the start of the I range (J=0) is greater than 0 + StartMask = Mask[I]; + } else if (Mask[(LaneLen - 1) * Factor + I] >= 0) { + // StartMask defined by the last value in lane + StartMask = Mask[(LaneLen - 1) * Factor + I] - J; + } else if (SavedNoUndefs > 0) { + // StartMask defined by some non-zero value in the j loop + StartMask = SavedLaneValue - (LaneLen - 1 - SavedNoUndefs); + } + // else StartMask remains set to 0, i.e. all elements are undefs + if (StartMask < 0) + break; + // We must stay within the vectors; This case can happen with undefs. + if (StartMask + LaneLen > OpNumElts*2) + break; + } + // Found an interleaved mask of current factor. + if (I == Factor) + return true; + } + return false; +} + +bool ContiguousLoadStore::checkInstUsers( + const SmallVectorImpl &Instructions, + const SmallVectorImpl &Endpoints) { + + // Check the users of the shuffles for each load + // Return false if the users of each shuffle + // (i.e. fmul, fadd, etc) do not match + SmallVector Users; + + // Make sure user is a binary op and not something else + // e.g. it's possible we could see the reinterleave shuffle here + if (!isa(Instructions[0])) + return false; + + if (!Instructions[0]->hasOneUse()) + return false; + + unsigned Opcode = Instructions[0]->getOpcode(); + Value *LHS = Instructions[0]->getOperand(0); + Value *RHS = Instructions[0]->getOperand(1); + + FastMathFlags FMF; + if (isa(Instructions[0])) + FMF = Instructions[0]->getFastMathFlags(); + + Value::user_iterator UI = Instructions[0]->user_begin(); + auto *User = cast(*UI); + + Users.push_back(User); + + for (unsigned i = 1; i < Instructions.size(); i++) { + Instruction *I = Instructions[i]; + + if(!isa(I)) + return false; + + if (!I->hasOneUse()) + return false; + + if (I->getOpcode() != Opcode) { + LLVM_DEBUG(dbgs() << "Opcodes of shuffle users do not match!\n"); + return false; + } + + if (isa(I) && (I->getFastMathFlags() != FMF)) { + LLVM_DEBUG(dbgs() << "FastMathFlags of shuffle users do not match!\n"); + return false; + } + + // Check for splats - these should all be the same for this group + if (auto SVI = dyn_cast(I->getOperand(0))) + if (SVI->findBroadcastElement() != -1) + if (SVI != dyn_cast(LHS)) + return false; + + if (auto SVI = dyn_cast(I->getOperand(1))) + if (SVI->findBroadcastElement() != -1) + if (SVI != dyn_cast(RHS)) + return false; + + // This also goes for constants, e.g. + if (auto CI = dyn_cast(I->getOperand(0))) + if (CI->getSplatValue() != 0) + if (CI != LHS) + return false; + + if (auto CI = dyn_cast(I->getOperand(1))) + if (CI->getSplatValue() != 0) + if (CI != RHS) + return false; + + Value::user_iterator UI = I->user_begin(); + auto *User = cast(*UI); + Users.push_back(User); + } + + if (Instructions == Endpoints) + return true; + + return checkInstUsers(Users, Endpoints); +} + +// Check that following the data path from the load instruction, in program +// order, always ends at the Endpoints whereby element order is maintained +// and the operations performed are consistent across all vector lanes +bool ContiguousLoadStore::forwardPathFromLoad( + LoadInst *LI, + unsigned Factor, + const SmallVectorImpl &Endpoints) { + SmallVector Instructions; + SmallVector Mask; + Instructions.resize(Factor); + + // Check if all users of this load are shufflevectors + for (auto UI = LI->user_begin(), E = LI->user_end(); UI != E; UI++) { + ShuffleVectorInst *SVI = dyn_cast(*UI); + if (!SVI || !isa(SVI->getOperand(1))) + return false; + + if (!SVI->hasOneUse()) + return false; + + Value::user_iterator SUI = SVI->user_begin(); + auto *User = cast(*SUI); + SVI->getShuffleMask(Mask); + + assert(!Instructions[Mask[0]] && "Unexpected Duplicate Shuffle!"); + Instructions[Mask[0]] = User; + } + + assert(all_of(Instructions, [](const Instruction *I) { return I != 0; })); + return checkInstUsers(Instructions, Endpoints); +} + +// Replace (Re)Interleave shuffle masks with identity +// vectors as we've proven the shuffling is redundant +void ContiguousLoadStore::replaceShuffleMasks( + SmallVectorImpl &Loads, + ShuffleVectorInst *SVI, + unsigned Factor) { + + LLVM_DEBUG(dbgs() << "Replacing deinterleave shuffle masks\n"); + + unsigned MaskNumElts = SVI->getType()->getVectorNumElements(); + Type *Ty = IntegerType::get(SVI->getContext(), 32); + + for (unsigned l = 0; l < Loads.size(); l++) { + LoadInst *LI = Loads[l]; + + for (auto UI = LI->user_begin(), E = LI->user_end(); UI != E; UI++) { + auto Shuffle = cast(*UI); + + SmallVector NewMask; + SmallVector ShuffleMask; + Shuffle->getShuffleMask(ShuffleMask); + + unsigned firstElt = ShuffleMask[0] * (MaskNumElts / Factor); + for (unsigned e = 0; e < ShuffleMask.size(); e++) + NewMask.push_back(ConstantInt::get(Ty, firstElt + e)); + + Shuffle->setOperand(2, ConstantVector::get(NewMask)); + } + } + + // Replace the final re-interleave mask + LLVM_DEBUG(dbgs() << "Replacing re-interleave mask\n"); + SmallVector Mask; + + for (unsigned i = 0; i < MaskNumElts; i++) + Mask.push_back(ConstantInt::get(Ty, i)); + + SVI->setOperand(2, ConstantVector::get(Mask)); +} + +bool ContiguousLoadStore::findMatchingInterleaveData(Value *V, + SmallVectorImpl &Loads, + unsigned Index) { + + // If this is a splat or an undef, return true + if (auto CI = dyn_cast(V)) + return CI->getSplatValue() != 0; + + if (ShuffleVectorInst *SVI = dyn_cast(V)) { + // Shuffle instruction found + + // If this shuffle is a splat, return true + if (SVI->findBroadcastElement() != -1) + return true; + + Value *Op0 = SVI->getOperand(0); + Value *Op1 = SVI->getOperand(1); + + // If this is a deinterleaved shuffle, check it's tied to a load + unsigned Factor; + unsigned NewIndex; + SmallVector Mask; + + SVI->getShuffleMask(Mask); + + // We're looking for a load instruction and an undef + if (isa(Op0) && isa(Op1)) { + if (isDeInterleaveMask(Mask, Factor, NewIndex, MaxFactor) + && Index == NewIndex) { + LoadInst *LI = cast(Op0); + // Don't store a Load that's already been found + if (std::find(Loads.begin(), Loads.end(), LI) == Loads.end()) + Loads.push_back(LI); + return true; + } + } + + return false; + } + else if (Instruction *I = dyn_cast(V)) { + // Binary op found, check the operands + Value *Op0 = I->getOperand(0); + Value *Op1 = I->getOperand(1); + + return findMatchingInterleaveData(Op0, Loads, Index) && + findMatchingInterleaveData(Op1, Loads, Index); + } + + // Some other kind of instruction we're not looking for + return false; +} + +// Follow the data path back from the reinterleave shuffle until we find +// binary operators. Preserve the order of these endpoints so that we can +// follow the path forward in program order later until they are reached +bool ContiguousLoadStore::extractReinterleaveData( + Value *V, + SmallVectorImpl &Endpoints, + unsigned &Index) { + // Undefs in shuffles should be ignored + if (isa(V)) + return true; + + if (ShuffleVectorInst *SVI = dyn_cast(V)) { + + // Check if current instruction is the reinterleave or a concat shuffle + if (SVI->isConcat()) + return extractReinterleaveData(SVI->getOperand(0), Endpoints, Index) && + extractReinterleaveData(SVI->getOperand(1), Endpoints, Index); + + LLVM_DEBUG(dbgs() << "Element order changed during shuffle\n"); + } + + // If this is a binary op, we've reached an endpoint + else if (isa(V)) { + Endpoints[Index++] = dyn_cast(V); + return true; + } + + LLVM_DEBUG(dbgs() << "Incompatible Reinterleave Data:" << *V << "\n"); + return false; +} + +bool ContiguousLoadStore::structuredLoadStore( + ShuffleVectorInst *SVI, + unsigned Factor) { + SmallVector Endpoints; + SmallVector Loads; + Endpoints.resize(Factor); + // Get the endpoints + unsigned EndIndex = 0; + if (extractReinterleaveData(SVI->getOperand(0), Endpoints, EndIndex) && + extractReinterleaveData(SVI->getOperand(1), Endpoints, EndIndex)) { + + if (EndIndex != Factor) return false; + + for (unsigned i = 0; i < EndIndex; i++) + if (!findMatchingInterleaveData(Endpoints[i], Loads, i)) + return false; + + for (auto LI = Loads.begin(); LI != Loads.end(); ++LI) { + if (!forwardPathFromLoad(*LI, Factor, Endpoints)) { + // Not safe to use contiguous load/store + LLVM_DEBUG(dbgs() << "Cannot replace with contiguous load/store\n"); + return false; + } + } + + replaceShuffleMasks(Loads, SVI, Factor); + return true; + } + + return false; +} + +bool ContiguousLoadStore::runOnFunction(Function &F) { + auto *TPC = getAnalysisIfAvailable(); + if (!TPC) + return false; + + LLVM_DEBUG(dbgs() << "*** " << getPassName() << ": " << F.getName() << "\n"); + + auto &TM = TPC->getTM(); + const TargetLowering *TLI = TM.getSubtargetImpl(F)->getTargetLowering(); + MaxFactor = TLI->getMaxSupportedInterleaveFactor(); + + bool Changed = false; + + // Iterate over instructions in function + for (auto &I : instructions(F)) { + + // If this is a store instruction, check if it + // is storing the result of a shuffle + if (StoreInst *SI = dyn_cast(&I)){ + Value *Op = SI->getValueOperand(); + + if (ShuffleVectorInst *SVI = dyn_cast(Op)) { + // Storing result of a shuffle - check if the shuffle is a reinterleave + SmallVector Mask; + unsigned Factor = 2; + unsigned OpNumElts = + SVI->getOperand(0)->getType()->getVectorNumElements(); + + if (SVI->getShuffleMask(Mask) && + isReInterleaveMask(Mask, Factor, MaxFactor, OpNumElts)){ + + // Reinterleave found - pass to recursive function and check result + LLVM_DEBUG(dbgs() << "Storing result of reinterleave shuffle " + "- checking for structured load store\n"); + + Changed |= structuredLoadStore(SVI, Factor); + } + } + } + } + + return Changed; +} Index: lib/CodeGen/GlobalISel/IRTranslator.cpp =================================================================== --- lib/CodeGen/GlobalISel/IRTranslator.cpp +++ lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -797,12 +797,13 @@ } else if (const auto *CI = dyn_cast(V)) { MIRBuilder.buildConstDbgValue(*CI, DI.getVariable(), DI.getExpression()); } else { - unsigned Reg = getOrCreateVReg(*V); - // FIXME: This does not handle register-indirect values at offset 0. The - // direct/indirect thing shouldn't really be handled by something as - // implicit as reg+noreg vs reg+imm in the first palce, but it seems - // pretty baked in right now. - MIRBuilder.buildDirectDbgValue(Reg, DI.getVariable(), DI.getExpression()); + for (unsigned Reg : getOrCreateVRegs(*V)) { + // FIXME: This does not handle register-indirect values at offset 0. The + // direct/indirect thing shouldn't really be handled by something as + // implicit as reg+noreg vs reg+imm in the first place, but it seems + // pretty baked in right now. + MIRBuilder.buildDirectDbgValue(Reg, DI.getVariable(), DI.getExpression()); + } } return true; } Index: lib/CodeGen/InterleavedAccessPass.cpp =================================================================== --- lib/CodeGen/InterleavedAccessPass.cpp +++ lib/CodeGen/InterleavedAccessPass.cpp @@ -304,8 +304,9 @@ unsigned Factor, Index; // Check if the first shufflevector is DE-interleave shuffle. - if (!isDeInterleaveMask(Shuffles[0]->getShuffleMask(), Factor, Index, - MaxFactor)) + SmallVector Mask; + if (!Shuffles[0]->getShuffleMask(Mask) || + !isDeInterleaveMask(Mask, Factor, Index, MaxFactor)) return false; // Holds the corresponding index for each DE-interleave shuffle. @@ -320,8 +321,9 @@ if (Shuffles[i]->getType() != VecTy) return false; - if (!isDeInterleaveMaskOfFactor(Shuffles[i]->getShuffleMask(), Factor, - Index)) + SmallVector Mask; + if (!Shuffles[i]->getShuffleMask(Mask) || + !isDeInterleaveMaskOfFactor(Mask, Factor, Index)) return false; Indices.push_back(Index); @@ -421,7 +423,9 @@ // Check if the shufflevector is RE-interleave shuffle. unsigned Factor; unsigned OpNumElts = SVI->getOperand(0)->getType()->getVectorNumElements(); - if (!isReInterleaveMask(SVI->getShuffleMask(), Factor, MaxFactor, OpNumElts)) + SmallVector Mask; + if (!SVI->getShuffleMask(Mask) || + !isReInterleaveMask(Mask, Factor, MaxFactor, OpNumElts)) return false; LLVM_DEBUG(dbgs() << "IA: Found an interleaved store: " << *SI << "\n"); Index: lib/CodeGen/InterleavedGatherScatterPass.cpp =================================================================== --- /dev/null +++ lib/CodeGen/InterleavedGatherScatterPass.cpp @@ -0,0 +1,305 @@ +//=----------------------- InterleavedGatherScatter.cpp----------------------=// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Interleaved Gather Scatter pass, which identifies +// masked gather and masked scatter intrinsics and transforms them into cheaper, +// target specific interleaved access intrinsics. +// +// When fixed-length vectors are used, this optimisation is achieved by creating +// wide loads followed by shuffles with constant indices, or shuffles followed +// by wide stores, which are then transformed into target specific interleaved +// access intrinsics in InterleavedAccessPass. See that pass for more details. +// +// The wide-load approach does not work well for scalable vectors, as the +// shuffle masks are very cumbersome to reason with. However, masked gathers +// and scatters can represent the same work quite easily. +// +// Accordingly, the job of this pass is to recognise groups of masked gathers +// or scatters, and replace each group with a single interleaved access where +// appropriate. +// +// This pass also does some rudimentary sinking of scatter stores using +// alias-analysis in order to create the required access groups. +// +//===----------------------------------------------------------------------===// +#include "InterleavedGatherScatterStrideDescUtils.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "interleaved-gather-scatter" + +// Allow conversion of structured loads where not all elements of the structure +// are loaded by the original code. This isn't generally safe without +// guaranteeing the additional loads will not cause faults that would never have +// occured in the original code. Once we've added checks to guarantee this, +// this opt should be removed. +static cl::opt IGSAllBlocks( + "sve-igs-all-blocks", + cl::desc("Lowering gather/scatters to interleaved intrinsics in all " + "basic blocks, not just loop blocks."), + cl::init(false), cl::Hidden); + +// Allow conversion of structured loads where not all elements of the structure +// are loaded by the original code. This isn't generally safe without +// guaranteeing the additional loads will not cause faults that would never have +// occured in the original code. Once we've added checks to guarantee this, +// this opt should be removed. +static cl::opt IGSAllowUnsafeLoads( + "sve-igs-allow-unsafe-loads", + cl::desc("Enable unsafe load patterns when lowering gather/scatters to " + "interleaved intrinsics"), + cl::init(false), cl::Hidden); + +namespace llvm { +void initializeInterleavedGatherScatterPass(PassRegistry &); +} + +namespace { + +class InterleavedGatherScatter : public FunctionPass { + +public: + static char ID; + InterleavedGatherScatter() : FunctionPass(ID), TLI(nullptr) { + initializeInterleavedGatherScatterPass(*PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { + return "Interleaved Gather Scatter Pass"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addPreserved(); + } + + bool runOnFunction(Function &F) override; + +private: + const TargetLowering *TLI; + const DataLayout *DL; + AliasAnalysis *AA; + LoopInfo *LI; + ScalarEvolution *SE; + TargetTransformInfo *TTI; + StrideDescriptorUtils *SDU; + + // Map of gather/scatter instructions to their scalar base pointers. + // These are inserted in program order for each block. + MapVector GSPtrMap; + + /// \brief Tries to Convert a group of gathers or scatters to a strided access + /// + /// Returns true if anything changed. + bool lowerSDGroup(SDUtils::StrideGroup &Group); + + bool runOnLoop(Loop *L); + bool runOnBlock(BasicBlock *Block); +}; +} // end anonymous namespace. + +char InterleavedGatherScatter::ID = 0; +static const char ia_name[] = + "Lower interleaved gathers/scatters to target specific intrinsics"; + +INITIALIZE_PASS_BEGIN(InterleavedGatherScatter, "interleaved-gather-scatter", + ia_name, false, false) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(InterleavedGatherScatter, "interleaved-gather-scatter", + ia_name, false, false) + +FunctionPass * +llvm::createInterleavedGatherScatterPass() { + return new InterleavedGatherScatter(); +} + +bool InterleavedGatherScatter::lowerSDGroup(SDUtils::StrideGroup &Group) { + assert(!Group.empty()); + // Find the first and last SDs in the group. Note the group is in reverse + // instruction order, so FirstSD is at the back of the group, and LastSD is at + // the front + auto FirstSD = Group.back().get(); + auto LastSD = Group.front().get(); + auto ID = FirstSD->Instruction->getIntrinsicID(); + bool IsLoad = (ID == Intrinsic::masked_gather); + assert((IsLoad || (ID == Intrinsic::masked_scatter)) && + "Attempted to lower unhandled intrinsic"); + + // Find the SD with the lowest base address + SDUtils::StrideDescriptor *LowestSD = FirstSD; + for (auto I = Group.begin(), E = Group.end(); I != E; ++I) { + auto SD = I->get(); + if (SD->OffsetFromSD0 < LowestSD->OffsetFromSD0) + LowestSD = SD; + } + + assert(LowestSD->Stride > 0 && "Don't have support for negative strides yet"); + unsigned Factor = unsigned(LowestSD->Stride) / LowestSD->Size; + if (Factor > TLI->getMaxSupportedInterleaveFactor()) { + LLVM_DEBUG(dbgs() << "IGS: Skipping (Factor is too large)\n"); + return false; + } else if (Factor == 1) { + // Fixing this is SC-1270 + LLVM_DEBUG(dbgs() << "IGS: Skipping (This should be a consecutive Load!)\n"); + return false; + } + + // For gathers, this array holds the gather intrinsics. For scatters, it + // holds the values to be stored + Value *Values[] = {nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr}; + + for (auto I = Group.begin(), E = Group.end(); I != E; ++I) { + auto SD = I->get(); + int Offset = SD->OffsetFromSD0 - LowestSD->OffsetFromSD0; + unsigned Index = unsigned(Offset) / LowestSD->Size; + + assert(!(Offset % LowestSD->Size) && "Offset isn't a multiple of Size"); + assert(Index < Factor && "Index out of range"); + assert(!Values[Index] && "Multiple nodes with the same Index"); + Values[Index] = IsLoad ? SD->Instruction : SD->Instruction->getOperand(0); + } + + LLVM_DEBUG(dbgs() << "Group size=" << Group.size() << ", factor=" << Factor + << "\n"); + if ((Group.size() != Factor)) { + if (!IsLoad) { + // A future optimisation would be to spot blocks where there is a matching + // ldN instruction, and use the loads in that to fill in gaps here + LLVM_DEBUG(dbgs() << "IGS: Skipping (Store group is not fully populated)\n"); + return false; + } else if (!IGSAllowUnsafeLoads) { + LLVM_DEBUG(dbgs() << "IGS: Skipping (Load group is not fully populated)\n"); + return false; + } + } + + bool Changed; + if (IsLoad) { + IntrinsicInst *FirstInstr = FirstSD->Instruction; + Changed = TLI->lowerGathersToInterleavedLoad( + Values, FirstInstr, LowestSD->OffsetFromSD0 - FirstSD->OffsetFromSD0, + Factor, TTI); + } else { +#ifndef NDEBUG + for (auto *V : Values) { + if (V) + LLVM_DEBUG(dbgs() << "IGS: Value to store: " << *V << "\n"); + } + LLVM_DEBUG(dbgs() << "IGS: Storing to location: " + << *LowestSD->Instruction->getOperand(1) << "\n"); +#endif + Changed = TLI->lowerScattersToInterleavedStore( + Values, LowestSD->Instruction->getOperand(1), LastSD->Instruction, + Factor, TTI); + } + + if (Changed) + for (auto I = Group.begin(), E = Group.end(); I != E; ++I) { + (*I)->Instruction->eraseFromParent(); + } + else + LLVM_DEBUG(dbgs() << "IGS: Skipping (lowering failed)\n"); + + return Changed; +} + +bool InterleavedGatherScatter::runOnBlock(BasicBlock *Block) { + bool Changed = false; + + // Clear GSPtrMap as optimization may RAUW values in the map. + GSPtrMap.clear(); + // Holds loads & stores (inc gather,scatter) + SmallVector MemAccessList; + + for (auto &Inst : *Block) { + IntrinsicInst *II = nullptr; + if ((II = dyn_cast(&Inst)) && + ((II->getIntrinsicID() == Intrinsic::masked_gather) || + (II->getIntrinsicID() == Intrinsic::masked_scatter))) + MemAccessList.push_back(II); + else if (Inst.mayWriteToMemory()) + MemAccessList.push_back(&Inst); + } + + SmallVector, 4> *StrideGroups; + StrideGroups = SDU->getSDGroups(MemAccessList); + // Lower each group + for (auto I = StrideGroups->begin(), E = StrideGroups->end(); I != E; ++I) + Changed |= lowerSDGroup(*I->get()); + + return Changed; +} + +bool InterleavedGatherScatter::runOnLoop(Loop *L) { + bool Changed = false; + + for (auto &Block : L->blocks()) { + if (LI->getLoopFor(Block) != L) // Ignore blocks in subloop. + continue; + Changed |= runOnBlock(Block); + } + return Changed; +} + +bool InterleavedGatherScatter::runOnFunction(Function &F) { + auto *TPC = getAnalysisIfAvailable(); + if (!TPC) + return false; + + LLVM_DEBUG(dbgs() << "*** " << getPassName() << ": " << F.getName() << "\n"); + + auto &TM = TPC->getTM(); + TLI = TM.getSubtargetImpl(F)->getTargetLowering(); + SE = &getAnalysis().getSE(); + TTI = &getAnalysis().getTTI(F); + AA = &getAnalysis().getAAResults(); + LI = &getAnalysis().getLoopInfo(); + + bool Changed = false; + + DL = &F.getParent()->getDataLayout(); + SDU = new StrideDescriptorUtils(SE, DL); + + if (!IGSAllBlocks) { + for (auto I = LI->begin(), IE = LI->end(); I != IE; ++I) + for (auto L = df_begin(*I), LE = df_end(*I); L != LE; ++L) { + Changed |= runOnLoop(*L); + } + } else { + for (auto &BB : F) { + Changed |= runOnBlock(&BB); + } + } + return Changed; +} Index: lib/CodeGen/InterleavedGatherScatterStoreSinkPass.cpp =================================================================== --- /dev/null +++ lib/CodeGen/InterleavedGatherScatterStoreSinkPass.cpp @@ -0,0 +1,408 @@ +//===------InterleavedGatherScatterStoreSink.cpp----------------------=// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Interleaved Gather Scatter Store Sink pass +// +// It uses alias-analysis in order to sink stores at the end of bottom of a +// basic block. This enables future passes to group more gather/scatter +// together. +// +//===----------------------------------------------------------------------===// + +#include "InterleavedGatherScatterStrideDescUtils.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "interleaved-gather-scatter-store-sink" + +static cl::opt + IGSStoreSinkAllBlocks("sve-igs-store-sink-all-blocks", + cl::desc("IGS sinking stores in all " + "basic blocks, not just loop blocks."), + cl::init(false), cl::Hidden); + +namespace llvm { +void initializeInterleavedGatherScatterStoreSinkPass(PassRegistry &); +} + +namespace { + +class InterleavedGatherScatterStoreSink : public FunctionPass { + +public: + static char ID; + InterleavedGatherScatterStoreSink() : FunctionPass(ID), TLI(nullptr) { + initializeInterleavedGatherScatterStoreSinkPass( + *PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { + return "Interleaved Gather Scatter Store Sink Pass"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addPreserved(); + } + + bool runOnFunction(Function &F) override; + +private: + const TargetLowering *TLI; + const DataLayout *DL; + AliasAnalysis *AA; + LoopInfo *LI; + ScalarEvolution *SE; + TargetTransformInfo *TTI; + StrideDescriptorUtils *SDU; + + // Map of gather/scatter instructions to their scalar base pointers. + // These are inserted in program order for each block. + MapVector GSPtrMap; + + /// \brief Scans the block and populated the alias set tracker given. This + /// also knows how to analyze gather/scatter operations to find the base + /// pointer, assuming they're in the expected canonical form. + /// + /// Returns true if any pointers were added to the AST. + bool buildAliasSets(BasicBlock *BB, AliasSetTracker &AST, + Instruction *StartI); + + /// \brief Uses alias-analysis to try to sink stores down a basic block. + /// Tries to sink sequential stores together + /// + /// Returns true if the block was modified. + bool sinkScatterStores(BasicBlock *BB, AliasSetTracker &AST); + + /// \brief Uses alias-analysis to run checks on single stores within + /// groups of sequential stores + /// + /// Returns true if the block was modified. + bool isStoreSinkable(Instruction *I, Value *BasePtr, Instruction **LastGather, + AliasSetTracker &AST); + + bool runOnLoop(Loop *L); + bool runOnBlock(BasicBlock *Block); +}; +} // end anonymous namespace. + +char InterleavedGatherScatterStoreSink::ID = 0; +static const char ia_name[] = + "Lower interleaved gathers/scatters to target specific intrinsics"; + +INITIALIZE_PASS_BEGIN(InterleavedGatherScatterStoreSink, + "interleaved-gather-scatter-store-sink", ia_name, false, + false) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(InterleavedGatherScatterStoreSink, + "interleaved-gather-scatter-store-sink", ia_name, false, + false) + +FunctionPass * +llvm::createInterleavedGatherScatterStoreSinkPass() { + return new InterleavedGatherScatterStoreSink(); +} + +bool InterleavedGatherScatterStoreSink::isStoreSinkable( + Instruction *I, Value *BasePtr, Instruction **LastGather, + AliasSetTracker &AST) { + AAMDNodes AAInfo; + I->getAAMetadata(AAInfo); + AliasSet *AS1 = AST.getAliasSetForPointerIfExists( + BasePtr, MemoryLocation::UnknownSize, AAInfo); + assert(AS1 && "Alias set for gather/scatter not found"); + assert(I->mayWriteToMemory()); + // LocalLastGather is where we can sink the current store + // For the group, we will pick the most conservative (earliest) + // of the LocalLastGather + Instruction *LocalLastGather = nullptr; + BasicBlock::iterator Next(I); + std::advance(Next, 1); + for (auto End = I->getParent()->end(); Next != End; ++Next) { + auto NextI = &*Next; + if (!NextI->mayReadOrWriteMemory()) + continue; + // Extract the pointer from the instruction. + Value *NextPtr; + bool IsGather = false; + uint64_t Size; + bool KnownGatherScatter = false; + if (auto LI = dyn_cast(NextI)) { + if (!LI->isSimple()) { + LLVM_DEBUG(dbgs() << "Re-order: Can't sink past atomic or volatile.\n"); + // Can't sink past atomic or volatile loads. + break; + } + NextPtr = LI->getPointerOperand(); + Size = DL->getTypeStoreSize(LI->getType()); + } else if (auto SI = dyn_cast(NextI)) { + if (!SI->isSimple()) { + LLVM_DEBUG(dbgs() << "Re-order: store instruction is not simple.\n"); + break; + } + NextPtr = SI->getPointerOperand(); + Size = DL->getTypeStoreSize(SI->getValueOperand()->getType()); + } else { + // Check if this is a known gather/scatter in the block which we should + // have cached in the baseptr map. + NextPtr = GSPtrMap[NextI]; + if (!NextPtr) { + LLVM_DEBUG(dbgs() + << "Re-order: this is a known gather/scatter," + << " which we should been have cached in the baseptr map.\n"); + // The base pointer couldn't be found so we have to stop. + break; + } + KnownGatherScatter = true; + IsGather = !NextI->mayWriteToMemory(); + Size = MemoryLocation::UnknownSize; + } + + // We have the next instruction's pointer value, now find the alias set. + NextI->getAAMetadata(AAInfo); + auto AS2 = AST.getAliasSetForPointerIfExists(NextPtr, Size, AAInfo); + assert(AS2 && "Couldn't find alias set for re-ordering"); + + if (AS1 == AS2) { + if (!(KnownGatherScatter && !SDU->isAliasedGatherScatter(I, NextI))) { + LLVM_DEBUG(dbgs() << "Re-order: Detected a potential alias, cant sink " + "scatter below.\n"); + break; // Detected a potential alias, stop. + } + } + + // In case we find LastGather, that means a previous scatter had a more + // conservative sinking place + if (NextI == *LastGather) + return true; + + if (IsGather) + LocalLastGather = NextI; + + LLVM_DEBUG(dbgs() << "Re-order: Can sink scatter below " << *NextI << "\n"); + } + + // We made it this far because we haven't found the group LastGather, + // In that case, the LocalLastGather is more conservative + *LastGather = LocalLastGather; + + // We came out of the loop without finding gathers to sink below + if (!LocalLastGather) + return false; + + return true; +} + +bool InterleavedGatherScatterStoreSink::sinkScatterStores( + BasicBlock *BB, AliasSetTracker &AST) { + bool Changed = false; + // To sink scatter stores, the idea is to start from the end of each block, + // until we find our scatter store, we then scan all subsequent memory + // accesses (scalar and vector), examining the alias sets of the pointers + // to see if we can sink the store past it. + // + // We start from the end so that stores in the example block can be sunk: + // bb: + // scatter.store p1, v + // gather.load p2 + // scatter.store p3, v + // gather.load p4, v + // ... + // ... assuming that p1 & p3 are may-alias, but neither aliases with p2 & p4. + // This ensures we should be able to sink all stores as far as possible + // with one pass, in this case both below p4. + // + // The caveat is that we cannot sink a store past a gather/scatter for + // which we weren't able to find the base pointer. + + SmallVector, 8> GSInBlock; + for (auto II = GSPtrMap.rbegin(), IE = GSPtrMap.rend(); II != IE; ++II) + GSInBlock.push_back(*II); + + for (auto It = GSInBlock.begin(), E = GSInBlock.end(); It != E; ++It) { + Instruction *I = It->first; + Value *BasePtr = It->second; + + // Subsequent stores are evaluated together, to not sink only some of the + // stores in the group + // The would unnecessarily break stride groups + SmallVector, 8> SubsequentStoresInBlock; + while (I->mayWriteToMemory() && It->second) { + SubsequentStoresInBlock.push_back(*It++); + if (It == E) + break; + I = It->first; + } + + if (SubsequentStoresInBlock.empty()) + continue; + --It; + + for (auto CtgIt : SubsequentStoresInBlock) + LLVM_DEBUG(dbgs() << "Re-order: Trying to sink store: " << *(CtgIt.first)); + + // Keep track of the last gather we have analyzed and found safe to + // sink through. Don't bother doing any sinking if we can't find any. + Instruction *LastGather = nullptr; + // Building the AliasSet based on subsequent instructions in the block, + // not on instructions that are executed previously + AliasSetTracker ASTPartial(*AA); + buildAliasSets(BB, ASTPartial, SubsequentStoresInBlock.back().first); + for (auto CtgIt : SubsequentStoresInBlock) { + I = CtgIt.first; + BasePtr = CtgIt.second; + // If one of the stores can't be sunk, we avoid sinking any of the + // stores in the group + if (!isStoreSinkable(I, BasePtr, &LastGather, ASTPartial)) + break; + } + + if (!LastGather) { + LLVM_DEBUG(dbgs() << "Re-order: Not sinking as no gathers to be passed.\n"); + continue; // No point in sinking. + } + + for (auto CtgIt : SubsequentStoresInBlock) { + I = CtgIt.first; + I->removeFromParent(); + I->insertAfter(LastGather); + } + + LLVM_DEBUG(dbgs() << "Re-order: Sunk scatter instruction(s).\n"); + Changed = true; + } + + return Changed; +} + +bool InterleavedGatherScatterStoreSink::runOnBlock(BasicBlock *Block) { + bool Changed = false; + + AliasSetTracker AST(*AA); + if (buildAliasSets(Block, AST, NULL)) { + LLVM_DEBUG(dbgs() << "Dumping AST for block: " << Block->getName() << "\n"); + LLVM_DEBUG(dbgs() << AST << "\n"); + // First we try to sink scatters down the loop block past any gathers, + // so that the grouping phase can work later. + Changed |= sinkScatterStores(Block, AST); + } else + LLVM_DEBUG(dbgs() << "No aliases!\n"); + + // Clear GSPtrMap as optimization may RAUW values in the map. + GSPtrMap.clear(); + + return Changed; +} + +bool InterleavedGatherScatterStoreSink::runOnLoop(Loop *L) { + bool Changed = false; + + for (auto &Block : L->blocks()) { + if (LI->getLoopFor(Block) != L) // Ignore blocks in subloop. + continue; + Changed |= runOnBlock(Block); + } + return Changed; +} + +bool InterleavedGatherScatterStoreSink::buildAliasSets(BasicBlock *BB, + AliasSetTracker &AST, + Instruction *StartI) { + // Look for gather/scatters specifically as AliasSetTracker doesn't handle + // these itself. Keeping this code in this pass because the style of + // addressing is specific for scalable SVE. + bool StartBuilding = false; + for (auto &II : *BB) { + Instruction *I = &II; + + // Building the AliasSet based on subsequent instructions in the block, + // not on instructions that are executed previously + // A store in the middle of the block will not care about aliasing + // with a previous load + if (!StartI || (I == StartI)) + StartBuilding = true; + if (!StartBuilding) + continue; + + // We want to handle the intrinsics specially instead of using the AST. + if (!match(I, m_Intrinsic()) && + !match(I, m_Intrinsic())) { + AST.add(I); // Let AST handle other instructions. + continue; + } + + LLVM_DEBUG(dbgs() << "IGS: Finding base ptr of " << *I << "\n"); + Value *BasePtr = SDU->findBasePtrFromInstruction(I); + GSPtrMap[I] = BasePtr; + + if (!BasePtr) { + LLVM_DEBUG(dbgs() << "IGS: Couldn't find base ptr of gather/scatter.\n"); + AST.add(I); + continue; + } + + // Now we add the pointer to the alias sets manually. + AAMDNodes AAInfo; + I->getAAMetadata(AAInfo); + AST.add(BasePtr, MemoryLocation::UnknownSize, AAInfo); + } + return !AST.getAliasSets().empty(); +} + +bool InterleavedGatherScatterStoreSink::runOnFunction(Function &F) { + auto *TPC = getAnalysisIfAvailable(); + if (!TPC) + return false; + + LLVM_DEBUG(dbgs() << "*** " << getPassName() << ": " << F.getName() << "\n"); + + auto &TM = TPC->getTM(); + TLI = TM.getSubtargetImpl(F)->getTargetLowering(); + SE = &getAnalysis().getSE(); + TTI = &getAnalysis().getTTI(F); + AA = &getAnalysis().getAAResults(); + LI = &getAnalysis().getLoopInfo(); + + bool Changed = false; + + DL = &F.getParent()->getDataLayout(); + SDU = new StrideDescriptorUtils(SE, DL); + + if (!IGSStoreSinkAllBlocks) { + for (auto I = LI->begin(), IE = LI->end(); I != IE; ++I) + for (auto L = df_begin(*I), LE = df_end(*I); L != LE; ++L) { + Changed |= runOnLoop(*L); + } + } else { + for (auto &BB : F) + Changed |= runOnBlock(&BB); + } + return Changed; +} Index: lib/CodeGen/InterleavedGatherScatterStrideDescUtils.h =================================================================== --- /dev/null +++ lib/CodeGen/InterleavedGatherScatterStrideDescUtils.h @@ -0,0 +1,137 @@ +//===---------InterleavedGatherScatterStrideDescUtils.h--------------------=// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===---------------------------------------------------------------------===// +// +// This file implements a set of utilities to handle stride descriptors. +// +// Stride descriptors are a representation for gathers and scatters. This pass +// can group stride descriptors from a list of memory accesses, and handle +// any aliasing concerns. +// +//===---------------------------------------------------------------------===// + +#ifndef LLVM_LIB_CODEGEN_IGSSTRIDEDESCRIPTORUTILS_H +#define LLVM_LIB_CODEGEN_IGSSTRIDEDESCRIPTORUTILS_H + +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +namespace SDUtils { +/// \brief The descriptor for a strided memory access. +struct StrideDescriptor { + // The current gather or scatter intrinsic. + IntrinsicInst *Instruction; + // The SCEV expression of the base ptr. + const SCEV *BaseSCEV; + // Base pointer as a value, to be used by alias tracker + Value *BasePtr; + // The start value of the vector of offsets from the base ptr. + Value *Start; + // The predicate for this access. + Value *GP; + // The access's stride in bytes. May be negative. + int Stride; + // The size of the memory object in bytes. + unsigned Size; + // The element type of the vector. + Type *EltTy; + // The alignment of this access in bytes. + unsigned Align; + // The byte offset between this access and the + // first access that was entered into the same + // group as this. + int OffsetFromSD0; + // If set, the type the address pointer was + // bitcast to from an original type. + Type *CastDestTy; + + StrideDescriptor(IntrinsicInst *Instruction, const SCEV *BaseSCEV, + Value *BasePtr, Value *Start, Value *GP, int Stride, + unsigned Size, Type *EltTy, unsigned Align, Type *CastDestTy) + : Instruction(Instruction), BaseSCEV(BaseSCEV), BasePtr(BasePtr), + Start(Start), GP(GP), Stride(Stride), Size(Size), EltTy(EltTy), + Align(Align), OffsetFromSD0(0), CastDestTy(CastDestTy) {} +}; +typedef SmallVector, 4> StrideGroup; +} // namespace SDUtils + +class StrideDescriptorUtils { +public: + StrideDescriptorUtils(ScalarEvolution *SE, const DataLayout *DL) + : SE(SE), DL(DL){}; + + /// \brief Replace masked gather/scatter intrinsics with strided load/stores + /// + /// Takes a vector of all memory accesses in a block, in instruction order + /// Returns true if anything changed + SmallVector, 4> * + getSDGroups(const SmallVectorImpl &MemAccessList); + + /// \brief Create a StrideDescriptor from a gather or scatter intrinsic. + /// + /// Returns nullptr on failure + std::unique_ptr + createStrideDescriptor(IntrinsicInst *Instr); + + /// \brief Finds base pointer for a gather scatter instruction + Value *findBasePtrFromInstruction(Instruction *I); + + /// \brief For 2 gather/scatter instructions, this checks that their + /// element memory accesses are interleaving without overlapping + /// This complements the AliasSetTracker, which can't deal + /// with gather/scatters + bool isAliasedGatherScatter(Instruction *I, Instruction *NextI); + +private: + ScalarEvolution *SE; + const DataLayout *DL; + + /// \brief Finds original source of a vector by removing bitcasts + /// + /// There can be an arbitrary number of bitcasts and pointer/integer + /// conversions between a gather/scatter and the instruction that + /// built the address vector + Value *removeVectorCastings(Value *VecPtr); + + /// \brief Extracts base address from a vector built by a getlementptr + /// + /// Called by analyseGatherScatterBaseAddresses. + bool analyseGSBaseAddressesFromGEP(IntrinsicInst *Instr, + SDUtils::StrideDescriptor &NewSD, + GetElementPtrInst *BaseGEP); + + /// \brief Extracts information from a gather/scatter definition. + /// + /// This function will find in the IR where the vector of addresses + /// was built, and determine the base address and stride from that + /// definition + bool analyseGatherScatterBaseAddresses(IntrinsicInst *Instr, + SDUtils::StrideDescriptor &NewSD, + bool IsLoad); + + /// \brief Tries to add a new stride descriptor to an existing group + /// + /// Will perform checks on stride, offsets, aliasing... to make sure the + /// stride descriptor fits in + bool addSDToGroup(SDUtils::StrideDescriptor &SD, SDUtils::StrideGroup &Group); + + /// \brief Creates groups of stride descriptors from a list of memory + /// accesses with no aliasing concern + void lowerGatherScatterNoAliasGroup( + SmallVectorImpl &NoAliasGroup, + SmallVector, 4> *StrideGroups); +}; + +#endif Index: lib/CodeGen/InterleavedGatherScatterStrideDescUtils.cpp =================================================================== --- /dev/null +++ lib/CodeGen/InterleavedGatherScatterStrideDescUtils.cpp @@ -0,0 +1,449 @@ +//===---------InterleavedGatherScatterStrideDescUtils.cpp------------------=// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a set of utilities to handle stride descriptors. +// +// Stride descriptors are a representation for gathers and scatters. This pass +// can group stride descriptors from a list of memory accesses, and handle +// any aliasing concerns. +// +//===----------------------------------------------------------------------===// + +#include "InterleavedGatherScatterStrideDescUtils.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "interleaved-gather-scatter-stride-descriptor-utils" + +Value *StrideDescriptorUtils::removeVectorCastings(Value *VecPtr) { + // This function only works on vector types. Other castings + // would not help analysing strides + while (VecPtr->getType()->isVectorTy()) { + auto BCast = dyn_cast(VecPtr); + if (BCast && (isa(BCast) || isa(BCast) || + isa(BCast))) { + VecPtr = BCast->getOperand(0); + continue; + } + // If we find something other than a casting, that's analysed by the + // main function + break; + } + return VecPtr; +} + +bool StrideDescriptorUtils::analyseGSBaseAddressesFromGEP( + IntrinsicInst *Instr, SDUtils::StrideDescriptor &NewSD, + GetElementPtrInst *BaseGEP) { + // We start to get basic info from the intrinsics + Value *Base, *Align, *GP, *Merge, *Start; + bool IsLoad; + if (match(Instr, + m_Intrinsic(m_Value(Base), m_Value(Align), + m_Value(GP), m_Value(Merge)))) + IsLoad = true; + else if (match(Instr, + m_Intrinsic( + m_Value(), m_Value(Base), m_Value(Align), m_Value(GP)))) + IsLoad = false; + else + return false; + + // We only handle the case where getlementptr creates a vector series + if (!BaseGEP || (BaseGEP->getNumOperands() > 2)) + return false; + if (IsLoad && !isa(Merge)) + return false; + ConstantInt *StrideInst; + if (!match(BaseGEP->getOperand(1), + m_SeriesVector(m_Value(Start), m_ConstantInt(StrideInst)))) + return false; + + auto GEPPtr = BaseGEP->getPointerOperand(); + auto GEPPtrNoCast = removeVectorCastings(GEPPtr); + NewSD.BasePtr = GEPPtrNoCast; + // TODO: Is bailing out the correct thing here? Or should we never even + // get this far to begin with? + // Failure context: denseLinearAlgebra, the pointers from a contiguous + // load get passed straight through to a scatter store. + if (!SE->isSCEVable(GEPPtr->getType())) + return false; + auto PtrSCEV = SE->getSCEV(GEPPtr); + + // We attempt to turn Start into an SCEV. If it is a success, we add Start + // as an offset directly into the BaseSCEV. + // If not, we keep Start as it is + // Using BaseSCEV as much as possible enables the pass to mix vectors coming + // from masked load and vectors coming from a getelementptr + const SCEV *OffsetSCEV = nullptr; + if (SE->isSCEVable((Start)->getType())) + OffsetSCEV = SE->getSCEV(Start); + + if (OffsetSCEV && (SE->getEffectiveSCEVType(PtrSCEV->getType()) == + SE->getEffectiveSCEVType(OffsetSCEV->getType()))) { + NewSD.BaseSCEV = SE->getAddExpr(OffsetSCEV, PtrSCEV); + Start = nullptr; + } else + NewSD.BaseSCEV = PtrSCEV; + NewSD.Start = Start; + + unsigned GEPSize = DL->getTypeStoreSize(BaseGEP->getResultElementType()); + NewSD.Stride = StrideInst->getSExtValue() * GEPSize; + // If we have pointers to aggregate types, those are bitcasted to scalar + // types when used in gathers/scatters. The real size of the access is the + // element size of the bitcast type, not the original aggregate type size. + assert(NewSD.Stride != 0 && "Stride (or size) is zero, and shouldn't be"); + return true; +} + +bool StrideDescriptorUtils::analyseGatherScatterBaseAddresses( + IntrinsicInst *Instr, SDUtils::StrideDescriptor &NewSD, bool IsLoad) { + // Recovering the vector of addresses + Value *VecPtr; + if (IsLoad) + match(Instr, m_Intrinsic(m_Value(VecPtr))); + else + match(Instr, + m_Intrinsic(m_Value(), m_Value(VecPtr))); + + GetElementPtrInst *BaseGEP = nullptr; + // We start analysing producers for the vector of addresses. + // We can find 3 types of elements + // -Castings + // -Original producer: getelementptr or load that is regular enough + // -ALU op which adds an offset to the original addresses + while (VecPtr->getType()->isVectorTy()) { + // Removing any of the castings + VecPtr = removeVectorCastings(VecPtr); + + // Detecting offsets. Currently detecting a single case + // of adding a splat to a vector + Value *Add1, *Add2, *SplatCst; + if (match(VecPtr, m_Add(m_Value(Add1), m_Value(Add2)))) { + // if the add isn't using a splat, we give up, because + // we won't be able to detect the offset value + if (!match(Add2, m_SplatVector(m_Value(SplatCst)))) + break; + assert(Add1->getType()->isVectorTy()); + // We store the offset value for future calculations, + // and continue to find the original producer of this vector + VecPtr = Add1; + continue; + } + + // Finding any of the instructions that can create a vector of addresses + BaseGEP = dyn_cast(VecPtr); + if (BaseGEP) + return analyseGSBaseAddressesFromGEP(Instr, NewSD, BaseGEP); + + // If we arrive at the end of the loop it means we have found + // an intrinsic not yet handled. Giving up search for base ptr + break; + } + return false; +} + +Value *StrideDescriptorUtils::findBasePtrFromInstruction(Instruction *I) { + if ((!match(I, m_Intrinsic())) && + (!match(I, m_Intrinsic()))) + return nullptr; + + auto NewSD = createStrideDescriptor(dyn_cast(I)); + if (!NewSD) + return nullptr; + + return NewSD->BasePtr; +} + +/// This function is to be used for dealiasing purposes +/// The regular alias set tracker doesn't handle gather/scatter because +/// the size of the memory access is unknown +/// This function detects one specific case, where we know for sure +/// there is no aliasing +bool StrideDescriptorUtils::isAliasedGatherScatter(Instruction *I, + Instruction *NextI) { + auto SD_cur = createStrideDescriptor(dyn_cast(I)); + auto SD_next = createStrideDescriptor(dyn_cast(NextI)); + if (!SD_cur || !SD_next) + return true; + + // SD_next is a stride that could potentially interleave with SD_cur + // without ovelapping + if ((SD_cur->Start == SD_next->Start) && + (SD_cur->Stride == SD_next->Stride) && (SD_cur->Size == SD_next->Size) && + (SD_cur->BaseSCEV != SD_next->BaseSCEV)) { + // We check there is no overlap by comparing size and offset of the first + // beats The equal strides ensure subsequent beats will also not overlap + auto OffsetSCEV = SE->getMinusSCEV(SD_next->BaseSCEV, SD_cur->BaseSCEV); + auto ConstOffsetSCEV = dyn_cast(OffsetSCEV); + if (!ConstOffsetSCEV) { + return true; + } + unsigned OffsetFromCur = + std::abs(ConstOffsetSCEV->getAPInt().getSExtValue()); + if ((OffsetFromCur >= SD_cur->Size) && + (OffsetFromCur <= (SD_cur->Stride - SD_cur->Size))) + return false; + } + // By default we consider the gathers/scatters as aliased + return true; +} + +/// \brief Create a StrideDescriptor from a gather or scatter intrinsic. +/// +/// Returns nullptr on failure +std::unique_ptr +StrideDescriptorUtils::createStrideDescriptor(IntrinsicInst *Instr) { + Value *Base, *Align, *GP, *Merge, *Start; + Type *CastDestTy = nullptr; + + // Basic check that we are working on a Gather/Scatter + bool IsLoad; + if (match(Instr, + m_Intrinsic(m_Value(Base), m_Value(Align), + m_Value(GP), m_Value(Merge)))) + IsLoad = true; + else if (match(Instr, + m_Intrinsic( + m_Value(), m_Value(Base), m_Value(Align), m_Value(GP)))) + IsLoad = false; + else + return nullptr; + + // Element type, size and alignment are found directly in + // the gather scatter intrinsic + auto ConstantAlign = dyn_cast(Align); + if (!ConstantAlign) + return nullptr; + Type *EltTy; + if (IsLoad) + EltTy = Instr->getType()->getScalarType(); + else + EltTy = Instr->getOperand(0)->getType()->getScalarType(); + unsigned Size = DL->getTypeStoreSize(EltTy); + + // for the stride and base address we have to find how + // address pointers have been built + int Stride = 0; + const SCEV *BaseSCEV; + Value *BasePtr = nullptr; + auto NewSD = llvm::make_unique( + Instr, BaseSCEV, BasePtr, Start, GP, Stride, Size, EltTy, + ConstantAlign->getSExtValue(), CastDestTy); + bool success = analyseGatherScatterBaseAddresses(Instr, *NewSD.get(), IsLoad); + // Non success means that we couldn't find base addresses of the right + // format. Some GatherScatters are indexed by a vector directly loaded + // from memory, in which case we can't be sure of the strides generated + if (!success) + return nullptr; + + LLVM_DEBUG(dbgs() << "Creating SD " << *BaseSCEV << ", Stride=" << Stride + << ", Size=" << Size << "\n"); + return NewSD; +} + +//===---------------------------------------------------------------------===// +// Extra functions to group stride descriptors in stride groups +//===---------------------------------------------------------------------===// + +bool StrideDescriptorUtils::addSDToGroup(SDUtils::StrideDescriptor &SD, + SDUtils::StrideGroup &Group) { + auto SD0 = Group.front().get(); + assert(SD0->Instruction->getIntrinsicID() == + SD.Instruction->getIntrinsicID() && + "Unexpected mix of loads and stores within a no-alias group"); + LLVM_DEBUG(dbgs() << "Trying to add to group containing " << *SD0->Instruction + << "Group size " << Group.size() << "\n"); + // Initial checks against first member + if ((SD0->Start != SD.Start) || (SD0->Stride != SD.Stride) || + (SD0->Size != SD.Size) || (SD0->GP != SD.GP)) { + LLVM_DEBUG(dbgs() << "IGS: not adding to group (Failed initial checks) " + << *SD.Instruction << "\n"); + return false; + } + + // Bases must have a constant offset + auto OffsetSCEV = SE->getMinusSCEV(SD.BaseSCEV, SD0->BaseSCEV); + auto ConstOffsetSVEV = dyn_cast(OffsetSCEV); + if (!ConstOffsetSVEV) { + LLVM_DEBUG(dbgs() << "IGS: not adding to group (No const offset) " + << *SD.Instruction << "\n"); + return false; + } + int OffsetFromSD0 = ConstOffsetSVEV->getAPInt().getSExtValue(); + + // Offset must be within the stride, and a multiple of size + if ((std::abs(OffsetFromSD0) >= SD.Stride) || (OffsetFromSD0 % SD0->Size)) { + LLVM_DEBUG(dbgs() << "IGS: not adding to group (offset out of range) " + << *SD.Instruction << "\n"); + return false; + } + + // Check that the offsets are within range for every other member + for (auto I = Group.begin(), E = Group.end(); I != E; I++) { + auto GroupSD = I->get(); + int OffsetFromGroupSD = OffsetFromSD0 - GroupSD->OffsetFromSD0; + // If there are multiple accesses to the same offset, bail out. It's + // possible a pass hasn't eliminated redundant memory ops. + if (OffsetFromGroupSD == 0) + return false; + if (std::abs(OffsetFromGroupSD) >= SD.Stride) { + LLVM_DEBUG(dbgs() << "IGS: not adding to group (Offset out of range for " + << "another member) " << *SD.Instruction << "\n"); + return false; + } + } + + // Check that the new descriptor doesn't alias with any of the other + // descriptors + for (auto I = Group.begin(), E = Group.end(); I != E; I++) { + auto GroupSD = I->get(); + LLVM_DEBUG(dbgs() << "Checking Gather/Scatter aliasing\n"); + if (isAliasedGatherScatter(SD.Instruction, GroupSD->Instruction)) { + LLVM_DEBUG(dbgs() << "Aliasing detected\n"); + return false; + } + } + + SD.OffsetFromSD0 = OffsetFromSD0; + LLVM_DEBUG(dbgs() << "IGS: Adding to existing StrideGroup: " << *SD.Instruction + << "\n"); + return true; +} + +void StrideDescriptorUtils::lowerGatherScatterNoAliasGroup( + SmallVectorImpl &NoAliasGroup, + SmallVector, 4> *StrideGroups) { + // Each inner vector contains a set of compatible gathers or scatters that + // will be combined into a single strided access + SmallVector, 4> NoAliasStrideGroups; + bool MayStore = NoAliasGroup[0]->mayWriteToMemory(); + // Create StrideGroups from the Intrinsics + for (auto Intrinsic : NoAliasGroup) { + assert(MayStore == Intrinsic->mayWriteToMemory() && + "Mix of loads and stores"); + auto SD = createStrideDescriptor(Intrinsic); + if (SD) { + LLVM_DEBUG(dbgs() << "Attempting to add to existing groups: " + << *SD->Instruction << "\n"); + } else { + LLVM_DEBUG(dbgs() << "IGS: Skipping (failed to Create SD): " << *Intrinsic + << "\n"); + continue; + } + + if (SD->Stride < 0) { + // Don't yet support negative strides + LLVM_DEBUG(dbgs() << "IGS: Skipping (negative stride): " << *SD->Instruction + << "\n"); + continue; + } + + if (((unsigned)SD->Stride) < SD->Size) { + // Don't yet support strides less than size -- happens if the stride + // was calculated against a smaller type but the pointers were then + // bitcasted to a larger type, and the stride is less than the larger + // type. Appears in some DSP code. Memory ops will overlap. + LLVM_DEBUG(dbgs() << "IGS: Skipping Stride: " << SD->Stride + << " less than Size: " << SD->Size << "\n"); + continue; + } + + bool Added = false; + if (!NoAliasStrideGroups.empty()) { + // Iterate backwards through the groups looking for a match. Stores can + // only match against the most recent group, since we don't know they + // don't alias + auto E = MayStore ? NoAliasStrideGroups.rbegin() + 1 + : NoAliasStrideGroups.rend(); + for (auto I = NoAliasStrideGroups.rbegin(); I != E; ++I) { + SDUtils::StrideGroup *Group = I->get(); + if (addSDToGroup(*SD.get(), *Group)) { + // Add SD to existing group + Group->push_back(std::move(SD)); + Added = true; + break; + } + } + } + + if (Added) + continue; + + // Create new group + LLVM_DEBUG(dbgs() << "IGS: New StrideGroup: " << *SD->Instruction << "\n"); + auto Group = llvm::make_unique(); + Group->push_back(std::move(SD)); + NoAliasStrideGroups.push_back(std::move(Group)); + } + + if (!NoAliasStrideGroups.empty()) { + for (auto I = NoAliasStrideGroups.rbegin(), E = NoAliasStrideGroups.rend(); + I != E; ++I) { + StrideGroups->push_back(std::move(*I)); + } + } + NoAliasGroup.clear(); +} + +/// Returns true if anything changed +SmallVector, 4> * +StrideDescriptorUtils::getSDGroups( + const SmallVectorImpl &MemAccessList) { + // Holds a group of intrinsics that have no aliasing concerns + SmallVector, 4> *StrideGroups; + StrideGroups = new SmallVector, 4>(); + SmallVector NoAliasGroup; + + // Note this iterates in reverse instruction order + for (auto I = MemAccessList.rbegin(), E = MemAccessList.rend(); I != E; ++I) + if (auto II = dyn_cast(*I)) { + // If this intrinsic is different from the current group, break the + // group + if (!NoAliasGroup.empty() && + (II->getIntrinsicID() != NoAliasGroup[0]->getIntrinsicID())) { + LLVM_DEBUG(dbgs() << "IGS: Breaking group due to different intrinsic:" << *II + << "\n"); + lowerGatherScatterNoAliasGroup(NoAliasGroup, StrideGroups); + } + // Then add this intrinsic to the (possibly empty) group + LLVM_DEBUG(dbgs() << "IGS: Adding intrinsic to current group: " << *II + << "\n"); + NoAliasGroup.push_back(II); + } else if ((*I)->mayWriteToMemory()) { + if (!NoAliasGroup.empty()) { + // Stores break the current group, but don't get added + LLVM_DEBUG(dbgs() << "IGS: Breaking group due to store: " << **I << "\n"); + lowerGatherScatterNoAliasGroup(NoAliasGroup, StrideGroups); + } + } else + llvm_unreachable("Unexpected instruction"); + + // lower the last group + if (!NoAliasGroup.empty()) + lowerGatherScatterNoAliasGroup(NoAliasGroup, StrideGroups); + + return StrideGroups; +} Index: lib/CodeGen/LiveDebugValues.cpp =================================================================== --- lib/CodeGen/LiveDebugValues.cpp +++ lib/CodeGen/LiveDebugValues.cpp @@ -33,6 +33,9 @@ #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineMemOperand.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/MachineOperand.h" +#include "llvm/CodeGen/PseudoSourceValue.h" #include "llvm/CodeGen/MachineOperand.h" #include "llvm/CodeGen/PseudoSourceValue.h" #include "llvm/CodeGen/TargetFrameLowering.h" @@ -470,16 +473,21 @@ MachineFunction *MF, unsigned &Reg) { const MachineFrameInfo &FrameInfo = MF->getFrameInfo(); int FI; - const MachineMemOperand *MMO; + SmallVector Accesses; // TODO: Handle multiple stores folded into one. if (!MI.hasOneMemOperand()) return false; // To identify a spill instruction, use the same criteria as in AsmPrinter. - if (!((TII->isStoreToStackSlotPostFE(MI, FI) || - TII->hasStoreToStackSlot(MI, MMO, FI)) && - FrameInfo.isSpillSlotObjectIndex(FI))) + if (!((TII->isStoreToStackSlotPostFE(MI, FI) && + FrameInfo.isSpillSlotObjectIndex(FI)) || + (TII->hasStoreToStackSlot(MI, Accesses) && + llvm::any_of(Accesses, [&FrameInfo](const MachineMemOperand *MMO) { + return FrameInfo.isSpillSlotObjectIndex( + cast(MMO->getPseudoValue()) + ->getFrameIndex()); + })))) return false; auto isKilledReg = [&](const MachineOperand MO, unsigned &Reg) { Index: lib/CodeGen/LocalStackSlotAllocation.cpp =================================================================== --- lib/CodeGen/LocalStackSlotAllocation.cpp +++ lib/CodeGen/LocalStackSlotAllocation.cpp @@ -200,19 +200,29 @@ // Make sure that the stack protector comes before the local variables on the // stack. SmallSet ProtectedObjs; - if (MFI.getStackProtectorIndex() >= 0) { + if (MFI.hasStackProtectorIndex()) { + int StackProtectorFI = MFI.getStackProtectorIndex(); + + // We need to make sure we didn't pre-allocate the stack protector when + // doing this. + // If we already have a stack protector, this will re-assign it to a slot + // that is **not** covering the protected objects. + assert(!MFI.isObjectPreAllocated(StackProtectorFI) && + "Stack protector pre-allocated in LocalStackSlotAllocation"); + StackObjSet LargeArrayObjs; StackObjSet SmallArrayObjs; StackObjSet AddrOfObjs; - AdjustStackOffset(MFI, MFI.getStackProtectorIndex(), Offset, - StackGrowsDown, MaxAlign); + AdjustStackOffset(MFI, StackProtectorFI, Offset, StackGrowsDown, MaxAlign); // Assign large stack objects first. for (unsigned i = 0, e = MFI.getObjectIndexEnd(); i != e; ++i) { if (MFI.isDeadObjectIndex(i)) continue; - if (MFI.getStackProtectorIndex() == (int)i) + if (StackProtectorFI == (int)i) + continue; + if (MFI.getStackID(i)) continue; switch (MFI.getObjectSSPLayout(i)) { @@ -248,6 +258,8 @@ continue; if (ProtectedObjs.count(i)) continue; + if (MFI.getStackID(i)) + continue; AdjustStackOffset(MFI, i, Offset, StackGrowsDown, MaxAlign); } @@ -344,6 +356,14 @@ assert(MFI.isObjectPreAllocated(FrameIdx) && "Only pre-allocated locals expected!"); + // We need to keep the references to the stack protector slot through frame + // index operands so that it gets resolved by PEI rather than this pass. + // This avoids accesses to the stack protector though virtual base + // registers, and forces PEI to address it using fp/sp/bp. + if (MFI.hasStackProtectorIndex() && + FrameIdx == MFI.getStackProtectorIndex()) + continue; + LLVM_DEBUG(dbgs() << "Considering: " << MI); unsigned idx = 0; Index: lib/CodeGen/MIRParser/MIRParser.cpp =================================================================== --- lib/CodeGen/MIRParser/MIRParser.cpp +++ lib/CodeGen/MIRParser/MIRParser.cpp @@ -640,7 +640,8 @@ else ObjectIdx = MFI.CreateStackObject( Object.Size, Object.Alignment, - Object.Type == yaml::MachineStackObject::SpillSlot, Alloca); + Object.Type == yaml::MachineStackObject::SpillSlot, Alloca, + Object.StackID); MFI.setObjectOffset(ObjectIdx, Object.Offset); MFI.setStackID(ObjectIdx, Object.StackID); Index: lib/CodeGen/MachineCombiner.cpp =================================================================== --- lib/CodeGen/MachineCombiner.cpp +++ lib/CodeGen/MachineCombiner.cpp @@ -478,8 +478,9 @@ std::tie(NewRootLatency, RootLatency) = getLatenciesForInstrSequences( Root, InsInstrs, DelInstrs, MinInstr->getTrace(MBB)); long CurrentLatencyDiff = ((long)RootLatency) - ((long)NewRootLatency); - assert(CurrentLatencyDiff <= PrevLatencyDiff && - "Current pattern is better than previous pattern."); +// TODO: Fix this when resolving SC-2684 +// assert(CurrentLatencyDiff <= PrevLatencyDiff && +// "Current pattern is better than previous pattern."); PrevLatencyDiff = CurrentLatencyDiff; } } Index: lib/CodeGen/MachineFrameInfo.cpp =================================================================== --- lib/CodeGen/MachineFrameInfo.cpp +++ lib/CodeGen/MachineFrameInfo.cpp @@ -57,7 +57,8 @@ !IsSpillSlot, StackID)); int Index = (int)Objects.size() - NumFixedObjects - 1; assert(Index >= 0 && "Bad frame index!"); - ensureMaxAlignment(Alignment); + if (StackID == 0) + ensureMaxAlignment(Alignment); return Index; } @@ -142,12 +143,16 @@ // should keep in mind that there's tight coupling between the two. for (int i = getObjectIndexBegin(); i != 0; ++i) { + if (getStackID(i)) + continue; int FixedOff = -getObjectOffset(i); if (FixedOff > Offset) Offset = FixedOff; } for (unsigned i = 0, e = getObjectIndexEnd(); i != e; ++i) { if (isDeadObjectIndex(i)) continue; + if (getStackID(i)) + continue; Offset += getObjectSize(i); unsigned Align = getObjectAlignment(i); // Adjust to alignment boundary Index: lib/CodeGen/MachineInstr.cpp =================================================================== --- lib/CodeGen/MachineInstr.cpp +++ lib/CodeGen/MachineInstr.cpp @@ -306,6 +306,16 @@ if (MRI && Operands[OpNo].isReg()) MRI->removeRegOperandFromUseList(Operands + OpNo); + // If within an Asm group we update the group flags to reflect the removal. + if (isInlineAsm()) { + int AsmFlagIdx = findInlineAsmFlagIdx(OpNo); + if (AsmFlagIdx >= 0) { + unsigned Flag = Operands[AsmFlagIdx].getImm(); + Flag = InlineAsm::decrementNumOperandRegisters(Flag); + Operands[AsmFlagIdx].setImm(Flag); + } + } + // Don't call the MachineOperand destructor. A lot of this code depends on // MachineOperand having a trivial destructor anyway, and adding a call here // wouldn't make it 'destructor-correct'. Index: lib/CodeGen/MachineScheduler.cpp =================================================================== --- lib/CodeGen/MachineScheduler.cpp +++ lib/CodeGen/MachineScheduler.cpp @@ -122,6 +122,16 @@ static cl::opt VerifyScheduling("verify-misched", cl::Hidden, cl::desc("Verify machine instrs before and after machine scheduling")); +static cl::opt AlwaysReduceLatency("misched-favour-latency", + cl::desc("Always favour latency over register pressure where possible"), + cl::init(true)); + +static cl::opt RegPressureThreshold("misched-regpressure-threshold", + cl::Hidden, + cl::desc("Don't prioritise latency when scheduling if the register exceeds " + "this threshold"), + cl::init(110)); + // DAG subtrees must have at least this many nodes. static const unsigned MinSubtreeSize = 8; @@ -2581,9 +2591,11 @@ bool tryLatency(GenericSchedulerBase::SchedCandidate &TryCand, GenericSchedulerBase::SchedCandidate &Cand, - SchedBoundary &Zone) { + SchedBoundary &Zone, + bool AlwaysReduceLatencyHeight) { if (Zone.isTop()) { - if (Cand.SU->getDepth() > Zone.getScheduledLatency()) { + if (AlwaysReduceLatencyHeight || + Cand.SU->getDepth() > Zone.getScheduledLatency()) { if (tryLess(TryCand.SU->getDepth(), Cand.SU->getDepth(), TryCand, Cand, GenericSchedulerBase::TopDepthReduce)) return true; @@ -2592,7 +2604,8 @@ TryCand, Cand, GenericSchedulerBase::TopPathReduce)) return true; } else { - if (Cand.SU->getHeight() > Zone.getScheduledLatency()) { + if (AlwaysReduceLatencyHeight || + Cand.SU->getHeight() > Zone.getScheduledLatency()) { if (tryLess(TryCand.SU->getHeight(), Cand.SU->getHeight(), TryCand, Cand, GenericSchedulerBase::BotHeightReduce)) return true; @@ -2664,6 +2677,8 @@ } } + RegionPolicy.AlwaysReduceLatencyHeight = AlwaysReduceLatency; + // For generic targets, we default to bottom-up, because it's simpler and more // compile-time optimizations have been implemented in that direction. RegionPolicy.OnlyBottomUp = true; @@ -2844,6 +2859,10 @@ Cand.RPDelta, DAG->getRegionCriticalPSets(), DAG->getRegPressure().MaxSetPressure); + + unsigned PressureFactor = RPTracker.getHighestUpwardPressureFactor( + &DAG->getPressureDiff(Cand.SU)); + Cand.PressureExceedsLimit = PressureFactor > RegPressureThreshold; } else { if (VerifyScheduling) { TempTracker.getMaxUpwardPressureDelta( @@ -2852,6 +2871,7 @@ Cand.RPDelta, DAG->getRegionCriticalPSets(), DAG->getRegPressure().MaxSetPressure); + Cand.PressureExceedsLimit = false; } else { RPTracker.getUpwardPressureDelta( Cand.SU->getInstr(), @@ -2859,6 +2879,10 @@ Cand.RPDelta, DAG->getRegionCriticalPSets(), DAG->getRegPressure().MaxSetPressure); + + unsigned PressureFactor = RPTracker.getHighestUpwardPressureFactor( + &DAG->getPressureDiff(Cand.SU)); + Cand.PressureExceedsLimit = PressureFactor > RegPressureThreshold; } } } @@ -2887,6 +2911,26 @@ return; } + + // We only compare a subset of features when comparing nodes between + // Top and Bottom boundary. Some properties are simply incomparable, in many + // other instances we should only override the other boundary if something + // is a clear good pick on one boundary. Skip heuristics that are more + // "tie-breaking" in nature. + bool SameBoundary = Zone != nullptr; + if (RegionPolicy.AlwaysReduceLatencyHeight && !TryCand.PressureExceedsLimit) { + if (SameBoundary) { + // For loops that are acyclic path limited, aggressively schedule for + // latency. Within an single cycle, whenever CurrMOps > 0, allow normal + // heuristics to take precedence. + if (Rem.IsAcyclicLatencyLimited && !Zone->getCurrMOps() && + tryLatency(TryCand, Cand, *Zone, + RegionPolicy.AlwaysReduceLatencyHeight)) { + return; + } + } + } + if (tryGreater(biasPhysRegCopy(TryCand.SU, TryCand.AtTop), biasPhysRegCopy(Cand.SU, Cand.AtTop), TryCand, Cand, PhysRegCopy)) @@ -2911,7 +2955,6 @@ // other instances we should only override the other boundary if something // is a clear good pick on one boundary. Skip heuristics that are more // "tie-breaking" in nature. - bool SameBoundary = Zone != nullptr; if (SameBoundary) { // For loops that are acyclic path limited, aggressively schedule for // latency. Within an single cycle, whenever CurrMOps > 0, allow normal Index: lib/CodeGen/PrologEpilogInserter.cpp =================================================================== --- lib/CodeGen/PrologEpilogInserter.cpp +++ lib/CodeGen/PrologEpilogInserter.cpp @@ -587,10 +587,12 @@ SmallVector AllocatedFrameSlots; // Add fixed objects. for (int i = MFI.getObjectIndexBegin(); i != 0; ++i) - AllocatedFrameSlots.push_back(i); + if (MFI.getStackID(i) == 0) + AllocatedFrameSlots.push_back(i); // Add callee-save objects. for (int i = MinCSFrameIndex; i <= (int)MaxCSFrameIndex; ++i) - AllocatedFrameSlots.push_back(i); + if (MFI.getStackID(i) == 0) + AllocatedFrameSlots.push_back(i); for (int i : AllocatedFrameSlots) { // These are converted from int64_t, but they should always fit in int @@ -735,6 +737,9 @@ // callee saved registers. if (StackGrowsDown) { for (unsigned i = MinCSFrameIndex; i <= MaxCSFrameIndex; ++i) { + if (MFI.getStackID(i)) + continue; + // If the stack grows down, we need to add the size to find the lowest // address of the object. Offset += MFI.getObjectSize(i); @@ -749,6 +754,9 @@ } else if (MaxCSFrameIndex >= MinCSFrameIndex) { // Be careful about underflow in comparisons agains MinCSFrameIndex. for (unsigned i = MaxCSFrameIndex; i != MinCSFrameIndex - 1; --i) { + if (MFI.getStackID(i)) + continue; + if (MFI.isDeadObjectIndex(i)) continue; @@ -817,18 +825,26 @@ // Make sure that the stack protector comes before the local variables on the // stack. SmallSet ProtectedObjs; - if (MFI.getStackProtectorIndex() >= 0) { + if (MFI.hasStackProtectorIndex()) { + int StackProtectorFI = MFI.getStackProtectorIndex(); StackObjSet LargeArrayObjs; StackObjSet SmallArrayObjs; StackObjSet AddrOfObjs; - AdjustStackOffset(MFI, MFI.getStackProtectorIndex(), StackGrowsDown, - Offset, MaxAlign, Skew); + // If we need a stack protector, we need to make sure that + // LocalStackSlotPass didn't already allocate a slot for it. + // If we are told to use the LocalStackAllocationBlock, the stack protector + // is expected to be already pre-allocated. + if (!MFI.getUseLocalStackAllocationBlock()) + AdjustStackOffset(MFI, StackProtectorFI, StackGrowsDown, Offset, MaxAlign, + Skew); + else if (!MFI.isObjectPreAllocated(MFI.getStackProtectorIndex())) + llvm_unreachable( + "Stack protector not pre-allocated by LocalStackSlotPass."); // Assign large stack objects first. for (unsigned i = 0, e = MFI.getObjectIndexEnd(); i != e; ++i) { - if (MFI.isObjectPreAllocated(i) && - MFI.getUseLocalStackAllocationBlock()) + if (MFI.isObjectPreAllocated(i) && MFI.getUseLocalStackAllocationBlock()) continue; if (i >= MinCSFrameIndex && i <= MaxCSFrameIndex) continue; @@ -836,8 +852,9 @@ continue; if (MFI.isDeadObjectIndex(i)) continue; - if (MFI.getStackProtectorIndex() == (int)i || - EHRegNodeFrameIndex == (int)i) + if (StackProtectorFI == (int)i || EHRegNodeFrameIndex == (int)i) + continue; + if (MFI.getStackID(i)) continue; switch (MFI.getObjectSSPLayout(i)) { @@ -856,6 +873,15 @@ llvm_unreachable("Unexpected SSPLayoutKind."); } + // We expect **all** the protected stack objects to be pre-allocated by + // LocalStackSlotPass. If it turns out that PEI still has to allocate some + // of them, we may end up messing up the expected order of the objects. + if (MFI.getUseLocalStackAllocationBlock() && + !(LargeArrayObjs.empty() && SmallArrayObjs.empty() && + AddrOfObjs.empty())) + llvm_unreachable("Found protected stack objects not pre-allocated by " + "LocalStackSlotPass."); + AssignProtectedObjSet(LargeArrayObjs, ProtectedObjs, MFI, StackGrowsDown, Offset, MaxAlign, Skew); AssignProtectedObjSet(SmallArrayObjs, ProtectedObjs, MFI, StackGrowsDown, @@ -877,11 +903,12 @@ continue; if (MFI.isDeadObjectIndex(i)) continue; - if (MFI.getStackProtectorIndex() == (int)i || - EHRegNodeFrameIndex == (int)i) + if (MFI.getStackProtectorIndex() == (int)i || EHRegNodeFrameIndex == (int)i) continue; if (ProtectedObjs.count(i)) continue; + if (MFI.getStackID(i)) + continue; // Add the objects that we need to allocate to our working set. ObjectsToAllocate.push_back(i); @@ -1063,7 +1090,8 @@ // Frame indices in debug values are encoded in a target independent // way with simply the frame index and offset rather than any // target-specific addressing mode. - if (MI.isDebugValue()) { + if (MI.isDebugValue() && + MF.getFrameInfo().getStackID(MI.getOperand(i).getIndex()) == 0) { assert(i == 0 && "Frame indices can only appear as the first " "operand of a DBG_VALUE machine instruction"); unsigned Reg; Index: lib/CodeGen/RegAllocGreedy.cpp =================================================================== --- lib/CodeGen/RegAllocGreedy.cpp +++ lib/CodeGen/RegAllocGreedy.cpp @@ -3120,18 +3120,23 @@ // Handle blocks that were not included in subloops. if (Loops->getLoopFor(MBB) == L) for (MachineInstr &MI : *MBB) { - const MachineMemOperand *MMO; + SmallVector Accesses; + auto isSpillSlotAccess = [&MFI](const MachineMemOperand *A) { + return MFI.isSpillSlotObjectIndex( + cast(A->getPseudoValue()) + ->getFrameIndex()); + }; if (TII->isLoadFromStackSlot(MI, FI) && MFI.isSpillSlotObjectIndex(FI)) ++Reloads; - else if (TII->hasLoadFromStackSlot(MI, MMO, FI) && - MFI.isSpillSlotObjectIndex(FI)) + else if (TII->hasLoadFromStackSlot(MI, Accesses) && + llvm::any_of(Accesses, isSpillSlotAccess)) ++FoldedReloads; else if (TII->isStoreToStackSlot(MI, FI) && MFI.isSpillSlotObjectIndex(FI)) ++Spills; - else if (TII->hasStoreToStackSlot(MI, MMO, FI) && - MFI.isSpillSlotObjectIndex(FI)) + else if (TII->hasStoreToStackSlot(MI, Accesses) && + llvm::any_of(Accesses, isSpillSlotAccess)) ++FoldedSpills; } Index: lib/CodeGen/RegisterPressure.cpp =================================================================== --- lib/CodeGen/RegisterPressure.cpp +++ lib/CodeGen/RegisterPressure.cpp @@ -1127,6 +1127,29 @@ #endif } +// Walk through each pressure difference calculating the pressure factor, i.e. +// the current register pressure as a percentage of the limit for that register +// set. +unsigned RegPressureTracker:: +getHighestUpwardPressureFactor(const PressureDiff *PDiff) const { + unsigned Factor = 0; + for (PressureDiff::const_iterator + PDiffI = PDiff->begin(), PDiffE = PDiff->end(); + PDiffI != PDiffE && PDiffI->isValid(); ++PDiffI) { + unsigned PSetID = PDiffI->getPSet(); + unsigned POld = CurrSetPressure[PSetID]; + unsigned PNew = POld + PDiffI->getUnitInc(); + unsigned Limit = RCI->getRegPressureSetLimit(PSetID); + if (!LiveThruPressure.empty()) + Limit += LiveThruPressure[PSetID]; + unsigned Percent = (100 * PNew) / Limit; + if (Percent > Factor) + Factor = Percent; + } + + return Factor; +} + /// This is the fast version of querying register pressure that does not /// directly depend on current liveness. /// Index: lib/CodeGen/RegisterScavenging.cpp =================================================================== --- lib/CodeGen/RegisterScavenging.cpp +++ lib/CodeGen/RegisterScavenging.cpp @@ -534,7 +534,7 @@ unsigned RegScavenger::scavengeRegister(const TargetRegisterClass *RC, MachineBasicBlock::iterator I, - int SPAdj) { + int SPAdj, bool SRLiveRangeEndsHere) { MachineInstr &MI = *I; const MachineFunction &MF = *MI.getMF(); // Consider all allocatable registers in the register class initially @@ -542,10 +542,16 @@ // Exclude all the registers being used by the instruction. for (const MachineOperand &MO : MI.operands()) { - if (MO.isReg() && MO.getReg() != 0 && !(MO.isUse() && MO.isUndef()) && - !TargetRegisterInfo::isVirtualRegister(MO.getReg())) + if (MO.isReg() && MO.getReg() != 0 && (MO.isDef() || !MO.isUndef()) && + !TargetRegisterInfo::isVirtualRegister(MO.getReg())) { + // We can reuse the destination register if it is not an earlyclobber + // and 'I' kills the scavenged register + if (MO.isDef() && SRLiveRangeEndsHere && !MO.isEarlyClobber()) + continue; + for (MCRegAliasIterator AI(MO.getReg(), TRI, true); AI.isValid(); ++AI) Candidates.reset(*AI); + } } // Try to find a register that's unused if there is one, as then we won't Index: lib/CodeGen/SafeStack.cpp =================================================================== --- lib/CodeGen/SafeStack.cpp +++ lib/CodeGen/SafeStack.cpp @@ -28,6 +28,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/CodeGen/TargetLowering.h" @@ -848,6 +849,7 @@ void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); AU.addRequired(); + AU.addRequired(); AU.addRequired(); } @@ -873,6 +875,7 @@ auto *DL = &F.getParent()->getDataLayout(); auto &TLI = getAnalysis().getTLI(); + auto &TTI = getAnalysis().getTTI(F); auto &ACT = getAnalysis().getAssumptionCache(F); // Compute DT and LI only for functions that have the attribute. @@ -884,7 +887,7 @@ DominatorTree DT(F); LoopInfo LI(DT); - ScalarEvolution SE(F, TLI, ACT, DT, LI); + ScalarEvolution SE(F, TLI, ACT, DT, LI, TTI); return SafeStack(F, *TL, *DL, SE).run(); } Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -386,6 +386,7 @@ SDValue visitMSCATTER(SDNode *N); SDValue visitFP_TO_FP16(SDNode *N); SDValue visitFP16_TO_FP(SDNode *N); + SDValue visitVECREDUCE(SDNode *N); SDValue visitFADDForFMACombine(SDNode *N); SDValue visitFSUBForFMACombine(SDNode *N); @@ -865,6 +866,36 @@ return false; } +// TODO: Move this to the ISD namespace in SelectionDAG.cpp as well +// as the other "isConstantFoo" functions? +static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) { + if (!ScalarTy.isSimple()) + return false; + + uint64_t MaskForTy = 0ull; + switch(ScalarTy.getSimpleVT().SimpleTy) { + case MVT::i8: + MaskForTy = 0xffull; + break; + case MVT::i16: + MaskForTy = 0xffffull; + break; + case MVT::i32: + MaskForTy = 0xffffffffull; + break; + default: + return false; + break; + } + + APInt Val; + if (ISD::isConstantSplatVector(N, Val)) { + return Val.getLimitedValue() == MaskForTy; + } + + return false; +} + static SDValue peekThroughBitcast(SDValue V) { while (V.getOpcode() == ISD::BITCAST) V = V.getOperand(0); @@ -952,9 +983,10 @@ return DAG.getNode(Opc, DL, VT, N0.getOperand(0), OpNode); return SDValue(); } - if (N0.hasOneUse()) { + if (N0.hasOneUse() && N1.getOpcode() != ISD::VSCALE) { // reassoc. (op (op x, c1), y) -> (op (op x, y), c1) iff x+c1 has one - // use + // use and if y is not vscale. Prefer moving out vscale beyond + // constants in expressions. SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N0.getOperand(0), N1); if (!OpNode.getNode()) return SDValue(); @@ -972,9 +1004,10 @@ return DAG.getNode(Opc, DL, VT, N1.getOperand(0), OpNode); return SDValue(); } - if (N1.hasOneUse()) { + if (N1.hasOneUse() && N0.getOpcode() != ISD::VSCALE) { // reassoc. (op x, (op y, c1)) -> (op (op x, y), c1) iff x+c1 has one - // use + // use and if x is not vscale. Prefer moving out vscale beyond + // constants in expressions. SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT, N0, N1.getOperand(0)); if (!OpNode.getNode()) return SDValue(); @@ -1596,6 +1629,19 @@ case ISD::MSTORE: return visitMSTORE(N); case ISD::FP_TO_FP16: return visitFP_TO_FP16(N); case ISD::FP16_TO_FP: return visitFP16_TO_FP(N); + case ISD::VECREDUCE_FADD: + case ISD::VECREDUCE_FMUL: + case ISD::VECREDUCE_ADD: + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: + case ISD::VECREDUCE_SMAX: + case ISD::VECREDUCE_SMIN: + case ISD::VECREDUCE_UMAX: + case ISD::VECREDUCE_UMIN: + case ISD::VECREDUCE_FMAX: + case ISD::VECREDUCE_FMIN: return visitVECREDUCE(N); } return SDValue(); } @@ -2188,6 +2234,26 @@ if (SDValue Combined = visitADDLike(N1, N0, N)) return Combined; + // fold a+vscale(c1)+vscale(c2) -> a+vscale(c1+c2) + if ((N0.getOpcode() == ISD::ADD) && + (N0.getOperand(1).getOpcode() == ISD::VSCALE) && + (N1.getOpcode() == ISD::VSCALE)) { + auto VS0 = cast(N0.getOperand(1).getOperand(0)); + auto VS1 = cast(N1.getOperand(0)); + auto VS = DAG.getVScale(DL, VT, VS0->getSExtValue() + VS1->getSExtValue()); + return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS); + } + + // fold (add (splat x), (series_vector y, z)) -> series_vector x+y, z + if (N0.getOpcode() == ISD::SPLAT_VECTOR && + N1.getOpcode() == ISD::SERIES_VECTOR && + N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()) { + // Limit to cases where the add is redundant. + if (isNullConstant(N1.getOperand(0))) // y == 0 + return DAG.getNode(ISD::SERIES_VECTOR, DL, VT, N0.getOperand(0), + N1.getOperand(1)); + } + return SDValue(); } @@ -2742,6 +2808,12 @@ } } + // canonicalize (X - (vscale * C)) to (X + (vscale * -C)) + if (N1.getOpcode() == ISD::VSCALE) { + int64_t MulImm = cast(N1.getOperand(0))->getSExtValue(); + return DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getVScale(DL, VT, -MulImm)); + } + // Prefer an add for more folding potential and possibly better codegen: // sub N0, (lshr N10, width-1) --> add N0, (ashr N10, width-1) if (!LegalOperations && N1.getOpcode() == ISD::SRL && N1.hasOneUse()) { @@ -2973,6 +3045,20 @@ DAG.getNode(ISD::MUL, SDLoc(N1), VT, N0.getOperand(1), N1)); + // fold (mul (series_vector x, y), c) -> series_vector x, y*c + if (N1IsConst && N0.getOpcode() == ISD::SERIES_VECTOR) { + auto CStep = dyn_cast(N0.getOperand(1)); + EVT StepVT = N0.getOperand(1).getValueType(); + + // Limit to cases where the mul is redundant. + if (CStep && StepVT.getSizeInBits() == ConstValue1.getBitWidth()) { + SDLoc DL(N); + APInt NewStep = CStep->getAPIntValue() * ConstValue1; + SDValue Op2 = DAG.getConstant(NewStep, DL, StepVT); + return DAG.getNode(ISD::SERIES_VECTOR, DL, VT, N0.getOperand(0), Op2); + } + } + // reassociate mul if (SDValue RMUL = ReassociateOps(ISD::MUL, SDLoc(N), N0, N1)) return RMUL; @@ -4554,6 +4640,45 @@ } } + if (auto *LN0 = dyn_cast(N0)) { + EVT MemVT = LN0->getMemoryVT(); + EVT ScalarVT = MemVT.getScalarType(); + if ( SDValue(LN0, 0).hasOneUse() + && isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) + && TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT)) { + SDValue ZExtLoad = DAG.getMaskedLoad(VT, SDLoc(N), LN0->getChain(), + LN0->getBasePtr(), LN0->getMask(), + LN0->getSrc0(), MemVT, + LN0->getMemOperand(), ISD::ZEXTLOAD); + CombineTo(N, ZExtLoad); + CombineTo(N0.getNode(), ZExtLoad, ZExtLoad.getValue(1)); + AddToWorklist(ZExtLoad.getNode()); + // Avoid recheck of N. + return SDValue(N, 0); + } + } + + if (auto *GN0 = dyn_cast(N0)) { + EVT MemVT = GN0->getMemoryVT(); + EVT ScalarVT = MemVT.getScalarType(); + if ( SDValue(GN0, 0).hasOneUse() + && isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) + && TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT)) { + SDValue Ops[] = { GN0->getChain(), GN0->getSrc0(), GN0->getMask(), + GN0->getBasePtr(), GN0->getIndex(), GN0->getScale() }; + SDValue ZExtLoad = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), + MemVT, SDLoc(N), Ops, + GN0->getMemOperand(), + ISD::ZEXTLOAD, + GN0->getIndexType()); + CombineTo(N, ZExtLoad); + CombineTo(N0.getNode(), ZExtLoad, ZExtLoad.getValue(1)); + AddToWorklist(ZExtLoad.getNode()); + // Avoid recheck of N. + return SDValue(N, 0); + } + } + // fold (and (load x), 255) -> (zextload x, i8) // fold (and (extload x, i16), 255) -> (zextload x, i8) // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8) @@ -4658,6 +4783,15 @@ return BSwap; } + + // vscale has a target specific range that will likely make the and redundant. + if (N1C && (N0.getOpcode() == ISD::VSCALE)) { + KnownBits Known; + DAG.computeKnownBits(N0, Known); + if ((Known.Zero | ~N1C->getAPIntValue()) == Known.Zero) + return N0; + } + if (SDValue Shifts = unfoldExtremeBitClearingToShifts(N)) return Shifts; @@ -5147,6 +5281,40 @@ if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N)) return Tmp; + // (or (and X, C1), (and Y, C2)) -> (and (or X, Y), C3) if possible. + if (N0.getOpcode() == ISD::AND && + N1.getOpcode() == ISD::AND && + N0.getOperand(1).getOpcode() == ISD::Constant && + N1.getOperand(1).getOpcode() == ISD::Constant && + // Don't increase # computations. + (N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) { + // We can only do this xform if we know that bits from X that are set in C2 + // but not in C1 are already zero. Likewise for Y. + const APInt &LHSMask = + cast(N0.getOperand(1))->getAPIntValue(); + const APInt &RHSMask = + cast(N1.getOperand(1))->getAPIntValue(); + + if (DAG.MaskedValueIsZero(N0.getOperand(0), RHSMask&~LHSMask) && + DAG.MaskedValueIsZero(N1.getOperand(0), LHSMask&~RHSMask)) { + SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT, + N0.getOperand(0), N1.getOperand(0)); + return DAG.getNode(ISD::AND, SDLoc(N), VT, X, + DAG.getConstant(LHSMask | RHSMask, SDLoc(N0), VT)); + } + } + + // (or (and X, M), (and X, N)) -> (and X, (or M, N)) + if (N0.getOpcode() == ISD::AND && + N1.getOpcode() == ISD::AND && + N0.getOperand(0) == N1.getOperand(0) && + // Don't increase # computations. + (N0.getNode()->hasOneUse() || N1.getNode()->hasOneUse())) { + SDValue X = DAG.getNode(ISD::OR, SDLoc(N0), VT, + N0.getOperand(1), N1.getOperand(1)); + return DAG.getNode(ISD::AND, SDLoc(N), VT, N0.getOperand(0), X); + } + // See if this is some rotate idiom. if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N))) return SDValue(Rot, 0); @@ -7398,6 +7566,7 @@ EVT MemoryVT = MSC->getMemoryVT(); unsigned Alignment = MSC->getOriginalAlignment(); + bool isTruncStore = MSC->isTruncatingStore(); EVT LoMemVT, HiMemVT; std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); @@ -7417,11 +7586,11 @@ SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo, Scale }; Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(), - DL, OpsLo, MMO); + DL, OpsLo, MMO, isTruncStore, MSC->getIndexType()); SDValue OpsHi[] = { Chain, DataHi, MaskHi, BasePtr, IndexHi, Scale }; Hi = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(), - DL, OpsHi, MMO); + DL, OpsHi, MMO, isTruncStore, MSC->getIndexType()); AddToWorklist(Lo.getNode()); AddToWorklist(Hi.getNode()); @@ -7535,6 +7704,7 @@ SDValue Chain = MGT->getChain(); EVT MemoryVT = MGT->getMemoryVT(); unsigned Alignment = MGT->getOriginalAlignment(); + ISD::LoadExtType ExtType = MGT->getExtensionType(); EVT LoMemVT, HiMemVT; std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); @@ -7552,11 +7722,11 @@ SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo, Scale }; Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, DL, OpsLo, - MMO); + MMO, ExtType, MGT->getIndexType()); SDValue OpsHi[] = { Chain, Src0Hi, MaskHi, BasePtr, IndexHi, Scale }; Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, DL, OpsHi, - MMO); + MMO, ExtType, MGT->getIndexType()); AddToWorklist(Lo.getNode()); AddToWorklist(Hi.getNode()); @@ -8097,6 +8267,11 @@ if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT)) return SDValue(); + // TODO: It seems sensible to support this but the address calculation needs + // updating to make use of DAG.getVScale(). + if (DstVT.isScalableVector()) + return SDValue(); + SDLoc DL(N); const unsigned NumSplits = DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements(); @@ -8558,6 +8733,34 @@ if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N)) return NewVSel; + // fold (sext (series_vector c1, c2)) -> series_vector sext(c1), sext(c2) + if (N0.getOpcode() == ISD::SERIES_VECTOR && + N0->getFlags().hasNoSignedWrap()) { + auto CStart = dyn_cast(N0.getOperand(0)); + auto CStep = dyn_cast(N0.getOperand(1)); + + // Limit to cases where the extension is redundant. + if (CStart && CStep) { + EVT EltVT = VT.getVectorElementType(); + APInt Start = CStart->getAPIntValue(); + APInt Step = CStep->getAPIntValue(); + unsigned OpBitWidth = std::max(EltVT.getSizeInBits(), + Start.getBitWidth()); + + // SERIES_VECTOR operands can be bigger than its result element type. + EVT SrcVT = N0.getValueType().getVectorElementType(); + if (SrcVT.getSizeInBits() < Start.getBitWidth()) { + Start.trunc(SrcVT.getSizeInBits()); + Step.trunc(SrcVT.getSizeInBits()); + } + + EVT OpVT = EVT::getIntegerVT(*DAG.getContext(), OpBitWidth); + SDValue Op1 = DAG.getConstant(Start.sext(OpBitWidth), DL, OpVT); + SDValue Op2 = DAG.getConstant(Step.sext(OpBitWidth), DL, OpVT); + return DAG.getNode(ISD::SERIES_VECTOR, DL, VT, Op1, Op2, N0->getFlags()); + } + } + return SDValue(); } @@ -8660,7 +8863,7 @@ } } - if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) { + if (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::AND, VT)) { SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT); AddToWorklist(Op.getNode()); SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType()); @@ -9311,6 +9514,38 @@ AddToWorklist(ExtLoad.getNode()); return SDValue(N, 0); // Return N so it doesn't get rechecked! } + // fold (sext_inreg (masked_load x)) -> (sext_masked_load x) + if (isa(N0) && + EVT == cast(N0)->getMemoryVT() && + ((!LegalOperations && !cast(N0)->isVolatile()) || + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, EVT))) { + MaskedLoadSDNode *LN0 = cast(N0); + SDValue ExtLoad = DAG.getMaskedLoad(VT, SDLoc(N), LN0->getChain(), + LN0->getBasePtr(), LN0->getMask(), + LN0->getSrc0(), LN0->getMemoryVT(), + LN0->getMemOperand(), ISD::SEXTLOAD); + CombineTo(N, ExtLoad); + CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); + AddToWorklist(ExtLoad.getNode()); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } + // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x) + if (isa(N0) && + EVT == cast(N0)->getMemoryVT() && + ((!LegalOperations && !cast(N0)->isVolatile()) || + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, EVT))) { + MaskedGatherSDNode *LN0 = cast(N0); + SDValue Ops[] = { LN0->getChain(), LN0->getSrc0(), LN0->getMask(), + LN0->getBasePtr(), LN0->getIndex(), LN0->getScale() }; + SDValue ExtLoad = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), + LN0->getMemoryVT(), SDLoc(N), Ops, + LN0->getMemOperand(), ISD::SEXTLOAD, + LN0->getIndexType()); + CombineTo(N, ExtLoad); + CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); + AddToWorklist(ExtLoad.getNode()); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse() && @@ -9421,10 +9656,11 @@ EVT ExTy = N0.getValueType(); EVT TrTy = N->getValueType(0); - unsigned NumElem = VecTy.getVectorNumElements(); + auto EltCnt = VecTy.getVectorElementCount(); unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits(); + auto NewEltCnt = EltCnt * SizeRatio; - EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, SizeRatio * NumElem); + EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt); assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size"); SDValue EltNo = N0->getOperand(1); @@ -10142,7 +10378,7 @@ return SDValue(); const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo(); - if (STI && STI->generateFMAsInMachineCombiner(OptLevel)) + if (STI && STI->generateFMAsInMachineCombiner(DAG, OptLevel)) return SDValue(); // Always prefer FMAD to FMA for precision. @@ -10354,7 +10590,7 @@ return SDValue(); const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo(); - if (STI && STI->generateFMAsInMachineCombiner(OptLevel)) + if (STI && STI->generateFMAsInMachineCombiner(DAG, OptLevel)) return SDValue(); // Always prefer FMAD to FMA for precision. @@ -14033,6 +14269,12 @@ if (!IsConstantSrc && !IsLoadSrc && !IsExtractVecSrc) return false; + // Don't merge vectors into wider vectors if the source data comes from loads. + // TODO: This restriction can be lifted by using logic similar to the + // ExtractVecSrc case. + if (MemVT.isVector() && IsLoadSrc) + return false; + SmallVector StoreNodes; SDNode *RootNode; // Find potential store merge candidates by searching through chain sub-DAG @@ -14992,6 +15234,14 @@ if (SDValue Shuf = combineInsertEltToShuffle(N, Elt)) return Shuf; + if(VT.isScalableVector()) { + // If EltNo is constant zero, InVec is undef, then return a SCALAR_TO_VECTOR + if((Elt == 0) && InVec.getOpcode() == ISD::UNDEF) + return DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, InVal); + // Otherwise, give up for scalable vectors + return SDValue(); + } + // Canonicalize insert_vector_elt dag nodes. // Example: // (insert_vector_elt (insert_vector_elt A, Idx0), Idx1) @@ -15152,7 +15402,8 @@ ConstantSDNode *ConstEltNo = dyn_cast(EltNo); // extract_vector_elt of out-of-bounds element -> UNDEF - if (ConstEltNo && ConstEltNo->getAPIntValue().uge(VT.getVectorNumElements())) + if (!VT.isScalableVector() && ConstEltNo && + ConstEltNo->getAPIntValue().uge(VT.getVectorNumElements())) return DAG.getUNDEF(NVT); // extract_vector_elt (build_vector x, y), 1 -> y @@ -15242,7 +15493,7 @@ // If only EXTRACT_VECTOR_ELT nodes use the source vector we can // simplify it based on the (valid) extraction indices. - if (llvm::all_of(InVec->uses(), [&](SDNode *Use) { + if (!VT.isScalableVector() && llvm::all_of(InVec->uses(), [&](SDNode *Use) { return Use->getOpcode() == ISD::EXTRACT_VECTOR_ELT && Use->getOperand(0) == InVec && isa(Use->getOperand(1)); @@ -15559,6 +15810,10 @@ EVT InVT1 = VecIn1.getValueType(); EVT InVT2 = VecIn2.getNode() ? VecIn2.getValueType() : InVT1; + if (InVT1.isScalableVector() || InVT2.isScalableVector() || + VT.isScalableVector()) + return SDValue(); + unsigned Vec2Offset = 0; unsigned NumElems = VT.getVectorNumElements(); unsigned ShuffleNumElems = NumElems; @@ -17266,46 +17521,23 @@ EVT VT = N->getValueType(0); // Replace a SCALAR_TO_VECTOR(EXTRACT_VECTOR_ELT(V,C0)) pattern - // with a VECTOR_SHUFFLE and possible truncate. + // with a VECTOR_SHUFFLE. if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT) { SDValue InVec = InVal->getOperand(0); SDValue EltNo = InVal->getOperand(1); - auto InVecT = InVec.getValueType(); - if (ConstantSDNode *C0 = dyn_cast(EltNo)) { - SmallVector NewMask(InVecT.getVectorNumElements(), -1); + + // FIXME: We could support implicit truncation if the shuffle can be + // scaled to a smaller vector scalar type. + ConstantSDNode *C0 = dyn_cast(EltNo); + if (C0 && VT == InVec.getValueType() && + VT.getScalarType() == InVal.getValueType()) { + SmallVector NewMask(VT.getVectorNumElements(), -1); int Elt = C0->getZExtValue(); NewMask[0] = Elt; - SDValue Val; - // If we have an implict truncate do truncate here as long as it's legal. - // if it's not legal, this should - if (VT.getScalarType() != InVal.getValueType() && - InVal.getValueType().isScalarInteger() && - isTypeLegal(VT.getScalarType())) { - Val = - DAG.getNode(ISD::TRUNCATE, SDLoc(InVal), VT.getScalarType(), InVal); - return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val); - } - if (VT.getScalarType() == InVecT.getScalarType() && - VT.getVectorNumElements() <= InVecT.getVectorNumElements() && - TLI.isShuffleMaskLegal(NewMask, VT)) { - Val = DAG.getVectorShuffle(InVecT, SDLoc(N), InVec, - DAG.getUNDEF(InVecT), NewMask); - // If the initial vector is the correct size this shuffle is a - // valid result. - if (VT == InVecT) - return Val; - // If not we must truncate the vector. - if (VT.getVectorNumElements() != InVecT.getVectorNumElements()) { - MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout()); - SDValue ZeroIdx = DAG.getConstant(0, SDLoc(N), IdxTy); - EVT SubVT = - EVT::getVectorVT(*DAG.getContext(), InVecT.getVectorElementType(), - VT.getVectorNumElements()); - Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT, Val, - ZeroIdx); - return Val; - } - } + + if (TLI.isShuffleMaskLegal(NewMask, VT)) + return DAG.getVectorShuffle(VT, SDLoc(N), InVec, DAG.getUNDEF(VT), + NewMask); } } @@ -17438,6 +17670,24 @@ return SDValue(); } +SDValue DAGCombiner::visitVECREDUCE(SDNode *N) { + SDValue N0 = N->getOperand(0); + EVT VT = N0.getValueType(); + + // VECREDUCE over 1-element vector is just an extract. + if (VT.getVectorNumElements() == 1) { + SDLoc dl(N); + SDValue Res = DAG.getNode( + ISD::EXTRACT_VECTOR_ELT, dl, VT.getVectorElementType(), N0, + DAG.getConstant(0, dl, TLI.getVectorIdxTy(DAG.getDataLayout()))); + if (Res.getValueType() != N->getValueType(0)) + Res = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Res); + return Res; + } + + return SDValue(); +} + /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle /// with the destination vector and a zero vector. /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==> Index: lib/CodeGen/SelectionDAG/FunctionLoweringInfo.cpp =================================================================== --- lib/CodeGen/SelectionDAG/FunctionLoweringInfo.cpp +++ lib/CodeGen/SelectionDAG/FunctionLoweringInfo.cpp @@ -153,11 +153,15 @@ FrameIndex = MF->getFrameInfo().CreateFixedObject( TySize, 0, /*Immutable=*/false, /*isAliased=*/true); MF->getFrameInfo().setObjectAlignment(FrameIndex, Align); - } else { + } else FrameIndex = MF->getFrameInfo().CreateStackObject(TySize, Align, false, AI); - } + // Let the target determine to which StackID this type should be + // allocated. + MF->getFrameInfo().setStackID(FrameIndex, TFI->getStackIDForType(Ty)); + + // Support for array like structs. StaticAllocaMap[AI] = FrameIndex; // Update the catch handler information. if (Iter != CatchObjects.end()) { Index: lib/CodeGen/SelectionDAG/LegalizeDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -353,6 +353,9 @@ SDValue Tmp2 = Val; SDValue Tmp3 = Idx; + assert(!Vec.getValueType().isScalableVector() && + "This code does not yet implement VL scaled spills/fills"); + // If the target doesn't support this, we have to spill the input vector // to a temporary stack slot, update the element, then reload it. This is // badness. We could also load the value into a vector register (either @@ -1090,6 +1093,9 @@ return; } break; + case ISD::VSCALE: + Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0)); + break; case ISD::STRICT_FADD: case ISD::STRICT_FSUB: case ISD::STRICT_FMUL: @@ -1114,6 +1120,30 @@ Action = TLI.getStrictFPOperationAction(Node->getOpcode(), Node->getValueType(0)); break; + case ISD::MSCATTER: + Action = TLI.getOperationAction(Node->getOpcode(), + cast(Node)->getValue().getValueType()); + break; + case ISD::MSTORE: + Action = TLI.getOperationAction(Node->getOpcode(), + cast(Node)->getValue().getValueType()); + break; + case ISD::VECREDUCE_FADD: + case ISD::VECREDUCE_FMUL: + case ISD::VECREDUCE_ADD: + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: + case ISD::VECREDUCE_SMAX: + case ISD::VECREDUCE_SMIN: + case ISD::VECREDUCE_UMAX: + case ISD::VECREDUCE_UMIN: + case ISD::VECREDUCE_FMAX: + case ISD::VECREDUCE_FMIN: + Action = TLI.getOperationAction( + Node->getOpcode(), Node->getOperand(0).getValueType()); + break; default: if (Node->getOpcode() >= ISD::BUILTIN_OP_END) { Action = TargetLowering::Legal; @@ -1233,6 +1263,7 @@ } SDValue SelectionDAGLegalize::ExpandExtractFromVectorThroughStack(SDValue Op) { + assert(!Op.getValueType().isScalableVector() && "WA not yet supported!"); SDValue Vec = Op.getOperand(0); SDValue Idx = Op.getOperand(1); SDLoc dl(Op); @@ -1313,6 +1344,7 @@ SDValue SelectionDAGLegalize::ExpandInsertToVectorThroughStack(SDValue Op) { assert(Op.getValueType().isVector() && "Non-vector insert subvector!"); + assert(!Op.getValueType().isScalableVector() && "WA not yet supported!"); SDValue Vec = Op.getOperand(0); SDValue Part = Op.getOperand(1); @@ -1340,6 +1372,8 @@ } SDValue SelectionDAGLegalize::ExpandVectorBuildThroughStack(SDNode* Node) { + assert(!Node->getValueType(0).isScalableVector() && + "WA not yet supported!"); // We can't handle this case efficiently. Allocate a sufficiently // aligned object on the stack, store each element into it, then load // the result as a vector. @@ -3939,6 +3973,21 @@ break; } + case ISD::VECREDUCE_FADD: + case ISD::VECREDUCE_FMUL: + case ISD::VECREDUCE_ADD: + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: + case ISD::VECREDUCE_SMAX: + case ISD::VECREDUCE_SMIN: + case ISD::VECREDUCE_UMAX: + case ISD::VECREDUCE_UMIN: + case ISD::VECREDUCE_FMAX: + case ISD::VECREDUCE_FMIN: + Results.push_back(TLI.expandVecReduce(Node, DAG)); + break; case ISD::GLOBAL_OFFSET_TABLE: case ISD::GlobalAddress: case ISD::GlobalTLSAddress: @@ -4597,6 +4646,35 @@ MVT EltVT = OVT.getVectorElementType(); MVT NewEltVT = NVT.getVectorElementType(); + SDValue Vec = Node->getOperand(0); + SDValue Idx = Node->getOperand(1); + SDLoc SL(Node); + + if (OVT.getSizeInBits() != NVT.getSizeInBits()) { + assert(NVT.isVector() && + OVT.getVectorNumElements() == NVT.getVectorNumElements() && + "Invalid promote type for extract."); + assert(NewEltVT.bitsGT(EltVT) && "Expected promoted element type."); + + Vec = DAG.getNode(ISD::ZERO_EXTEND, SL, NVT, Vec); + + // Result type is type of this node. If the element we extract is wider, + // we need to truncate it. In the opposite case, e.g. extracting an i8, + // we leave the zero_extend to the instruction itself. + EVT ResVT = Node->getValueType(0); + if (NewEltVT.bitsGT(ResVT.getSimpleVT())) { + SDValue Ex = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, + NewEltVT, Vec, Idx); + SDValue Res = DAG.getNode(ISD::TRUNCATE, SL, ResVT, Ex); + Results.push_back(Res); + } else { + SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ResVT, Vec, Idx); + Results.push_back(Res); + } + + break; + } + // Handle bitcasts to a different vector type with the same total bit size. // // e.g. v2i64 = extract_vector_elt x:v2i64, y:i32 @@ -4615,13 +4693,11 @@ MVT MidVT = getPromotedVectorElementType(TLI, EltVT, NewEltVT); unsigned NewEltsPerOldElt = MidVT.getVectorNumElements(); - SDValue Idx = Node->getOperand(1); EVT IdxVT = Idx.getValueType(); - SDLoc SL(Node); SDValue Factor = DAG.getConstant(NewEltsPerOldElt, SL, IdxVT); SDValue NewBaseIdx = DAG.getNode(ISD::MUL, SL, IdxVT, Idx, Factor); - SDValue CastVec = DAG.getNode(ISD::BITCAST, SL, NVT, Node->getOperand(0)); + SDValue CastVec = DAG.getNode(ISD::BITCAST, SL, NVT, Vec); SmallVector NewOps; for (unsigned I = 0; I < NewEltsPerOldElt; ++I) { @@ -4641,6 +4717,26 @@ MVT EltVT = OVT.getVectorElementType(); MVT NewEltVT = NVT.getVectorElementType(); + SDValue Vec = Node->getOperand(0); + SDValue Val = Node->getOperand(1); + SDValue Idx = Node->getOperand(2); + SDLoc SL(Node); + + if (OVT.getSizeInBits() != NVT.getSizeInBits()) { + assert(NVT.isVector() && + OVT.getVectorNumElements() == NVT.getVectorNumElements() && + "Invalid promote type for insert_vector_elt."); + assert(NewEltVT.bitsGT(EltVT) && "Expected promoted element type."); + + Vec = DAG.getNode(ISD::ANY_EXTEND, SL, NVT, Vec); + if (Val.getValueType().bitsLT(NewEltVT)) + Val = DAG.getNode(ISD::ANY_EXTEND, SL, NewEltVT, Val); + + SDValue Ins = DAG.getNode(ISD::INSERT_VECTOR_ELT, SL, NVT, Vec, Val, Idx); + Results.push_back(DAG.getNode(ISD::TRUNCATE, SL, OVT, Ins)); + break; + } + // Handle bitcasts to a different vector type with the same total bit size // // e.g. v2i64 = insert_vector_elt x:v2i64, y:i64, z:i32 @@ -4661,15 +4757,11 @@ MVT MidVT = getPromotedVectorElementType(TLI, EltVT, NewEltVT); unsigned NewEltsPerOldElt = MidVT.getVectorNumElements(); - SDValue Val = Node->getOperand(1); - SDValue Idx = Node->getOperand(2); EVT IdxVT = Idx.getValueType(); - SDLoc SL(Node); - SDValue Factor = DAG.getConstant(NewEltsPerOldElt, SDLoc(), IdxVT); SDValue NewBaseIdx = DAG.getNode(ISD::MUL, SL, IdxVT, Idx, Factor); - SDValue CastVec = DAG.getNode(ISD::BITCAST, SL, NVT, Node->getOperand(0)); + SDValue CastVec = DAG.getNode(ISD::BITCAST, SL, NVT, Vec); SDValue CastVal = DAG.getNode(ISD::BITCAST, SL, MidVT, Val); SDValue NewVec = CastVec; Index: lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -90,17 +90,26 @@ case ISD::TRUNCATE: Res = PromoteIntRes_TRUNCATE(N); break; case ISD::UNDEF: Res = PromoteIntRes_UNDEF(N); break; case ISD::VAARG: Res = PromoteIntRes_VAARG(N); break; + case ISD::VSCALE: Res = PromoteIntRes_VSCALE(N); break; + case ISD::INSERT_SUBVECTOR: + Res = PromoteIntRes_INSERT_SUBVECTOR(N); break; case ISD::EXTRACT_SUBVECTOR: Res = PromoteIntRes_EXTRACT_SUBVECTOR(N); break; case ISD::VECTOR_SHUFFLE: Res = PromoteIntRes_VECTOR_SHUFFLE(N); break; + case ISD::VECTOR_SHUFFLE_VAR: + Res = PromoteInt_VECTOR_SHUFFLE_VAR(N); break; + case ISD::SERIES_VECTOR: + Res = PromoteInt_SERIES_VECTOR(N); break; case ISD::INSERT_VECTOR_ELT: Res = PromoteIntRes_INSERT_VECTOR_ELT(N); break; case ISD::BUILD_VECTOR: Res = PromoteIntRes_BUILD_VECTOR(N); break; case ISD::SCALAR_TO_VECTOR: Res = PromoteIntRes_SCALAR_TO_VECTOR(N); break; + case ISD::SPLAT_VECTOR: + Res = PromoteIntRes_SPLAT_VECTOR(N); break; case ISD::CONCAT_VECTORS: Res = PromoteIntRes_CONCAT_VECTORS(N); break; @@ -162,6 +171,18 @@ case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS: Res = PromoteIntRes_AtomicCmpSwap(cast(N), ResNo); break; + + case ISD::VECREDUCE_ADD: + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: + case ISD::VECREDUCE_SMAX: + case ISD::VECREDUCE_SMIN: + case ISD::VECREDUCE_UMAX: + case ISD::VECREDUCE_UMIN: + Res = PromoteIntRes_VECREDUCE(N); + break; } // If the result is null then the sub-method took care of registering it. @@ -485,10 +506,13 @@ EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); SDValue ExtSrc0 = GetPromotedInteger(N->getSrc0()); + ISD::LoadExtType ExtType = N->getExtensionType(); + if (ExtType == ISD::NON_EXTLOAD) + ExtType = ISD::SEXTLOAD; SDLoc dl(N); SDValue Res = DAG.getMaskedLoad(NVT, dl, N->getChain(), N->getBasePtr(), N->getMask(), ExtSrc0, N->getMemoryVT(), - N->getMemOperand(), ISD::SEXTLOAD); + N->getMemOperand(), ExtType); // Legalize the chain result - switch anything that used the old chain to // use the new one. ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); @@ -501,12 +525,22 @@ assert(NVT == ExtSrc0.getValueType() && "Gather result type and the passThru agrument type should be the same"); + ISD::LoadExtType ExtType = N->getExtensionType(); + if (ExtType == ISD::NON_EXTLOAD) + ExtType = ISD::SEXTLOAD; + + SDValue Mask = N->getMask(); + EVT NewMaskVT = getSetCCResultType(NVT); + if (NewMaskVT != N->getMask().getValueType()) + Mask = PromoteTargetBoolean(Mask, NewMaskVT); SDLoc dl(N); SDValue Ops[] = {N->getChain(), ExtSrc0, N->getMask(), N->getBasePtr(), N->getIndex(), N->getScale() }; + SDValue Res = DAG.getMaskedGather(DAG.getVTList(NVT, MVT::Other), N->getMemoryVT(), dl, Ops, - N->getMemOperand()); + N->getMemOperand(), ExtType, + N->getIndexType()); // Legalize the chain result - switch anything that used the old chain to // use the new one. ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); @@ -694,17 +728,17 @@ case TargetLowering::TypeSplitVector: { EVT InVT = InOp.getValueType(); assert(InVT.isVector() && "Cannot split scalar types"); - unsigned NumElts = InVT.getVectorNumElements(); - assert(NumElts == NVT.getVectorNumElements() && + auto EltCnt = InVT.getVectorElementCount(); + assert(EltCnt == NVT.getVectorElementCount() && "Dst and Src must have the same number of elements"); - assert(isPowerOf2_32(NumElts) && + assert(isPowerOf2_32(EltCnt.Min) && "Promoted vector type must be a power of two"); SDValue EOp1, EOp2; GetSplitVector(InOp, EOp1, EOp2); EVT HalfNVT = EVT::getVectorVT(*DAG.getContext(), NVT.getScalarType(), - NumElts/2); + EltCnt/2); EOp1 = DAG.getNode(ISD::TRUNCATE, dl, HalfNVT, EOp1); EOp2 = DAG.getNode(ISD::TRUNCATE, dl, HalfNVT, EOp2); @@ -849,6 +883,13 @@ N->getValueType(0))); } +SDValue DAGTypeLegalizer::PromoteIntRes_VSCALE(SDNode *N) { + EVT VT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); + + int64_t MulImm = cast(N->getOperand(0))->getSExtValue(); + return DAG.getVScale(SDLoc(N), VT, MulImm); +} + SDValue DAGTypeLegalizer::PromoteIntRes_VAARG(SDNode *N) { SDValue Chain = N->getOperand(0); // Get the chain. SDValue Ptr = N->getOperand(1); // Get the pointer. @@ -930,6 +971,12 @@ Res = PromoteIntOp_INSERT_VECTOR_ELT(N, OpNo);break; case ISD::SCALAR_TO_VECTOR: Res = PromoteIntOp_SCALAR_TO_VECTOR(N); break; + case ISD::SPLAT_VECTOR: + Res = PromoteIntOp_SPLAT_VECTOR(N); break; + case ISD::VECTOR_SHUFFLE_VAR: + Res = PromoteInt_VECTOR_SHUFFLE_VAR(N); break; + case ISD::SERIES_VECTOR: + Res = PromoteInt_SERIES_VECTOR(N); break; case ISD::VSELECT: case ISD::SELECT: Res = PromoteIntOp_SELECT(N, OpNo); break; case ISD::SELECT_CC: Res = PromoteIntOp_SELECT_CC(N, OpNo); break; @@ -960,6 +1007,16 @@ case ISD::ADDCARRY: case ISD::SUBCARRY: Res = PromoteIntOp_ADDSUBCARRY(N, OpNo); break; + + case ISD::VECREDUCE_ADD: + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: + case ISD::VECREDUCE_SMAX: + case ISD::VECREDUCE_SMIN: + case ISD::VECREDUCE_UMAX: + case ISD::VECREDUCE_UMIN: Res = PromoteIntOp_VECREDUCE(N); break; } // If the result is null, the sub-method took care of registering results etc. @@ -1140,6 +1197,11 @@ GetPromotedInteger(N->getOperand(0))), 0); } +SDValue DAGTypeLegalizer::PromoteIntOp_SPLAT_VECTOR(SDNode *N) { + return SDValue(DAG.UpdateNodeOperands(N, + GetPromotedInteger(N->getOperand(0))), 0); +} + SDValue DAGTypeLegalizer::PromoteIntOp_SELECT(SDNode *N, unsigned OpNo) { assert(OpNo == 0 && "Only know how to promote the condition!"); SDValue Cond = N->getOperand(0); @@ -1211,14 +1273,14 @@ } SDValue DAGTypeLegalizer::PromoteIntOp_MSTORE(MaskedStoreSDNode *N, - unsigned OpNo) { - + unsigned OpNo){ + assert((OpNo != 1) && "Not expecting to promote the base pointer!"); SDValue DataOp = N->getValue(); EVT DataVT = DataOp.getValueType(); SDValue Mask = N->getMask(); SDLoc dl(N); - bool TruncateStore = false; + bool TruncateStore = N->isTruncatingStore(); if (OpNo == 2) { // Mask comes before the data operand. If the data operand is legal, we just // promote the mask. @@ -1250,16 +1312,6 @@ TruncateStore, N->isCompressingStore()); } -SDValue DAGTypeLegalizer::PromoteIntOp_MLOAD(MaskedLoadSDNode *N, - unsigned OpNo) { - assert(OpNo == 2 && "Only know how to promote the mask!"); - EVT DataVT = N->getValueType(0); - SDValue Mask = PromoteTargetBoolean(N->getOperand(OpNo), DataVT); - SmallVector NewOps(N->op_begin(), N->op_end()); - NewOps[OpNo] = Mask; - return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0); -} - SDValue DAGTypeLegalizer::PromoteIntOp_MGATHER(MaskedGatherSDNode *N, unsigned OpNo) { @@ -1286,6 +1338,7 @@ SDValue DAGTypeLegalizer::PromoteIntOp_MSCATTER(MaskedScatterSDNode *N, unsigned OpNo) { + bool TruncateStore = N->isTruncatingStore(); SmallVector NewOps(N->op_begin(), N->op_end()); if (OpNo == 2) { // The Mask @@ -1294,8 +1347,23 @@ } else if (OpNo == 4) { // Need to sign extend the index since the bits will likely be used. NewOps[OpNo] = SExtPromotedInteger(N->getOperand(OpNo)); - } else + } else { NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo)); + TruncateStore = true; + } + + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), N->getMemoryVT(), + SDLoc(N), NewOps, N->getMemOperand(), + TruncateStore, N->getIndexType()); +} + +SDValue DAGTypeLegalizer::PromoteIntOp_MLOAD(MaskedLoadSDNode *N, + unsigned OpNo) { + assert(OpNo == 2 && "Only know how to promote the mask!"); + EVT DataVT = N->getValueType(0); + SDValue Mask = PromoteTargetBoolean(N->getOperand(OpNo), DataVT); + SmallVector NewOps(N->op_begin(), N->op_end()); + NewOps[OpNo] = Mask; return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0); } @@ -1342,6 +1410,39 @@ return SDValue(DAG.UpdateNodeOperands(N, LHS, RHS, Carry), 0); } +SDValue DAGTypeLegalizer::PromoteIntOp_VECREDUCE(SDNode *N) { + SDLoc dl(N); + SDValue Op; + switch (N->getOpcode()) { + default: llvm_unreachable("Expected integer vector reduction"); + case ISD::VECREDUCE_ADD: + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: + Op = GetPromotedInteger(N->getOperand(0)); + break; + case ISD::VECREDUCE_SMAX: + case ISD::VECREDUCE_SMIN: + Op = SExtPromotedInteger(N->getOperand(0)); + break; + case ISD::VECREDUCE_UMAX: + case ISD::VECREDUCE_UMIN: + Op = ZExtPromotedInteger(N->getOperand(0)); + break; + } + + EVT EltVT = Op.getValueType().getVectorElementType(); + EVT VT = N->getValueType(0); + if (VT.bitsGE(EltVT)) + return DAG.getNode(N->getOpcode(), SDLoc(N), VT, Op); + + // Result size must be >= element size. If this is not the case after + // promotion, also promote the result type and then truncate. + SDValue Reduce = DAG.getNode(N->getOpcode(), dl, EltVT, Op); + return DAG.getNode(ISD::TRUNCATE, dl, VT, Reduce); +} + //===----------------------------------------------------------------------===// // Integer Result Expansion //===----------------------------------------------------------------------===// @@ -1475,6 +1576,16 @@ case ISD::USUBO: ExpandIntRes_UADDSUBO(N, Lo, Hi); break; case ISD::UMULO: case ISD::SMULO: ExpandIntRes_XMULO(N, Lo, Hi); break; + + case ISD::VECREDUCE_ADD: + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: + case ISD::VECREDUCE_SMAX: + case ISD::VECREDUCE_SMIN: + case ISD::VECREDUCE_UMAX: + case ISD::VECREDUCE_UMIN: ExpandIntRes_VECREDUCE(N, Lo, Hi); break; } // If Lo/Hi is null, the sub-method took care of registering results etc. @@ -2881,6 +2992,14 @@ ReplaceValueWith(SDValue(N, 1), Swap.getValue(2)); } +void DAGTypeLegalizer::ExpandIntRes_VECREDUCE(SDNode *N, + SDValue &Lo, SDValue &Hi) { + // TODO For VECREDUCE_(AND|OR|XOR) we could split the vector and calculate + // both halves independently. + SDValue Res = TLI.expandVecReduce(N, DAG); + SplitInteger(Res, Lo, Hi); +} + //===----------------------------------------------------------------------===// // Integer Operand Expansion //===----------------------------------------------------------------------===// @@ -3367,6 +3486,38 @@ } +SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_SUBVECTOR(SDNode *N) { + SDLoc dl(N); + // The first operand has the same type as the result, so must also be + // promoted. The second operand is narrower and so is not necessarily + // handled in the same way. First deal with the simple case in which + // both vectors are promoted. + SDValue Op0 = GetPromotedInteger(N->getOperand(0)); + SDValue Op1 = N->getOperand(1); + if (getTypeAction(Op1.getValueType()) == TargetLowering::TypePromoteInteger) { + Op1 = GetPromotedInteger(Op1); + return DAG.getNode(ISD::INSERT_SUBVECTOR, dl, Op0.getValueType(), + Op0, Op1, N->getOperand(2)); + } + // Handle the more difficult case of a subvector that isn't legalized + // by promotion. The safest fallback is to promote each element + // individually, as for EXTRACT_SUBVECTOR below. + EVT EltVT0 = Op0.getValueType().getVectorElementType(); + EVT EltVT1 = Op1.getValueType().getVectorElementType(); + SDValue BaseIdx = N->getOperand(2); + unsigned NumElts = Op1.getValueType().getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + SDValue Index = DAG.getNode(ISD::ADD, dl, BaseIdx.getValueType(), BaseIdx, + DAG.getConstant(i, dl, BaseIdx.getValueType())); + SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT1, Op1, Index); + Ext = DAG.getNode(ISD::ANY_EXTEND, dl, EltVT0, Ext); + Op0 = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, Op0.getValueType(), + Op0, Ext, Index); + } + return Op0; +} + + SDValue DAGTypeLegalizer::PromoteIntRes_EXTRACT_SUBVECTOR(SDNode *N) { SDValue InOp0 = N->getOperand(0); EVT InVT = InOp0.getValueType(); @@ -3380,6 +3531,24 @@ SDLoc dl(N); SDValue BaseIdx = N->getOperand(1); + // NOTE: There is nothing particularly "scalable" about this code but it + // causes pain when merging because many of the tests assume BUILD_VECTOR will + // be used to create a whole new vector. For now it is less work to make this + // change SVE only than to continually fix up the tests. + if (OutVT.isScalableVector()) { + // Promote operands and see if this is handled by target lowering, + // Otherwise, use the BUILD_VECTOR approach below + if (getTypeAction(InVT) == TargetLowering::TypePromoteInteger) { + // Collect the (promoted) operands + SDValue Ops[] = { GetPromotedInteger(InOp0), BaseIdx }; + + EVT PromEltVT = Ops[0].getValueType().getVectorElementType(); + EVT ExtVT = NOutVT.changeVectorElementType(PromEltVT); + SDValue Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), ExtVT, Ops); + return DAG.getNode(ISD::ANY_EXTEND, dl, NOutVT, Ext); + } + } + SmallVector Ops; Ops.reserve(OutNumElems); for (unsigned i = 0; i != OutNumElems; ++i) { @@ -3414,6 +3583,42 @@ } +// Handle VECTOR_SHUFFLE_VARs in which the output and/or mask need promotion. +// The first two inputs have the same type as the output and so require +// promotion iff the output does. +SDValue DAGTypeLegalizer::PromoteInt_VECTOR_SHUFFLE_VAR(SDNode *N) { + SDLoc dl(N); + + SDValue V0 = N->getOperand(0); + SDValue V1 = N->getOperand(1); + SDValue Mask = N->getOperand(2); + if (getTypeAction(V0.getValueType()) == TargetLowering::TypePromoteInteger) { + V0 = GetPromotedInteger(V0); + V1 = GetPromotedInteger(V1); + } + if (getTypeAction(Mask.getValueType()) == TargetLowering::TypePromoteInteger) + Mask = ZExtPromotedInteger(Mask); + return DAG.getNode(ISD::VECTOR_SHUFFLE_VAR, dl, V0.getValueType(), + V0, V1, Mask); +} + +SDValue DAGTypeLegalizer::PromoteInt_SERIES_VECTOR(SDNode *N) { + SDLoc dl(N); + + SDValue Initial = N->getOperand(0); + SDValue Step = N->getOperand(1); + EVT VT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); + + assert(Initial.getValueType() == Step.getValueType()); + + if (getTypeAction(Initial.getValueType()) == TargetLowering::TypePromoteInteger) { + Initial = GetPromotedInteger(Initial); + Step = GetPromotedInteger(Step); + } + + return DAG.getNode(ISD::SERIES_VECTOR, dl, VT, Initial, Step); +} + SDValue DAGTypeLegalizer::PromoteIntRes_BUILD_VECTOR(SDNode *N) { EVT OutVT = N->getValueType(0); EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT); @@ -3458,6 +3663,23 @@ return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, NOutVT, Op); } +SDValue DAGTypeLegalizer::PromoteIntRes_SPLAT_VECTOR(SDNode *N) { + SDLoc dl(N); + + SDValue SplatVal = N->getOperand(0); + + assert(!SplatVal.getValueType().isVector() && "Input must be a scalar"); + + EVT OutVT = N->getValueType(0); + EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT); + assert(NOutVT.isVector() && "Type must be promoted to a vector type"); + EVT NOutElemVT = NOutVT.getVectorElementType(); + + SDValue Op = DAG.getNode(ISD::ANY_EXTEND, dl, NOutElemVT, SplatVal); + + return DAG.getNode(ISD::SPLAT_VECTOR, dl, NOutVT, Op); +} + SDValue DAGTypeLegalizer::PromoteIntRes_CONCAT_VECTORS(SDNode *N) { SDLoc dl(N); @@ -3545,6 +3767,14 @@ V0, ConvElem, N->getOperand(2)); } +SDValue DAGTypeLegalizer::PromoteIntRes_VECREDUCE(SDNode *N) { + // The VECREDUCE result size may be larger than the element size, so + // we can simply change the result type. + SDLoc dl(N); + EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); + return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0)); +} + SDValue DAGTypeLegalizer::PromoteIntOp_EXTRACT_VECTOR_ELT(SDNode *N) { SDLoc dl(N); SDValue V0 = GetPromotedInteger(N->getOperand(0)); Index: lib/CodeGen/SelectionDAG/LegalizeTypes.h =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -289,10 +289,12 @@ SDValue PromoteIntRes_Atomic0(AtomicSDNode *N); SDValue PromoteIntRes_Atomic1(AtomicSDNode *N); SDValue PromoteIntRes_AtomicCmpSwap(AtomicSDNode *N, unsigned ResNo); + SDValue PromoteIntRes_INSERT_SUBVECTOR(SDNode *N); SDValue PromoteIntRes_EXTRACT_SUBVECTOR(SDNode *N); SDValue PromoteIntRes_VECTOR_SHUFFLE(SDNode *N); SDValue PromoteIntRes_BUILD_VECTOR(SDNode *N); SDValue PromoteIntRes_SCALAR_TO_VECTOR(SDNode *N); + SDValue PromoteIntRes_SPLAT_VECTOR(SDNode *N); SDValue PromoteIntRes_EXTEND_VECTOR_INREG(SDNode *N); SDValue PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N); SDValue PromoteIntRes_CONCAT_VECTORS(SDNode *N); @@ -329,7 +331,12 @@ SDValue PromoteIntRes_ADDSUBCARRY(SDNode *N, unsigned ResNo); SDValue PromoteIntRes_UNDEF(SDNode *N); SDValue PromoteIntRes_VAARG(SDNode *N); + SDValue PromoteIntRes_VSCALE(SDNode *N); SDValue PromoteIntRes_XMULO(SDNode *N, unsigned ResNo); + // SDValue PromoteIntRes_ADDSUBSAT(SDNode *N); + // SDValue PromoteIntRes_MULFIX(SDNode *N); + // SDValue PromoteIntRes_FLT_ROUNDS(SDNode *N); + SDValue PromoteIntRes_VECREDUCE(SDNode *N); // Integer Operand Promotion. bool PromoteIntegerOperand(SDNode *N, unsigned OpNo); @@ -345,6 +352,7 @@ SDValue PromoteIntOp_EXTRACT_SUBVECTOR(SDNode *N); SDValue PromoteIntOp_CONCAT_VECTORS(SDNode *N); SDValue PromoteIntOp_SCALAR_TO_VECTOR(SDNode *N); + SDValue PromoteIntOp_SPLAT_VECTOR(SDNode *N); SDValue PromoteIntOp_SELECT(SDNode *N, unsigned OpNo); SDValue PromoteIntOp_SELECT_CC(SDNode *N, unsigned OpNo); SDValue PromoteIntOp_SETCC(SDNode *N, unsigned OpNo); @@ -361,6 +369,10 @@ SDValue PromoteIntOp_MGATHER(MaskedGatherSDNode *N, unsigned OpNo); SDValue PromoteIntOp_ADDSUBCARRY(SDNode *N, unsigned OpNo); + SDValue PromoteInt_VECTOR_SHUFFLE_VAR(SDNode *N); + SDValue PromoteInt_SERIES_VECTOR(SDNode *N); + SDValue PromoteIntOp_VECREDUCE(SDNode *N); + void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code); //===--------------------------------------------------------------------===// @@ -416,6 +428,7 @@ void ExpandIntRes_XMULO (SDNode *N, SDValue &Lo, SDValue &Hi); void ExpandIntRes_ATOMIC_LOAD (SDNode *N, SDValue &Lo, SDValue &Hi); + void ExpandIntRes_VECREDUCE (SDNode *N, SDValue &Lo, SDValue &Hi); void ExpandShiftByConstant(SDNode *N, const APInt &Amt, SDValue &Lo, SDValue &Hi); @@ -678,6 +691,9 @@ SDValue ScalarizeVecOp_VSETCC(SDNode *N); SDValue ScalarizeVecOp_STORE(StoreSDNode *N, unsigned OpNo); SDValue ScalarizeVecOp_FP_ROUND(SDNode *N, unsigned OpNo); + SDValue ScalarizeVecOp_VECREDUCE(SDNode *N); + + SDValue ScalarizeVec_VECTOR_SHUFFLE_VAR(SDNode *N); //===--------------------------------------------------------------------===// // Vector Splitting Support: LegalizeVectorTypes.cpp @@ -715,9 +731,12 @@ void SplitVecRes_MLOAD(MaskedLoadSDNode *MLD, SDValue &Lo, SDValue &Hi); void SplitVecRes_MGATHER(MaskedGatherSDNode *MGT, SDValue &Lo, SDValue &Hi); void SplitVecRes_SCALAR_TO_VECTOR(SDNode *N, SDValue &Lo, SDValue &Hi); + void SplitVecRes_SERIES_VECTOR(SDNode *N, SDValue &Lo, SDValue &Hi); + void SplitVecRes_SPLAT_VECTOR(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_VECTOR_SHUFFLE(ShuffleVectorSDNode *N, SDValue &Lo, SDValue &Hi); + void SplitVecRes_VECTOR_SHUFFLE_VAR(SDNode *N, SDValue &Lo, SDValue &Hi); // Vector Operand Splitting: <128 x ty> -> 2 x <64 x ty>. bool SplitVectorOperand(SDNode *N, unsigned OpNo); @@ -735,6 +754,7 @@ SDValue SplitVecOp_MSCATTER(MaskedScatterSDNode *N, unsigned OpNo); SDValue SplitVecOp_MGATHER(MaskedGatherSDNode *MGT, unsigned OpNo); SDValue SplitVecOp_CONCAT_VECTORS(SDNode *N); + SDValue SplitVecOp_TRUNCATE(SDNode *N); SDValue SplitVecOp_VSETCC(SDNode *N); SDValue SplitVecOp_FP_ROUND(SDNode *N); SDValue SplitVecOp_FCOPYSIGN(SDNode *N); @@ -801,6 +821,7 @@ SDValue WidenVecOp_Convert(SDNode *N); SDValue WidenVecOp_FCOPYSIGN(SDNode *N); + SDValue WidenVecOp_VECREDUCE(SDNode *N); //===--------------------------------------------------------------------===// // Vector Widening Utilities Support: LegalizeVectorTypes.cpp Index: lib/CodeGen/SelectionDAG/LegalizeTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeTypes.cpp @@ -19,6 +19,7 @@ #include "llvm/CodeGen/MachineFunction.h" #include "llvm/IR/CallingConv.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -833,6 +834,8 @@ Op.getValueType().getVectorElementType() && 2*Lo.getValueType().getVectorNumElements() == Op.getValueType().getVectorNumElements() && + Op.getValueType().isScalableVector() == + Lo.getValueType().isScalableVector() && Hi.getValueType() == Lo.getValueType() && "Invalid type for split vector"); // Lo/Hi may have been newly allocated, if so, add nodeid's as relevant. @@ -881,6 +884,8 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op, EVT DestVT) { + assert(!DestVT.isScalableVector() && "Can't store-load scalable vector"); + SDLoc dl(Op); // Create the stack frame object. Make sure it is aligned for both // the source and destination types. Index: lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -130,6 +130,8 @@ SDValue ExpandBITREVERSE(SDValue Op); SDValue ExpandCTLZ(SDValue Op); SDValue ExpandCTTZ_ZERO_UNDEF(SDValue Op); + SDValue ExpandVECTOR_SHUFFLE_VAR(SDValue Op); + SDValue ExpandStrictFPOp(SDValue Op); /// Implements vector promotion. @@ -283,11 +285,13 @@ } else if (Op.getOpcode() == ISD::MSCATTER || Op.getOpcode() == ISD::MSTORE) HasVectorValue = true; - for (SDNode::value_iterator J = Node->value_begin(), E = Node->value_end(); - J != E; - ++J) - HasVectorValue |= J->isVector(); - if (!HasVectorValue) + bool HasVectorValueOrOp = false; + for (auto J = Node->value_begin(), E = Node->value_end(); J != E; ++J) + HasVectorValueOrOp |= J->isVector(); + for (const SDValue &Op : Node->op_values()) + HasVectorValueOrOp |= Op.getValueType().isVector(); + + if (!HasVectorValueOrOp) return TranslateLegalizeResults(Op, Result); TargetLowering::LegalizeAction Action = TargetLowering::Legal; @@ -387,6 +391,7 @@ case ISD::ANY_EXTEND_VECTOR_INREG: case ISD::SIGN_EXTEND_VECTOR_INREG: case ISD::ZERO_EXTEND_VECTOR_INREG: + case ISD::VECTOR_SHUFFLE_VAR: case ISD::SMIN: case ISD::SMAX: case ISD::UMIN: @@ -402,6 +407,19 @@ break; case ISD::SINT_TO_FP: case ISD::UINT_TO_FP: + case ISD::VECREDUCE_ADD: + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: + case ISD::VECREDUCE_SMAX: + case ISD::VECREDUCE_SMIN: + case ISD::VECREDUCE_UMAX: + case ISD::VECREDUCE_UMIN: + case ISD::VECREDUCE_FADD: + case ISD::VECREDUCE_FMUL: + case ISD::VECREDUCE_FMAX: + case ISD::VECREDUCE_FMIN: Action = TLI.getOperationAction(Node->getOpcode(), Node->getOperand(0).getValueType()); break; @@ -704,6 +722,70 @@ return TF; } +SDValue VectorLegalizer::ExpandVECTOR_SHUFFLE_VAR(SDValue Op) { + SDLoc dl(Op); + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + SDValue Mask = Op.getOperand(2); + EVT VT = Op.getValueType(); + + if (VT.isScalableVector()) { + unsigned VTNumElts = VT.getVectorNumElements(); + unsigned VTEltSize = VT.getVectorElementType().getSizeInBits(); + unsigned NewVTNumElts; + unsigned NewVTEltSize; + // Loop through vector types looking for vectors with the same number of + // elements as VT but with more bits per element + for (MVT NewVT : MVT::integer_scalable_vector_valuetypes()) { + NewVTNumElts = NewVT.getVectorNumElements(); + NewVTEltSize = NewVT.getVectorElementType().getSizeInBits(); + if (VTNumElts == NewVTNumElts && NewVTEltSize > VTEltSize) { + // Use this expanded vector type if it is legal for VECTOR_SHUFFLE_VAR + if (TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE_VAR, NewVT) && + TLI.isTypeLegal(NewVT)) { + SDValue ExpandedPred0 = DAG.getNode(ISD::ZERO_EXTEND, dl, NewVT, Op0); + SDValue ExpandedPred1 = DAG.getNode(ISD::ZERO_EXTEND, dl, NewVT, Op1); + SDValue ExpandedShuffle = DAG.getNode(ISD::VECTOR_SHUFFLE_VAR, dl, + NewVT, ExpandedPred0, + ExpandedPred1, Mask); + SDValue TruncatedShuffle = DAG.getNode(ISD::TRUNCATE, dl, + Op.getValueType(), + ExpandedShuffle); + return TruncatedShuffle; + } + } + } + llvm_unreachable("Unable to find legal expanded vector type for shuffle!"); + } + + assert(!VT.isScalableVector() && + "This code can't handle scalable vectors"); + EVT SrcVT = Op0.getValueType(); + EVT MaskVT = Mask.getValueType(); + EVT SrcEltVT = SrcVT.getVectorElementType(); + EVT MaskEltVT = MaskVT.getVectorElementType(); + assert(SrcEltVT == VT.getVectorElementType() && "Mismatched element types"); + unsigned SrcNumElts = SrcVT.getVectorNumElements(); + unsigned MaskNumElts = MaskVT.getVectorNumElements(); + EVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout()); + SmallVector Ops(MaskNumElts); + for (unsigned i = 0; i < MaskNumElts; ++i) { + SDValue IN = DAG.getConstant(i, dl, IdxVT); + SDValue Idx = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MaskEltVT, Mask, IN); + Idx = DAG.getSExtOrTrunc(Idx, dl, IdxVT); + SDValue SafeIdx = DAG.getNode(ISD::UREM, dl, IdxVT, + Idx, DAG.getConstant(SrcNumElts, dl, IdxVT)); + SDValue I0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, + SrcEltVT, Op0, SafeIdx); + SDValue I1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, + SrcEltVT, Op1, SafeIdx); + Ops[i] = DAG.getSelectCC(dl, Idx, SafeIdx, I0, I1, ISD::SETEQ); + } + SDValue NewOp = DAG.getNode(ISD::BUILD_VECTOR, dl, VT, Ops); + AddLegalizedOperand(Op, NewOp); + return NewOp; +} + SDValue VectorLegalizer::Expand(SDValue Op) { switch (Op->getOpcode()) { case ISD::SIGN_EXTEND_INREG: @@ -735,6 +817,9 @@ return ExpandCTLZ(Op); case ISD::CTTZ_ZERO_UNDEF: return ExpandCTTZ_ZERO_UNDEF(Op); + case ISD::VECTOR_SHUFFLE_VAR: + return ExpandVECTOR_SHUFFLE_VAR(Op); + case ISD::STRICT_FADD: case ISD::STRICT_FSUB: case ISD::STRICT_FMUL: @@ -753,6 +838,20 @@ case ISD::STRICT_FRINT: case ISD::STRICT_FNEARBYINT: return ExpandStrictFPOp(Op); + case ISD::VECREDUCE_ADD: + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: + case ISD::VECREDUCE_SMAX: + case ISD::VECREDUCE_SMIN: + case ISD::VECREDUCE_UMAX: + case ISD::VECREDUCE_UMIN: + case ISD::VECREDUCE_FADD: + case ISD::VECREDUCE_FMUL: + case ISD::VECREDUCE_FMAX: + case ISD::VECREDUCE_FMIN: + return TLI.expandVecReduce(Op.getNode(), DAG); default: return DAG.UnrollVectorOp(Op.getNode()); } Index: lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -64,6 +64,8 @@ case ISD::SETCC: R = ScalarizeVecRes_SETCC(N); break; case ISD::UNDEF: R = ScalarizeVecRes_UNDEF(N); break; case ISD::VECTOR_SHUFFLE: R = ScalarizeVecRes_VECTOR_SHUFFLE(N); break; + case ISD::VECTOR_SHUFFLE_VAR: + R = ScalarizeVec_VECTOR_SHUFFLE_VAR(N); break; case ISD::ANY_EXTEND_VECTOR_INREG: case ISD::SIGN_EXTEND_VECTOR_INREG: case ISD::ZERO_EXTEND_VECTOR_INREG: @@ -402,6 +404,44 @@ return GetScalarizedVector(N->getOperand(Op)); } +// Handle VECTOR_SHUFFLE_VARs in which the output and/or mask need +// to be scalarized. The first two inputs have the same type as the +// output and so require scalarization iff the output does. +SDValue DAGTypeLegalizer::ScalarizeVec_VECTOR_SHUFFLE_VAR(SDNode *N) { + SDLoc dl(N); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + SDValue Mask = N->getOperand(2); + if (getTypeAction(Op0.getValueType()) == + TargetLowering::TypeScalarizeVector) { + Op0 = GetScalarizedVector(Op0); + Op1 = GetScalarizedVector(Op1); + } + if (getTypeAction(Mask.getValueType()) == TargetLowering::TypeScalarizeVector) + Mask = GetScalarizedVector(Mask); + else + Mask = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, + Mask.getValueType().getVectorElementType(), + Mask, DAG.getConstant(0, dl, + TLI.getVectorIdxTy(DAG.getDataLayout()))); + // We want to select Op0 if Mask is 0 and Op1 if Mask is 1. Convert the + // 0/1 Mask into a form that the target can use for a SELECT operation. + EVT SccVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), + Mask.getValueType()); + switch (TLI.getBooleanContents(false, false)) { + case TargetLowering::ZeroOrNegativeOneBooleanContent: + Mask = DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, SccVT, + Mask, DAG.getValueType(MVT::i1)); + break; + case TargetLowering::ZeroOrOneBooleanContent: + case TargetLowering::UndefinedBooleanContent: + // The mask is already required to be 0 or 1. + Mask = DAG.getZExtOrTrunc(Mask, dl, SccVT); + break; + } + return DAG.getSelect(dl, Op0.getValueType(), Mask, Op1, Op0); +} + SDValue DAGTypeLegalizer::ScalarizeVecRes_SETCC(SDNode *N) { assert(N->getValueType(0).isVector() && N->getOperand(0).getValueType().isVector() && @@ -478,6 +518,9 @@ case ISD::VSELECT: Res = ScalarizeVecOp_VSELECT(N); break; + case ISD::VECTOR_SHUFFLE_VAR: + Res = ScalarizeVec_VECTOR_SHUFFLE_VAR(N); + break; case ISD::SETCC: Res = ScalarizeVecOp_VSETCC(N); break; @@ -487,6 +530,21 @@ case ISD::FP_ROUND: Res = ScalarizeVecOp_FP_ROUND(N, OpNo); break; + case ISD::VECREDUCE_FADD: + case ISD::VECREDUCE_FMUL: + case ISD::VECREDUCE_ADD: + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: + case ISD::VECREDUCE_SMAX: + case ISD::VECREDUCE_SMIN: + case ISD::VECREDUCE_UMAX: + case ISD::VECREDUCE_UMIN: + case ISD::VECREDUCE_FMAX: + case ISD::VECREDUCE_FMIN: + Res = ScalarizeVecOp_VECREDUCE(N); + break; } } @@ -617,6 +675,14 @@ return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), N->getValueType(0), Res); } +SDValue DAGTypeLegalizer::ScalarizeVecOp_VECREDUCE(SDNode *N) { + SDValue Res = GetScalarizedVector(N->getOperand(0)); + // Result type may be wider than element type. + if (Res.getValueType() != N->getValueType(0)) + Res = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), N->getValueType(0), Res); + return Res; +} + //===----------------------------------------------------------------------===// // Result Vector Splitting //===----------------------------------------------------------------------===// @@ -658,6 +724,8 @@ case ISD::FCOPYSIGN: SplitVecRes_FCOPYSIGN(N, Lo, Hi); break; case ISD::INSERT_VECTOR_ELT: SplitVecRes_INSERT_VECTOR_ELT(N, Lo, Hi); break; case ISD::SCALAR_TO_VECTOR: SplitVecRes_SCALAR_TO_VECTOR(N, Lo, Hi); break; + case ISD::SERIES_VECTOR: SplitVecRes_SERIES_VECTOR(N, Lo, Hi); break; + case ISD::SPLAT_VECTOR: SplitVecRes_SPLAT_VECTOR(N, Lo, Hi); break; case ISD::SIGN_EXTEND_INREG: SplitVecRes_InregOp(N, Lo, Hi); break; case ISD::LOAD: SplitVecRes_LOAD(cast(N), Lo, Hi); @@ -674,6 +742,9 @@ case ISD::VECTOR_SHUFFLE: SplitVecRes_VECTOR_SHUFFLE(cast(N), Lo, Hi); break; + case ISD::VECTOR_SHUFFLE_VAR: + SplitVecRes_VECTOR_SHUFFLE_VAR(N, Lo, Hi); + break; case ISD::ANY_EXTEND_VECTOR_INREG: case ISD::SIGN_EXTEND_VECTOR_INREG: @@ -924,6 +995,10 @@ SDValue SubVec = N->getOperand(1); SDValue Idx = N->getOperand(2); SDLoc dl(N); + + assert(!Vec.getValueType().isScalableVector() && + "This code does not yet implement VL scaled spills/fills"); + GetSplitVector(Vec, Lo, Hi); EVT VecVT = Vec.getValueType(); @@ -1129,6 +1204,9 @@ if (CustomLowerNode(N, N->getValueType(0), true)) return; + assert(!Vec.getValueType().isScalableVector() && + "This code does not yet implement VL scaled spills/fills"); + // Make the vector elements byte-addressable if they aren't already. EVT VecVT = Vec.getValueType(); EVT EltVT = VecVT.getVectorElementType(); @@ -1191,6 +1269,30 @@ Hi = DAG.getUNDEF(HiVT); } +void DAGTypeLegalizer::SplitVecRes_SERIES_VECTOR(SDNode *N, SDValue &Lo, + SDValue &Hi) { + EVT LoVT, HiVT; + SDLoc dl(N); + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0)); + SDValue LoStart = N->getOperand(0); + SDValue Step = N->getOperand(1); + EVT CountVT = LoStart.getValueType(); + SDValue EltCount = DAG.getVScale(dl, CountVT, LoVT.getVectorNumElements()); + SDValue Mult = DAG.getNode(ISD::MUL, dl, CountVT, EltCount, Step); + SDValue HiStart = DAG.getNode(ISD::ADD, dl, CountVT, LoStart, Mult); + Lo = DAG.getNode(ISD::SERIES_VECTOR, dl, LoVT, LoStart, Step); + Hi = DAG.getNode(ISD::SERIES_VECTOR, dl, HiVT, HiStart, Step); +} + +void DAGTypeLegalizer::SplitVecRes_SPLAT_VECTOR(SDNode *N, SDValue &Lo, + SDValue &Hi) { + EVT LoVT, HiVT; + SDLoc dl(N); + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0)); + Lo = DAG.getNode(ISD::SPLAT_VECTOR, dl, LoVT, N->getOperand(0)); + Hi = DAG.getNode(ISD::SPLAT_VECTOR, dl, HiVT, N->getOperand(0)); +} + void DAGTypeLegalizer::SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo, SDValue &Hi) { assert(ISD::isUNINDEXEDLoad(LD) && "Indexed load during type legalization!"); @@ -1214,10 +1316,20 @@ LD->getPointerInfo(), LoMemVT, Alignment, MMOFlags, AAInfo); unsigned IncrementSize = LoMemVT.getSizeInBits()/8; - Ptr = DAG.getObjectPtrOffset(dl, Ptr, IncrementSize); + + SDValue BytesIncrement; + MachinePointerInfo MPI; + if (LoVT.isScalableVector()) { + BytesIncrement = DAG.getVScale(dl, Ptr.getValueType(), IncrementSize); + MPI = LD->getPointerInfo(); + } else { + BytesIncrement = DAG.getConstant(IncrementSize, dl, Ptr.getValueType()); + MPI = LD->getPointerInfo().getWithOffset(IncrementSize); + } + + Ptr = DAG.getNode(ISD::ADD, dl, Ptr.getValueType(), Ptr, BytesIncrement); Hi = DAG.getLoad(ISD::UNINDEXED, ExtType, HiVT, dl, Ch, Ptr, Offset, - LD->getPointerInfo().getWithOffset(IncrementSize), HiMemVT, - Alignment, MMOFlags, AAInfo); + MPI, HiMemVT, Alignment, MMOFlags, AAInfo); // Build a factor node to remember that this load is independent of the // other one. @@ -1309,6 +1421,7 @@ SDValue Index = MGT->getIndex(); SDValue Scale = MGT->getScale(); unsigned Alignment = MGT->getOriginalAlignment(); + ISD::LoadExtType ExtType = MGT->getExtensionType(); // Split Mask operand SDValue MaskLo, MaskHi; @@ -1340,12 +1453,12 @@ Alignment, MGT->getAAInfo(), MGT->getRanges()); SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo, Scale}; - Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, OpsLo, - MMO); + Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoMemVT, dl, OpsLo, + MMO, ExtType, MGT->getIndexType()); SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi, Scale}; - Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, OpsHi, - MMO); + Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiMemVT, dl, OpsHi, + MMO, ExtType, MGT->getIndexType()); // Build a factor node to remember that this load is independent of the // other one. @@ -1430,8 +1543,8 @@ // more effectively move in the right direction and prevent falling down // to scalarization in many cases due to the input vector being split too // far. - unsigned NumElements = SrcVT.getVectorNumElements(); - if ((NumElements & 1) == 0 && + auto EltCnt = SrcVT.getVectorElementCount(); + if ((EltCnt.Min & 1) == 0 && SrcVT.getSizeInBits() * 2 < DestVT.getSizeInBits()) { LLVMContext &Ctx = *DAG.getContext(); EVT NewSrcVT = SrcVT.widenIntegerVectorElementType(Ctx); @@ -1567,6 +1680,93 @@ } } +/// Check if the node is a nodes is all zeros (used for splatting scalars into +/// vectors in the VECTOR_SHUFFLE_VAR). +static bool isASplatMask(SDValue &S) { + + if (auto N = dyn_cast(S.getNode())) { + return N->isNullValue(); + } + + // Scalable vector specific code. + unsigned Opcode = S.getOpcode(); + + if (Opcode == ISD::SPLAT_VECTOR) { + auto *COp = dyn_cast(S.getOperand(0)); + if (COp->isNullValue()) + return true; + } + + // Todo: Remove after SPLAT_VECTOR works? + if (Opcode != ISD::SERIES_VECTOR) + return false; + + for (unsigned i = 0; i < 2; ++i) { + auto Op = S.getOperand(i); + auto N = dyn_cast (Op); + if (!(N && N->isNullValue())) + return false; + } + + return true; +} + +void DAGTypeLegalizer::SplitVecRes_VECTOR_SHUFFLE_VAR(SDNode *N, + SDValue &Lo, + SDValue &Hi) { + SDValue Inputs[4], Masks[2]; + SDLoc dl(N); + // The low and high parts of the original input give four input vectors. + GetSplitVector(N->getOperand(0), Inputs[0], Inputs[1]); + GetSplitVector(N->getOperand(1), Inputs[2], Inputs[3]); + // The mask may or may not be split. If it isn't, split it manually. + SDValue Mask = N->getOperand(2); + if (getTypeAction(Mask.getValueType()) == TargetLowering::TypeSplitVector) + GetSplitVector(Mask, Masks[0], Masks[1]); + else + std::tie(Masks[0], Masks[1]) = DAG.SplitVector(Mask, dl); + + EVT VT = Inputs[0].getValueType(); + EVT MaskVT = Masks[0].getValueType(); + + // first check if we can pull out a simple SPLAT + if (isASplatMask(Mask)) { + auto Undef = DAG.getUNDEF(VT); + auto ZeroMask = DAG.getConstant(0, dl, MaskVT); + // If it is a splat all it is needed is Input[0] and Masks[0] + Lo = DAG.getNode(ISD::VECTOR_SHUFFLE_VAR, dl, VT, + Inputs[0], Undef, ZeroMask); + Hi = DAG.getNode(ISD::VECTOR_SHUFFLE_VAR, dl, VT, + Inputs[0], Undef, ZeroMask); + return; + } + + // Generic code for non-splat + EVT SccVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), + MaskVT); + auto EltCnt = VT.getVectorElementCount(); + + SDValue Limit = DAG.getConstant(2 * EltCnt.Min, dl, MaskVT); + for (unsigned i = 0; i < 2; ++i) { + SDValue &HalfRes = i == 0 ? Lo : Hi; + // Reduce the mask so that each element is in the range [0, NumElts * 2). + // HalfOp0 contains the right answer for mask elements in that range + // and HalfOp1 contains the right answer for the rest. + SDValue HalfMask = DAG.getNode(ISD::UREM, dl, MaskVT, Masks[i], Limit); + SDValue HalfOp0 = DAG.getNode(ISD::VECTOR_SHUFFLE_VAR, dl, VT, + Inputs[0], Inputs[1], HalfMask); + SDValue HalfOp1 = DAG.getNode(ISD::VECTOR_SHUFFLE_VAR, dl, VT, + Inputs[2], Inputs[3], HalfMask); + // Get a vector of 0/1 selectors, with 1 selecting HalfOp0 and + // 0 selecting HalfOp1. + SDValue WhichOp = DAG.getSetCC(dl, SccVT, HalfMask, Masks[i], ISD::SETEQ); + // Use a VSELECT to get the final result. It requires the selector + // to be a vector of i1s. + EVT CondVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, EltCnt); + WhichOp = DAG.getNode(ISD::TRUNCATE, dl, CondVT, WhichOp); + HalfRes = DAG.getNode(ISD::VSELECT, dl, VT, WhichOp, HalfOp0, HalfOp1); + } +} //===----------------------------------------------------------------------===// // Operand Vector Splitting @@ -1592,8 +1792,7 @@ N->dump(&DAG); dbgs() << "\n"; #endif - report_fatal_error("Do not know how to split this operator's " - "operand!\n"); + report_fatal_error("Do not know how to split this operator's operand!\n"); case ISD::SETCC: Res = SplitVecOp_VSETCC(N); break; case ISD::BITCAST: Res = SplitVecOp_BITCAST(N); break; @@ -1770,7 +1969,7 @@ EVT InVT = Lo.getValueType(); EVT OutVT = EVT::getVectorVT(*DAG.getContext(), ResVT.getVectorElementType(), - InVT.getVectorNumElements()); + InVT.getVectorElementCount()); Lo = DAG.getNode(N->getOpcode(), dl, OutVT, Lo); Hi = DAG.getNode(N->getOpcode(), dl, OutVT, Hi); @@ -1841,6 +2040,19 @@ if (CustomLowerNode(N, N->getValueType(0), true)) return SDValue(); + // TODO: + // At the point of legalization, StackRegions are not yet initialized, + // which means there is no opportunity yet to register a slot to + // a stackregion. Since the object is not created by an alloca + // instruction, the InitializeStackRegions pass does not register + // the slot to any region, because it does not have its type, only a + // block of allocated stack space. This means we should either record + // the EVT in a stack object (which I think was explicitly designed to + // be type agnostic), or just keep the assert and let target lowering + // resolve the problem without using the stack. + assert(!Vec.getValueType().isScalableVector() && + "This code does not yet implement VL scaled spills/fills"); + // Make the vector elements byte-addressable if they aren't already. SDLoc dl(N); EVT EltVT = VecVT.getVectorElementType(); @@ -1888,6 +2100,7 @@ SDValue Mask = MGT->getMask(); SDValue Src0 = MGT->getValue(); unsigned Alignment = MGT->getOriginalAlignment(); + ISD::LoadExtType ExtType = MGT->getExtensionType(); SDValue MaskLo, MaskHi; if (getTypeAction(Mask.getValueType()) == TargetLowering::TypeSplitVector) @@ -1918,8 +2131,9 @@ Alignment, MGT->getAAInfo(), MGT->getRanges()); SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo, Scale}; - SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, - OpsLo, MMO); + SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoMemVT, dl, + OpsLo, MMO, ExtType, MGT->getIndexType()); + MMO = DAG.getMachineFunction(). getMachineMemOperand(MGT->getPointerInfo(), @@ -1928,8 +2142,8 @@ MGT->getRanges()); SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi, Scale}; - SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, - OpsHi, MMO); + SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiMemVT, dl, + OpsHi, MMO, ExtType, MGT->getIndexType()); // Build a factor node to remember that this load is independent of the // other one. @@ -2036,6 +2250,12 @@ else std::tie(MaskLo, MaskHi) = DAG.SplitVector(Mask, DL); + SDValue PtrLo, PtrHi; + if (Ptr.getValueType().isVector()) // gather form vector of pointers + std::tie(PtrLo, PtrHi) = DAG.SplitVector(Ptr, DL); + else + PtrLo = PtrHi = Ptr; + SDValue IndexHi, IndexLo; if (getTypeAction(Index.getValueType()) == TargetLowering::TypeSplitVector) GetSplitVector(Index, IndexLo, IndexHi); @@ -2049,20 +2269,18 @@ Alignment, N->getAAInfo(), N->getRanges()); SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo, Scale}; - Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(), - DL, OpsLo, MMO); + Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), LoMemVT, DL, OpsLo, MMO, + N->isTruncatingStore(), N->getIndexType()); MMO = DAG.getMachineFunction(). getMachineMemOperand(N->getPointerInfo(), MachineMemOperand::MOStore, HiMemVT.getStoreSize(), Alignment, N->getAAInfo(), N->getRanges()); - // The order of the Scatter operation after split is well defined. The "Hi" - // part comes after the "Lo". So these two operations should be chained one - // after another. SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi, Scale}; - return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(), - DL, OpsHi, MMO); + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), HiMemVT, DL, OpsHi, + MMO, N->isTruncatingStore(), N->getIndexType()); + } SDValue DAGTypeLegalizer::SplitVecOp_STORE(StoreSDNode *N, unsigned OpNo) { @@ -2096,16 +2314,24 @@ Lo = DAG.getStore(Ch, DL, Lo, Ptr, N->getPointerInfo(), Alignment, MMOFlags, AAInfo); + SDValue BytesIncrement; + MachinePointerInfo MPI; + if (LoMemVT.isScalableVector()) { + BytesIncrement = DAG.getVScale(DL, Ptr.getValueType(), IncrementSize); + MPI = N->getPointerInfo(); + } else { + BytesIncrement = DAG.getConstant(IncrementSize, DL, Ptr.getValueType()); + MPI = N->getPointerInfo().getWithOffset(IncrementSize); + } + // Increment the pointer to the other half. - Ptr = DAG.getObjectPtrOffset(DL, Ptr, IncrementSize); + Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement); if (isTruncating) - Hi = DAG.getTruncStore(Ch, DL, Hi, Ptr, - N->getPointerInfo().getWithOffset(IncrementSize), + Hi = DAG.getTruncStore(Ch, DL, Hi, Ptr, MPI, HiMemVT, Alignment, MMOFlags, AAInfo); else - Hi = DAG.getStore(Ch, DL, Hi, Ptr, - N->getPointerInfo().getWithOffset(IncrementSize), + Hi = DAG.getStore(Ch, DL, Hi, Ptr, MPI, Alignment, MMOFlags, AAInfo); return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Lo, Hi); @@ -2154,12 +2380,12 @@ SDValue InVec = N->getOperand(0); EVT InVT = InVec->getValueType(0); EVT OutVT = N->getValueType(0); - unsigned NumElements = OutVT.getVectorNumElements(); + auto EltCnt = OutVT.getVectorElementCount(); bool IsFloat = OutVT.isFloatingPoint(); // Widening should have already made sure this is a power-two vector // if we're trying to split it at all. assert() that's true, just in case. - assert(!(NumElements & 1) && "Splitting vector, but not in half!"); + assert(!(EltCnt.Min & 1) && "Splitting vector, but not in half!"); unsigned InElementSize = InVT.getScalarSizeInBits(); unsigned OutElementSize = OutVT.getScalarSizeInBits(); @@ -2178,12 +2404,11 @@ EVT HalfElementVT = IsFloat ? EVT::getFloatingPointVT(InElementSize/2) : EVT::getIntegerVT(*DAG.getContext(), InElementSize/2); - EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), HalfElementVT, - NumElements/2); + EVT HalfVT = EVT::getVectorVT(*DAG.getContext(), HalfElementVT, EltCnt/2); SDValue HalfLo = DAG.getNode(N->getOpcode(), DL, HalfVT, InLoVec); SDValue HalfHi = DAG.getNode(N->getOpcode(), DL, HalfVT, InHiVec); // Concatenate them to get the full intermediate truncation result. - EVT InterVT = EVT::getVectorVT(*DAG.getContext(), HalfElementVT, NumElements); + EVT InterVT = EVT::getVectorVT(*DAG.getContext(), HalfElementVT, EltCnt); SDValue InterVec = DAG.getNode(ISD::CONCAT_VECTORS, DL, InterVT, HalfLo, HalfHi); // Now finish up by truncating all the way down to the original result @@ -2206,9 +2431,11 @@ SDLoc DL(N); GetSplitVector(N->getOperand(0), Lo0, Hi0); GetSplitVector(N->getOperand(1), Lo1, Hi1); - unsigned PartElements = Lo0.getValueType().getVectorNumElements(); - EVT PartResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, PartElements); - EVT WideResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, 2*PartElements); + auto PartEltCnt = Lo0.getValueType().getVectorElementCount(); + + LLVMContext &Context = *DAG.getContext(); + EVT PartResVT = EVT::getVectorVT(Context, MVT::i1, PartEltCnt); + EVT WideResVT = EVT::getVectorVT(Context, MVT::i1, PartEltCnt*2); LoRes = DAG.getNode(ISD::SETCC, DL, PartResVT, Lo0, Lo1, N->getOperand(2)); HiRes = DAG.getNode(ISD::SETCC, DL, PartResVT, Hi0, Hi1, N->getOperand(2)); @@ -2226,11 +2453,10 @@ EVT InVT = Lo.getValueType(); EVT OutVT = EVT::getVectorVT(*DAG.getContext(), ResVT.getVectorElementType(), - InVT.getVectorNumElements()); + InVT.getVectorElementCount()); Lo = DAG.getNode(ISD::FP_ROUND, DL, OutVT, Lo, N->getOperand(1)); Hi = DAG.getNode(ISD::FP_ROUND, DL, OutVT, Hi, N->getOperand(1)); - return DAG.getNode(ISD::CONCAT_VECTORS, DL, ResVT, Lo, Hi); } @@ -3048,7 +3274,8 @@ SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale }; SDValue Res = DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other), N->getMemoryVT(), dl, Ops, - N->getMemOperand()); + N->getMemOperand(), N->getExtensionType(), + N->getIndexType()); // Legalize the chain result - switch anything that used the old chain to // use the new one. @@ -3111,7 +3338,7 @@ // Make a new Mask node, with a legal result VT. SmallVector Ops; - for (unsigned i = 0, e = InMask->getNumOperands(); i < e; ++i) + for (unsigned i = 0; i < InMask->getNumOperands(); ++i) Ops.push_back(InMask->getOperand(i)); SDValue Mask = DAG.getNode(InMask->getOpcode(), SDLoc(InMask), MaskVT, Ops); @@ -3144,9 +3371,12 @@ } else if (CurrMaskNumEls < ToMaskVT.getVectorNumElements()) { unsigned NumSubVecs = (ToMaskVT.getVectorNumElements() / CurrMaskNumEls); EVT SubVT = Mask->getValueType(0); - SmallVector SubOps(NumSubVecs, DAG.getUNDEF(SubVT)); - SubOps[0] = Mask; - Mask = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(Mask), ToMaskVT, SubOps); + SmallVector SubConcatOps(NumSubVecs); + SubConcatOps[0] = Mask; + for (unsigned i = 1; i < NumSubVecs; ++i) + SubConcatOps[i] = DAG.getUNDEF(SubVT); + Mask = + DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(Mask), ToMaskVT, SubConcatOps); } assert((Mask->getValueType(0) == ToMaskVT) && @@ -3190,6 +3420,9 @@ if (!isPowerOf2_64(VSelVT.getSizeInBits())) return SDValue(); + if (VSelVT.isScalableVector()) + return SDValue(); + // Don't touch if this will be scalarized. EVT FinalVT = VSelVT; while (getTypeAction(FinalVT) == TargetLowering::TypeSplitVector) @@ -3432,6 +3665,22 @@ case ISD::TRUNCATE: Res = WidenVecOp_Convert(N); break; + + case ISD::VECREDUCE_FADD: + case ISD::VECREDUCE_FMUL: + case ISD::VECREDUCE_ADD: + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: + case ISD::VECREDUCE_SMAX: + case ISD::VECREDUCE_SMIN: + case ISD::VECREDUCE_UMAX: + case ISD::VECREDUCE_UMIN: + case ISD::VECREDUCE_FMAX: + case ISD::VECREDUCE_FMIN: + Res = WidenVecOp_VECREDUCE(N); + break; } // If Res is null, the sub-method took care of registering the result. @@ -3711,8 +3960,8 @@ SDValue Ops[] = {MSC->getChain(), WideVal, Mask, MSC->getBasePtr(), Index, Scale}; return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), - MSC->getMemoryVT(), dl, Ops, - MSC->getMemOperand()); + MSC->getMemoryVT(), dl, Ops, MSC->getMemOperand(), + MSC->isTruncatingStore(), MSC->getIndexType()); } SDValue DAGTypeLegalizer::WidenVecOp_SETCC(SDNode *N) { @@ -3748,6 +3997,62 @@ return PromoteTargetBoolean(CC, VT); } +SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE(SDNode *N) { + SDLoc dl(N); + SDValue Op = GetWidenedVector(N->getOperand(0)); + EVT OrigVT = N->getOperand(0).getValueType(); + EVT WideVT = Op.getValueType(); + EVT ElemVT = OrigVT.getVectorElementType(); + + SDValue NeutralElem; + switch (N->getOpcode()) { + case ISD::VECREDUCE_ADD: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: + case ISD::VECREDUCE_UMAX: + NeutralElem = DAG.getConstant(0, dl, ElemVT); + break; + case ISD::VECREDUCE_MUL: + NeutralElem = DAG.getConstant(1, dl, ElemVT); + break; + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_UMIN: + NeutralElem = DAG.getAllOnesConstant(dl, ElemVT); + break; + case ISD::VECREDUCE_SMAX: + NeutralElem = DAG.getConstant( + APInt::getSignedMinValue(ElemVT.getSizeInBits()), dl, ElemVT); + break; + case ISD::VECREDUCE_SMIN: + NeutralElem = DAG.getConstant( + APInt::getSignedMaxValue(ElemVT.getSizeInBits()), dl, ElemVT); + break; + case ISD::VECREDUCE_FADD: + NeutralElem = DAG.getConstantFP(0.0, dl, ElemVT); + break; + case ISD::VECREDUCE_FMUL: + NeutralElem = DAG.getConstantFP(1.0, dl, ElemVT); + break; + case ISD::VECREDUCE_FMAX: + NeutralElem = DAG.getConstantFP( + std::numeric_limits::infinity(), dl, ElemVT); + break; + case ISD::VECREDUCE_FMIN: + NeutralElem = DAG.getConstantFP( + -std::numeric_limits::infinity(), dl, ElemVT); + break; + } + + // Pad the vector with the neutral element. + unsigned OrigElts = OrigVT.getVectorNumElements(); + unsigned WideElts = WideVT.getVectorNumElements(); + for (unsigned Idx = OrigElts; Idx < WideElts; Idx++) + Op = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, WideVT, Op, NeutralElem, + DAG.getConstant(Idx, dl, TLI.getVectorIdxTy(DAG.getDataLayout()))); + + return DAG.getNode(N->getOpcode(), dl, N->getValueType(0), Op, N->getFlags()); +} + //===----------------------------------------------------------------------===// // Vector Widening Utilities @@ -4091,7 +4396,9 @@ } else { // Cast the vector to the scalar type we can store. unsigned NumElts = ValWidth / NewVTWidth; - EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewVT, NumElts); + bool IsScalable = ValVT.isScalableVector(); + EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewVT, NumElts, + IsScalable); SDValue VecOp = DAG.getNode(ISD::BITCAST, dl, NewVecVT, ValOp); // Readjust index position based on new vector type. Idx = Idx * ValEltWidth / NewVTWidth; Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -130,7 +130,19 @@ // ISD Namespace //===----------------------------------------------------------------------===// -bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) { +bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal, + bool AllowShrink) { + if (N->getOpcode() == ISD::SPLAT_VECTOR) { + unsigned EltSize = + N->getValueType(0).getVectorElementType().getSizeInBits(); + if (auto *Op0 = dyn_cast(N->getOperand(0))) { + SplatVal = Op0->getAPIntValue(); + if (EltSize < SplatVal.getBitWidth()) + SplatVal = SplatVal.trunc(EltSize); + return true; + } + } + auto *BV = dyn_cast(N); if (!BV) return false; @@ -147,11 +159,23 @@ // FIXME: AllOnes and AllZeros duplicate a lot of code. Could these be // specializations of the more general isConstantSplatVector()? +// FIXME: Moved SPLAT_VECTOR handling here to match similar moves in +// community code in February 2017 merge. May want to rename or +// comment as to why it's not just build vector nodes. + bool ISD::isBuildVectorAllOnes(const SDNode *N) { // Look through a bit convert. while (N->getOpcode() == ISD::BITCAST) N = N->getOperand(0).getNode(); + if (N->getOpcode() == ISD::SPLAT_VECTOR) { + APInt Val; + if (isConstantSplatVector(N, Val)) + return Val.isAllOnesValue(); + + return false; + } + if (N->getOpcode() != ISD::BUILD_VECTOR) return false; unsigned i = 0, e = N->getNumOperands(); @@ -196,6 +220,14 @@ while (N->getOpcode() == ISD::BITCAST) N = N->getOperand(0).getNode(); + if (N->getOpcode() == ISD::SPLAT_VECTOR) { + APInt Val; + if (isConstantSplatVector(N, Val)) + return Val.getLimitedValue() == 0L; + + return false; + } + if (N->getOpcode() != ISD::BUILD_VECTOR) return false; bool IsAllUndef = true; @@ -1278,8 +1310,13 @@ } SDValue Result(N, 0); - if (VT.isVector()) - Result = getSplatBuildVector(VT, DL, Result); + if (VT.isVector()) { + if (VT.isScalableVector()) { + Result = getNode(ISD::SPLAT_VECTOR, DL, VT, Result); + } else { + Result = getSplatBuildVector(VT, DL, Result); + } + } return Result; } @@ -1320,9 +1357,13 @@ } SDValue Result(N, 0); - if (VT.isVector()) - Result = getSplatBuildVector(VT, DL, Result); - NewSDValueDbgMsg(Result, "Creating fp constant: ", this); + if (VT.isVector()) { + if(VT.isScalableVector()) { + Result = getNode(ISD::SPLAT_VECTOR, DL, VT, Result); + } else { + Result = getSplatBuildVector(VT, DL, Result); + } + } return Result; } @@ -1562,8 +1603,6 @@ SDValue N2, ArrayRef Mask) { assert(VT.getVectorNumElements() == Mask.size() && "Must have the same number of vector elements as mask elements!"); - assert(VT == N1.getValueType() && VT == N2.getValueType() && - "Invalid VECTOR_SHUFFLE"); // Canonicalize shuffle undef, undef -> undef if (N1.isUndef() && N2.isUndef()) @@ -3052,6 +3091,7 @@ case ISD::INTRINSIC_WO_CHAIN: case ISD::INTRINSIC_W_CHAIN: case ISD::INTRINSIC_VOID: + case ISD::VSCALE: // Allow the target to implement this method for its nodes. TLI->computeKnownBitsForTargetNode(Op, Known, DemandedElts, *this, Depth); break; @@ -4018,6 +4058,22 @@ transferDbgValues(Operand, OpOp); return OpOp; } + } else if (OpOpcode == ISD::BUILD_VECTOR) { + // ext (buildvector [opnd0, .., opndN]) + // -> buildvector [ext(opnd0), .., ext(OpndN)] + // iff ext(opnd_i) is a NOP, foreach i=0..N + SmallVector Ops; + int i = 0, e = VT.getVectorNumElements(); + for (; i != e; ++i) { + unsigned OpndOpcode = Operand.getOperand(i).getNode()->getOpcode(); + if (!isa(Operand.getOperand(i)) && + OpndOpcode != ISD::UNDEF) + break; + Ops.push_back(getNode(ISD::ANY_EXTEND, DL, + VT.getVectorElementType(), Operand.getOperand(i))); + } + if (i == e) + return getNode(ISD::BUILD_VECTOR, DL, VT, Ops); } break; case ISD::TRUNCATE: @@ -4091,6 +4147,17 @@ Operand.getOperand(0).getValueType() == VT) return Operand.getOperand(0); break; + case ISD::SPLAT_VECTOR: + if (!VT.isScalableVector()) { + SmallVector Ops; + for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) { + + Ops.push_back(Operand); + } + + return getNode(ISD::BUILD_VECTOR, DL, VT, Ops); + } + break; case ISD::FNEG: // -(X-Y) -> (Y-X) is unsafe because when X==Y, -0.0 != +0.0 if ((getTarget().Options.UnsafeFPMath || Flags.hasNoSignedZeros()) && @@ -4408,6 +4475,121 @@ return V; } +bool SelectionDAG::isConstantIntSplat(SDValue Val, APInt* SplatValue) { + if (Val.getOpcode() == ISD::SPLAT_VECTOR) { + auto *CSplatVal = dyn_cast(Val->getOperand(0)); + + if (CSplatVal) { + if (SplatValue) + *SplatValue = CSplatVal->getAPIntValue(); + return true; + } + } + + return false; +} + +SDValue SelectionDAG::FoldSeriesVectorBinOp(unsigned Opcode, SDLoc DL, EVT VT, + SDValue N1, SDValue N2, + const SDNodeFlags Flags) { + // These convert common stepvector idioms to full fat seriesvectors. They're + // potentially short lived as the current objective it to mataintain the + // output produced by the original seriesvector based IR. + + if ((Opcode == ISD::ADD) && + (N1.getOpcode() == ISD::SERIES_VECTOR) && + (N2.getOpcode() == ISD::SPLAT_VECTOR) && + (N1.getOperand(0).getValueType() == N2.getOperand(0).getValueType())) { + EVT ElType = VT.getVectorElementType(); + SDValue Base = getNode(Opcode, SDLoc(), ElType, + N1.getOperand(0), N2.getOperand(0)); + return getNode(ISD::SERIES_VECTOR, SDLoc(), VT, Base, N1.getOperand(1)); + } + + if ((Opcode == ISD::ADD) && + (N1.getOpcode() == ISD::SPLAT_VECTOR) && + (N2.getOpcode() == ISD::SERIES_VECTOR) && + (N1.getOperand(0).getValueType() == N2.getOperand(0).getValueType())) { + EVT ElType = VT.getVectorElementType(); + SDValue Base = getNode(Opcode, SDLoc(), ElType, + N1.getOperand(0), N2.getOperand(0)); + return getNode(ISD::SERIES_VECTOR, SDLoc(), VT, Base, N2.getOperand(1)); + } + + if ((Opcode == ISD::MUL) && + (N1.getOpcode() == ISD::SPLAT_VECTOR) && + (N2.getOpcode() == ISD::SERIES_VECTOR) && + (N1.getOperand(0).getValueType() == N2.getOperand(0).getValueType())) { + EVT ElType = VT.getVectorElementType(); + SDValue Base = getNode(Opcode, SDLoc(), ElType, + N2.getOperand(0), N1.getOperand(0)); + SDValue Step = getNode(Opcode, SDLoc(), ElType, + N2.getOperand(1), N1.getOperand(0)); + + return getNode(ISD::SERIES_VECTOR, SDLoc(), VT, Base, Step); + } + + // We are not checking if N1 is a constant splat here because either 1) the + // constant has been canonicalized to RHS, or 2) it's an operation where we + // only care if RHS is a constant, e.g. ISD::SHL. + if (N1.getOpcode() != ISD::SERIES_VECTOR || !isConstantIntSplat(N2, nullptr)) + return SDValue(); + + switch (Opcode) { + default: + break; + // These opcodes operate on integer types only. + case ISD::MUL: + case ISD::SHL: + if (dyn_cast(N1.getOperand(1))) { + EVT ElType = VT.getVectorElementType(); + + // Now we need to shift the operands to the first series vector + SDValue Base = getNode(Opcode, SDLoc(), ElType, + N1.getOperand(0), N2.getOperand(0)); + SDValue Step = getNode(Opcode, SDLoc(), ElType, + N1.getOperand(1), N2.getOperand(0)); + + return getNode(ISD::SERIES_VECTOR, SDLoc(), VT, Base, Step); + } + break; + } + + return SDValue(); +} + +SDValue SelectionDAG::FoldSplatVectorBinOp(unsigned Opcode, SDLoc DL, EVT VT, + SDValue N1, SDValue N2, + const SDNodeFlags Flags) { + // Only want to check operations on two splats here... + if (N1.getOpcode() != ISD::SPLAT_VECTOR || + N2.getOpcode() != ISD::SPLAT_VECTOR) + return SDValue(); + + EVT EltVT = VT.getVectorElementType(); + SDValue Op1 = N1.getOperand(0); + SDValue Op2 = N2.getOperand(0); + + // ...whose operands don't require extra processing. + if (Op1.getValueType() != EltVT || + Op2.getValueType() != EltVT) + return SDValue(); + + switch (Opcode) { + default: + break; + case ISD::MUL: + case ISD::SHL: { + // Could be lots of others for a splat, if not already taken care + // of elsewhere? + SDValue NewSplatVal = getNode(Opcode, DL, EltVT, Op1, Op2); + return getNode(ISD::SPLAT_VECTOR, DL, VT, NewSplatVal); + } + } + + return SDValue(); +} + SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, SDValue N2, const SDNodeFlags Flags) { ConstantSDNode *N1C = dyn_cast(N1); @@ -4426,6 +4608,12 @@ } } + if (SDValue SV = FoldSeriesVectorBinOp(Opcode, DL, VT, N1, N2, Flags)) + return SV; + + if (SDValue SV = FoldSplatVectorBinOp(Opcode, DL, VT, N1, N2, Flags)) + return SV; + switch (Opcode) { default: break; case ISD::TokenFactor: @@ -4466,11 +4654,23 @@ if (N2C && N2C->isNullValue()) return N1; break; + case ISD::MUL: + assert(VT.isInteger() && "This operator does not apply to FP types!"); + assert(N1.getValueType() == N2.getValueType() && + N1.getValueType() == VT && "Binary operator types must match!"); + // (X * 0) -> 0. This commonly occurs when legalizing scalable vector + // splats so it's worth handling here. + if (N2C && N2C->isNullValue()) + return N2; + if (N2C && (N1.getOpcode() == ISD::VSCALE)) { + int64_t MulImm = cast(N1->getOperand(0))->getSExtValue(); + return getVScale(DL, VT, MulImm * N2C->getSExtValue()); + } + break; case ISD::UDIV: case ISD::UREM: case ISD::MULHU: case ISD::MULHS: - case ISD::MUL: case ISD::SDIV: case ISD::SREM: case ISD::SMIN: @@ -4497,6 +4697,11 @@ "Invalid FCOPYSIGN!"); break; case ISD::SHL: + if (N2C && (N1.getOpcode() == ISD::VSCALE)) { + int64_t MulImm = cast(N1->getOperand(0))->getSExtValue(); + return getVScale(DL, VT, MulImm << N2C->getSExtValue()); + } + // Intentional fall-through! case ISD::SRA: case ISD::SRL: case ISD::ROTL: @@ -4583,7 +4788,12 @@ if (N1C) { const APInt &Val = N1C->getAPIntValue(); return SignExtendInReg(Val, VT); + } else { + APInt Val; + if (isConstantIntSplat(N1, &Val)) + return SignExtendInReg(Val, VT); } + if (ISD::isBuildVectorOfConstantSDNodes(N1.getNode())) { SmallVector Ops; llvm::EVT OpVT = N1.getOperand(0).getValueType(); @@ -4610,8 +4820,10 @@ if (N1.isUndef()) return getUNDEF(VT); - // EXTRACT_VECTOR_ELT of out-of-bounds element is an UNDEF - if (N2C && N2C->getAPIntValue().uge(N1.getValueType().getVectorNumElements())) + // EXTRACT_VECTOR_ELT of out-of-bounds element is an UNDEF, + // unless we are dealing with a scalable vector. + if (N2C && N2C->getZExtValue() >= N1.getValueType().getVectorNumElements() + && !N1.getValueType().isScalableVector()) return getUNDEF(VT); // EXTRACT_VECTOR_ELT of CONCAT_VECTORS is often formed while lowering is @@ -4708,6 +4920,9 @@ && "Extract subvector overflow!"); } + if (N1.getOpcode() == ISD::UNDEF) + return getUNDEF(VT); + // Trivial extraction. if (VT.getSimpleVT() == N1.getSimpleValueType()) return N1; @@ -4732,6 +4947,18 @@ return N1.getOperand(1); } break; + case ISD::SERIES_VECTOR: + if (!VT.isScalableVector()) { + EVT ElVT = N2.getValueType(); + + SmallVector Ops; + for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) { + SDValue Idx = getNode(ISD::MUL, DL, ElVT, N2, getConstant(i, DL, ElVT)); + Ops.push_back(getNode(ISD::ADD, DL, ElVT, N1, Idx)); + } + + return getNode(ISD::BUILD_VECTOR, DL, VT, Ops); + } } // Perform trivial constant folding. @@ -4937,11 +5164,16 @@ break; case ISD::VECTOR_SHUFFLE: llvm_unreachable("should use getVectorShuffle constructor!"); + case ISD::VECTOR_SHUFFLE_VAR: + if ((N1->getOpcode() == ISD::UNDEF) && (N2->getOpcode() == ISD::UNDEF)) + return getUNDEF(VT); + break; case ISD::INSERT_VECTOR_ELT: { ConstantSDNode *N3C = dyn_cast(N3); // INSERT_VECTOR_ELT into out-of-bounds element is an UNDEF if (N3C && N3C->getZExtValue() >= N1.getValueType().getVectorNumElements()) - return getUNDEF(VT); + if (!N1.getValueType().isScalableVector()) + return getUNDEF(VT); break; } case ISD::INSERT_SUBVECTOR: { @@ -6553,16 +6785,18 @@ return V; } -SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl, - ArrayRef Ops, - MachineMemOperand *MMO) { +SDValue +SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO, + ISD::LoadExtType ExtTy, + ISD::MemIndexType IndexType) { assert(Ops.size() == 6 && "Incompatible number of operands"); FoldingSetNodeID ID; AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops); ID.AddInteger(VT.getRawBits()); ID.AddInteger(getSyntheticNodeSubclassData( - dl.getIROrder(), VTs, VT, MMO)); + dl.getIROrder(), VTs, ExtTy, VT, MMO, IndexType)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { @@ -6571,7 +6805,7 @@ } auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), - VTs, VT, MMO); + VTs, ExtTy, VT, MMO, IndexType); createOperands(N, Ops); assert(N->getValue().getValueType() == N->getValueType(0) && @@ -6595,22 +6829,24 @@ SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, - MachineMemOperand *MMO) { + MachineMemOperand *MMO, bool isTrunc, + ISD::MemIndexType IndexType) { assert(Ops.size() == 6 && "Incompatible number of operands"); FoldingSetNodeID ID; AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops); ID.AddInteger(VT.getRawBits()); ID.AddInteger(getSyntheticNodeSubclassData( - dl.getIROrder(), VTs, VT, MMO)); + dl.getIROrder(), VTs, isTrunc, VT, MMO, IndexType)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); return SDValue(E, 0); } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), - VTs, VT, MMO); + VTs, isTrunc, VT, MMO, IndexType); createOperands(N, Ops); assert(N->getMask().getValueType().getVectorNumElements() == @@ -6654,6 +6890,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, ArrayRef Ops, const SDNodeFlags Flags) { + assert(!(VT.isScalableVector() && Opcode == ISD::BUILD_VECTOR) && + "BUILD_VECTOR cannot be used when you don't know the element count."); unsigned NumOps = Ops.size(); switch (NumOps) { case 0: return getNode(Opcode, DL, VT); @@ -7371,6 +7609,9 @@ /// TargetOpcode::INSERT_SUBREG nodes. SDValue SelectionDAG::getTargetInsertSubreg(int SRIdx, const SDLoc &DL, EVT VT, SDValue Operand, SDValue Subreg) { + if (Subreg.isUndef() && Operand.isUndef()) + return getUNDEF(VT); + SDValue SRIdxVal = getTargetConstant(SRIdx, DL, MVT::i32); SDNode *Result = getMachineNode(TargetOpcode::INSERT_SUBREG, DL, VT, Operand, Subreg, SRIdxVal); @@ -8319,6 +8560,8 @@ } SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) { + assert(!N->getValueType(0).isScalableVector() && + "Cannot unroll a scalable vector."); assert(N->getNumValues() == 1 && "Can't unroll a vector with multiple results!"); Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -87,6 +87,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Operator.h" #include "llvm/IR/Statepoint.h" #include "llvm/IR/Type.h" @@ -121,6 +122,7 @@ #include using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "isel" @@ -398,11 +400,11 @@ // Build a vector with BUILD_VECTOR or CONCAT_VECTORS from the // intermediate operands. - EVT BuiltVectorTy = - EVT::getVectorVT(*DAG.getContext(), IntermediateVT.getScalarType(), - (IntermediateVT.isVector() - ? IntermediateVT.getVectorNumElements() * NumParts - : NumIntermediates)); + EVT BuiltVectorTy = IntermediateVT.isVector() + ? EVT::getVectorVT(*DAG.getContext(), IntermediateVT.getScalarType(), + IntermediateVT.getVectorElementCount() * NumParts) + : EVT::getVectorVT(*DAG.getContext(), IntermediateVT.getScalarType(), + NumIntermediates); Val = DAG.getNode(IntermediateVT.isVector() ? ISD::CONCAT_VECTORS : ISD::BUILD_VECTOR, DL, BuiltVectorTy, Ops); @@ -706,7 +708,7 @@ NumIntermediates * (IntermediateVT.isVector() ? IntermediateVT.getVectorNumElements() : 1); EVT BuiltVectorTy = EVT::getVectorVT( - *DAG.getContext(), IntermediateVT.getScalarType(), DestVectorNoElts); + *DAG.getContext(), IntermediateVT.getScalarType(), DestVectorNoElts, ValueVT.isScalableVector()); if (Val.getValueType() != BuiltVectorTy) Val = DAG.getNode(ISD::BITCAST, DL, BuiltVectorTy, Val); @@ -1260,9 +1262,26 @@ if (const ConstantFP *CFP = dyn_cast(C)) return DAG.getConstantFP(*CFP, getCurSDLoc(), VT); + if (isa(C)) { + SDLoc DL = getCurSDLoc(); + + // StepVector is implicitly defined with no wrap. + SDNodeFlags Flags; + Flags.setNoSignedWrap(true); + Flags.setNoUnsignedWrap(true); + + EVT EltVT = VT.getVectorElementType(); + SDValue One = DAG.getConstant(1, DL, EltVT); + SDValue Zero = DAG.getConstant(0, DL, EltVT); + return DAG.getNode(ISD::SERIES_VECTOR, DL, VT, Zero, One, Flags); + } + if (isa(C) && !V->getType()->isAggregateType()) return DAG.getUNDEF(VT); + if (isa(C)) + return DAG.getVScale(getCurSDLoc(), VT); + if (const ConstantExpr *CE = dyn_cast(C)) { visit(CE->getOpcode(), *CE); SDValue N1 = NodeMap[V]; @@ -1342,12 +1361,15 @@ EVT EltVT = TLI.getValueType(DAG.getDataLayout(), VecTy->getElementType()); - SDValue Op; + SDValue SplatVal; if (EltVT.isFloatingPoint()) - Op = DAG.getConstantFP(0, getCurSDLoc(), EltVT); + SplatVal = DAG.getConstantFP(0, getCurSDLoc(), EltVT); else - Op = DAG.getConstant(0, getCurSDLoc(), EltVT); - Ops.assign(NumElements, Op); + SplatVal = DAG.getConstant(0, getCurSDLoc(), EltVT); + if (VT.isScalableVector()) + return DAG.getNode(ISD::SPLAT_VECTOR, getCurSDLoc(), VT, SplatVal); + + Ops.assign(NumElements, SplatVal); } // Create a BUILD_VECTOR node. @@ -1559,15 +1581,19 @@ SDValue RetOp = getValue(I.getOperand(0)); SmallVector ValueVTs; - SmallVector Offsets; + SmallVector Offsets; ComputeValueVTs(TLI, DL, I.getOperand(0)->getType(), ValueVTs, &Offsets); unsigned NumValues = ValueVTs.size(); SmallVector Chains(NumValues); for (unsigned i = 0; i != NumValues; ++i) { - // An aggregate return value cannot wrap around the address space, so + // We shouldn't encounter scalable types here + assert(Offsets[i].ScaledBytes == 0 && + "Scalable type in unlowered return"); + // An aggregate return value cannot wrap around the address space, so // offsets to its parts don't wrap either. - SDValue Ptr = DAG.getObjectPtrOffset(getCurSDLoc(), RetPtr, Offsets[i]); + SDValue Ptr = DAG.getObjectPtrOffset(getCurSDLoc(), RetPtr, + Offsets[i].UnscaledBytes); Chains[i] = DAG.getStore( Chain, getCurSDLoc(), SDValue(RetOp.getNode(), RetOp.getResNo() + i), // FIXME: better loc info would be nice. @@ -2749,12 +2775,19 @@ return false; if (!isa(U->getOperand(1))) return false; - for (unsigned i = 0; i < ElemNumToReduce / 2; ++i) - if (ShufInst->getMaskValue(i) != int(i + ElemNumToReduce / 2)) + int EltIdx = 0; + for (unsigned i = 0; i < ElemNumToReduce / 2; ++i) { + if (!ShufInst->getMaskValue(i, EltIdx)) return false; - for (unsigned i = ElemNumToReduce / 2; i < ElemNum; ++i) - if (ShufInst->getMaskValue(i) != -1) + if (EltIdx!= int(i + ElemNumToReduce / 2)) return false; + } + for (unsigned i = ElemNumToReduce / 2; i < ElemNum; ++i) { + if (!ShufInst->getMaskValue(i, EltIdx)) + return false; + if (EltIdx != -1) + return false; + } // There is only one user of this ShuffleVector instruction, which // must be a reduction operation. @@ -3178,17 +3211,138 @@ void SelectionDAGBuilder::visitShuffleVector(const User &I) { SDValue Src1 = getValue(I.getOperand(0)); SDValue Src2 = getValue(I.getOperand(1)); + Value *MaskV = I.getOperand(2); SDLoc DL = getCurSDLoc(); - SmallVector Mask; - ShuffleVectorInst::getShuffleMask(cast(I.getOperand(2)), Mask); - unsigned MaskNumElts = Mask.size(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); + bool IsScalable = VT.isScalableVector(); EVT SrcVT = Src1.getValueType(); unsigned SrcNumElts = SrcVT.getVectorNumElements(); + SmallVector Mask; + if (!ShuffleVectorInst::getShuffleMask(MaskV, Mask)) { + SDValue Mask = getValue(I.getOperand(2)); + unsigned NumElts = VT.getVectorNumElements(); + if (NumElts < SrcNumElts) { + // The result is narrower than the source operands. Create a shuffle + // that is as wide as the source operands, filling the trailing mask + // elements with 0 (to ensure that the indices are in-range, which + // wouldn't be guaranteed if the elements were left undefined). + SDValue ZeroIdx = DAG.getConstant(0, DL, + TLI.getVectorIdxTy(DAG.getDataLayout())); + EVT MaskVT = Mask.getValueType(); + MaskVT = EVT::getVectorVT(*DAG.getContext(), + MaskVT.getVectorElementType(), + SrcNumElts, IsScalable); + Mask = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, MaskVT, + DAG.getConstant(0, DL, MaskVT), + Mask, ZeroIdx); + SDValue Op = DAG.getNode(ISD::VECTOR_SHUFFLE_VAR, DL, SrcVT, + Src1, Src2, Mask); + // Get the low part of the result. + Op = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Op, ZeroIdx); + setValue(&I, Op); + return; + } + if (NumElts > SrcNumElts) { + // The operands are narrower than the result. Join them together + // and pad with undefs to get a vector that has 2*NumElts elements, + // then extract the two halves. + EVT EltVT = VT.getVectorElementType(); + SrcVT = EVT::getVectorVT(*DAG.getContext(), EltVT, SrcNumElts * 2, + IsScalable); + EVT DoubleVT = EVT::getVectorVT(*DAG.getContext(), EltVT, + NumElts * 2, IsScalable); + EVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout()); + SDValue ZeroIdx = DAG.getConstant(0, DL, IdxVT); + SDValue Joined = DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, Src1, Src2); + Joined = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, DoubleVT, + DAG.getUNDEF(DoubleVT), Joined, ZeroIdx); + Src1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Joined, ZeroIdx); + Src2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Joined, + DAG.getConstant(NumElts, DL, IdxVT)); + } + + if (auto *CMask = dyn_cast(MaskV)) + if (CMask->isNullValue()) { + // Splat of first element. + auto FirstElt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, + SrcVT.getScalarType(), Src1, + DAG.getConstant(0, DL, + TLI.getVectorIdxTy(DAG.getDataLayout()))); + + setValue(&I, DAG.getNode(ISD::SPLAT_VECTOR, DL, VT, + FirstElt)); + return; + } + + setValue(&I, DAG.getNode(ISD::VECTOR_SHUFFLE_VAR, DL, VT, + Src1, Src2, Mask)); + return; + } + + // Case where we extract the fixed lower k elements: + // + // -> < k x ty>, with k <= m. + unsigned MaskNumElts = Mask.size(); + if (SrcVT.isScalableVector() && MaskNumElts <= SrcNumElts) { + // The to extraction is performed in 2 steps + // (notice m >= k here): + // + // 1. First, the widest fixed type is extracted from the scalable + // type, as becomes . + // + // 2. Then, if k < m, a fixed shuffle from to is + // performed. + // + // In case of FP data, we first bitcast the scalable FP vector to + // a Int one, do the scalable to fixed shuffle on integer data, + // and then we bitcast the final fixed Int vector to the epected + // FP fixed one. + assert(*std::max_element(Mask.begin(), Mask.end()) < ((int)SrcNumElts) && + "Index exceeded non-scalable part of vector"); + assert(!VT.isScalableVector() && + "Expected to extract a fixed width vector"); + + SDValue Input = Src1; + EVT TypeForExtraction = VT; + const bool NeedBitcast = SrcVT.isFloatingPoint(); + // If input is FP, bitcast to . + if (NeedBitcast) { + EVT ScalableIntSrcVT = SrcVT.changeTypeToInteger(); + // Convert the FP Scalable input into Scalable Int input. + Input = DAG.getNode(ISD::BITCAST, DL, ScalableIntSrcVT, Input); + TypeForExtraction = ScalableIntSrcVT; + } + + // Extract from . + SDValue ZeroIdx = + DAG.getConstant(0, DL, TLI.getVectorIdxTy(DAG.getDataLayout())); + EVT FixedVT = EVT::getVectorVT(*DAG.getContext(), + TypeForExtraction.getVectorElementType(), + SrcVT.getVectorNumElements(), false); + SDValue NewV = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FixedVT, Input, ZeroIdx); + + // Bitcast back to FP lanes if needed. + if (NeedBitcast) { + // Build the type , because we need it when k < m. + EVT FloatFixedVT = + EVT::getVectorVT(*DAG.getContext(), SrcVT.getVectorElementType(), + SrcVT.getVectorNumElements(), false); + + NewV = DAG.getNode(ISD::BITCAST, DL, FloatFixedVT, NewV); + } + + // If k < m, do an additional extraction from to . + if (MaskNumElts < SrcNumElts) + NewV = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, NewV, ZeroIdx); + + setValue(&I, DAG.getVectorShuffle(VT, DL, NewV, DAG.getUNDEF(VT), Mask)); + return; + } + if (SrcNumElts == MaskNumElts) { setValue(&I, DAG.getVectorShuffle(VT, DL, Src1, Src2, Mask)); return; @@ -3467,14 +3621,19 @@ SDValue N = getValue(Op0); SDLoc dl = getCurSDLoc(); + bool IsScalable = false; + if (auto *VTy = dyn_cast(I.getType())) + IsScalable = VTy->isScalable(); + // Normalize Vector GEP - all scalar operands should be converted to the // splat vector. unsigned VectorWidth = I.getType()->isVectorTy() ? cast(I.getType())->getVectorNumElements() : 0; if (VectorWidth && !N.getValueType().isVector()) { - LLVMContext &Context = *DAG.getContext(); - EVT VT = EVT::getVectorVT(Context, N.getValueType(), VectorWidth); + EVT NVT = N.getValueType(); + + MVT VT = MVT::getVectorVT(NVT.getSimpleVT(), VectorWidth, IsScalable); N = DAG.getSplatBuildVector(VT, dl, N); } @@ -3487,19 +3646,43 @@ // N = N + Offset uint64_t Offset = DL->getStructLayout(StTy)->getElementOffset(Field); + auto *FTy = StTy->getElementType(Field); + auto *VTy = dyn_cast(FTy); + const bool Sizeless = VTy && VTy->isScalable(); + + // Cannot currently handle both scaled and unscaled offsets; we need + // to decide how alignment and packing affect sizeless types. + for (auto *Ty : StTy->elements()) { + if (auto *VTy = dyn_cast(Ty)) { + if (Sizeless) + assert(VTy->isScalable() && "Mixed sizeless struct"); + else + assert(!VTy->isScalable() && "Mixed sizeless struct"); + } + } + // In an inbounds GEP with an offset that is nonnegative even when // interpreted as signed, assume there is no unsigned overflow. + // TODO: Check vs. max size with scaled.. SDNodeFlags Flags; if (int64_t(Offset) >= 0 && cast(I).isInBounds()) Flags.setNoUnsignedWrap(true); - N = DAG.getNode(ISD::ADD, dl, N.getValueType(), N, - DAG.getConstant(Offset, dl, N.getValueType()), Flags); + SDValue Increment; + if (Sizeless) + Increment = DAG.getVScale(dl, N.getValueType(), Offset); + else + Increment = DAG.getConstant(Offset, dl, N.getValueType()); + + N = DAG.getNode(ISD::ADD, dl, N.getValueType(), N, Increment, Flags); } } else { - unsigned IdxSize = DAG.getDataLayout().getIndexSizeInBits(AS); - MVT IdxTy = MVT::getIntegerVT(IdxSize); - APInt ElementSize(IdxSize, DL->getTypeAllocSize(GTI.getIndexedType())); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + MVT PtrTy = + DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout(), AS); + unsigned PtrSize = PtrTy.getSizeInBits(); + APInt ElementSize(PtrSize, DL->getTypeAllocSize(GTI.getIndexedType())); + // If this is a scalar constant or a splat vector of constants, // handle it quickly. @@ -3508,14 +3691,32 @@ cast(Idx)->getSplatValue()) CI = cast(cast(Idx)->getSplatValue()); - if (CI) { + if (CI) if (CI->isZero()) continue; - APInt Offs = ElementSize * CI->getValue().sextOrTrunc(IdxSize); - LLVMContext &Context = *DAG.getContext(); - SDValue OffsVal = VectorWidth ? - DAG.getConstant(Offs, dl, EVT::getVectorVT(Context, IdxTy, VectorWidth)) : - DAG.getConstant(Offs, dl, IdxTy); + + VectorType *VecTy = dyn_cast(GTI.getIndexedType()); + if (VecTy && VecTy->isScalable()) { + EVT PTy = TLI.getPointerTy(DAG.getDataLayout(), AS); + + // If index is smaller or larger than intptr_t, truncate or extend it. + SDValue IdxN = getValue(Idx); + IdxN = DAG.getSExtOrTrunc(IdxN, dl, N.getValueType()); + + // Calculate the byte offset. + SDValue VS = DAG.getVScale(dl, PTy, DL->getTypeAllocSize(VecTy)); + SDValue OffsVal = DAG.getNode(ISD::MUL, dl, PTy, VS, IdxN); + + N = DAG.getNode(ISD::ADD, dl, N.getValueType(), N, OffsVal); + continue; + } + + if (CI) { + APInt Offs = ElementSize * CI->getValue().sextOrTrunc(PtrSize); + EVT OffsTy = VectorWidth ? EVT::getVectorVT(*DAG.getContext(), PtrTy, + {VectorWidth, IsScalable}) + : PtrTy; + SDValue OffsVal = DAG.getConstant(Offs, dl, OffsTy); // In an inbouds GEP with an offset that is nonnegative even when // interpreted as signed, assume there is no unsigned overflow. @@ -3531,7 +3732,7 @@ SDValue IdxN = getValue(Idx); if (!IdxN.getValueType().isVector() && VectorWidth) { - EVT VT = EVT::getVectorVT(*Context, IdxN.getValueType(), VectorWidth); + EVT VT = EVT::getVectorVT(*Context, IdxN.getValueType(), VectorWidth, IsScalable); IdxN = DAG.getSplatBuildVector(VT, dl, IdxN); } @@ -3651,7 +3852,7 @@ const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); SmallVector ValueVTs; - SmallVector Offsets; + SmallVector Offsets; ComputeValueVTs(TLI, DAG.getDataLayout(), Ty, ValueVTs, &Offsets); unsigned NumValues = ValueVTs.size(); if (NumValues == 0) @@ -3700,10 +3901,19 @@ Root = Chain; ChainI = 0; } + + uint64_t ScaledOffset = Offsets[i].ScaledBytes; + uint64_t UnscaledOffset = Offsets[i].UnscaledBytes; + SDValue A = DAG.getNode(ISD::ADD, dl, PtrVT, Ptr, - DAG.getConstant(Offsets[i], dl, PtrVT), - Flags); + DAG.getConstant(UnscaledOffset, dl, PtrVT), Flags); + + if (ScaledOffset) { + SDValue VS = DAG.getVScale(dl, PtrVT, ScaledOffset); + A = DAG.getNode(ISD::ADD, dl, PtrVT, A, VS, Flags); + } + auto MMOFlags = MachineMemOperand::MONone; if (isVolatile) MMOFlags |= MachineMemOperand::MOVolatile; @@ -3715,9 +3925,10 @@ MMOFlags |= MachineMemOperand::MODereferenceable; MMOFlags |= TLI.getMMOFlags(I); + // TODO: Do we need to modify MachinePointerInfo for this? SDValue L = DAG.getLoad(ValueVTs[i], dl, Root, A, - MachinePointerInfo(SV, Offsets[i]), Alignment, - MMOFlags, AAInfo, Ranges); + MachinePointerInfo(SV, Offsets[i].UnscaledBytes), + Alignment, MMOFlags, AAInfo, Ranges); Values[i] = L; Chains[ChainI] = L.getValue(1); @@ -3741,12 +3952,12 @@ "call visitStoreToSwiftError when backend supports swifterror"); SmallVector ValueVTs; - SmallVector Offsets; + SmallVector Offsets; const Value *SrcV = I.getOperand(0); ComputeValueVTs(DAG.getTargetLoweringInfo(), DAG.getDataLayout(), SrcV->getType(), ValueVTs, &Offsets); - assert(ValueVTs.size() == 1 && Offsets[0] == 0 && - "expect a single EVT for swifterror"); + assert(ValueVTs.size() == 1 && Offsets[0].UnscaledBytes == 0 && + Offsets[0].ScaledBytes == 0 && "expect a single EVT for swifterror"); SDValue Src = getValue(SrcV); // Create a virtual register, then update the virtual register. @@ -3779,11 +3990,11 @@ "load_from_swift_error should not be constant memory"); SmallVector ValueVTs; - SmallVector Offsets; + SmallVector Offsets; ComputeValueVTs(DAG.getTargetLoweringInfo(), DAG.getDataLayout(), Ty, ValueVTs, &Offsets); - assert(ValueVTs.size() == 1 && Offsets[0] == 0 && - "expect a single EVT for swifterror"); + assert(ValueVTs.size() == 1 && Offsets[0].UnscaledBytes == 0 && + Offsets[0].ScaledBytes == 0 && "expect a single EVT for swifterror"); // Chain, DL, Reg, VT, Glue or Chain, DL, Reg, VT SDValue L = DAG.getCopyFromReg( @@ -3817,7 +4028,7 @@ } SmallVector ValueVTs; - SmallVector Offsets; + SmallVector Offsets; ComputeValueVTs(DAG.getTargetLoweringInfo(), DAG.getDataLayout(), SrcV->getType(), ValueVTs, &Offsets); unsigned NumValues = ValueVTs.size(); @@ -3859,11 +4070,23 @@ Root = Chain; ChainI = 0; } + + uint64_t ScaledOffset = Offsets[i].ScaledBytes; + uint64_t UnscaledOffset = Offsets[i].UnscaledBytes; + SDValue Add = DAG.getNode(ISD::ADD, dl, PtrVT, Ptr, - DAG.getConstant(Offsets[i], dl, PtrVT), Flags); + DAG.getConstant(UnscaledOffset, dl, PtrVT), + Flags); + + if (ScaledOffset) { + SDValue VS = DAG.getVScale(dl, PtrVT, ScaledOffset); + Add = DAG.getNode(ISD::ADD, dl, PtrVT, Add, VS, Flags); + } + SDValue St = DAG.getStore( Root, dl, SDValue(Src.getNode(), Src.getResNo() + i), Add, - MachinePointerInfo(PtrV, Offsets[i]), Alignment, MMOFlags, AAInfo); + MachinePointerInfo(PtrV, Offsets[i].UnscaledBytes), Alignment, + MMOFlags, AAInfo); Chains[ChainI] = St; } @@ -3939,11 +4162,40 @@ // extract the splat value and use it as a uniform base. // In all other cases the function returns 'false'. static bool getUniformBase(const Value* &Ptr, SDValue& Base, SDValue& Index, + ISD::MemIndexType &IndexType, SDValue &Scale, SelectionDAGBuilder* SDB) { + assert (Ptr->getType()->isVectorTy() && "Unexpected pointer type"); SelectionDAG& DAG = SDB->DAG; + + // Look through bitcast instruction iff #elements is same + uint64_t IdxScale = 1; + if (auto *BitCast = dyn_cast(Ptr)) { + Type *BCTy = BitCast->getType(); + Type *BCSrcTy = BitCast->getOperand(0)->getType(); + + if (BCTy->getVectorNumElements() == BCSrcTy->getVectorNumElements()) { + Type *ResPtrTy = + BCTy->getVectorElementType()->getPointerElementType(); + Type *SrcPtrTy = BCSrcTy->getVectorElementType()->getPointerElementType(); + + // Only support this where we need to scale up the stride. + // We cannot safely scale down the stride, because if every + // gather loads from an overlapping address, this is valid + // LLVM IR, but would result in incorrect code. + uint64_t SrcSize = SDB->DL->getTypeStoreSize(SrcPtrTy); + uint64_t ResSize = SDB->DL->getTypeStoreSize(ResPtrTy); + if ((SrcSize >= ResSize) && (SrcSize % ResSize == 0)) { + IdxScale = SrcSize / ResSize; + Ptr = BitCast->getOperand(0); + } + } + } + + // %splat = shuffle(..insert(%ptr)) + // getelementptr %splat, %idx + // -> Base = %ptr, Index = %idx LLVMContext &Context = *DAG.getContext(); - assert(Ptr->getType()->isVectorTy() && "Uexpected pointer type"); const GetElementPtrInst *GEP = dyn_cast(Ptr); if (!GEP) return false; @@ -3975,6 +4227,37 @@ SDB->getCurSDLoc(), TLI.getPointerTy(DL)); Base = SDB->getValue(Ptr); Index = SDB->getValue(IndexVal); + IndexType = ISD::SIGNED_SCALED; + + // Suppress sign extension. + if (auto *Sext = dyn_cast(IndexVal)) { + if (SDB->findValue(Sext->getOperand(0))) { + IndexVal = Sext->getOperand(0); + Index = SDB->getValue(IndexVal); + } + } + // Restrict zero extension to the smallest type that still gets the job done. + else if (auto *Zext = dyn_cast(IndexVal)) { + if (SDB->findValue(Zext->getOperand(0))) { + IndexVal = Zext->getOperand(0); + Index = SDB->getValue(IndexVal); + + EVT VT = Index.getValueType().widenIntegerVectorElementType(Context); + Index = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Index), VT, Index); + } + } + + // Scale the index if we're looking through a bitcast, + // e.g. to load from <2 x struct.bla*> to <2 x i64*> we + // need to have the offsets scaled in i64's. + if (IdxScale > 1) { + // e.g. <2 x i64*> to <2 x i32*> + SDValue ScaleVal = + DAG.getConstant(IdxScale, SDLoc(Index), Index.getValueType()); + Index = DAG.getNode(ISD::MUL, SDLoc(Index), + Index.getValueType(), Index, ScaleVal); + } + if (!Index.getValueType().isVector()) { unsigned GEPWidth = GEP->getType()->getVectorNumElements(); @@ -4002,9 +4285,10 @@ SDValue Base; SDValue Index; + ISD::MemIndexType IndexType; SDValue Scale; const Value *BasePtr = Ptr; - bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this); + bool UniformBase = getUniformBase(BasePtr, Base, Index, IndexType, Scale, this); const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr; MachineMemOperand *MMO = DAG.getMachineFunction(). @@ -4014,11 +4298,16 @@ if (!UniformBase) { Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); Index = getValue(Ptr); + IndexType = ISD::SIGNED_SCALED; + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); + } else if (VT.isScalableVector()) { + // AC 6.10; the current SVE code already takes scaling into account, + // so just set to 1. Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); } SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index, Scale }; SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl, - Ops, MMO); + Ops, MMO, false, IndexType); DAG.setRoot(Scatter); setValue(&I, Scatter); } @@ -4101,9 +4390,10 @@ SDValue Root = DAG.getRoot(); SDValue Base; SDValue Index; + ISD::MemIndexType IndexType; SDValue Scale; const Value *BasePtr = Ptr; - bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this); + bool UniformBase = getUniformBase(BasePtr, Base, Index, IndexType, Scale, this); bool ConstantMemory = false; if (UniformBase && AA && AA->pointsToConstantMemory(MemoryLocation( @@ -4123,11 +4413,16 @@ if (!UniformBase) { Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); Index = getValue(Ptr); + IndexType = ISD::SIGNED_SCALED; + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); + } else if (VT.isScalableVector()) { + // AC 6.10; the current SVE code already takes scaling into account, + // so just set to 1. Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); } SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale }; SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl, - Ops, MMO); + Ops, MMO, ISD::NON_EXTLOAD, IndexType); SDValue OutChain = Gather.getValue(1); if (!ConstantMemory) @@ -8319,23 +8614,26 @@ CLI.Ins.clear(); Type *OrigRetTy = CLI.RetTy; SmallVector RetTys; - SmallVector Offsets; + SmallVector Offsets; auto &DL = CLI.DAG.getDataLayout(); + + // TODO: Handle scaled offsets ComputeValueVTs(*this, DL, CLI.RetTy, RetTys, &Offsets); if (CLI.IsPostTypeLegalization) { // If we are lowering a libcall after legalization, split the return type. SmallVector OldRetTys = std::move(RetTys); - SmallVector OldOffsets = std::move(Offsets); + SmallVector OldOffsets = std::move(Offsets); for (size_t i = 0, e = OldRetTys.size(); i != e; ++i) { EVT RetVT = OldRetTys[i]; - uint64_t Offset = OldOffsets[i]; + FieldOffsets Offset = OldOffsets[i]; MVT RegisterVT = getRegisterType(CLI.RetTy->getContext(), RetVT); unsigned NumRegs = getNumRegisters(CLI.RetTy->getContext(), RetVT); unsigned RegisterVTByteSZ = RegisterVT.getSizeInBits() / 8; RetTys.append(NumRegs, RegisterVT); for (unsigned j = 0; j != NumRegs; ++j) - Offsets.push_back(Offset + j * RegisterVTByteSZ); + Offsets.push_back( + {Offset.UnscaledBytes + j * RegisterVTByteSZ, Offset.ScaledBytes}); } } @@ -8606,14 +8904,14 @@ Flags.setNoUnsignedWrap(true); for (unsigned i = 0; i < NumValues; ++i) { - SDValue Add = CLI.DAG.getNode(ISD::ADD, CLI.DL, PtrVT, DemoteStackSlot, - CLI.DAG.getConstant(Offsets[i], CLI.DL, - PtrVT), Flags); - SDValue L = CLI.DAG.getLoad( - RetTys[i], CLI.DL, CLI.Chain, Add, - MachinePointerInfo::getFixedStack(CLI.DAG.getMachineFunction(), - DemoteStackIdx, Offsets[i]), - /* Alignment = */ 1); + SDValue Add = CLI.DAG.getNode( + ISD::ADD, CLI.DL, PtrVT, DemoteStackSlot, + CLI.DAG.getConstant(Offsets[i].UnscaledBytes, CLI.DL, PtrVT), Flags); + SDValue L = CLI.DAG.getLoad(RetTys[i], CLI.DL, CLI.Chain, Add, + MachinePointerInfo::getFixedStack( + CLI.DAG.getMachineFunction(), + DemoteStackIdx, Offsets[i].UnscaledBytes), + /* Alignment = */ 1); ReturnValues[i] = L; Chains[i] = L.getValue(1); } Index: lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -166,6 +166,7 @@ case ISD::CopyToReg: return "CopyToReg"; case ISD::CopyFromReg: return "CopyFromReg"; case ISD::UNDEF: return "undef"; + case ISD::VSCALE: return "vscale"; case ISD::MERGE_VALUES: return "merge_values"; case ISD::INLINEASM: return "inlineasm"; case ISD::EH_LABEL: return "eh_label"; @@ -262,7 +263,10 @@ case ISD::INSERT_SUBVECTOR: return "insert_subvector"; case ISD::EXTRACT_SUBVECTOR: return "extract_subvector"; case ISD::SCALAR_TO_VECTOR: return "scalar_to_vector"; + case ISD::SPLAT_VECTOR: return "splat_vector"; case ISD::VECTOR_SHUFFLE: return "vector_shuffle"; + case ISD::VECTOR_SHUFFLE_VAR: return "vector_shuffle_var"; + case ISD::SERIES_VECTOR: return "series_vector"; case ISD::CARRY_FALSE: return "carry_false"; case ISD::ADDC: return "addc"; case ISD::ADDE: return "adde"; @@ -619,7 +623,25 @@ OS << ", " << AM; OS << ">"; - } else if (const StoreSDNode *ST = dyn_cast(this)) { + } + else if (const MaskedLoadSDNode *LD = dyn_cast(this)) { + OS << "<"; + + printMemOperand(OS, *LD->getMemOperand(), G); + + bool doExt = true; + switch (LD->getExtensionType()) { + default: doExt = false; break; + case ISD::EXTLOAD: OS << ", anyext"; break; + case ISD::SEXTLOAD: OS << ", sext"; break; + case ISD::ZEXTLOAD: OS << ", zext"; break; + } + if (doExt) + OS << " from " << LD->getMemoryVT().getEVTString(); + + OS << ">"; + } + else if (const StoreSDNode *ST = dyn_cast(this)) { OS << "<"; printMemOperand(OS, *ST->getMemOperand(), G); @@ -631,6 +653,14 @@ OS << ", " << AM; OS << ">"; + } else if (const MaskedStoreSDNode *ST = dyn_cast(this)) { + OS << "<"; + printMemOperand(OS, *ST->getMemOperand(), G); + + if (ST->isTruncatingStore()) + OS << ", trunc to " << ST->getMemoryVT().getEVTString(); + + OS << ">"; } else if (const MemSDNode* M = dyn_cast(this)) { OS << "<"; printMemOperand(OS, *M->getMemOperand(), G); Index: lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -4270,6 +4270,8 @@ assert(DataVT.getVectorNumElements() == MaskVT.getVectorNumElements() && "Incompatible types of Data and Mask"); if (IsCompressedMemory) { + assert(!DataVT.isScalableVector() && + "Cannot currently handle compressed memory with scalable vectors"); // Incrementing the pointer according to number of '1's in the mask. EVT MaskIntVT = EVT::getIntegerVT(*DAG.getContext(), MaskVT.getSizeInBits()); SDValue MaskInIntReg = DAG.getBitcast(MaskIntVT, Mask); @@ -4285,6 +4287,8 @@ SDValue Scale = DAG.getConstant(DataVT.getScalarSizeInBits() / 8, DL, AddrVT); Increment = DAG.getNode(ISD::MUL, DL, AddrVT, Increment, Scale); + } else if (DataVT.isScalableVector()) { + Increment = DAG.getVScale(DL, AddrVT, DataVT.getSizeInBits() / 8); } else Increment = DAG.getConstant(DataVT.getStoreSize(), DL, AddrVT); @@ -4399,3 +4403,64 @@ } return SDValue(); } + +SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const { + assert(!Node->getOperand(0).getValueType().isScalableVector() && + "Expanding reductions for scalable vectors is undefined."); + + SDLoc dl(Node); + bool NoNaN = Node->getFlags().hasNoNaNs(); + unsigned BaseOpcode = 0; + switch (Node->getOpcode()) { + default: llvm_unreachable("Expected VECREDUCE opcode"); + case ISD::VECREDUCE_FADD: BaseOpcode = ISD::FADD; break; + case ISD::VECREDUCE_FMUL: BaseOpcode = ISD::FMUL; break; + case ISD::VECREDUCE_ADD: BaseOpcode = ISD::ADD; break; + case ISD::VECREDUCE_MUL: BaseOpcode = ISD::MUL; break; + case ISD::VECREDUCE_AND: BaseOpcode = ISD::AND; break; + case ISD::VECREDUCE_OR: BaseOpcode = ISD::OR; break; + case ISD::VECREDUCE_XOR: BaseOpcode = ISD::XOR; break; + case ISD::VECREDUCE_SMAX: BaseOpcode = ISD::SMAX; break; + case ISD::VECREDUCE_SMIN: BaseOpcode = ISD::SMIN; break; + case ISD::VECREDUCE_UMAX: BaseOpcode = ISD::UMAX; break; + case ISD::VECREDUCE_UMIN: BaseOpcode = ISD::UMIN; break; + case ISD::VECREDUCE_FMAX: + BaseOpcode = NoNaN ? ISD::FMAXNUM : ISD::FMAXIMUM; + break; + case ISD::VECREDUCE_FMIN: + BaseOpcode = NoNaN ? ISD::FMINNUM : ISD::FMINIMUM; + break; + } + + SDValue Op = Node->getOperand(0); + EVT VT = Op.getValueType(); + + // Try to use a shuffle reduction for power of two vectors. + if (VT.isPow2VectorType()) { + while (VT.getVectorNumElements() > 1) { + EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext()); + if (!isOperationLegalOrCustom(BaseOpcode, HalfVT)) + break; + + SDValue Lo, Hi; + std::tie(Lo, Hi) = DAG.SplitVector(Op, dl); + Op = DAG.getNode(BaseOpcode, dl, HalfVT, Lo, Hi); + VT = HalfVT; + } + } + + EVT EltVT = VT.getVectorElementType(); + unsigned NumElts = VT.getVectorNumElements(); + + SmallVector Ops; + DAG.ExtractVectorElements(Op, Ops, 0, NumElts); + + SDValue Res = Ops[0]; + for (unsigned i = 1; i < NumElts; i++) + Res = DAG.getNode(BaseOpcode, dl, EltVT, Res, Ops[i], Node->getFlags()); + + // Result type may be wider than element type. + if (EltVT != Node->getValueType(0)) + Res = DAG.getNode(ISD::ANY_EXTEND, dl, Node->getValueType(0), Res); + return Res; +} Index: lib/CodeGen/StackColoring.cpp =================================================================== --- lib/CodeGen/StackColoring.cpp +++ lib/CodeGen/StackColoring.cpp @@ -1228,7 +1228,12 @@ // We use -1 to denote a uninteresting slot. Place these slots at the end. if (LHS == -1) return false; if (RHS == -1) return true; - // Sort according to size. + + // First sort according to Region + if (MFI->getStackID(LHS) != MFI->getStackID(RHS)) + return MFI->getStackID(LHS) > MFI->getStackID(RHS); + + // Then sort according to size. return MFI->getObjectSize(LHS) > MFI->getObjectSize(RHS); }); @@ -1248,6 +1253,11 @@ int FirstSlot = SortedSlots[I]; int SecondSlot = SortedSlots[J]; + + // StackRegions must match + if (MFI->getStackID(FirstSlot) != MFI->getStackID(SecondSlot)) + continue; + LiveInterval *First = &*Intervals[FirstSlot]; LiveInterval *Second = &*Intervals[SecondSlot]; auto &FirstS = LiveStarts[FirstSlot]; Index: lib/CodeGen/StackSlotColoring.cpp =================================================================== --- lib/CodeGen/StackSlotColoring.cpp +++ lib/CodeGen/StackSlotColoring.cpp @@ -76,7 +76,7 @@ // OrigAlignments - Alignments of stack objects before coloring. SmallVector OrigAlignments; - // OrigSizes - Sizess of stack objects before coloring. + // OrigSizes - Sizes of stack objects before coloring. SmallVector OrigSizes; // AllColors - If index is set, it's a spill slot, i.e. color. @@ -230,10 +230,12 @@ OrigAlignments[FI] = MFI->getObjectAlignment(FI); OrigSizes[FI] = MFI->getObjectSize(FI); - auto StackID = MFI->getStackID(FI); + unsigned StackID = MFI->getStackID(FI); if (StackID != 0) { - AllColors.resize(StackID + 1); - UsedColors.resize(StackID + 1); + if (StackID >= AllColors.size()) { + AllColors.resize(StackID+1); + UsedColors.resize(StackID+1); + } AllColors[StackID].resize(LastFI); UsedColors[StackID].resize(LastFI); } Index: lib/CodeGen/TargetInstrInfo.cpp =================================================================== --- lib/CodeGen/TargetInstrInfo.cpp +++ lib/CodeGen/TargetInstrInfo.cpp @@ -339,42 +339,32 @@ return MadeChange; } -bool TargetInstrInfo::hasLoadFromStackSlot(const MachineInstr &MI, - const MachineMemOperand *&MMO, - int &FrameIndex) const { +bool TargetInstrInfo::hasLoadFromStackSlot( + const MachineInstr &MI, + SmallVectorImpl &Accesses) const { + size_t StartSize = Accesses.size(); for (MachineInstr::mmo_iterator o = MI.memoperands_begin(), oe = MI.memoperands_end(); o != oe; ++o) { - if ((*o)->isLoad()) { - if (const FixedStackPseudoSourceValue *Value = - dyn_cast_or_null( - (*o)->getPseudoValue())) { - FrameIndex = Value->getFrameIndex(); - MMO = *o; - return true; - } - } + if ((*o)->isLoad() && + dyn_cast_or_null((*o)->getPseudoValue())) + Accesses.push_back(*o); } - return false; + return Accesses.size() != StartSize; } -bool TargetInstrInfo::hasStoreToStackSlot(const MachineInstr &MI, - const MachineMemOperand *&MMO, - int &FrameIndex) const { +bool TargetInstrInfo::hasStoreToStackSlot( + const MachineInstr &MI, + SmallVectorImpl &Accesses) const { + size_t StartSize = Accesses.size(); for (MachineInstr::mmo_iterator o = MI.memoperands_begin(), oe = MI.memoperands_end(); o != oe; ++o) { - if ((*o)->isStore()) { - if (const FixedStackPseudoSourceValue *Value = - dyn_cast_or_null( - (*o)->getPseudoValue())) { - FrameIndex = Value->getFrameIndex(); - MMO = *o; - return true; - } - } + if ((*o)->isStore() && + dyn_cast_or_null((*o)->getPseudoValue())) + Accesses.push_back(*o); } - return false; + return Accesses.size() != StartSize; } bool TargetInstrInfo::getStackSlotRange(const TargetRegisterClass *RC, Index: lib/CodeGen/TargetLoweringBase.cpp =================================================================== --- lib/CodeGen/TargetLoweringBase.cpp +++ lib/CodeGen/TargetLoweringBase.cpp @@ -643,10 +643,28 @@ setOperationAction(ISD::ANY_EXTEND_VECTOR_INREG, VT, Expand); setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, VT, Expand); setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, VT, Expand); + setOperationAction(ISD::VECTOR_SHUFFLE_VAR, VT, Expand); + setOperationAction(ISD::SPLAT_VECTOR, VT, Expand); + setOperationAction(ISD::SERIES_VECTOR, VT, Expand); } // For most targets @llvm.get.dynamic.area.offset just returns 0. setOperationAction(ISD::GET_DYNAMIC_AREA_OFFSET, VT, Expand); + + // Vector reduction default to expand. + setOperationAction(ISD::VECREDUCE_FADD, VT, Expand); + setOperationAction(ISD::VECREDUCE_FMUL, VT, Expand); + setOperationAction(ISD::VECREDUCE_ADD, VT, Expand); + setOperationAction(ISD::VECREDUCE_MUL, VT, Expand); + setOperationAction(ISD::VECREDUCE_AND, VT, Expand); + setOperationAction(ISD::VECREDUCE_OR, VT, Expand); + setOperationAction(ISD::VECREDUCE_XOR, VT, Expand); + setOperationAction(ISD::VECREDUCE_SMAX, VT, Expand); + setOperationAction(ISD::VECREDUCE_SMIN, VT, Expand); + setOperationAction(ISD::VECREDUCE_UMAX, VT, Expand); + setOperationAction(ISD::VECREDUCE_UMIN, VT, Expand); + setOperationAction(ISD::VECREDUCE_FMAX, VT, Expand); + setOperationAction(ISD::VECREDUCE_FMIN, VT, Expand); } // Most targets ignore the @llvm.prefetch intrinsic. @@ -734,9 +752,10 @@ "Promote may not follow Expand or Promote"); if (LA == TypeSplitVector) - return LegalizeKind(LA, - EVT::getVectorVT(Context, SVT.getVectorElementType(), - SVT.getVectorNumElements() / 2)); + return LegalizeKind(LA,EVT::getVectorVT(Context, + SVT.getVectorElementType(), + SVT.getVectorElementCount()/2)); + if (LA == TypeScalarizeVector) return LegalizeKind(LA, SVT.getVectorElementType()); return LegalizeKind(LA, NVT); @@ -765,9 +784,10 @@ // Handle vector types. unsigned NumElts = VT.getVectorNumElements(); EVT EltVT = VT.getVectorElementType(); + bool IsScalable = VT.isScalableVector(); // Vectors with only one element are always scalarized. - if (NumElts == 1) + if ((NumElts == 1) && !IsScalable) return LegalizeKind(TypeScalarizeVector, EltVT); // Try to widen vector elements until the element type is a power of two and @@ -778,7 +798,7 @@ // widened, for example <3 x i8> -> <4 x i8>. if (!VT.isPow2VectorType()) { NumElts = (unsigned)NextPowerOf2(NumElts); - EVT NVT = EVT::getVectorVT(Context, EltVT, NumElts); + EVT NVT = EVT::getVectorVT(Context, EltVT, NumElts, IsScalable); return LegalizeKind(TypeWidenVector, NVT); } @@ -789,7 +809,8 @@ // <4 x i140> -> <2 x i140> if (LK.first == TypeExpandInteger) return LegalizeKind(TypeSplitVector, - EVT::getVectorVT(Context, EltVT, NumElts / 2)); + EVT::getVectorVT(Context, EltVT, NumElts / 2, + IsScalable)); // Promote the integer element types until a legal vector type is found // or until the element integer type is too big. If a legal type was not @@ -810,11 +831,12 @@ break; // Build a new vector type and check if it is legal. - MVT NVT = MVT::getVectorVT(EltVT.getSimpleVT(), NumElts); + MVT NVT = MVT::getVectorVT(EltVT.getSimpleVT(), NumElts, IsScalable); // Found a legal promoted vector type. if (NVT != MVT() && ValueTypeActions.getTypeAction(NVT) == TypeLegal) return LegalizeKind(TypePromoteInteger, - EVT::getVectorVT(Context, EltVT, NumElts)); + EVT::getVectorVT(Context, EltVT, NumElts, + IsScalable)); } // Reset the type to the unexpanded type if we did not find a legal vector @@ -833,7 +855,8 @@ // there are no skipped intermediate vector types in the simple types. if (!EltVT.isSimple()) break; - MVT LargerVector = MVT::getVectorVT(EltVT.getSimpleVT(), NumElts); + MVT LargerVector = MVT::getVectorVT(EltVT.getSimpleVT(), NumElts, + IsScalable); if (LargerVector == MVT()) break; @@ -849,7 +872,8 @@ } // Vectors with illegal element types are expanded. - EVT NVT = EVT::getVectorVT(Context, EltVT, VT.getVectorNumElements() / 2); + EVT NVT = EVT::getVectorVT(Context, EltVT, + VT.getVectorNumElements() / 2, IsScalable); return LegalizeKind(TypeSplitVector, NVT); } @@ -860,6 +884,7 @@ // Figure out the right, legal destination reg to copy into. unsigned NumElts = VT.getVectorNumElements(); MVT EltTy = VT.getVectorElementType(); + bool IsScalable = VT.isScalableVector(); unsigned NumVectorRegs = 1; @@ -872,14 +897,15 @@ // Divide the input until we get to a supported size. This will always // end with a scalar if the target doesn't support vectors. - while (NumElts > 1 && !TLI->isTypeLegal(MVT::getVectorVT(EltTy, NumElts))) { + while (NumElts > 1 && + !TLI->isTypeLegal(MVT::getVectorVT(EltTy, NumElts, IsScalable))) { NumElts >>= 1; NumVectorRegs <<= 1; } NumIntermediates = NumVectorRegs; - MVT NewVT = MVT::getVectorVT(EltTy, NumElts); + MVT NewVT = MVT::getVectorVT(EltTy, NumElts, IsScalable); if (!TLI->isTypeLegal(NewVT)) NewVT = EltTy; IntermediateVT = NewVT; @@ -1162,6 +1188,7 @@ MVT EltVT = VT.getVectorElementType(); unsigned NElts = VT.getVectorNumElements(); + bool IsScalable = VT.isScalableVector(); bool IsLegalWiderType = false; LegalizeTypeAction PreferredAction = getPreferredVectorAction(VT); switch (PreferredAction) { @@ -1173,7 +1200,8 @@ // Promote vectors of integers to vectors with the same number // of elements, with a wider element type. if (SVT.getScalarSizeInBits() > EltVT.getSizeInBits() && - SVT.getVectorNumElements() == NElts && isTypeLegal(SVT)) { + SVT.getVectorNumElements() == NElts && + SVT.isScalableVector() == IsScalable && isTypeLegal(SVT)) { TransformToType[i] = SVT; RegisterTypeForVT[i] = SVT; NumRegistersForVT[i] = 1; @@ -1191,7 +1219,9 @@ for (unsigned nVT = i + 1; nVT <= MVT::LAST_VECTOR_VALUETYPE; ++nVT) { MVT SVT = (MVT::SimpleValueType) nVT; if (SVT.getVectorElementType() == EltVT - && SVT.getVectorNumElements() > NElts && isTypeLegal(SVT)) { + && SVT.getVectorNumElements() > NElts + && SVT.isScalableVector() == IsScalable + && isTypeLegal(SVT)) { TransformToType[i] = SVT; RegisterTypeForVT[i] = SVT; NumRegistersForVT[i] = 1; @@ -1223,8 +1253,9 @@ ValueTypeActions.setTypeAction(VT, TypeSplitVector); else // Set type action according to the number of elements. - ValueTypeActions.setTypeAction(VT, NElts == 1 ? TypeScalarizeVector - : TypeSplitVector); + ValueTypeActions.setTypeAction(VT, + (NElts == 1) && !IsScalable ? TypeScalarizeVector : + TypeSplitVector); } else { TransformToType[i] = NVT; ValueTypeActions.setTypeAction(VT, TypeWidenVector); @@ -1273,6 +1304,7 @@ unsigned &NumIntermediates, MVT &RegisterVT) const { unsigned NumElts = VT.getVectorNumElements(); + bool IsScalable = VT.isScalableVector(); // If there is a wider vector type with the same element type as this one, // or a promoted vector type that has the same number of elements which @@ -1304,15 +1336,15 @@ // Divide the input until we get to a supported size. This will always // end with a scalar if the target doesn't support vectors. - while (NumElts > 1 && !isTypeLegal( - EVT::getVectorVT(Context, EltTy, NumElts))) { + while (NumElts > 1 && + !isTypeLegal(EVT::getVectorVT(Context, EltTy, NumElts, IsScalable))) { NumElts >>= 1; NumVectorRegs <<= 1; } NumIntermediates = NumVectorRegs; - EVT NewVT = EVT::getVectorVT(Context, EltTy, NumElts); + EVT NewVT = EVT::getVectorVT(Context, EltTy, NumElts, IsScalable); if (!isTypeLegal(NewVT)) NewVT = EltTy; IntermediateVT = NewVT; Index: lib/CodeGen/TargetPassConfig.cpp =================================================================== --- lib/CodeGen/TargetPassConfig.cpp +++ lib/CodeGen/TargetPassConfig.cpp @@ -1081,6 +1081,9 @@ addPass(&TwoAddressInstructionPassID, false); addPass(&RegisterCoalescerID); + // Allow targets to change the live ranges after coalescing + addPostCoalesce(); + // The machine scheduler may accidentally create disconnected components // when moving subregister definitions around, avoid this by splitting them to // separate vregs before. Splitting can also improve reg. allocation quality. Index: lib/CodeGen/ValueTypes.cpp =================================================================== --- lib/CodeGen/ValueTypes.cpp +++ lib/CodeGen/ValueTypes.cpp @@ -10,6 +10,7 @@ #include "llvm/CodeGen/ValueTypes.h" #include "llvm/ADT/StringExtras.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Type.h" #include "llvm/Support/ErrorHandling.h" using namespace llvm; @@ -22,7 +23,14 @@ EVT EVT::changeExtendedVectorElementTypeToInteger() const { LLVMContext &Context = LLVMTy->getContext(); EVT IntTy = getIntegerVT(Context, getScalarSizeInBits()); - return getVectorVT(Context, IntTy, getVectorNumElements()); + return getVectorVT(Context, IntTy, getVectorNumElements(), + isScalableVector()); +} + +EVT EVT::changeExtendedVectorElementType(EVT EltVT) const { + LLVMContext &Context = LLVMTy->getContext(); + return getVectorVT(Context, EltVT, getVectorNumElements(), + isScalableVector()); } EVT EVT::getExtendedIntegerVT(LLVMContext &Context, unsigned BitWidth) { @@ -33,9 +41,19 @@ } EVT EVT::getExtendedVectorVT(LLVMContext &Context, EVT VT, - unsigned NumElements) { + unsigned NumElements, bool IsScalable) { EVT ResultVT; - ResultVT.LLVMTy = VectorType::get(VT.getTypeForEVT(Context), NumElements); + ResultVT.LLVMTy = VectorType::get(VT.getTypeForEVT(Context), NumElements, + IsScalable); + assert(ResultVT.isExtended() && "Type is not extended!"); + return ResultVT; +} + +EVT EVT::getExtendedVectorVT(LLVMContext &Context, EVT VT, + MVT::ElementCount EC) { + EVT ResultVT; + ResultVT.LLVMTy = VectorType::get(VT.getTypeForEVT(Context), + { EC.Min, EC.Scalable }); assert(ResultVT.isExtended() && "Type is not extended!"); return ResultVT; } @@ -92,6 +110,10 @@ return isExtendedVector() && getExtendedSizeInBits() == 2048; } +bool EVT::isExtendedScalableVector() const { + return isExtendedVector() && cast(LLVMTy)->isScalable(); +} + EVT EVT::getExtendedVectorElementType() const { assert(isExtended() && "Type is not extended!"); return EVT::getEVT(cast(LLVMTy)->getElementType()); @@ -102,6 +124,12 @@ return cast(LLVMTy)->getNumElements(); } +MVT::ElementCount EVT::getExtendedVectorElementCount() const { + assert(isExtended() && "Type is not extended!"); + auto EC = cast(LLVMTy)->getElementCount(); + return { EC.Min, EC.Scalable }; +} + unsigned EVT::getExtendedSizeInBits() const { assert(isExtended() && "Type is not extended!"); if (IntegerType *ITy = dyn_cast(LLVMTy)) @@ -115,9 +143,14 @@ std::string EVT::getEVTString() const { switch (V.SimpleTy) { default: - if (isVector()) - return "v" + utostr(getVectorNumElements()) + - getVectorElementType().getEVTString(); + if (isVector()) { + std::string Prefix = "v"; + if (isScalableVector()) + Prefix = "nxv"; + + return Prefix + utostr(getVectorNumElements()) + + getVectorElementType().getEVTString(); + } if (isInteger()) return "i" + utostr(getSizeInBits()); llvm_unreachable("Invalid EVT!"); @@ -183,6 +216,8 @@ case MVT::v2f16: return "v2f16"; case MVT::v4f16: return "v4f16"; case MVT::v8f16: return "v8f16"; + case MVT::v16f16: return "v16f16"; + case MVT::v32f16: return "v32f16"; case MVT::v4f32: return "v4f32"; case MVT::v8f32: return "v8f32"; case MVT::v16f32: return "v16f32"; @@ -190,6 +225,49 @@ case MVT::v2f64: return "v2f64"; case MVT::v4f64: return "v4f64"; case MVT::v8f64: return "v8f64"; + case MVT::nxv2i1: return "nxv2i1"; + case MVT::nxv4i1: return "nxv4i1"; + case MVT::nxv8i1: return "nxv8i1"; + case MVT::nxv16i1: return "nxv16i1"; + case MVT::nxv32i1: return "nxv32i1"; + case MVT::nxv1i8: return "nxv1i8"; + case MVT::nxv2i8: return "nxv2i8"; + case MVT::nxv4i8: return "nxv4i8"; + case MVT::nxv8i8: return "nxv8i8"; + case MVT::nxv16i8: return "nxv16i8"; + case MVT::nxv32i8: return "nxv32i8"; + case MVT::nxv1i16: return "nxv1i16"; + case MVT::nxv2i16: return "nxv2i16"; + case MVT::nxv4i16: return "nxv4i16"; + case MVT::nxv8i16: return "nxv8i16"; + case MVT::nxv16i16:return "nxv16i16"; + case MVT::nxv32i16:return "nxv32i16"; + case MVT::nxv1i32: return "nxv1i32"; + case MVT::nxv2i32: return "nxv2i32"; + case MVT::nxv4i32: return "nxv4i32"; + case MVT::nxv8i32: return "nxv8i32"; + case MVT::nxv16i32:return "nxv16i32"; + case MVT::nxv32i32:return "nxv32i32"; + case MVT::nxv1i64: return "nxv1i64"; + case MVT::nxv2i64: return "nxv2i64"; + case MVT::nxv4i64: return "nxv4i64"; + case MVT::nxv8i64: return "nxv8i64"; + case MVT::nxv16i64:return "nxv16i64"; + case MVT::nxv32i64:return "nxv32i64"; + + case MVT::nxv2f16: return "nxv2f16"; + case MVT::nxv4f16: return "nxv4f16"; + case MVT::nxv8f16: return "nxv8f16"; + case MVT::nxv1f32: return "nxv1f32"; + case MVT::nxv2f32: return "nxv2f32"; + case MVT::nxv4f32: return "nxv4f32"; + case MVT::nxv8f32: return "nxv8f32"; + case MVT::nxv16f32: return "nxv16f32"; + case MVT::nxv1f64: return "nxv1f64"; + case MVT::nxv2f64: return "nxv2f64"; + case MVT::nxv4f64: return "nxv4f64"; + case MVT::nxv8f64: return "nxv8f64"; + case MVT::Metadata:return "Metadata"; case MVT::Untyped: return "Untyped"; case MVT::ExceptRef: return "ExceptRef"; @@ -262,15 +340,61 @@ case MVT::v2f16: return VectorType::get(Type::getHalfTy(Context), 2); case MVT::v4f16: return VectorType::get(Type::getHalfTy(Context), 4); case MVT::v8f16: return VectorType::get(Type::getHalfTy(Context), 8); + case MVT::v16f16: return VectorType::get(Type::getHalfTy(Context), 16); + case MVT::v32f16: return VectorType::get(Type::getHalfTy(Context), 32); case MVT::v1f32: return VectorType::get(Type::getFloatTy(Context), 1); case MVT::v2f32: return VectorType::get(Type::getFloatTy(Context), 2); case MVT::v4f32: return VectorType::get(Type::getFloatTy(Context), 4); case MVT::v8f32: return VectorType::get(Type::getFloatTy(Context), 8); - case MVT::v16f32: return VectorType::get(Type::getFloatTy(Context), 16); + case MVT::v16f32: return VectorType::get(Type::getFloatTy(Context), 16); case MVT::v1f64: return VectorType::get(Type::getDoubleTy(Context), 1); case MVT::v2f64: return VectorType::get(Type::getDoubleTy(Context), 2); case MVT::v4f64: return VectorType::get(Type::getDoubleTy(Context), 4); case MVT::v8f64: return VectorType::get(Type::getDoubleTy(Context), 8); + + case MVT::nxv2i1: return VectorType::get(Type::getInt1Ty(Context), 2, true); + case MVT::nxv4i1: return VectorType::get(Type::getInt1Ty(Context), 4, true); + case MVT::nxv8i1: return VectorType::get(Type::getInt1Ty(Context), 8, true); + case MVT::nxv16i1: return VectorType::get(Type::getInt1Ty(Context), 16, true); + case MVT::nxv32i1: return VectorType::get(Type::getInt1Ty(Context), 32, true); + case MVT::nxv1i8: return VectorType::get(Type::getInt8Ty(Context), 1, true); + case MVT::nxv2i8: return VectorType::get(Type::getInt8Ty(Context), 2, true); + case MVT::nxv4i8: return VectorType::get(Type::getInt8Ty(Context), 4, true); + case MVT::nxv8i8: return VectorType::get(Type::getInt8Ty(Context), 8, true); + case MVT::nxv16i8: return VectorType::get(Type::getInt8Ty(Context), 16, true); + case MVT::nxv32i8: return VectorType::get(Type::getInt8Ty(Context), 32, true); + case MVT::nxv1i16: return VectorType::get(Type::getInt16Ty(Context), 1, true); + case MVT::nxv2i16: return VectorType::get(Type::getInt16Ty(Context), 2, true); + case MVT::nxv4i16: return VectorType::get(Type::getInt16Ty(Context), 4, true); + case MVT::nxv8i16: return VectorType::get(Type::getInt16Ty(Context), 8, true); + case MVT::nxv16i16:return VectorType::get(Type::getInt16Ty(Context), 16, true); + case MVT::nxv32i16:return VectorType::get(Type::getInt16Ty(Context), 32, true); + case MVT::nxv1i32: return VectorType::get(Type::getInt32Ty(Context), 1, true); + case MVT::nxv2i32: return VectorType::get(Type::getInt32Ty(Context), 2, true); + case MVT::nxv4i32: return VectorType::get(Type::getInt32Ty(Context), 4, true); + case MVT::nxv8i32: return VectorType::get(Type::getInt32Ty(Context), 8, true); + case MVT::nxv16i32:return VectorType::get(Type::getInt32Ty(Context), 16, true); + case MVT::nxv32i32:return VectorType::get(Type::getInt32Ty(Context), 32, true); + case MVT::nxv1i64: return VectorType::get(Type::getInt64Ty(Context), 1, true); + case MVT::nxv2i64: return VectorType::get(Type::getInt64Ty(Context), 2, true); + case MVT::nxv4i64: return VectorType::get(Type::getInt64Ty(Context), 4, true); + case MVT::nxv8i64: return VectorType::get(Type::getInt64Ty(Context), 8, true); + case MVT::nxv16i64:return VectorType::get(Type::getInt64Ty(Context), 16, true); + case MVT::nxv32i64:return VectorType::get(Type::getInt64Ty(Context), 32, true); + + case MVT::nxv2f16: return VectorType::get(Type::getHalfTy(Context), 2, true); + case MVT::nxv4f16: return VectorType::get(Type::getHalfTy(Context), 4, true); + case MVT::nxv8f16: return VectorType::get(Type::getHalfTy(Context), 8, true); + case MVT::nxv1f32: return VectorType::get(Type::getFloatTy(Context), 1, true); + case MVT::nxv2f32: return VectorType::get(Type::getFloatTy(Context), 2, true); + case MVT::nxv4f32: return VectorType::get(Type::getFloatTy(Context), 4, true); + case MVT::nxv8f32: return VectorType::get(Type::getFloatTy(Context), 8, true); + case MVT::nxv16f32:return VectorType::get(Type::getFloatTy(Context), 16, true); + case MVT::nxv1f64: return VectorType::get(Type::getDoubleTy(Context), 1, true); + case MVT::nxv2f64: return VectorType::get(Type::getDoubleTy(Context), 2, true); + case MVT::nxv4f64: return VectorType::get(Type::getDoubleTy(Context), 4, true); + case MVT::nxv8f64: return VectorType::get(Type::getDoubleTy(Context), 8, true); + case MVT::Metadata: return Type::getMetadataTy(Context); } } @@ -297,8 +421,9 @@ case Type::PointerTyID: return MVT(MVT::iPTR); case Type::VectorTyID: { VectorType *VTy = cast(Ty); + auto EC = VTy->getElementCount(); return getVectorVT( - getVT(VTy->getElementType(), false), VTy->getNumElements()); + getVT(VTy->getElementType(), false), { EC.Min, EC.Scalable }); } } } @@ -314,8 +439,9 @@ return getIntegerVT(Ty->getContext(), cast(Ty)->getBitWidth()); case Type::VectorTyID: { VectorType *VTy = cast(Ty); + auto EC = VTy->getElementCount(); return getVectorVT(Ty->getContext(), getEVT(VTy->getElementType(), false), - VTy->getNumElements()); + { EC.Min, EC.Scalable }); } } } Index: lib/IR/AsmWriter.cpp =================================================================== --- lib/IR/AsmWriter.cpp +++ lib/IR/AsmWriter.cpp @@ -363,6 +363,7 @@ case CallingConv::ARM_APCS: Out << "arm_apcscc"; break; case CallingConv::ARM_AAPCS: Out << "arm_aapcscc"; break; case CallingConv::ARM_AAPCS_VFP: Out << "arm_aapcs_vfpcc"; break; + case CallingConv::AArch64_VectorCall: Out << "aarch64_vector_pcs"; break; case CallingConv::MSP430_INTR: Out << "msp430_intrcc"; break; case CallingConv::AVR_INTR: Out << "avr_intrcc "; break; case CallingConv::AVR_SIGNAL: Out << "avr_signalcc "; break; @@ -621,7 +622,12 @@ } case Type::VectorTyID: { VectorType *PTy = cast(Ty); - OS << "<" << PTy->getNumElements() << " x "; + OS << '<'; + + if (PTy->isScalable()) + OS << "n x "; + + OS << PTy->getNumElements() << " x "; print(PTy->getElementType(), OS); OS << '>'; return; @@ -1486,11 +1492,21 @@ return; } + if (isa(CV)) { + Out << "stepvector"; + return; + } + if (isa(CV)) { Out << "undef"; return; } + if (isa(CV)) { + Out << "vscale"; + return; + } + if (const ConstantExpr *CE = dyn_cast(CV)) { Out << CE->getOpcodeName(); WriteOptimizationInfo(Out, CE); @@ -1754,6 +1770,8 @@ MDFieldPrinter Printer(Out, TypePrinter, Machine, Context); if (auto *CE = N->getCount().dyn_cast()) Printer.printInt("count", CE->getSExtValue(), /* ShouldSkipZero */ false); + else if (auto *EE = N->getCount().dyn_cast()) + Printer.printMetadata("count", EE, /*ShouldSkipNull */ false); else Printer.printMetadata("count", N->getCount().dyn_cast(), /*ShouldSkipNull */ false); @@ -1761,6 +1779,24 @@ Out << ")"; } +static void writeDIFortranSubrange(raw_ostream &Out, const DIFortranSubrange *N, + TypePrinting *TypePrinter, + SlotTracker *Machine, + const Module *Context) { + Out << "!DIFortranSubrange("; + MDFieldPrinter Printer(Out, TypePrinter, Machine, Context); + Printer.printInt("constLowerBound", N->getCLowerBound(), false); + if (!N->noUpperBound()) + Printer.printInt("constUpperBound", N->getCUpperBound(), false); + Printer.printMetadata("lowerBound", N->getRawLowerBound()); + Printer.printMetadata("lowerBoundExpression", + N->getRawLowerBoundExpression()); + Printer.printMetadata("upperBound", N->getRawUpperBound()); + Printer.printMetadata("upperBoundExpression", + N->getRawUpperBoundExpression()); + Out << ")"; +} + static void writeDIEnumerator(raw_ostream &Out, const DIEnumerator *N, TypePrinting *, SlotTracker *, const Module *) { Out << "!DIEnumerator("; @@ -1790,6 +1826,23 @@ Out << ")"; } +static void writeDIStringType(raw_ostream &Out, const DIStringType *N, + TypePrinting *TypePrinter, SlotTracker *Machine, + const Module *Context) { + Out << "!DIStringType("; + MDFieldPrinter Printer(Out, TypePrinter, Machine, Context); + if (N->getTag() != dwarf::DW_TAG_string_type) + Printer.printTag(N); + Printer.printString("name", N->getName()); + Printer.printMetadata("stringLength", N->getRawStringLength()); + Printer.printMetadata("stringLengthExpression", N->getRawStringLengthExp()); + Printer.printInt("size", N->getSizeInBits()); + Printer.printInt("align", N->getAlignInBits()); + Printer.printDwarfEnum("encoding", N->getEncoding(), + dwarf::AttributeEncodingString); + Out << ")"; +} + static void writeDIDerivedType(raw_ostream &Out, const DIDerivedType *N, TypePrinting *TypePrinter, SlotTracker *Machine, const Module *Context) { @@ -1838,6 +1891,25 @@ Out << ")"; } +static void writeDIFortranArrayType( + raw_ostream &Out, const DIFortranArrayType *N, TypePrinting *TypePrinter, + SlotTracker *Machine, const Module *Context) { + Out << "!DIFortranArrayType("; + MDFieldPrinter Printer(Out, TypePrinter, Machine, Context); + Printer.printTag(N); + Printer.printString("name", N->getName()); + Printer.printMetadata("scope", N->getRawScope()); + Printer.printMetadata("file", N->getRawFile()); + Printer.printInt("line", N->getLine()); + Printer.printMetadata("baseType", N->getRawBaseType()); + Printer.printInt("size", N->getSizeInBits()); + Printer.printInt("align", N->getAlignInBits()); + Printer.printInt("offset", N->getOffsetInBits()); + Printer.printDIFlags("flags", N->getFlags()); + Printer.printMetadata("elements", N->getRawElements()); + Out << ")"; +} + static void writeDISubroutineType(raw_ostream &Out, const DISubroutineType *N, TypePrinting *TypePrinter, SlotTracker *Machine, const Module *Context) { @@ -1962,6 +2034,20 @@ Out << ")"; } +static void writeDICommonBlock(raw_ostream &Out, const DICommonBlock *N, + TypePrinting *TypePrinter, SlotTracker *Machine, + const Module *Context) { + Out << "!DICommonBlock("; + MDFieldPrinter Printer(Out, TypePrinter, Machine, Context); + Printer.printMetadata("scope", N->getRawScope(), false); + Printer.printMetadata("declaration", N->getRawDecl(), false); + Printer.printString("name", N->getName()); + Printer.printMetadata("file", N->getRawFile()); + Printer.printInt("line", N->getLineNo()); + Printer.printInt("align", N->getAlignInBits()); + Out << ")"; +} + static void writeDIMacro(raw_ostream &Out, const DIMacro *N, TypePrinting *TypePrinter, SlotTracker *Machine, const Module *Context) { @@ -1995,6 +2081,8 @@ Printer.printString("configMacros", N->getConfigurationMacros()); Printer.printString("includePath", N->getIncludePath()); Printer.printString("isysroot", N->getISysRoot()); + Printer.printMetadata("file", N->getRawFile()); + Printer.printInt("line", N->getLine()); Out << ")"; } @@ -2040,6 +2128,7 @@ Printer.printBool("isLocal", N->isLocalToUnit()); Printer.printBool("isDefinition", N->isDefinition()); Printer.printMetadata("declaration", N->getRawStaticDataMemberDeclaration()); + Printer.printDIFlags("flags", N->getFlags()); Printer.printInt("align", N->getAlignInBits()); Out << ")"; } Index: lib/IR/AutoUpgrade.cpp =================================================================== --- lib/IR/AutoUpgrade.cpp +++ lib/IR/AutoUpgrade.cpp @@ -430,6 +430,13 @@ return false; } +static void copyMetadata(Instruction *Src, Instruction *Dst) { + SmallVector, 4> MDs; + Src->getAllMetadata(MDs); + for (const auto &MD : MDs) + Dst->setMetadata(MD.first, MD.second); +} + static bool UpgradeIntrinsicFunction1(Function *F, Function *&NewFn) { assert(F && "Illegal to upgrade a non-existent Function."); @@ -1539,7 +1546,7 @@ unsigned NumElts = CI->getType()->getPrimitiveSizeInBits() / ExtTy->getPrimitiveSizeInBits(); Rep = Builder.CreateZExt(CI->getArgOperand(0), ExtTy); - Rep = Builder.CreateVectorSplat(NumElts, Rep); + Rep = Builder.CreateVectorSplat({NumElts, false}, Rep); } else if (IsX86 && (Name == "sse.sqrt.ss" || Name == "sse2.sqrt.sd")) { Value *Vec = CI->getArgOperand(0); @@ -1588,7 +1595,7 @@ } else if (IsX86 && (Name.startswith("avx512.mask.pbroadcast"))){ unsigned NumElts = CI->getArgOperand(1)->getType()->getVectorNumElements(); - Rep = Builder.CreateVectorSplat(NumElts, CI->getArgOperand(0)); + Rep = Builder.CreateVectorSplat({NumElts, false}, CI->getArgOperand(0)); Rep = EmitX86Select(Builder, CI->getArgOperand(2), Rep, CI->getArgOperand(1)); } else if (IsX86 && (Name.startswith("avx512.kunpck"))) { @@ -3387,6 +3394,7 @@ SmallVector Args(CI->arg_operands().begin(), CI->arg_operands().end()); NewCall = Builder.CreateCall(NewFn, Args); + copyMetadata(CI, NewCall); break; } @@ -3409,6 +3417,7 @@ Value *Args[4] = {CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), CI->getArgOperand(4)}; NewCall = Builder.CreateCall(NewFn, Args); + copyMetadata(CI, NewCall); auto *MemCI = cast(NewCall); // All mem intrinsics support dest alignment. const ConstantInt *Align = cast(CI->getArgOperand(3)); Index: lib/IR/ConstantFold.cpp =================================================================== --- lib/IR/ConstantFold.cpp +++ lib/IR/ConstantFold.cpp @@ -46,6 +46,7 @@ if (CV->isAllOnesValue()) return Constant::getAllOnesValue(DstTy); if (CV->isNullValue()) return Constant::getNullValue(DstTy); + if (DstTy->isScalable()) return nullptr; // If this cast changes element count then we can't handle it here: // doing so requires endianness information. This should be handled by @@ -693,6 +694,9 @@ return ConstantInt::get(V->getContext(), CI->getValue().zext(BitWidth)); } + // NOTE: StepVector is implicitly defined with NSW/NUW. + if (isa(V)) + return StepVector::get(DestTy); return nullptr; case Instruction::SExt: if (ConstantInt *CI = dyn_cast(V)) { @@ -700,10 +704,50 @@ return ConstantInt::get(V->getContext(), CI->getValue().sext(BitWidth)); } + // NOTE: StepVector is implicitly defined with NSW/NUW. + if (isa(V)) + return StepVector::get(DestTy); return nullptr; case Instruction::Trunc: { - if (V->getType()->isVectorTy()) + if (V->getType()->isVectorTy()) { + if (auto *CE = dyn_cast(V)) { + if (CE->getOpcode() == Instruction::ShuffleVector) { + auto TruncElTy = cast(DestTy)->getScalarType(); + auto OpTy = cast(CE->getOperand(0)->getType()); + auto NewOpTy = VectorType::get(TruncElTy, OpTy->getElementCount()); + + Constant *A = ConstantExpr::getTrunc(CE->getOperand(0), NewOpTy); + Constant *B = ConstantExpr::getTrunc(CE->getOperand(1), NewOpTy); + return ConstantExpr::getShuffleVector(A, B, CE->getOperand(2)); + } + if (CE->getOpcode() == Instruction::InsertElement) { + auto TruncElTy = cast(DestTy)->getScalarType(); + + Constant *A = ConstantExpr::getTrunc(CE->getOperand(0), DestTy); + Constant *B = ConstantExpr::getTrunc(CE->getOperand(1), TruncElTy); + return ConstantExpr::getInsertElement(A, B, CE->getOperand(2)); + } + if (CE->getOpcode() == Instruction::Add) { + Constant *A = ConstantExpr::getTrunc(CE->getOperand(0), DestTy); + Constant *B = ConstantExpr::getTrunc(CE->getOperand(1), DestTy); + return ConstantExpr::getAdd(A, B); + } + if (CE->getOpcode() == Instruction::Sub) { + Constant *A = ConstantExpr::getTrunc(CE->getOperand(0), DestTy); + Constant *B = ConstantExpr::getTrunc(CE->getOperand(1), DestTy); + return ConstantExpr::getSub(A, B); + } + if (CE->getOpcode() == Instruction::Mul) { + Constant *A = ConstantExpr::getTrunc(CE->getOperand(0), DestTy); + Constant *B = ConstantExpr::getTrunc(CE->getOperand(1), DestTy); + return ConstantExpr::getMul(A, B); + } + } else if (isa(V)) { + return StepVector::get(DestTy); + } + return nullptr; + } uint32_t DestBitWidth = cast(DestTy)->getBitWidth(); if (ConstantInt *CI = dyn_cast(V)) { @@ -719,6 +763,13 @@ if (Constant *Res = ExtractConstantBytes(V, 0, DestBitWidth / 8)) return Res; + ConstantInt *Cst = nullptr; + // Transform "i32 trunc (i64 mul (vscale, cst))" -> "i32 mul(vscale, cst)" + if (match(V, m_Mul(m_VScale(), m_ConstantInt(Cst)))) { + return ConstantExpr::getMul( + VScale::get(DestTy), ConstantInt::get(DestTy, Cst->getZExtValue())); + } + return nullptr; } case Instruction::BitCast: @@ -795,8 +846,9 @@ if (ConstantInt *CIdx = dyn_cast(Idx)) { // ee({w,x,y,z}, wrong_value) -> undef - if (CIdx->uge(Val->getType()->getVectorNumElements())) - return UndefValue::get(Val->getType()->getVectorElementType()); + if (!Val->getType()->getVectorIsScalable()) + if (CIdx->uge(Val->getType()->getVectorNumElements())) + return UndefValue::get(Val->getType()->getVectorElementType()); return Val->getAggregateElement(CIdx->getZExtValue()); } return nullptr; @@ -808,6 +860,13 @@ if (isa(Idx)) return UndefValue::get(Val->getType()); + if (Val->isNullValue() && Elt->isNullValue()) + return Val; + + // Everything after this point assumes you can iterate across Val. + if (Val->getType()->getVectorIsScalable()) + return nullptr; + ConstantInt *CIdx = dyn_cast(Idx); if (!CIdx) return nullptr; @@ -835,22 +894,41 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1, Constant *V2, Constant *Mask) { - unsigned MaskNumElts = Mask->getType()->getVectorNumElements(); + auto *MaskTy = cast(Mask->getType()); + auto MaskNumElts = MaskTy->getElementCount(); Type *EltTy = V1->getType()->getVectorElementType(); + Type *ResultTy = VectorType::get(EltTy, MaskNumElts); // Undefined shuffle mask -> undefined value. if (isa(Mask)) - return UndefValue::get(VectorType::get(EltTy, MaskNumElts)); + return UndefValue::get(ResultTy); // Don't break the bitcode reader hack. if (isa(Mask)) return nullptr; + if (cast(Mask->getType())->isScalable()) { + // Is splat? + if (Mask->isNullValue()) { + Constant *Zero = Constant::getNullValue(MaskTy->getElementType()); + Constant* SplatVal = ConstantFoldExtractElementInstruction(V1, Zero); + // Is splat of zero or undef? + if (SplatVal){ + if( SplatVal->isNullValue()) + return Constant::getNullValue(ResultTy); + if( isa(SplatVal)) + return UndefValue::get(ResultTy); + } + } + return nullptr; + } unsigned SrcNumElts = V1->getType()->getVectorNumElements(); // Loop over the shuffle mask, evaluating each element. SmallVector Result; - for (unsigned i = 0; i != MaskNumElts; ++i) { - int Elt = ShuffleVectorInst::getMaskValue(Mask, i); + for (unsigned i = 0; i != MaskNumElts.Min; ++i) { + int Elt; + if (!ShuffleVectorInst::getMaskValue(Mask, i, Elt)) + return nullptr; if (Elt == -1) { Result.push_back(UndefValue::get(EltTy)); continue; @@ -1218,6 +1296,76 @@ } } } else if (VectorType *VTy = dyn_cast(C1->getType())) { + // Constant fold binary op with constant splats, + // e.g.: + // ( splat( )), + // ( splat( )) + // ==> splat ( (C1, C2)) + if (isa(C2->getType())) { + if (Constant *Splat1 = C1->getSplatValue()) { + if (Constant *Splat2 = C2->getSplatValue()) { + auto EC = cast(C1->getType())->getElementCount(); + Constant *Rtrn = ConstantExpr::get(Opcode, Splat1, Splat2); + return ConstantVector::getSplat(EC, Rtrn); + } + } + } + + // A set of simple folds applicable to all vector types. + switch (Opcode) { + default: + break; + + case Instruction::Add: + // X + 0 == X + if (C1->isNullValue()) + return C2; + if (C2->isNullValue()) + return C1; + + // PtrToInt(Splat(X)) + Splat(Y) ==> PtrToInt(Splat(GetElementPtr(X, Y))) + if (isa(C1) && + cast(C1)->getOpcode() == Instruction::PtrToInt) { + Constant *X = cast(C1->getOperand(0))->getSplatValue(); + Constant *Y = C2->getSplatValue(); + + if (X && Y) { + auto AS = X->getType()->getPointerAddressSpace(); + auto BaseTy = Type::getInt8Ty(C1->getContext())->getPointerTo(AS); + + // Convert to byte* so Y doesn't get scaled. + Constant *Base = ConstantExpr::getPointerCast(X, BaseTy); + Constant *GEP = ConstantExpr::getGetElementPtr(nullptr, Base, Y); + + auto EC = cast(C1->getType())->getElementCount(); + Constant *Splat = ConstantVector::getSplat(EC, GEP); + + return ConstantExpr::getPtrToInt(Splat, C1->getType()); + } + } + break; + + case Instruction::Sub: + case Instruction::Or: + case Instruction::Xor: + // X op 0 == X + if (C2->isNullValue()) + return C1; + break; + + case Instruction::Mul: + // X * 1 == X + if (match(C1, m_SplatVector(m_One()))) + return C2; + if (match(C2, m_SplatVector(m_One()))) + return C1; + break; + } + + // Everything after this point assumes you can iterate across C1 & C2. + if (VTy->isScalable()) + return nullptr; + // Perform elementwise folding. SmallVector Result; Type *Ty = IntegerType::get(VTy->getContext(), 32); @@ -1248,6 +1396,30 @@ if (!isa(T) || cast(T)->getOpcode() != Opcode) return ConstantExpr::get(Opcode, CE1->getOperand(0), T); } + + unsigned CE1Opcode = CE1->getOpcode(); + ConstantExpr *CE2 = dyn_cast(C2); + if (CE2 && CE1Opcode == CE2->getOpcode() && + Instruction::isBinaryOp(CE1Opcode)) { + Constant* COp1 = CE1->getOperand(0); + Constant* COp2 = CE2->getOperand(0); + if (COp1 == COp2) { + // reduce (Cexpr + C1) - (Cexpr + C2) to C1 - C2 + if (Instruction::Sub == Opcode && CE1Opcode ==Instruction::Add) + return ConstantExpr::get(Opcode, CE1->getOperand(1), + CE2->getOperand(1)); + + ConstantInt *CI1 = dyn_cast(CE1->getOperand(1)); + ConstantInt *CI2 = dyn_cast(CE2->getOperand(1)); + // add/sub ( mul (vscale, c1), mul (vscale, c2)) + // into mul vscale, c1 add/sub c2 + if (CI1 && CI2 && (CE1Opcode == Instruction::Mul) && + (Instruction::Add == Opcode || Instruction::Sub == Opcode)) { + Constant *TI = ConstantExpr::get(Opcode, CI1, CI2); + return ConstantExpr::get(CE1->getOpcode(), COp1, TI); + } + } + } } else if (isa(C2)) { // If C2 is a constant expr and C1 isn't, flop them around and fold the // other way if possible. @@ -1451,9 +1623,9 @@ if (V1 == V2) return ICmpInst::ICMP_EQ; if (!isa(V1) && !isa(V1) && - !isa(V1)) { + !isa(V1) && !isa(V1)) { if (!isa(V2) && !isa(V2) && - !isa(V2)) { + !isa(V2) && !isa(V2)) { // We distilled this down to a simple case, use the standard constant // folder. ConstantInt *R = nullptr; @@ -1532,10 +1704,9 @@ "Canonicalization guarantee!"); return ICmpInst::ICMP_NE; } - } else { + } else if (const ConstantExpr *CE1 = dyn_cast(V1)) { // Ok, the LHS is known to be a constantexpr. The RHS can be any of a // constantexpr, a global, block address, or a simple constant. - ConstantExpr *CE1 = cast(V1); Constant *CE1Op0 = CE1->getOperand(0); switch (CE1->getOpcode()) { @@ -1567,7 +1738,7 @@ break; case Instruction::GetElementPtr: { - GEPOperator *CE1GEP = cast(CE1); + const GEPOperator *CE1GEP = cast(CE1); // Ok, since this is a getelementptr, we know that the constant has a // pointer type. Check the various cases. if (isa(V2)) { @@ -1699,7 +1870,7 @@ Type *ResultTy; if (VectorType *VT = dyn_cast(C1->getType())) ResultTy = VectorType::get(Type::getInt1Ty(C1->getContext()), - VT->getNumElements()); + VT->getElementCount()); else ResultTy = Type::getInt1Ty(C1->getContext()); @@ -1830,6 +2001,9 @@ R==APFloat::cmpEqual); } } else if (C1->getType()->isVectorTy()) { + if (cast(C1->getType())->isScalable()) + return nullptr; + // If we can constant fold the comparison of each element, constant fold // the whole vector comparison. SmallVector ResElts; @@ -2081,7 +2255,7 @@ if (Idxs.size() == 1 && (Idx0->isNullValue() || isa(Idx0))) return GEPTy->isVectorTy() && !C->getType()->isVectorTy() ? ConstantVector::getSplat( - cast(GEPTy)->getNumElements(), C) + cast(GEPTy)->getElementCount(), C) : C; if (C->isNullValue()) { @@ -2100,7 +2274,7 @@ Type *OrigGEPTy = PointerType::get(Ty, PtrTy->getAddressSpace()); Type *GEPTy = PointerType::get(Ty, PtrTy->getAddressSpace()); if (VectorType *VT = dyn_cast(C->getType())) - GEPTy = VectorType::get(OrigGEPTy, VT->getNumElements()); + GEPTy = VectorType::get(GEPTy, VT->getElementCount()); // The GEP returns a vector of pointers when one of more of // its arguments is a vector. Index: lib/IR/Constants.cpp =================================================================== --- lib/IR/Constants.cpp +++ lib/IR/Constants.cpp @@ -106,19 +106,11 @@ if (const ConstantFP *CFP = dyn_cast(this)) return CFP->getValueAPF().bitcastToAPInt().isAllOnesValue(); - // Check for constant vectors which are splats of -1 values. - if (const ConstantVector *CV = dyn_cast(this)) - if (Constant *Splat = CV->getSplatValue()) + // Check for constant vectors which are splats of -1 + if (this->getType()->isVectorTy()) + if (Constant *Splat = this->getSplatValue()) return Splat->isAllOnesValue(); - // Check for constant vectors which are splats of -1 values. - if (const ConstantDataVector *CV = dyn_cast(this)) { - if (CV->isSplat()) { - if (CV->getElementType()->isFloatingPointTy()) - return CV->getElementAsAPFloat(0).bitcastToAPInt().isAllOnesValue(); - return CV->getElementAsAPInt(0).isAllOnesValue(); - } - } return false; } @@ -314,7 +306,7 @@ // Broadcast a scalar to a vector, if necessary. if (VectorType *VTy = dyn_cast(Ty)) - C = ConstantVector::getSplat(VTy->getNumElements(), C); + C = ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -331,7 +323,7 @@ } VectorType *VTy = cast(Ty); - return ConstantVector::getSplat(VTy->getNumElements(), + return ConstantVector::getSplat(VTy->getElementCount(), getAllOnesValue(VTy->getElementType())); } @@ -348,6 +340,39 @@ if (const ConstantDataSequential *CDS =dyn_cast(this)) return Elt < CDS->getNumElements() ? CDS->getElementAsConstant(Elt) : nullptr; + + if (isa(this)) + return ConstantInt::get(getType()->getVectorElementType(), Elt); + + if (const auto *CE = dyn_cast(this)) { + if (CE->getOpcode() == Instruction::ShuffleVector) { + if (CE->getOperand(2)->isNullValue()) { + // ee(splat(x), ?) -> x + auto *IdxTy = Type::getInt64Ty(CE->getType()->getContext()); + auto *Zero = ConstantInt::get(IdxTy, 0); + return ConstantExpr::getExtractElement(CE->getOperand(0), Zero); + } + } else if (CE->getOpcode() == Instruction::InsertElement) { + if (auto *CIdx = dyn_cast(CE->getOperand(2))) + if (CIdx->getZExtValue() == Elt) + // ee(ei(?,x,idx), idx) -> x + return CE->getOperand(1); + } else if (CE->getType()->isVectorTy()) { + // Can the extract be pushed down to our children? + switch (CE->getOpcode()) { + case Instruction::Add: + case Instruction::Mul: { + auto *IdxTy = Type::getInt64Ty(CE->getType()->getContext()); + auto *Idx = ConstantInt::get(IdxTy, Elt); + auto *LHS = ConstantExpr::getExtractElement(CE->getOperand(0), Idx); + auto *RHS = ConstantExpr::getExtractElement(CE->getOperand(1), Idx); + return ConstantExpr::get(CE->getOpcode(), LHS, RHS); + } + } + } + + } + return nullptr; } @@ -592,7 +617,7 @@ assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1."); ConstantInt *TrueC = ConstantInt::getTrue(Ty->getContext()); if (auto *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), TrueC); + return ConstantVector::getSplat(VTy->getElementCount(), TrueC); return TrueC; } @@ -600,7 +625,7 @@ assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1."); ConstantInt *FalseC = ConstantInt::getFalse(Ty->getContext()); if (auto *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), FalseC); + return ConstantVector::getSplat(VTy->getElementCount(), FalseC); return FalseC; } @@ -623,7 +648,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -647,7 +672,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -692,7 +717,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -704,7 +729,7 @@ // For vectors, broadcast the value. if (auto *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -717,7 +742,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -728,7 +753,7 @@ Constant *C = get(Ty->getContext(), NaN); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -739,7 +764,7 @@ Constant *C = get(Ty->getContext(), NegZero); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -787,7 +812,7 @@ Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative)); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -841,6 +866,38 @@ } //===----------------------------------------------------------------------===// +// StepVector Implementation +//===----------------------------------------------------------------------===// + +Constant *StepVector::get(Type *Ty) { + assert(Ty->isVectorTy() && Ty->getVectorElementType()->isIntegerTy() && + "StepVector must be an integer vector type!"); + + if (!Ty->getVectorIsScalable()) { + // Always return a constant vector when the vector length is known. + Type* EltTy = Ty->getVectorElementType(); + + SmallVector Indices; + for (int i = 0, e = Ty->getVectorNumElements(); i < e; ++i) + Indices.push_back(ConstantInt::get(EltTy, i)); + + return ConstantVector::get(Indices); + } + + std::unique_ptr &Entry = Ty->getContext().pImpl->SVVConstants[Ty]; + if (!Entry) + Entry.reset(new StepVector(Ty)); + + return Entry.get(); +} + +/// Remove the constant from the constant table. +void StepVector::destroyConstantImpl() { + // Free the constant and any dangling references to it. + getContext().pImpl->SVVConstants.erase(getType()); +} + +//===----------------------------------------------------------------------===// // UndefValue Implementation //===----------------------------------------------------------------------===// @@ -872,6 +929,26 @@ } //===----------------------------------------------------------------------===// +// VScale Implementation +//===----------------------------------------------------------------------===// + +Constant *VScale::get(Type *Ty) { + assert(Ty->isIntegerTy() && "VScale must be an integer type!"); + + std::unique_ptr &Entry = Ty->getContext().pImpl->VSVConstants[Ty]; + if (!Entry) + Entry.reset(new VScale(Ty)); + + return Entry.get(); +} + +/// Remove the constant from the constant table. +void VScale::destroyConstantImpl() { + // Free the constant and any dangling references to it. + getContext().pImpl->VSVConstants.erase(getType()); +} + +//===----------------------------------------------------------------------===// // ConstantXXX Classes //===----------------------------------------------------------------------===// @@ -1093,15 +1170,35 @@ return nullptr; } -Constant *ConstantVector::getSplat(unsigned NumElts, Constant *V) { - // If this splat is compatible with ConstantDataVector, use it instead of - // ConstantVector. - if ((isa(V) || isa(V)) && - ConstantDataSequential::isElementTypeCompatible(V->getType())) - return ConstantDataVector::getSplat(NumElts, V); +Constant *ConstantVector::getSplat(VectorType::ElementCount EC, Constant *V) { + if (!EC.Scalable) { + // If this splat is compatible with ConstantDataVector, use it instead of + // ConstantVector. + if ((isa(V) || isa(V)) && + ConstantDataSequential::isElementTypeCompatible(V->getType())) + return ConstantDataVector::getSplat(EC.Min, V); - SmallVector Elts(NumElts, V); - return get(Elts); + SmallVector Elts(EC.Min, V); + return get(Elts); + } + + Type *VTy = VectorType::get(V->getType(), EC); + + if (V->isNullValue()) + return ConstantAggregateZero::get(VTy); + else if (isa(V)) + return UndefValue::get(VTy); + + Type *I32Ty = Type::getInt32Ty(VTy->getContext()); + + // Move scalar into vector. + Constant *UndefV = UndefValue::get(VTy); + V = ConstantExpr::getInsertElement(UndefV, V, ConstantInt::get(I32Ty, 0)); + // Build shuffle mask to perform the splat. + Type *MaskTy = VectorType::get(I32Ty, EC); + Constant *Zeros = ConstantAggregateZero::get(MaskTy); + // Splat. + return ConstantExpr::getShuffleVector(V, UndefV, Zeros); } ConstantTokenNone *ConstantTokenNone::get(LLVMContext &Context) { @@ -1350,6 +1447,17 @@ return CV->getSplatValue(); if (const ConstantVector *CV = dyn_cast(this)) return CV->getSplatValue(); + + // Is scalable vector splat? + // shufflevector(insertelement(-,X,0), -, zeroinitializer) + if (const auto *CE = dyn_cast(this)) + if (CE->getOpcode() == Instruction::ShuffleVector && + CE->getOperand(2)->isNullValue()) + if (const auto *CEOp0 = dyn_cast(CE->getOperand(0))) + if (CEOp0->getOpcode() == Instruction::InsertElement && + CEOp0->getOperand(2)->isNullValue()) + return CEOp0->getOperand(1); + return nullptr; } @@ -1960,15 +2068,16 @@ unsigned AS = C->getType()->getPointerAddressSpace(); Type *ReqTy = DestTy->getPointerTo(AS); - unsigned NumVecElts = 0; - if (C->getType()->isVectorTy()) - NumVecElts = C->getType()->getVectorNumElements(); - else for (auto Idx : Idxs) - if (Idx->getType()->isVectorTy()) - NumVecElts = Idx->getType()->getVectorNumElements(); + VectorType::ElementCount EltCount = {0, false}; + if (VectorType *VecTy = dyn_cast(C->getType())) + EltCount = VecTy->getElementCount(); + else + for (auto Idx : Idxs) + if (VectorType *VecTy = dyn_cast(Idx->getType())) + EltCount = VecTy->getElementCount(); - if (NumVecElts) - ReqTy = VectorType::get(ReqTy, NumVecElts); + if (EltCount.Min != 0) + ReqTy = VectorType::get(ReqTy, EltCount); if (OnlyIfReducedTy == ReqTy) return nullptr; @@ -1979,12 +2088,11 @@ ArgVec.push_back(C); for (unsigned i = 0, e = Idxs.size(); i != e; ++i) { assert((!Idxs[i]->getType()->isVectorTy() || - Idxs[i]->getType()->getVectorNumElements() == NumVecElts) && - "getelementptr index type missmatch"); - + cast(Idxs[i]->getType())->getElementCount() == EltCount) + && "getelementptr index type missmatch"); Constant *Idx = cast(Idxs[i]); - if (NumVecElts && !Idxs[i]->getType()->isVectorTy()) - Idx = ConstantVector::getSplat(NumVecElts, Idx); + if (EltCount.Min != 0 && !Idxs[i]->getType()->isVectorTy()) + Idx = ConstantVector::getSplat(EltCount, Idx); ArgVec.push_back(Idx); } @@ -2017,7 +2125,7 @@ Type *ResultTy = Type::getInt1Ty(LHS->getContext()); if (VectorType *VT = dyn_cast(LHS->getType())) - ResultTy = VectorType::get(ResultTy, VT->getNumElements()); + ResultTy = VectorType::get(ResultTy, VT->getElementCount()); LLVMContextImpl *pImpl = LHS->getType()->getContext().pImpl; return pImpl->ExprConstants.getOrCreate(ResultTy, Key); @@ -2042,7 +2150,7 @@ Type *ResultTy = Type::getInt1Ty(LHS->getContext()); if (VectorType *VT = dyn_cast(LHS->getType())) - ResultTy = VectorType::get(ResultTy, VT->getNumElements()); + ResultTy = VectorType::get(ResultTy, VT->getElementCount()); LLVMContextImpl *pImpl = LHS->getType()->getContext().pImpl; return pImpl->ExprConstants.getOrCreate(ResultTy, Key); @@ -2101,7 +2209,7 @@ if (Constant *FC = ConstantFoldShuffleVectorInstruction(V1, V2, Mask)) return FC; // Fold a few common cases. - unsigned NElts = Mask->getType()->getVectorNumElements(); + auto NElts = cast(Mask->getType())->getElementCount(); Type *EltTy = V1->getType()->getVectorElementType(); Type *ShufTy = VectorType::get(EltTy, NElts); @@ -2116,6 +2224,24 @@ return pImpl->ExprConstants.getOrCreate(ShufTy, Key); } +Constant *ConstantExpr::getRuntimeNumElements(Type *Ty, Type *SrcTy) { + assert(Ty->isIntegerTy() && "ElementCount expects to return an int type!"); + auto NumElts = ConstantInt::get(Ty, SrcTy->getVectorNumElements()); + if (SrcTy->getVectorIsScalable()) + NumElts = getMul(VScale::get(Ty), NumElts); + return NumElts; +} + +Constant *ConstantExpr::getSeriesVector(VectorType::ElementCount EC, + Constant *Start, Constant* Step, + bool HasNUW, bool HasNSW, + Type *OnlyIfReducedTy) { + auto Ty = VectorType::get(Step->getType(), EC); + auto StartV = ConstantVector::getSplat(EC, Start); + auto StepV = ConstantVector::getSplat(EC, Step); + return getAdd(getMul(StepVector::get(Ty), StepV), StartV); +} + Constant *ConstantExpr::getInsertValue(Constant *Agg, Constant *Val, ArrayRef Idxs, Type *OnlyIfReducedTy) { @@ -2621,7 +2747,8 @@ return getFP(V->getContext(), Elts); } } - return ConstantVector::getSplat(NumElts, V); + + return ConstantVector::getSplat({NumElts, false}, V); } Index: lib/IR/ConstantsContext.h =================================================================== --- lib/IR/ConstantsContext.h +++ lib/IR/ConstantsContext.h @@ -150,7 +150,7 @@ ShuffleVectorConstantExpr(Constant *C1, Constant *C2, Constant *C3) : ConstantExpr(VectorType::get( cast(C1->getType())->getElementType(), - cast(C3->getType())->getNumElements()), + cast(C3->getType())->getElementCount()), Instruction::ShuffleVector, &Op<0>(), 3) { Op<0>() = C1; Index: lib/IR/DIBuilder.cpp =================================================================== --- lib/IR/DIBuilder.cpp +++ lib/IR/DIBuilder.cpp @@ -262,6 +262,12 @@ 0, Encoding); } +DIStringType *DIBuilder::createStringType(StringRef Name, uint64_t SizeInBits) { + assert(!Name.empty() && "Unable to create type without name"); + return DIStringType::get(VMContext, dwarf::DW_TAG_string_type, Name, + SizeInBits, 0); +} + DIDerivedType *DIBuilder::createQualifiedType(unsigned Tag, DIType *FromTy) { return DIDerivedType::get(VMContext, Tag, "", nullptr, 0, nullptr, FromTy, 0, 0, 0, None, DINode::FlagZero); @@ -525,6 +531,15 @@ return R; } +DIFortranArrayType *DIBuilder::createFortranArrayType( + uint64_t Size, uint32_t AlignInBits, DIType *Ty, DINodeArray Subscripts) { + auto *R = DIFortranArrayType::get(VMContext, dwarf::DW_TAG_array_type, "", + nullptr, 0, nullptr, Ty, Size, AlignInBits, + 0, DINode::FlagZero, Subscripts); + trackIfUnresolved(R); + return R; +} + DICompositeType *DIBuilder::createVectorType(uint64_t Size, uint32_t AlignInBits, DIType *Ty, DINodeArray Subscripts) { @@ -628,6 +643,12 @@ return DISubrange::get(VMContext, CountNode, Lo); } +DIFortranSubrange *DIBuilder::getOrCreateFortranSubrange( + int64_t CLB, int64_t CUB, bool NUB, Metadata *LB, Metadata *LBE, + Metadata *UB, Metadata *UBE) { + return DIFortranSubrange::get(VMContext, CLB, CUB, NUB, LB, LBE, UB, UBE); +} + static void checkGlobalVariableScope(DIScope *Context) { #ifndef NDEBUG if (auto *CT = @@ -640,13 +661,13 @@ DIGlobalVariableExpression *DIBuilder::createGlobalVariableExpression( DIScope *Context, StringRef Name, StringRef LinkageName, DIFile *F, unsigned LineNumber, DIType *Ty, bool isLocalToUnit, DIExpression *Expr, - MDNode *Decl, uint32_t AlignInBits) { + MDNode *Decl, DINode::DIFlags Flags, uint32_t AlignInBits) { checkGlobalVariableScope(Context); auto *GV = DIGlobalVariable::getDistinct( VMContext, cast_or_null(Context), Name, LinkageName, F, LineNumber, Ty, isLocalToUnit, true, cast_or_null(Decl), - AlignInBits); + Flags, AlignInBits); if (!Expr) Expr = createExpression(); auto *N = DIGlobalVariableExpression::get(VMContext, GV, Expr); @@ -657,13 +678,13 @@ DIGlobalVariable *DIBuilder::createTempGlobalVariableFwdDecl( DIScope *Context, StringRef Name, StringRef LinkageName, DIFile *F, unsigned LineNumber, DIType *Ty, bool isLocalToUnit, MDNode *Decl, - uint32_t AlignInBits) { + DINode::DIFlags Flags, uint32_t AlignInBits) { checkGlobalVariableScope(Context); return DIGlobalVariable::getTemporary( VMContext, cast_or_null(Context), Name, LinkageName, F, LineNumber, Ty, isLocalToUnit, false, - cast_or_null(Decl), AlignInBits) + cast_or_null(Decl), Flags, AlignInBits) .release(); } @@ -804,6 +825,13 @@ return SP; } +DICommonBlock *DIBuilder::createCommonBlock( + DIScope *Scope, DIGlobalVariable *Decl, StringRef Name, DIFile *File, + unsigned LineNo, uint32_t AlignInBits) { + return DICommonBlock::get( + VMContext, Scope, Decl, Name, File, LineNo, AlignInBits); +} + DINamespace *DIBuilder::createNameSpace(DIScope *Scope, StringRef Name, bool ExportSymbols) { @@ -819,9 +847,12 @@ DIModule *DIBuilder::createModule(DIScope *Scope, StringRef Name, StringRef ConfigurationMacros, StringRef IncludePath, - StringRef ISysRoot) { + StringRef ISysRoot, DIFile *File, + unsigned LineNo) { + return DIModule::get(VMContext, getNonCompileUnitScope(Scope), Name, - ConfigurationMacros, IncludePath, ISysRoot); + ConfigurationMacros, IncludePath, ISysRoot, File, + LineNo); } DILexicalBlockFile *DIBuilder::createLexicalBlockFile(DIScope *Scope, Index: lib/IR/DataLayout.cpp =================================================================== --- lib/IR/DataLayout.cpp +++ lib/IR/DataLayout.cpp @@ -750,7 +750,7 @@ unsigned NumBits = getIndexTypeSizeInBits(Ty); IntegerType *IntTy = IntegerType::get(Ty->getContext(), NumBits); if (VectorType *VecTy = dyn_cast(Ty)) - return VectorType::get(IntTy, VecTy->getNumElements()); + return VectorType::get(IntTy, VecTy->getElementCount()); return IntTy; } Index: lib/IR/DebugInfo.cpp =================================================================== --- lib/IR/DebugInfo.cpp +++ lib/IR/DebugInfo.cpp @@ -1230,12 +1230,13 @@ LLVMBool LocalToUnit, LLVMMetadataRef Expr, LLVMMetadataRef Decl, + LLVMDIFlags Flags, uint32_t AlignInBits) { return wrap(unwrap(Builder)->createGlobalVariableExpression( unwrapDI(Scope), {Name, NameLen}, {Linkage, LinkLen}, unwrapDI(File), LineNo, unwrapDI(Ty), LocalToUnit, unwrap(Expr), - unwrapDI(Decl), AlignInBits)); + unwrapDI(Decl), map_from_llvmDIFlags(Flags), AlignInBits)); } LLVMMetadataRef LLVMTemporaryMDNode(LLVMContextRef Ctx, LLVMMetadataRef *Data, @@ -1265,11 +1266,12 @@ LLVMMetadataRef Ty, LLVMBool LocalToUnit, LLVMMetadataRef Decl, + LLVMDIFlags Flags, uint32_t AlignInBits) { return wrap(unwrap(Builder)->createTempGlobalVariableFwdDecl( unwrapDI(Scope), {Name, NameLen}, {Linkage, LnkLen}, unwrapDI(File), LineNo, unwrapDI(Ty), - LocalToUnit, unwrapDI(Decl), AlignInBits)); + LocalToUnit, unwrapDI(Decl), map_from_llvmDIFlags(Flags), AlignInBits)); } LLVMValueRef LLVMDIBuilderInsertDeclareBefore( Index: lib/IR/DebugInfoMetadata.cpp =================================================================== --- lib/IR/DebugInfoMetadata.cpp +++ lib/IR/DebugInfoMetadata.cpp @@ -160,6 +160,9 @@ if (auto *NS = dyn_cast(this)) return NS->getScope(); + if (auto *CB = dyn_cast(this)) + return CB->getScope(); + if (auto *M = dyn_cast(this)) return M->getScope(); @@ -175,6 +178,8 @@ return SP->getName(); if (auto *NS = dyn_cast(this)) return NS->getName(); + if (auto *CB = dyn_cast(this)) + return CB->getName(); if (auto *M = dyn_cast(this)) return M->getName(); assert((isa(this) || isa(this) || @@ -262,6 +267,15 @@ DEFINE_GETIMPL_STORE(DISubrange, (CountNode, Lo), Ops); } +DIFortranSubrange *DIFortranSubrange::getImpl( + LLVMContext &Context, int64_t CLB, int64_t CUB, bool NUB, Metadata *LB, + Metadata *LBE, Metadata *UB, Metadata *UBE, StorageType Storage, + bool ShouldCreate) { + DEFINE_GETIMPL_LOOKUP(DIFortranSubrange, (CLB, CUB, NUB, LB, LBE, UB, UBE)); + Metadata *Ops[] = {LB, LBE, UB, UBE}; + DEFINE_GETIMPL_STORE(DIFortranSubrange, (CLB, CUB, NUB), Ops); +} + DIEnumerator *DIEnumerator::getImpl(LLVMContext &Context, int64_t Value, bool IsUnsigned, MDString *Name, StorageType Storage, bool ShouldCreate) { @@ -283,6 +297,21 @@ Ops); } +DIStringType *DIStringType::getImpl(LLVMContext &Context, unsigned Tag, + MDString *Name, Metadata *StringLength, + Metadata *StringLengthExp, + uint64_t SizeInBits, uint32_t AlignInBits, + unsigned Encoding, StorageType Storage, + bool ShouldCreate) { + assert(isCanonical(Name) && "Expected canonical MDString"); + DEFINE_GETIMPL_LOOKUP(DIStringType, + (Tag, Name, StringLength, StringLengthExp, SizeInBits, + AlignInBits, Encoding)); + Metadata *Ops[] = {nullptr, nullptr, Name, StringLength, StringLengthExp}; + DEFINE_GETIMPL_STORE(DIStringType, (Tag, SizeInBits, AlignInBits, Encoding), + Ops); +} + Optional DIBasicType::getSignedness() const { switch (getEncoding()) { case dwarf::DW_ATE_signed: @@ -335,6 +364,22 @@ Ops); } +DIFortranArrayType *DIFortranArrayType::getImpl( + LLVMContext &Context, unsigned Tag, MDString *Name, Metadata *File, + unsigned Line, Metadata *Scope, Metadata *BaseType, uint64_t SizeInBits, + uint32_t AlignInBits, uint64_t OffsetInBits, DIFlags Flags, + Metadata *Elements, StorageType Storage, bool ShouldCreate) { + assert(isCanonical(Name) && "Expected canonical MDString"); + + // Keep this in sync with buildODRType. + DEFINE_GETIMPL_LOOKUP( + DIFortranArrayType, (Tag, Name, File, Line, Scope, BaseType, SizeInBits, + AlignInBits, OffsetInBits, Flags, Elements)); + Metadata *Ops[] = {File, Scope, Name, BaseType, Elements}; + DEFINE_GETIMPL_STORE(DIFortranArrayType, (Tag, Line, SizeInBits, AlignInBits, + OffsetInBits, Flags), Ops); +} + DICompositeType *DICompositeType::buildODRType( LLVMContext &Context, MDString &Identifier, unsigned Tag, MDString *Name, Metadata *File, unsigned Line, Metadata *Scope, Metadata *BaseType, @@ -573,15 +618,31 @@ DEFINE_GETIMPL_STORE(DINamespace, (ExportSymbols), Ops); } +DICommonBlock *DICommonBlock::getImpl(LLVMContext &Context, Metadata *Scope, + Metadata *Decl, MDString *Name, + Metadata *File, unsigned LineNo, + uint32_t AlignInBits, + StorageType Storage, bool ShouldCreate) { + assert(isCanonical(Name) && "Expected canonical MDString"); + DEFINE_GETIMPL_LOOKUP(DICommonBlock, (Scope, Decl, Name, File, LineNo, + AlignInBits)); + // The nullptr is for DIScope's File operand. This should be refactored. + Metadata *Ops[] = {Scope, Decl, Name, File}; + DEFINE_GETIMPL_STORE(DICommonBlock, (LineNo, AlignInBits), Ops); +} + DIModule *DIModule::getImpl(LLVMContext &Context, Metadata *Scope, MDString *Name, MDString *ConfigurationMacros, MDString *IncludePath, MDString *ISysRoot, + Metadata *File, unsigned Line, StorageType Storage, bool ShouldCreate) { assert(isCanonical(Name) && "Expected canonical MDString"); DEFINE_GETIMPL_LOOKUP( - DIModule, (Scope, Name, ConfigurationMacros, IncludePath, ISysRoot)); - Metadata *Ops[] = {Scope, Name, ConfigurationMacros, IncludePath, ISysRoot}; - DEFINE_GETIMPL_STORE_NO_CONSTRUCTOR_ARGS(DIModule, Ops); + DIModule, (Scope, Name, ConfigurationMacros, IncludePath, ISysRoot, File, + Line)); + Metadata *Ops[] = {Scope, Name, ConfigurationMacros, IncludePath, ISysRoot, + File}; + DEFINE_GETIMPL_STORE(DIModule, (Line), Ops); } DITemplateTypeParameter *DITemplateTypeParameter::getImpl(LLVMContext &Context, @@ -608,7 +669,7 @@ DIGlobalVariable::getImpl(LLVMContext &Context, Metadata *Scope, MDString *Name, MDString *LinkageName, Metadata *File, unsigned Line, Metadata *Type, bool IsLocalToUnit, bool IsDefinition, - Metadata *StaticDataMemberDeclaration, + Metadata *StaticDataMemberDeclaration, DIFlags Flags, uint32_t AlignInBits, StorageType Storage, bool ShouldCreate) { assert(isCanonical(Name) && "Expected canonical MDString"); @@ -616,11 +677,11 @@ DEFINE_GETIMPL_LOOKUP(DIGlobalVariable, (Scope, Name, LinkageName, File, Line, Type, IsLocalToUnit, IsDefinition, - StaticDataMemberDeclaration, AlignInBits)); + StaticDataMemberDeclaration, Flags, AlignInBits)); Metadata *Ops[] = { Scope, Name, File, Type, Name, LinkageName, StaticDataMemberDeclaration}; DEFINE_GETIMPL_STORE(DIGlobalVariable, - (Line, IsLocalToUnit, IsDefinition, AlignInBits), + (Line, IsLocalToUnit, IsDefinition, Flags, AlignInBits), Ops); } @@ -687,10 +748,12 @@ unsigned DIExpression::ExprOperand::getSize() const { switch (getOp()) { + case dwarf::DW_OP_bregx: case dwarf::DW_OP_LLVM_fragment: return 3; case dwarf::DW_OP_constu: case dwarf::DW_OP_plus_uconst: + case dwarf::DW_OP_deref_size: return 2; default: return 1; @@ -734,8 +797,9 @@ case dwarf::DW_OP_constu: case dwarf::DW_OP_plus_uconst: case dwarf::DW_OP_plus: - case dwarf::DW_OP_minus: case dwarf::DW_OP_mul: + case dwarf::DW_OP_bregx: + case dwarf::DW_OP_minus: case dwarf::DW_OP_div: case dwarf::DW_OP_mod: case dwarf::DW_OP_or: @@ -746,6 +810,7 @@ case dwarf::DW_OP_shra: case dwarf::DW_OP_deref: case dwarf::DW_OP_xderef: + case dwarf::DW_OP_deref_size: case dwarf::DW_OP_lit0: case dwarf::DW_OP_not: case dwarf::DW_OP_dup: Index: lib/IR/Function.cpp =================================================================== --- lib/IR/Function.cpp +++ lib/IR/Function.cpp @@ -587,7 +587,9 @@ Result += "vararg"; // Ensure nested function types are distinguishable. Result += "f"; - } else if (isa(Ty)) { + } else if (auto *VTyp = dyn_cast(Ty)) { + if (VTyp->isScalable()) + Result += "nx"; Result += "v" + utostr(Ty->getVectorNumElements()) + getMangledTypeStr(Ty->getVectorElementType()); } else if (Ty) { @@ -675,7 +677,13 @@ IIT_STRUCT6 = 38, IIT_STRUCT7 = 39, IIT_STRUCT8 = 40, - IIT_F128 = 41 + IIT_F128 = 41, + IIT_SCALABLE_VEC = 42, + IIT_VEC_ELEMENT = 43, + IIT_VEC_OF_BITCASTS_TO_INT = 44, + IIT_DOUBLE_VEC_ARG = 45, + IIT_SUBDIVIDE2_ARG = 46, + IIT_SUBDIVIDE4_ARG = 47 }; static void DecodeIITType(unsigned &NextElt, ArrayRef Infos, @@ -794,12 +802,30 @@ ArgInfo)); return; } + case IIT_SUBDIVIDE2_ARG: { + unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); + OutputTable.push_back(IITDescriptor::get(IITDescriptor::Subdivide2Argument, + ArgInfo)); + return; + } + case IIT_SUBDIVIDE4_ARG: { + unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); + OutputTable.push_back(IITDescriptor::get(IITDescriptor::Subdivide4Argument, + ArgInfo)); + return; + } case IIT_HALF_VEC_ARG: { unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back(IITDescriptor::get(IITDescriptor::HalfVecArgument, ArgInfo)); return; } + case IIT_DOUBLE_VEC_ARG: { + unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); + OutputTable.push_back(IITDescriptor::get(IITDescriptor::DoubleVecArgument, + ArgInfo)); + return; + } case IIT_SAME_VEC_WIDTH_ARG: { unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument, @@ -824,6 +850,12 @@ IITDescriptor::get(IITDescriptor::VecOfAnyPtrsToElt, ArgNo, RefNo)); return; } + case IIT_VEC_OF_BITCASTS_TO_INT: { + unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); + OutputTable.push_back(IITDescriptor::get(IITDescriptor::VecOfBitcastsToInt, + ArgInfo)); + return; + } case IIT_EMPTYSTRUCT: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Struct, 0)); return; @@ -840,6 +872,18 @@ DecodeIITType(NextElt, Infos, OutputTable); return; } + case IIT_SCALABLE_VEC: { + OutputTable.push_back(IITDescriptor::get(IITDescriptor::ScalableVecArgument, + 0)); + DecodeIITType(NextElt, Infos, OutputTable); + return; + } + case IIT_VEC_ELEMENT: { + unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); + OutputTable.push_back(IITDescriptor::get(IITDescriptor::VecElementArgument, + ArgInfo)); + return; + } } llvm_unreachable("unhandled"); } @@ -930,14 +974,47 @@ assert(ITy->getBitWidth() % 2 == 0); return IntegerType::get(Context, ITy->getBitWidth() / 2); } + case IITDescriptor::Subdivide2Argument: { + Type *Ty = Tys[D.getArgumentNumber()]; + if (VectorType *VTy = dyn_cast(Ty)) { + if (VTy->getElementType()->isFloatingPointTy()) { + return VectorType::getDoubleElementsVectorType( + VectorType::getNarrowerFpElementVectorType(VTy)); + } + return VectorType::getDoubleElementsVectorType( + VectorType::getTruncatedElementVectorType(VTy)); + } + + llvm_unreachable("unhandled"); + } + case IITDescriptor::Subdivide4Argument: { + Type *Ty = Tys[D.getArgumentNumber()]; + if (VectorType *VTy = dyn_cast(Ty)) { + if (VTy->getElementType()->isFloatingPointTy()) { + return VectorType::getDoubleElementsVectorType( + VectorType::getNarrowerFpElementVectorType( + VectorType::getDoubleElementsVectorType( + VectorType::getNarrowerFpElementVectorType(VTy)))); + } + return VectorType::getDoubleElementsVectorType( + VectorType::getTruncatedElementVectorType( + VectorType::getDoubleElementsVectorType( + VectorType::getTruncatedElementVectorType(VTy)))); + } + + llvm_unreachable("unhandled"); + } case IITDescriptor::HalfVecArgument: return VectorType::getHalfElementsVectorType(cast( Tys[D.getArgumentNumber()])); + case IITDescriptor::DoubleVecArgument: + return VectorType::getDoubleElementsVectorType(cast( + Tys[D.getArgumentNumber()])); case IITDescriptor::SameVecWidthArgument: { Type *EltTy = DecodeFixedType(Infos, Tys, Context); Type *Ty = Tys[D.getArgumentNumber()]; if (VectorType *VTy = dyn_cast(Ty)) { - return VectorType::get(EltTy, VTy->getNumElements()); + return VectorType::get(EltTy, VTy->getElementCount()); } llvm_unreachable("unhandled"); } @@ -953,6 +1030,23 @@ Type *EltTy = VTy->getVectorElementType(); return PointerType::getUnqual(EltTy); } + case IITDescriptor::ScalableVecArgument: { + Type *Ty = DecodeFixedType(Infos, Tys, Context); + return VectorType::get(Ty->getVectorElementType(), + { Ty->getVectorNumElements(), true }); + } + case IITDescriptor::VecElementArgument: { + Type *Ty = Tys[D.getArgumentNumber()]; + if (VectorType *VTy = dyn_cast(Ty)) + return VTy->getElementType(); + llvm_unreachable("Expected an argument of Vector Type"); + } + case IITDescriptor::VecOfBitcastsToInt: { + Type *Ty = Tys[D.getArgumentNumber()]; + if (VectorType *VTy = dyn_cast(Ty)) + return VectorType::getInteger(VTy); + llvm_unreachable("Expected an argument of Vector Type"); + } case IITDescriptor::VecOfAnyPtrsToElt: // Return the overloaded type (which determines the pointers address space) return Tys[D.getOverloadArgNumber()]; @@ -1022,12 +1116,26 @@ #include "llvm/IR/IntrinsicImpl.inc" #undef GET_LLVM_INTRINSIC_FOR_MS_BUILTIN -bool Intrinsic::matchIntrinsicType(Type *Ty, ArrayRef &Infos, - SmallVectorImpl &ArgTys) { +using DeferredIntrinsicMatchPair = + std::pair>; + +static bool matchIntrinsicType( + Type *Ty, ArrayRef &Infos, + SmallVectorImpl &ArgTys, + SmallVectorImpl &DeferredChecks, + bool IsDeferredCheck) { using namespace Intrinsic; // If we ran out of descriptors, there are too many arguments. if (Infos.empty()) return true; + + // Do this before slicing off the 'front' part + auto InfosRef = Infos; + auto DeferCheck = [&DeferredChecks, &InfosRef](Type *T) { + DeferredChecks.emplace_back(T, InfosRef); + return false; + }; + IITDescriptor D = Infos.front(); Infos = Infos.slice(1); @@ -1045,12 +1153,14 @@ case IITDescriptor::Vector: { VectorType *VT = dyn_cast(Ty); return !VT || VT->getNumElements() != D.Vector_Width || - matchIntrinsicType(VT->getElementType(), Infos, ArgTys); + matchIntrinsicType(VT->getElementType(), Infos, ArgTys, + DeferredChecks, IsDeferredCheck); } case IITDescriptor::Pointer: { PointerType *PT = dyn_cast(Ty); return !PT || PT->getAddressSpace() != D.Pointer_AddressSpace || - matchIntrinsicType(PT->getElementType(), Infos, ArgTys); + matchIntrinsicType(PT->getElementType(), Infos, ArgTys, + DeferredChecks, IsDeferredCheck); } case IITDescriptor::Struct: { @@ -1059,35 +1169,40 @@ return true; for (unsigned i = 0, e = D.Struct_NumElements; i != e; ++i) - if (matchIntrinsicType(ST->getElementType(i), Infos, ArgTys)) + if (matchIntrinsicType(ST->getElementType(i), Infos, ArgTys, + DeferredChecks, IsDeferredCheck)) return true; return false; } case IITDescriptor::Argument: - // Two cases here - If this is the second occurrence of an argument, verify - // that the later instance matches the previous instance. + // If this is the second occurrence of an argument, + // verify that the later instance matches the previous instance. if (D.getArgumentNumber() < ArgTys.size()) return Ty != ArgTys[D.getArgumentNumber()]; - // Otherwise, if this is the first instance of an argument, record it and - // verify the "Any" kind. - assert(D.getArgumentNumber() == ArgTys.size() && "Table consistency error"); - ArgTys.push_back(Ty); + if (D.getArgumentNumber() > ArgTys.size() || + D.getArgumentKind() == IITDescriptor::AK_MatchType) + return IsDeferredCheck || DeferCheck(Ty); - switch (D.getArgumentKind()) { - case IITDescriptor::AK_Any: return false; // Success - case IITDescriptor::AK_AnyInteger: return !Ty->isIntOrIntVectorTy(); - case IITDescriptor::AK_AnyFloat: return !Ty->isFPOrFPVectorTy(); - case IITDescriptor::AK_AnyVector: return !isa(Ty); - case IITDescriptor::AK_AnyPointer: return !isa(Ty); - } - llvm_unreachable("all argument kinds not covered"); + assert(D.getArgumentNumber() == ArgTys.size() && !IsDeferredCheck && + "Table consistency error"); + ArgTys.push_back(Ty); + + switch (D.getArgumentKind()) { + case IITDescriptor::AK_Any: return false; // Success + case IITDescriptor::AK_AnyInteger: return !Ty->isIntOrIntVectorTy(); + case IITDescriptor::AK_AnyFloat: return !Ty->isFPOrFPVectorTy(); + case IITDescriptor::AK_AnyVector: return !isa(Ty); + case IITDescriptor::AK_AnyPointer: return !isa(Ty); + default: break; + } + llvm_unreachable("all argument kinds not covered"); case IITDescriptor::ExtendArgument: { - // This may only be used when referring to a previous vector argument. + // If this is a forward reference, defer the check for later. if (D.getArgumentNumber() >= ArgTys.size()) - return true; + return IsDeferredCheck || DeferCheck(Ty); Type *NewTy = ArgTys[D.getArgumentNumber()]; if (VectorType *VTy = dyn_cast(NewTy)) @@ -1100,9 +1215,9 @@ return Ty != NewTy; } case IITDescriptor::TruncArgument: { - // This may only be used when referring to a previous vector argument. + // If this is a forward reference, defer the check for later. if (D.getArgumentNumber() >= ArgTys.size()) - return true; + return IsDeferredCheck || DeferCheck(Ty); Type *NewTy = ArgTys[D.getArgumentNumber()]; if (VectorType *VTy = dyn_cast(NewTy)) @@ -1115,34 +1230,48 @@ return Ty != NewTy; } case IITDescriptor::HalfVecArgument: - // This may only be used when referring to a previous vector argument. + // If this is a forward reference, defer the check for later. return D.getArgumentNumber() >= ArgTys.size() || !isa(ArgTys[D.getArgumentNumber()]) || VectorType::getHalfElementsVectorType( cast(ArgTys[D.getArgumentNumber()])) != Ty; + case IITDescriptor::DoubleVecArgument: + // This may only be used when referring to a previous vector argument. + return D.getArgumentNumber() >= ArgTys.size() || + !isa(ArgTys[D.getArgumentNumber()]) || + VectorType::getDoubleElementsVectorType( + cast(ArgTys[D.getArgumentNumber()])) != Ty; case IITDescriptor::SameVecWidthArgument: { - if (D.getArgumentNumber() >= ArgTys.size()) + if (D.getArgumentNumber() >= ArgTys.size()) { + // Defer check and subsequent check for the vector element type. + Infos = Infos.slice(1); + return IsDeferredCheck || DeferCheck(Ty); + } + auto *ReferenceType = dyn_cast(ArgTys[D.getArgumentNumber()]); + auto *ThisArgType = dyn_cast(Ty); + // Both must be vectors of the same number of elements or neither. + if ((ReferenceType != nullptr) != (ThisArgType != nullptr)) return true; - VectorType * ReferenceType = - dyn_cast(ArgTys[D.getArgumentNumber()]); - VectorType *ThisArgType = dyn_cast(Ty); - if (!ThisArgType || !ReferenceType || - (ReferenceType->getVectorNumElements() != - ThisArgType->getVectorNumElements())) - return true; - return matchIntrinsicType(ThisArgType->getVectorElementType(), - Infos, ArgTys); + Type *EltTy = Ty; + if (ThisArgType) { + if (ReferenceType->getVectorNumElements() != + ThisArgType->getVectorNumElements()) + return true; + EltTy = ThisArgType->getVectorElementType(); + } + return matchIntrinsicType(EltTy, Infos, ArgTys, DeferredChecks, + IsDeferredCheck); } case IITDescriptor::PtrToArgument: { if (D.getArgumentNumber() >= ArgTys.size()) - return true; + return IsDeferredCheck || DeferCheck(Ty); Type * ReferenceType = ArgTys[D.getArgumentNumber()]; PointerType *ThisArgType = dyn_cast(Ty); return (!ThisArgType || ThisArgType->getElementType() != ReferenceType); } case IITDescriptor::PtrToElt: { if (D.getArgumentNumber() >= ArgTys.size()) - return true; + return IsDeferredCheck || DeferCheck(Ty); VectorType * ReferenceType = dyn_cast (ArgTys[D.getArgumentNumber()]); PointerType *ThisArgType = dyn_cast(Ty); @@ -1150,17 +1279,31 @@ return (!ThisArgType || !ReferenceType || ThisArgType->getElementType() != ReferenceType->getElementType()); } + case IITDescriptor::VecOfBitcastsToInt: { + if (D.getArgumentNumber() >= ArgTys.size()) + return IsDeferredCheck || DeferCheck(Ty); + auto *ReferenceType = dyn_cast(ArgTys[D.getArgumentNumber()]); + auto *ThisArgVecTy = dyn_cast(Ty); + if (!ThisArgVecTy || !ReferenceType) + return true; + return ThisArgVecTy != VectorType::getInteger(ReferenceType); + } case IITDescriptor::VecOfAnyPtrsToElt: { unsigned RefArgNumber = D.getRefArgNumber(); + if (RefArgNumber >= ArgTys.size()) { + if (IsDeferredCheck) + return true; + // If forward referencing, already add the pointer-vector type and + // defer the checks for later. + ArgTys.push_back(Ty); + return DeferCheck(Ty); + } - // This may only be used when referring to a previous argument. - if (RefArgNumber >= ArgTys.size()) - return true; - - // Record the overloaded type - assert(D.getOverloadArgNumber() == ArgTys.size() && - "Table consistency error"); - ArgTys.push_back(Ty); + if (!IsDeferredCheck){ + assert(D.getOverloadArgNumber() == ArgTys.size() && + "Table consistency error"); + ArgTys.push_back(Ty); + } // Verify the overloaded type "matches" the Ref type. // i.e. Ty is a vector with the same width as Ref. @@ -1178,10 +1321,91 @@ return ThisArgEltTy->getElementType() != ReferenceType->getVectorElementType(); } + case IITDescriptor::ScalableVecArgument: { + VectorType *VTy = dyn_cast(Ty); + if (!VTy || !VTy->isScalable()) + return true; + return matchIntrinsicType(VTy, Infos, ArgTys, DeferredChecks, + IsDeferredCheck); + } + case IITDescriptor::VecElementArgument: { + if (D.getArgumentNumber() >= ArgTys.size()) + return IsDeferredCheck || DeferCheck(Ty);; + auto *ReferenceType = dyn_cast(ArgTys[D.getArgumentNumber()]); + return !ReferenceType || Ty != ReferenceType->getElementType(); + } + case IITDescriptor::Subdivide2Argument: { + // This may only be used when referring to a previous vector argument. + if (D.getArgumentNumber() >= ArgTys.size()) + return IsDeferredCheck || DeferCheck(Ty);; + + Type *NewTy = ArgTys[D.getArgumentNumber()]; + if (VectorType *VTy = dyn_cast(NewTy)) { + if (VTy->getElementType()->isFloatingPointTy()) + NewTy = VectorType::getDoubleElementsVectorType( + VectorType::getNarrowerFpElementVectorType(VTy)); + else + NewTy = VectorType::getDoubleElementsVectorType( + VectorType::getTruncatedElementVectorType(VTy)); + } + else + return true; + + return Ty != NewTy; + } + case IITDescriptor::Subdivide4Argument: { + // This may only be used when referring to a previous vector argument. + if (D.getArgumentNumber() >= ArgTys.size()) + return IsDeferredCheck || DeferCheck(Ty);; + + Type *NewTy = ArgTys[D.getArgumentNumber()]; + if (VectorType *VTy = dyn_cast(NewTy)) { + if (VTy->getElementType()->isFloatingPointTy()) + NewTy = VectorType::getDoubleElementsVectorType( + VectorType::getNarrowerFpElementVectorType( + VectorType::getDoubleElementsVectorType( + VectorType::getNarrowerFpElementVectorType(VTy)))); + else + NewTy = VectorType::getDoubleElementsVectorType( + VectorType::getTruncatedElementVectorType( + VectorType::getDoubleElementsVectorType( + VectorType::getTruncatedElementVectorType(VTy)))); + } + else + return true; + + return Ty != NewTy; + } } llvm_unreachable("unhandled"); } +Intrinsic::MatchIntrinsicTypesResult +Intrinsic::matchIntrinsicSignature(FunctionType *FTy, + ArrayRef &Infos, + SmallVectorImpl &ArgTys) { + SmallVector DeferredChecks; + if (matchIntrinsicType(FTy->getReturnType(), Infos, ArgTys, DeferredChecks, + false)) + return MatchIntrinsicTypes_NoMatchRet; + + unsigned NumDeferredReturnChecks = DeferredChecks.size(); + + for (auto Ty : FTy->params()) + if (matchIntrinsicType(Ty, Infos, ArgTys, DeferredChecks, false)) + return MatchIntrinsicTypes_NoMatchArg; + + for (unsigned I = 0, E = DeferredChecks.size(); I != E; ++I) { + DeferredIntrinsicMatchPair &Check = DeferredChecks[I]; + if (matchIntrinsicType(Check.first, Check.second, ArgTys, DeferredChecks, + true)) + return I < NumDeferredReturnChecks ? MatchIntrinsicTypes_NoMatchRet + : MatchIntrinsicTypes_NoMatchArg; + } + + return MatchIntrinsicTypes_Match; +} + bool Intrinsic::matchIntrinsicVarArg(bool isVarArg, ArrayRef &Infos) { @@ -1215,13 +1439,8 @@ getIntrinsicInfoTableEntries(ID, Table); ArrayRef TableRef = Table; - // If we encounter any problems matching the signature with the descriptor - // just give up remangling. It's up to verifier to report the discrepancy. - if (Intrinsic::matchIntrinsicType(FTy->getReturnType(), TableRef, ArgTys)) + if (Intrinsic::matchIntrinsicSignature(FTy, TableRef, ArgTys)) return None; - for (auto Ty : FTy->params()) - if (Intrinsic::matchIntrinsicType(Ty, TableRef, ArgTys)) - return None; if (Intrinsic::matchIntrinsicVarArg(FTy->isVarArg(), TableRef)) return None; } Index: lib/IR/IRBuilder.cpp =================================================================== --- lib/IR/IRBuilder.cpp +++ lib/IR/IRBuilder.cpp @@ -96,34 +96,6 @@ return II; } -CallInst *IRBuilderBase:: -CreateMemSet(Value *Ptr, Value *Val, Value *Size, unsigned Align, - bool isVolatile, MDNode *TBAATag, MDNode *ScopeTag, - MDNode *NoAliasTag) { - Ptr = getCastedInt8PtrValue(Ptr); - Value *Ops[] = {Ptr, Val, Size, getInt1(isVolatile)}; - Type *Tys[] = { Ptr->getType(), Size->getType() }; - Module *M = BB->getParent()->getParent(); - Value *TheFn = Intrinsic::getDeclaration(M, Intrinsic::memset, Tys); - - CallInst *CI = createCallHelper(TheFn, Ops, this); - - if (Align > 0) - cast(CI)->setDestAlignment(Align); - - // Set the TBAA info if present. - if (TBAATag) - CI->setMetadata(LLVMContext::MD_tbaa, TBAATag); - - if (ScopeTag) - CI->setMetadata(LLVMContext::MD_alias_scope, ScopeTag); - - if (NoAliasTag) - CI->setMetadata(LLVMContext::MD_noalias, NoAliasTag); - - return CI; -} - CallInst *IRBuilderBase::CreateElementUnorderedAtomicMemSet( Value *Ptr, Value *Val, Value *Size, unsigned Align, uint32_t ElementSize, MDNode *TBAATag, MDNode *ScopeTag, MDNode *NoAliasTag) { @@ -154,45 +126,6 @@ return CI; } -CallInst *IRBuilderBase:: -CreateMemCpy(Value *Dst, unsigned DstAlign, Value *Src, unsigned SrcAlign, - Value *Size, bool isVolatile, MDNode *TBAATag, - MDNode *TBAAStructTag, MDNode *ScopeTag, MDNode *NoAliasTag) { - assert((DstAlign == 0 || isPowerOf2_32(DstAlign)) && "Must be 0 or a power of 2"); - assert((SrcAlign == 0 || isPowerOf2_32(SrcAlign)) && "Must be 0 or a power of 2"); - Dst = getCastedInt8PtrValue(Dst); - Src = getCastedInt8PtrValue(Src); - - Value *Ops[] = {Dst, Src, Size, getInt1(isVolatile)}; - Type *Tys[] = { Dst->getType(), Src->getType(), Size->getType() }; - Module *M = BB->getParent()->getParent(); - Value *TheFn = Intrinsic::getDeclaration(M, Intrinsic::memcpy, Tys); - - CallInst *CI = createCallHelper(TheFn, Ops, this); - - auto* MCI = cast(CI); - if (DstAlign > 0) - MCI->setDestAlignment(DstAlign); - if (SrcAlign > 0) - MCI->setSourceAlignment(SrcAlign); - - // Set the TBAA info if present. - if (TBAATag) - CI->setMetadata(LLVMContext::MD_tbaa, TBAATag); - - // Set the TBAA Struct info if present. - if (TBAAStructTag) - CI->setMetadata(LLVMContext::MD_tbaa_struct, TBAAStructTag); - - if (ScopeTag) - CI->setMetadata(LLVMContext::MD_alias_scope, ScopeTag); - - if (NoAliasTag) - CI->setMetadata(LLVMContext::MD_noalias, NoAliasTag); - - return CI; -} - CallInst *IRBuilderBase::CreateElementUnorderedAtomicMemCpy( Value *Dst, unsigned DstAlign, Value *Src, unsigned SrcAlign, Value *Size, uint32_t ElementSize, MDNode *TBAATag, MDNode *TBAAStructTag, @@ -473,7 +406,6 @@ auto *PtrTy = cast(Ptr->getType()); Type *DataTy = PtrTy->getElementType(); assert(DataTy->isVectorTy() && "Ptr should point to a vector"); - assert(Mask && "Mask should not be all-ones (null)"); if (!PassThru) PassThru = UndefValue::get(DataTy); Type *OverloadedTypes[] = { DataTy, PtrTy }; @@ -482,6 +414,27 @@ OverloadedTypes, Name); } +/// \brief Create a call to a Masked Speculative Load intrinsic. +/// \p Ptr - the base pointer for the load +/// \p Align - alignment of the source location +/// \p Mask - an vector of booleans which indicates what vector lanes should +/// be accessed in memory +/// \p PassThru - a pass-through value that is used to fill the masked-off lanes +/// of the result +/// \p Name - name of the result variable +CallInst *IRBuilderBase::CreateMaskedSpecLoad(Value *Ptr, unsigned Align, + Value *Mask, Value *PassThru, + const Twine &Name) { + assert(Ptr->getType()->isPointerTy() && "Ptr must be of pointer type"); + // DataTy is the overloaded type + Type *DataTy = cast(Ptr->getType())->getElementType(); + assert(DataTy->isVectorTy() && "Ptr should point to a vector"); + if (!PassThru) + PassThru = UndefValue::get(DataTy); + Value *Ops[] = { Ptr, getInt32(Align), Mask, PassThru}; + return CreateMaskedIntrinsic(Intrinsic::masked_spec_load, Ops, DataTy, Name); +} + /// Create a call to a Masked Store intrinsic. /// \p Val - data to be stored, /// \p Ptr - base pointer for the store @@ -493,7 +446,6 @@ auto *PtrTy = cast(Ptr->getType()); Type *DataTy = PtrTy->getElementType(); assert(DataTy->isVectorTy() && "Ptr should point to a vector"); - assert(Mask && "Mask should not be all-ones (null)"); Type *OverloadedTypes[] = { DataTy, PtrTy }; Value *Ops[] = { Val, Ptr, getInt32(Align), Mask }; return CreateMaskedIntrinsic(Intrinsic::masked_store, Ops, OverloadedTypes); @@ -524,13 +476,12 @@ const Twine& Name) { auto PtrsTy = cast(Ptrs->getType()); auto PtrTy = cast(PtrsTy->getElementType()); - unsigned NumElts = PtrsTy->getVectorNumElements(); + auto NumElts = PtrsTy->getElementCount(); Type *DataTy = VectorType::get(PtrTy->getElementType(), NumElts); if (!Mask) - Mask = Constant::getAllOnesValue(VectorType::get(Type::getInt1Ty(Context), + Mask = ConstantInt::getTrue(VectorType::get(Type::getInt1Ty(Context), NumElts)); - if (!PassThru) PassThru = UndefValue::get(DataTy); @@ -554,17 +505,17 @@ unsigned Align, Value *Mask) { auto PtrsTy = cast(Ptrs->getType()); auto DataTy = cast(Data->getType()); - unsigned NumElts = PtrsTy->getVectorNumElements(); + auto NumElts = PtrsTy->getElementCount(); #ifndef NDEBUG auto PtrTy = cast(PtrsTy->getElementType()); - assert(NumElts == DataTy->getVectorNumElements() && + assert(NumElts == DataTy->getElementCount() && PtrTy->getElementType() == DataTy->getElementType() && "Incompatible pointer and data types"); #endif if (!Mask) - Mask = Constant::getAllOnesValue(VectorType::get(Type::getInt1Ty(Context), + Mask = ConstantInt::getTrue(VectorType::get(Type::getInt1Ty(Context), NumElts)); Type *OverloadedTypes[] = {DataTy, PtrsTy}; @@ -730,6 +681,16 @@ return createCallHelper(FnGCRelocate, Args, this, Name); } +CallInst *IRBuilderBase::CreateCntVPop(Value *PredVec, const Twine &Name) { + Value *Ops[] = { PredVec }; + Type *Tys[] = { PredVec->getType() }; + + Module *M = BB->getParent()->getParent(); + Value *Func = Intrinsic::getDeclaration(M, Intrinsic::ctvpop, Tys); + + return createCallHelper(Func, Ops, this, Name); +} + CallInst *IRBuilderBase::CreateBinaryIntrinsic(Intrinsic::ID ID, Value *LHS, Value *RHS, const Twine &Name) { Index: lib/IR/InlineAsm.cpp =================================================================== --- lib/IR/InlineAsm.cpp +++ lib/IR/InlineAsm.cpp @@ -182,6 +182,16 @@ // FIXME: For now assuming these are 2-character constraints. pCodes->push_back(StringRef(I+1, 2)); I += 3; + } else if (*I == '@') { + // Multi-letter constraint + ++I; + unsigned char C = static_cast(*I); + assert(isdigit(C) && "Not a single digit!"); + int N = C - '0'; + assert(N > 0 && "Found a zero letter constraint!"); + ++I; + pCodes->push_back(std::string(I, I+N)); + I += N; } else { // Single letter constraint. pCodes->push_back(StringRef(I, 1)); Index: lib/IR/Instruction.cpp =================================================================== --- lib/IR/Instruction.cpp +++ lib/IR/Instruction.cpp @@ -513,7 +513,8 @@ case Instruction::CatchRet: return true; case Instruction::Call: - return !cast(this)->doesNotAccessMemory(); + return (!cast(this)->doesNotAccessMemory() + && !isa(this)); case Instruction::Invoke: return !cast(this)->doesNotAccessMemory(); case Instruction::Store: Index: lib/IR/Instructions.cpp =================================================================== --- lib/IR/Instructions.cpp +++ lib/IR/Instructions.cpp @@ -1596,7 +1596,7 @@ const Twine &Name, Instruction *InsertBefore) : Instruction(VectorType::get(cast(V1->getType())->getElementType(), - cast(Mask->getType())->getNumElements()), + cast(Mask->getType())->getElementCount()), ShuffleVector, OperandTraits::op_begin(this), OperandTraits::operands(this), @@ -1613,7 +1613,7 @@ const Twine &Name, BasicBlock *InsertAtEnd) : Instruction(VectorType::get(cast(V1->getType())->getElementType(), - cast(Mask->getType())->getNumElements()), + cast(Mask->getType())->getElementCount()), ShuffleVector, OperandTraits::op_begin(this), OperandTraits::operands(this), @@ -1667,50 +1667,100 @@ // used as the shuffle mask. When this occurs, the shuffle mask will // fall into this case and fail. To avoid this error, do this bit of // ugliness to allow such a mask pass. + // NOTE: Scalable vectors predominantly use ConstantExpr based masks. if (const auto *CE = dyn_cast(Mask)) - if (CE->getOpcode() == Instruction::UserOp1) + if ((CE->getOpcode() == Instruction::UserOp1) || MaskTy->isScalable()) return true; + if (isa(Mask)) + return true; + return false; } -int ShuffleVectorInst::getMaskValue(const Constant *Mask, unsigned i) { +/// getMaskValue - Try to extract the index from the shuffle mask for the +/// specified output result, returning true on success. This is either +/// -1 if the element is undef or a number less than 2*numelements. +bool ShuffleVectorInst::getMaskValue(const Value *Mask, unsigned i, int &Result) { assert(i < Mask->getType()->getVectorNumElements() && "Index out of range"); - if (auto *CDS = dyn_cast(Mask)) - return CDS->getElementAsInteger(i); - Constant *C = Mask->getAggregateElement(i); - if (isa(C)) - return -1; - return cast(C)->getZExtValue(); -} - -void ShuffleVectorInst::getShuffleMask(const Constant *Mask, - SmallVectorImpl &Result) { - unsigned NumElts = Mask->getType()->getVectorNumElements(); - if (auto *CDS = dyn_cast(Mask)) { - for (unsigned i = 0; i != NumElts; ++i) - Result.push_back(CDS->getElementAsInteger(i)); - return; + Result = CDS->getElementAsInteger(i); + return true; } - for (unsigned i = 0; i != NumElts; ++i) { - Constant *C = Mask->getAggregateElement(i); - Result.push_back(isa(C) ? -1 : - cast(C)->getZExtValue()); + + if (auto *C = dyn_cast(Mask)) { + C = C->getAggregateElement(i); + if (C) { + if (isa(C)) { + Result = -1; + return true; + } + if (auto *CI = dyn_cast(C)) { + Result = CI->getZExtValue(); + return true; + } + } } + + //if (const auto *CDS = dyn_cast(Mask)) { + //unsigned V1Size = cast(V1->getType())->getNumElements(); + //for (unsigned i = 0, e = MaskTy->getNumElements(); i != e; ++i) + //if (CDS->getElementAsInteger(i) >= V1Size*2) + //return false; + //return true; + //} + + //// The bitcode reader can create a place holder for a forward reference + //// used as the shuffle mask. When this occurs, the shuffle mask will + //// fall into this case and fail. To avoid this error, do this bit of + //// ugliness to allow such a mask pass. + //if (const auto *CE = dyn_cast(Mask)) + //if (CE->getOpcode() == Instruction::UserOp1) + //return true; + + return false; } -bool ShuffleVectorInst::isSingleSourceMask(ArrayRef Mask) { +/// getShuffleMask - Return the full mask for this instruction, where each +/// element is the element number and undef's are returned as -1. +bool ShuffleVectorInst::getShuffleMask(const Value *Mask, + SmallVectorImpl &Result) { + VectorType *VecTy = cast(Mask->getType()); + if (VecTy->isScalable()) + return false; + unsigned NumElts = VecTy->getVectorNumElements(); + Result.resize(NumElts); + for (unsigned i = 0; i != NumElts; ++i) + if (!getMaskValue(Mask, i, Result[i])) + return false; + return true; +} + +int ShuffleVectorInst::findBroadcastElement(const Value *Mask) { + int SplatElem = -1; + SmallVector MaskElems; + if (getShuffleMask(Mask, MaskElems)) { + for (unsigned i = 0; i < MaskElems.size(); ++i) { + if (SplatElem != -1 && MaskElems[i] != -1 && MaskElems[i] != SplatElem) + return -1; + if (SplatElem == -1) + SplatElem = MaskElems[i]; + } + } + return SplatElem; +} + +static bool isSingleSourceMaskImpl(ArrayRef Mask, int NumOpElts) { assert(!Mask.empty() && "Shuffle mask must contain elements"); bool UsesLHS = false; bool UsesRHS = false; - for (int i = 0, NumElts = Mask.size(); i < NumElts; ++i) { + for (int i = 0, NumMaskElts = Mask.size(); i < NumMaskElts; ++i) { if (Mask[i] == -1) continue; - assert(Mask[i] >= 0 && Mask[i] < (NumElts * 2) && + assert(Mask[i] >= 0 && Mask[i] < (NumOpElts * 2) && "Out-of-bounds shuffle mask element"); - UsesLHS |= (Mask[i] < NumElts); - UsesRHS |= (Mask[i] >= NumElts); + UsesLHS |= (Mask[i] < NumOpElts); + UsesRHS |= (Mask[i] >= NumOpElts); if (UsesLHS && UsesRHS) return false; } @@ -1718,18 +1768,30 @@ return true; } -bool ShuffleVectorInst::isIdentityMask(ArrayRef Mask) { - if (!isSingleSourceMask(Mask)) +bool ShuffleVectorInst::isSingleSourceMask(ArrayRef Mask) { + // We don't have vector operand size information, so assume operands are the + // same size as the mask. + return isSingleSourceMaskImpl(Mask, Mask.size()); +} + +static bool isIdentityMaskImpl(ArrayRef Mask, int NumOpElts) { + if (!isSingleSourceMaskImpl(Mask, NumOpElts)) return false; - for (int i = 0, NumElts = Mask.size(); i < NumElts; ++i) { + for (int i = 0, NumMaskElts = Mask.size(); i < NumMaskElts; ++i) { if (Mask[i] == -1) continue; - if (Mask[i] != i && Mask[i] != (NumElts + i)) + if (Mask[i] != i && Mask[i] != (NumOpElts + i)) return false; } return true; } +bool ShuffleVectorInst::isIdentityMask(ArrayRef Mask) { + // We don't have vector operand size information, so assume operands are the + // same size as the mask. + return isIdentityMaskImpl(Mask, Mask.size()); +} + bool ShuffleVectorInst::isReverseMask(ArrayRef Mask) { if (!isSingleSourceMask(Mask)) return false; @@ -1801,6 +1863,48 @@ return true; } +bool ShuffleVectorInst::isConcat() const { + // Will return true if the mask is in the form <0, 1, 2, 3...> + SmallVector Mask = getShuffleMask(); + int NumMaskElts = Mask.size(); + + for (int i = 0; i < NumMaskElts; ++i) { + if (Mask[i] == -1) + continue; + if (Mask[i] != i) + return false; + } + + return true; +} + +bool ShuffleVectorInst::isIdentityWithPadding() const { + int NumOpElts = Op<0>()->getType()->getVectorNumElements(); + int NumMaskElts = getType()->getVectorNumElements(); + if (NumMaskElts <= NumOpElts) + return false; + + // The first part of the mask must choose elements from exactly 1 source op. + ArrayRef Mask = getShuffleMask(); + if (!isIdentityMaskImpl(Mask, NumOpElts)) + return false; + + // All extending must be with undef elements. + for (int i = NumOpElts; i < NumMaskElts; ++i) + if (Mask[i] != -1) + return false; + + return true; +} + +bool ShuffleVectorInst::isIdentityWithExtract() const { + int NumOpElts = Op<0>()->getType()->getVectorNumElements(); + int NumMaskElts = getType()->getVectorNumElements(); + if (NumMaskElts >= NumOpElts) + return false; + + return isIdentityMaskImpl(getShuffleMask(), NumOpElts); +} //===----------------------------------------------------------------------===// // InsertValueInst Class @@ -2740,6 +2844,17 @@ if (SrcTy == DestTy) return true; + bool SrcIsScalable = false; + if (VectorType *SrcVecTy = dyn_cast(SrcTy)) + SrcIsScalable = SrcVecTy->isScalable(); + + bool DestIsScalable = false; + if (VectorType *DestVecTy = dyn_cast(DestTy)) + DestIsScalable = DestVecTy->isScalable(); + + if (SrcIsScalable != DestIsScalable) + return false; + if (VectorType *SrcVecTy = dyn_cast(SrcTy)) { if (VectorType *DestVecTy = dyn_cast(DestTy)) { if (SrcVecTy->getNumElements() == DestVecTy->getNumElements()) { @@ -2914,10 +3029,12 @@ // If these are vector types, get the lengths of the vectors (using zero for // scalar types means that checking that vector lengths match also checks that // scalars are not being converted to vectors or vectors to scalars). - unsigned SrcLength = SrcTy->isVectorTy() ? - cast(SrcTy)->getNumElements() : 0; - unsigned DstLength = DstTy->isVectorTy() ? - cast(DstTy)->getNumElements() : 0; + VectorType::ElementCount SrcLength = + SrcTy->isVectorTy() ? cast(SrcTy)->getElementCount() + : VectorType::ElementCount(0, false); + VectorType::ElementCount DstLength = + DstTy->isVectorTy() ? cast(DstTy)->getElementCount() + : VectorType::ElementCount(0, false); // Switch on the opcode provided switch (op) { @@ -2949,14 +3066,14 @@ if (isa(SrcTy) != isa(DstTy)) return false; if (VectorType *VT = dyn_cast(SrcTy)) - if (VT->getNumElements() != cast(DstTy)->getNumElements()) + if (VT->getElementCount() != cast(DstTy)->getElementCount()) return false; return SrcTy->isPtrOrPtrVectorTy() && DstTy->isIntOrIntVectorTy(); case Instruction::IntToPtr: if (isa(SrcTy) != isa(DstTy)) return false; if (VectorType *VT = dyn_cast(SrcTy)) - if (VT->getNumElements() != cast(DstTy)->getNumElements()) + if (VT->getElementCount() != cast(DstTy)->getElementCount()) return false; return SrcTy->isIntOrIntVectorTy() && DstTy->isPtrOrPtrVectorTy(); case Instruction::BitCast: { @@ -2971,7 +3088,8 @@ // For non-pointer cases, the cast is okay if the source and destination bit // widths are identical. if (!SrcPtrTy) - return SrcTy->getPrimitiveSizeInBits() == DstTy->getPrimitiveSizeInBits(); + return SrcLength.Scalable == DstLength.Scalable && + SrcTy->getPrimitiveSizeInBits() == DstTy->getPrimitiveSizeInBits(); // If both are pointers then the address spaces must match. if (SrcPtrTy->getAddressSpace() != DstPtrTy->getAddressSpace()) @@ -2980,7 +3098,7 @@ // A vector of pointers must have the same number of elements. if (VectorType *SrcVecTy = dyn_cast(SrcTy)) { if (VectorType *DstVecTy = dyn_cast(DstTy)) - return (SrcVecTy->getNumElements() == DstVecTy->getNumElements()); + return (SrcVecTy->getElementCount() == DstVecTy->getElementCount()); return false; } @@ -3001,7 +3119,7 @@ if (VectorType *SrcVecTy = dyn_cast(SrcTy)) { if (VectorType *DstVecTy = dyn_cast(DstTy)) - return (SrcVecTy->getNumElements() == DstVecTy->getNumElements()); + return (SrcVecTy->getElementCount() == DstVecTy->getElementCount()); return false; } Index: lib/IR/LLVMContextImpl.h =================================================================== --- lib/IR/LLVMContextImpl.h +++ lib/IR/LLVMContextImpl.h @@ -321,9 +321,12 @@ }; template <> struct MDNodeKeyImpl { + int64_t Count; Metadata *CountNode; int64_t LowerBound; + MDNodeKeyImpl(int64_t Count, int64_t LowerBound) + : Count(Count), CountNode(nullptr), LowerBound(LowerBound) {} MDNodeKeyImpl(Metadata *CountNode, int64_t LowerBound) : CountNode(CountNode), LowerBound(LowerBound) {} MDNodeKeyImpl(const DISubrange *N) @@ -351,6 +354,42 @@ } }; +template <> struct MDNodeKeyImpl { + int64_t CLBound; + int64_t CUBound; + bool NoUBound; + Metadata *LowerBound; + Metadata *LowerBoundExp; + Metadata *UpperBound; + Metadata *UpperBoundExp; + + MDNodeKeyImpl(int64_t CLB, int64_t CUB, bool NUB, Metadata *LB, Metadata *LBE, + Metadata *UB, Metadata *UBE) + : CLBound(CLB), CUBound(CUB), NoUBound(NUB), LowerBound(LB), + LowerBoundExp(LBE), UpperBound(UB), UpperBoundExp(UBE) {} + MDNodeKeyImpl(const DIFortranSubrange *N) + : CLBound(N->getCLowerBound()), CUBound(N->getCUpperBound()), + NoUBound(N->noUpperBound()), LowerBound(N->getRawLowerBound()), + LowerBoundExp(N->getRawLowerBoundExpression()), + UpperBound(N->getRawUpperBound()), + UpperBoundExp(N->getRawUpperBoundExpression()) {} + + bool isKeyOf(const DIFortranSubrange *RHS) const { + return CLBound == RHS->getCLowerBound() && + CUBound == RHS->getCUpperBound() && + NoUBound == RHS->noUpperBound() && + LowerBound == RHS->getRawLowerBound() && + LowerBoundExp == RHS->getRawLowerBoundExpression() && + UpperBound == RHS->getRawUpperBound() && + UpperBoundExp == RHS->getRawUpperBoundExpression(); + } + + unsigned getHashValue() const { + return hash_combine(CLBound, CUBound, NoUBound, UpperBound, UpperBoundExp, + LowerBound, LowerBoundExp); + } +}; + template <> struct MDNodeKeyImpl { int64_t Value; MDString *Name; @@ -397,6 +436,39 @@ } }; +template <> struct MDNodeKeyImpl { + unsigned Tag; + MDString *Name; + Metadata *StringLength; + Metadata *StringLengthExp; + uint64_t SizeInBits; + uint32_t AlignInBits; + unsigned Encoding; + + MDNodeKeyImpl(unsigned Tag, MDString *Name, Metadata *StringLength, + Metadata *StringLengthExp, uint64_t SizeInBits, + uint32_t AlignInBits, unsigned Encoding) + : Tag(Tag), Name(Name), StringLength(StringLength), + StringLengthExp(StringLengthExp), SizeInBits(SizeInBits), + AlignInBits(AlignInBits), Encoding(Encoding) {} + MDNodeKeyImpl(const DIStringType *N) + : Tag(N->getTag()), Name(N->getRawName()), + StringLength(N->getRawStringLength()), + StringLengthExp(N->getRawStringLengthExp()), + SizeInBits(N->getSizeInBits()), + AlignInBits(N->getAlignInBits()), Encoding(N->getEncoding()) {} + + bool isKeyOf(const DIStringType *RHS) const { + return Tag == RHS->getTag() && Name == RHS->getRawName() && + SizeInBits == RHS->getSizeInBits() && + AlignInBits == RHS->getAlignInBits() && + Encoding == RHS->getEncoding(); + } + unsigned getHashValue() const { + return hash_combine(Tag, Name, SizeInBits, AlignInBits, Encoding); + } +}; + template <> struct MDNodeKeyImpl { unsigned Tag; MDString *Name; @@ -553,6 +625,52 @@ } }; +template <> struct MDNodeKeyImpl { + unsigned Tag; + MDString *Name; + Metadata *File; + unsigned Line; + Metadata *Scope; + Metadata *BaseType; + uint64_t SizeInBits; + uint64_t OffsetInBits; + uint32_t AlignInBits; + unsigned Flags; + Metadata *Elements; + + MDNodeKeyImpl(unsigned Tag, MDString *Name, Metadata *File, unsigned Line, + Metadata *Scope, Metadata *BaseType, uint64_t SizeInBits, + uint32_t AlignInBits, uint64_t OffsetInBits, unsigned Flags, + Metadata *Elements) + : Tag(Tag), Name(Name), File(File), Line(Line), Scope(Scope), + BaseType(BaseType), SizeInBits(SizeInBits), OffsetInBits(OffsetInBits), + AlignInBits(AlignInBits), Flags(Flags), Elements(Elements) {} + MDNodeKeyImpl(const DIFortranArrayType *N) + : Tag(N->getTag()), Name(N->getRawName()), File(N->getRawFile()), + Line(N->getLine()), Scope(N->getRawScope()), + BaseType(N->getRawBaseType()), SizeInBits(N->getSizeInBits()), + OffsetInBits(N->getOffsetInBits()), AlignInBits(N->getAlignInBits()), + Flags(N->getFlags()), Elements(N->getRawElements()) {} + + bool isKeyOf(const DIFortranArrayType *RHS) const { + return Tag == RHS->getTag() && Name == RHS->getRawName() && + File == RHS->getRawFile() && Line == RHS->getLine() && + Scope == RHS->getRawScope() && BaseType == RHS->getRawBaseType() && + SizeInBits == RHS->getSizeInBits() && + AlignInBits == RHS->getAlignInBits() && + OffsetInBits == RHS->getOffsetInBits() && Flags == RHS->getFlags() && + Elements == RHS->getRawElements(); + } + + unsigned getHashValue() const { + // Intentionally computes the hash on a subset of the operands for + // performance reason. The subset has to be significant enough to avoid + // collision "most of the time". There is no correctness issue in case of + // collision because of the full check above. + return hash_combine(Name, File, Line, BaseType, Scope, Elements); + } +}; + template <> struct MDNodeKeyImpl { unsigned Flags; uint8_t CC; @@ -791,32 +909,65 @@ } }; +template <> struct MDNodeKeyImpl { + Metadata *Scope; + Metadata *Decl; + MDString *Name; + Metadata *File; + unsigned LineNo; + uint32_t AlignInBits; + + MDNodeKeyImpl(Metadata *Scope, Metadata *Decl, MDString *Name, + Metadata *File, unsigned LineNo, uint32_t AlignInBits) + : Scope(Scope), Decl(Decl), Name(Name), File(File), LineNo(LineNo), + AlignInBits(AlignInBits) {} + MDNodeKeyImpl(const DICommonBlock *N) + : Scope(N->getRawScope()), Decl(N->getRawDecl()), Name(N->getRawName()), + File(N->getRawFile()), LineNo(N->getLineNo()), + AlignInBits(N->getAlignInBits()) {} + + bool isKeyOf(const DICommonBlock *RHS) const { + return Scope == RHS->getRawScope() && Decl == RHS->getRawDecl() && + Name == RHS->getRawName() && File == RHS->getRawFile() && + LineNo == RHS->getLineNo() && AlignInBits == RHS->getAlignInBits(); + } + + unsigned getHashValue() const { + return hash_combine(Scope, Decl, Name, File, LineNo, AlignInBits); + } +}; + template <> struct MDNodeKeyImpl { Metadata *Scope; MDString *Name; MDString *ConfigurationMacros; MDString *IncludePath; MDString *ISysRoot; + Metadata *File; + unsigned Line; MDNodeKeyImpl(Metadata *Scope, MDString *Name, MDString *ConfigurationMacros, - MDString *IncludePath, MDString *ISysRoot) + MDString *IncludePath, MDString *ISysRoot, Metadata *File, + unsigned Line) : Scope(Scope), Name(Name), ConfigurationMacros(ConfigurationMacros), - IncludePath(IncludePath), ISysRoot(ISysRoot) {} + IncludePath(IncludePath), ISysRoot(ISysRoot), File(File), Line(Line) {} MDNodeKeyImpl(const DIModule *N) : Scope(N->getRawScope()), Name(N->getRawName()), ConfigurationMacros(N->getRawConfigurationMacros()), - IncludePath(N->getRawIncludePath()), ISysRoot(N->getRawISysRoot()) {} + IncludePath(N->getRawIncludePath()), ISysRoot(N->getRawISysRoot()), + File(N->getRawFile()), Line(N->getLine()) {} bool isKeyOf(const DIModule *RHS) const { return Scope == RHS->getRawScope() && Name == RHS->getRawName() && ConfigurationMacros == RHS->getRawConfigurationMacros() && IncludePath == RHS->getRawIncludePath() && - ISysRoot == RHS->getRawISysRoot(); + ISysRoot == RHS->getRawISysRoot() && + File == RHS->getRawFile() && Line == RHS->getLine(); } unsigned getHashValue() const { return hash_combine(Scope, Name, - ConfigurationMacros, IncludePath, ISysRoot); + ConfigurationMacros, IncludePath, ISysRoot, File, Line); } }; @@ -865,16 +1016,19 @@ bool IsLocalToUnit; bool IsDefinition; Metadata *StaticDataMemberDeclaration; + unsigned Flags; uint32_t AlignInBits; MDNodeKeyImpl(Metadata *Scope, MDString *Name, MDString *LinkageName, Metadata *File, unsigned Line, Metadata *Type, bool IsLocalToUnit, bool IsDefinition, - Metadata *StaticDataMemberDeclaration, uint32_t AlignInBits) + Metadata *StaticDataMemberDeclaration, + unsigned Flags, uint32_t AlignInBits) : Scope(Scope), Name(Name), LinkageName(LinkageName), File(File), Line(Line), Type(Type), IsLocalToUnit(IsLocalToUnit), IsDefinition(IsDefinition), StaticDataMemberDeclaration(StaticDataMemberDeclaration), + Flags(Flags), AlignInBits(AlignInBits) {} MDNodeKeyImpl(const DIGlobalVariable *N) : Scope(N->getRawScope()), Name(N->getRawName()), @@ -882,7 +1036,7 @@ Line(N->getLine()), Type(N->getRawType()), IsLocalToUnit(N->isLocalToUnit()), IsDefinition(N->isDefinition()), StaticDataMemberDeclaration(N->getRawStaticDataMemberDeclaration()), - AlignInBits(N->getAlignInBits()) {} + Flags(N->getFlags()), AlignInBits(N->getAlignInBits()) {} bool isKeyOf(const DIGlobalVariable *RHS) const { return Scope == RHS->getRawScope() && Name == RHS->getRawName() && @@ -892,6 +1046,7 @@ IsDefinition == RHS->isDefinition() && StaticDataMemberDeclaration == RHS->getRawStaticDataMemberDeclaration() && + Flags == RHS->getFlags() && AlignInBits == RHS->getAlignInBits(); } @@ -905,7 +1060,7 @@ // TODO: make hashing work fine with such situations return hash_combine(Scope, Name, LinkageName, File, Line, Type, IsLocalToUnit, IsDefinition, /* AlignInBits, */ - StaticDataMemberDeclaration); + StaticDataMemberDeclaration, Flags); } }; @@ -1272,6 +1427,9 @@ using VectorConstantsTy = ConstantUniqueMap; VectorConstantsTy VectorConstants; + DenseMap> SVVConstants; + DenseMap> VSVConstants; + DenseMap> CPNConstants; DenseMap> UVConstants; @@ -1308,7 +1466,7 @@ unsigned NamedStructTypesUniqueID = 0; DenseMap, ArrayType*> ArrayTypes; - DenseMap, VectorType*> VectorTypes; + DenseMap, VectorType*> VectorTypes; DenseMap PointerTypes; // Pointers in AddrSpace = 0 DenseMap, PointerType*> ASPointerTypes; Index: lib/IR/LLVMContextImpl.cpp =================================================================== --- lib/IR/LLVMContextImpl.cpp +++ lib/IR/LLVMContextImpl.cpp @@ -99,6 +99,8 @@ UVConstants.clear(); IntConstants.clear(); FPConstants.clear(); + SVVConstants.clear(); + VSVConstants.clear(); for (auto &CDSConstant : CDSConstants) delete CDSConstant.second; Index: lib/IR/LegacyPassManager.cpp =================================================================== --- lib/IR/LegacyPassManager.cpp +++ lib/IR/LegacyPassManager.cpp @@ -76,6 +76,15 @@ llvm::cl::desc("Print IR after specified passes"), cl::Hidden); +static cl::opt +PrintBeforePass("print-before-pass", + llvm::cl::desc("Print IR before specified pass"), + cl::Hidden); + +static cl::opt +PrintAfterPass("print-after-pass", + llvm::cl::desc("Print IR after specified pass"), + cl::Hidden); static cl::opt PrintBeforeAll("print-before-all", llvm::cl::desc("Print IR before each pass"), cl::init(false), cl::Hidden); @@ -113,13 +122,15 @@ /// This is a utility to check whether a pass should have IR dumped /// before it. static bool ShouldPrintBeforePass(const PassInfo *PI) { - return PrintBeforeAll || ShouldPrintBeforeOrAfterPass(PI, PrintBefore); + return PrintBeforeAll || ShouldPrintBeforeOrAfterPass(PI, PrintBefore) || + (PrintBeforePass == PI->getPassArgument()); } /// This is a utility to check whether a pass should have IR dumped /// after it. static bool ShouldPrintAfterPass(const PassInfo *PI) { - return PrintAfterAll || ShouldPrintBeforeOrAfterPass(PI, PrintAfter); + return PrintAfterAll || ShouldPrintBeforeOrAfterPass(PI, PrintAfter) || + (PrintAfterPass == PI->getPassArgument()); } bool llvm::forcePrintModuleIR() { return PrintModuleScope; } @@ -129,6 +140,7 @@ PrintFuncsList.end()); return PrintFuncNames.empty() || PrintFuncNames.count(FunctionName); } + /// isPassDebuggingExecutionsOrMore - Return true if -debug-pass=Executions /// or higher is specified. bool PMDataManager::isPassDebuggingExecutionsOrMore() const { Index: lib/IR/Operator.cpp =================================================================== --- lib/IR/Operator.cpp +++ lib/IR/Operator.cpp @@ -38,6 +38,11 @@ DL.getIndexSizeInBits(getPointerAddressSpace()) && "The offset bit width does not match DL specification."); + // This contains a hidden multiplication by a runtime constant. + if (auto *PtrTy = dyn_cast(getSourceElementType())) + if (PtrTy->isScalable()) + return false; + for (gep_type_iterator GTI = gep_type_begin(this), GTE = gep_type_end(this); GTI != GTE; ++GTI) { ConstantInt *OpC = dyn_cast(GTI.getOperand()); Index: lib/IR/Type.cpp =================================================================== --- lib/IR/Type.cpp +++ lib/IR/Type.cpp @@ -587,21 +587,24 @@ // VectorType Implementation //===----------------------------------------------------------------------===// -VectorType::VectorType(Type *ElType, unsigned NumEl) - : SequentialType(VectorTyID, ElType, NumEl) {} +VectorType::VectorType(Type *ElType, unsigned NumEl, bool Scalable) + : SequentialType(VectorTyID, ElType, NumEl), Scalable(Scalable) {} -VectorType *VectorType::get(Type *ElementType, unsigned NumElements) { - assert(NumElements > 0 && "#Elements of a VectorType must be greater than 0"); - assert(isValidElementType(ElementType) && "Element type of a VectorType must " - "be an integer, floating point, or " - "pointer type."); +VectorType *VectorType::get(Type *ElType, VectorType::ElementCount EC) { + Type *ElementType = const_cast(ElType); + assert(EC.Min > 0 && + "#Elements of a VectorType must be greater than 0"); + assert(isValidElementType(ElementType) && + "Element type of a VectorType must be an integer, floating point, or " + "pointer type."); LLVMContextImpl *pImpl = ElementType->getContext().pImpl; VectorType *&Entry = ElementType->getContext().pImpl - ->VectorTypes[std::make_pair(ElementType, NumElements)]; + ->VectorTypes[std::make_tuple(ElementType, EC.Min, EC.Scalable)]; if (!Entry) - Entry = new (pImpl->TypeAllocator) VectorType(ElementType, NumElements); + Entry = new (pImpl->TypeAllocator) VectorType(ElementType, EC.Min, + EC.Scalable); return Entry; } Index: lib/IR/Verifier.cpp =================================================================== --- lib/IR/Verifier.cpp +++ lib/IR/Verifier.cpp @@ -871,20 +871,35 @@ void Verifier::visitDISubrange(const DISubrange &N) { AssertDI(N.getTag() == dwarf::DW_TAG_subrange_type, "invalid tag", &N); auto Count = N.getCount(); - AssertDI(Count, "Count must either be a signed constant or a DIVariable", - &N); + AssertDI(Count, + "Count must either be a signed constant, a DIVariable or a DIExpression", + &N); AssertDI(!Count.is() || Count.get()->getSExtValue() >= -1, "invalid subrange count", &N); } +void Verifier::visitDIFortranSubrange(const DIFortranSubrange &N) { + AssertDI(N.getTag() == dwarf::DW_TAG_subrange_type, "invalid tag", &N); + AssertDI(N.getLowerBound() ? (N.getLowerBoundExp() != nullptr) : true, + "no lower bound", &N); + AssertDI(N.getUpperBound() ? (N.getUpperBoundExp() != nullptr) : true, + "no upper bound", &N); +} + void Verifier::visitDIEnumerator(const DIEnumerator &N) { AssertDI(N.getTag() == dwarf::DW_TAG_enumerator, "invalid tag", &N); } void Verifier::visitDIBasicType(const DIBasicType &N) { AssertDI(N.getTag() == dwarf::DW_TAG_base_type || - N.getTag() == dwarf::DW_TAG_unspecified_type, + N.getTag() == dwarf::DW_TAG_unspecified_type || + N.getTag() == dwarf::DW_TAG_string_type, + "invalid tag", &N); +} + +void Verifier::visitDIStringType(const DIStringType &N) { + AssertDI( N.getTag() == dwarf::DW_TAG_string_type, "invalid tag", &N); } @@ -984,6 +999,22 @@ } } +void Verifier::visitDIFortranArrayType(const DIFortranArrayType &N) { + // Common scope checks. + visitDIScope(N); + + AssertDI(N.getTag() == dwarf::DW_TAG_array_type, "invalid tag", &N); + + AssertDI(isScope(N.getRawScope()), "invalid scope", &N, N.getRawScope()); + AssertDI(isType(N.getRawBaseType()), "invalid base type", &N, + N.getRawBaseType()); + + AssertDI(!N.getRawElements() || isa(N.getRawElements()), + "invalid composite elements", &N, N.getRawElements()); + AssertDI(!hasConflictingReferenceFlags(N.getFlags()), + "invalid reference flags", &N); +} + void Verifier::visitDISubroutineType(const DISubroutineType &N) { AssertDI(N.getTag() == dwarf::DW_TAG_subroutine_type, "invalid tag", &N); if (auto *Types = N.getRawTypeArray()) { @@ -1138,6 +1169,14 @@ visitDILexicalBlockBase(N); } +void Verifier::visitDICommonBlock(const DICommonBlock &N) { + AssertDI(N.getTag() == dwarf::DW_TAG_common_block, "invalid tag", &N); + if (auto *S = N.getRawScope()) + AssertDI(isa(S), "invalid scope ref", &N, S); + if (auto *S = N.getRawDecl()) + AssertDI(isa(S), "invalid declaration", &N, S); +} + void Verifier::visitDINamespace(const DINamespace &N) { AssertDI(N.getTag() == dwarf::DW_TAG_namespace, "invalid tag", &N); if (auto *S = N.getRawScope()) @@ -1206,7 +1245,8 @@ visitDIVariable(N); AssertDI(N.getTag() == dwarf::DW_TAG_variable, "invalid tag", &N); - AssertDI(!N.getName().empty(), "missing global variable name", &N); + AssertDI(!N.getName().empty() || N.isArtificial(), + "missing global variable name", &N); AssertDI(isType(N.getRawType()), "invalid type ref", &N, N.getRawType()); AssertDI(N.getType(), "missing global variable type", &N); if (auto *Member = N.getRawStaticDataMemberDeclaration()) { @@ -4005,14 +4045,14 @@ getIntrinsicInfoTableEntries(ID, Table); ArrayRef TableRef = Table; + // Walk the descriptors to extract overloaded types. SmallVector ArgTys; - Assert(!Intrinsic::matchIntrinsicType(IFTy->getReturnType(), - TableRef, ArgTys), + Intrinsic::MatchIntrinsicTypesResult Res = + Intrinsic::matchIntrinsicSignature(IFTy, TableRef, ArgTys); + Assert(Res != Intrinsic::MatchIntrinsicTypes_NoMatchRet, "Intrinsic has incorrect return type!", IF); - for (unsigned i = 0, e = IFTy->getNumParams(); i != e; ++i) - Assert(!Intrinsic::matchIntrinsicType(IFTy->getParamType(i), - TableRef, ArgTys), - "Intrinsic has incorrect argument type!", IF); + Assert(Res != Intrinsic::MatchIntrinsicTypes_NoMatchArg, + "Intrinsic has incorrect argument type!", IF); // Verify if the intrinsic call matches the vararg property. if (IsVarArg) Index: lib/Linker/IRMover.cpp =================================================================== --- lib/Linker/IRMover.cpp +++ lib/Linker/IRMover.cpp @@ -174,6 +174,10 @@ if (DSTy->isLiteral() != SSTy->isLiteral() || DSTy->isPacked() != SSTy->isPacked()) return false; + } else if (auto *DVecTy = dyn_cast(DstTy)) { + if (DVecTy->getElementCount() != + cast(SrcTy)->getElementCount()) + return false; } else if (auto *DSeqTy = dyn_cast(DstTy)) { if (DSeqTy->getNumElements() != cast(SrcTy)->getNumElements()) @@ -306,7 +310,7 @@ cast(Ty)->getNumElements()); case Type::VectorTyID: return *Entry = VectorType::get(ElementTypes[0], - cast(Ty)->getNumElements()); + cast(Ty)->getElementCount()); case Type::PointerTyID: return *Entry = PointerType::get(ElementTypes[0], cast(Ty)->getAddressSpace()); Index: lib/MC/MCDwarf.cpp =================================================================== --- lib/MC/MCDwarf.cpp +++ lib/MC/MCDwarf.cpp @@ -1229,6 +1229,81 @@ MCGenDwarfLabelEntry(Name, FileNumber, LineNumber, Label)); } +// Create Multiply (by reg) and Add expression: +// + Offset + (Scalereg * Offset2) +static void addMulAddExpression(SmallVectorImpl &expr, int Offset, + unsigned Scalereg, int Offset2) { + uint8_t buffer[10]; + + // Add unscaled offset if non-zero + if (Offset) { + expr.push_back(dwarf::DW_OP_consts); + expr.append(buffer, buffer + encodeSLEB128(Offset, buffer)); + expr.push_back((uint8_t)dwarf::DW_OP_plus); + } + + // Add scaled offset if non-zero + if (Offset2) { + expr.push_back((uint8_t)dwarf::DW_OP_consts); + expr.append(buffer, buffer + encodeSLEB128(Offset2, buffer)); + + expr.push_back((uint8_t)dwarf::DW_OP_bregx); + expr.append(buffer, buffer + encodeULEB128(Scalereg, buffer)); + expr.push_back(0); + + expr.push_back((uint8_t)dwarf::DW_OP_mul); + expr.push_back((uint8_t)dwarf::DW_OP_plus); + } +} + +// Create expression to add value of base register on the expression stack. +static void getBaseRegExpression(SmallVectorImpl &expr, + unsigned Basereg) { + uint8_t buffer[10]; + + // Create expression for base + offset + offset2 * scalereg + expr.push_back((uint8_t)dwarf::DW_OP_bregx); + expr.append(buffer, buffer + encodeULEB128(Basereg, buffer)); + expr.push_back(0); +} + +// Creates an MCCFIInstruction: +// { DW_CFA_def_cfa_expression, ULEB128 (sizeof expr), expr } +MCCFIInstruction +MCCFIInstruction::createScaledDefCfa(MCSymbol *L, + unsigned Basereg, int Offset, + unsigned Scalereg, int Offset2, + StringRef Comment) { + SmallVector expr; + getBaseRegExpression(expr, Basereg); + addMulAddExpression(expr, Offset, Scalereg, Offset2); + + // Create the def_cfa + expression + uint8_t buffer[10]; + SmallVector expr2 = { dwarf::DW_CFA_def_cfa_expression }; + expr2.append(buffer, buffer + encodeULEB128(expr.size(), buffer)); + expr2.append(expr.begin(), expr.end()); + + return createEscape(L, StringRef(expr2.data(), expr2.size()), Comment); +} + +// Creates an MCCFIInstruction: +// { DW_CFA_expression, ULEB128 (reg), ULEB128 (sizeof expr), expr } +MCCFIInstruction +MCCFIInstruction::createScaledCfaOffset(MCSymbol *L, unsigned Reg, int Offset, + unsigned Scalereg, int Offset2, + StringRef Comment) { + SmallVector expr; + addMulAddExpression(expr, Offset, Scalereg, Offset2); + + uint8_t buffer[10]; + SmallVector expr2 = { dwarf::DW_CFA_expression }; + expr2.append(buffer, buffer + encodeULEB128(Reg, buffer)); + expr2.append(buffer, buffer + encodeULEB128(expr.size(), buffer)); + expr2.append(expr.begin(), expr.end()); + return createEscape(L, StringRef(expr2.data(), expr2.size()), Comment); +} + static int getDataAlignmentFactor(MCStreamer &streamer) { MCContext &context = streamer.getContext(); const MCAsmInfo *asmInfo = context.getAsmInfo(); Index: lib/MC/MCParser/AsmLexer.cpp =================================================================== --- lib/MC/MCParser/AsmLexer.cpp +++ lib/MC/MCParser/AsmLexer.cpp @@ -62,8 +62,6 @@ return (unsigned char)*CurPtr++; } -/// LexFloatLiteral: [0-9]*[.][0-9]*([eE][+-]?[0-9]*)? -/// /// The leading integral digit sequence and dot should have already been /// consumed, some or all of the fractional digit sequence *can* have been /// consumed. @@ -72,14 +70,16 @@ while (isDigit(*CurPtr)) ++CurPtr; - // Check for exponent; we intentionally accept a slighlty wider set of - // literals here and rely on the upstream client to reject invalid ones (e.g., - // "1e+"). - if (*CurPtr == 'e' || *CurPtr == 'E') { + if (*CurPtr == '-' || *CurPtr == '+') + return ReturnError(CurPtr, "Invalid sign in float literal"); + + // Check for exponent + if ((*CurPtr == 'e' || *CurPtr == 'E')) { ++CurPtr; + if (*CurPtr == '-' || *CurPtr == '+') ++CurPtr; - while (isDigit(*CurPtr)) + while (isdigit(*CurPtr)) ++CurPtr; } @@ -136,18 +136,18 @@ /// LexIdentifier: [a-zA-Z_.][a-zA-Z0-9_$.@?]* static bool IsIdentifierChar(char c, bool AllowAt) { - return isAlnum(c) || c == '_' || c == '$' || c == '.' || + return isalnum(c) || c == '_' || c == '$' || c == '.' || (c == '@' && AllowAt) || c == '?'; } - AsmToken AsmLexer::LexIdentifier() { // Check for floating point literals. if (CurPtr[-1] == '.' && isDigit(*CurPtr)) { // Disambiguate a .1243foo identifier from a floating literal. while (isDigit(*CurPtr)) ++CurPtr; - if (*CurPtr == 'e' || *CurPtr == 'E' || - !IsIdentifierChar(*CurPtr, AllowAtInIdentifier)) + + if (!IsIdentifierChar(*CurPtr, AllowAtInIdentifier) || + *CurPtr == 'e' || *CurPtr == 'E') return LexFloatLiteral(); } @@ -323,8 +323,9 @@ unsigned Radix = doLookAhead(CurPtr, 10); bool isHex = Radix == 16; // Check for floating point literals. - if (!isHex && (*CurPtr == '.' || *CurPtr == 'e')) { - ++CurPtr; + if (!isHex && (*CurPtr == '.' || *CurPtr == 'e' || *CurPtr == 'E')) { + if (*CurPtr == '.') + ++CurPtr; return LexFloatLiteral(); } Index: lib/Support/APFloat.cpp =================================================================== --- lib/Support/APFloat.cpp +++ lib/Support/APFloat.cpp @@ -199,7 +199,10 @@ const unsigned int overlargeExponent = 24000; /* FIXME. */ StringRef::iterator p = begin; - assert(p != end && "Exponent has no digits"); + // Treat no exponent as 0 to match binutils + if (p == end || ((*p == '-' || *p == '+') && (p + 1) == end)) { + return 0; + } isNegative = (*p == '-'); if (*p == '-' || *p == '+') { Index: lib/Support/LockFileManager.cpp =================================================================== --- lib/Support/LockFileManager.cpp +++ lib/Support/LockFileManager.cpp @@ -291,7 +291,7 @@ sys::DontRemoveFileOnSignal(UniqueLockFileName); } -LockFileManager::WaitForUnlockResult LockFileManager::waitForUnlock() { +LockFileManager::WaitForUnlockResult LockFileManager::waitForUnlock(unsigned MaxSeconds) { if (getState() != LFS_Shared) return Res_Success; @@ -302,9 +302,6 @@ Interval.tv_sec = 0; Interval.tv_nsec = 1000000; #endif - // Don't wait more than 40s per iteration. Total timeout for the file - // to appear is ~1.5 minutes. - const unsigned MaxSeconds = 40; do { // Sleep for the designated interval, to allow the owning process time to // finish up and remove the lock file. Index: lib/Support/TargetParser.cpp =================================================================== --- lib/Support/TargetParser.cpp +++ lib/Support/TargetParser.cpp @@ -472,6 +472,8 @@ Features.push_back("+rdm"); if (Extensions & AArch64::AEK_SVE) Features.push_back("+sve"); + if (Extensions & AArch64::AEK_SVE2) + Features.push_back("+sve2"); if (Extensions & AArch64::AEK_RCPC) Features.push_back("+rcpc"); Index: lib/Support/Unix/Process.inc =================================================================== --- lib/Support/Unix/Process.inc +++ lib/Support/Unix/Process.inc @@ -458,3 +458,40 @@ return ::rand(); #endif } + +#if defined(__linux__) +#include /* statfs */ +#include /* PROC_SUPER_MAGIC */ +#endif + +bool Process::CheckMyParentUsesSameExeImage() { +#if defined(__linux__) + // Check that /proc is really procfs + struct statfs buf; + if (statfs("/proc", &buf) != 0) + return false; + if (buf.f_type != PROC_SUPER_MAGIC) + return false; + + // Parent + pid_t parent = getppid(); + char proc_path[128]; + snprintf(proc_path, 128, "/proc/%u/exe", parent); + proc_path[127] = '\0'; + + struct stat buf_parent = {}; + if (stat(proc_path, &buf_parent) != 0) + return false; + + // Current process + struct stat buf_me = {}; + if (stat("/proc/self/exe", &buf_me) != 0) + return false; + + // Check that the image file is the same for both + return (buf_me.st_dev == buf_parent.st_dev && + buf_me.st_ino == buf_parent.st_ino); +#else // !defined(__linux__) + return false; +#endif +} Index: lib/Support/Unix/Signals.inc =================================================================== --- lib/Support/Unix/Signals.inc +++ lib/Support/Unix/Signals.inc @@ -47,6 +47,7 @@ #include "llvm/Support/raw_ostream.h" #include #include +#include #ifdef HAVE_BACKTRACE # include BACKTRACE_HEADER // For backtrace(). #endif @@ -334,6 +335,10 @@ if (auto OldInterruptFunction = InterruptFunction.exchange(nullptr)) return OldInterruptFunction(); + // Send a special return code that drivers can check for, from sysexits.h. + if (Sig == SIGPIPE) + exit(EX_IOERR); + raise(Sig); // Execute the default handler. return; } Index: lib/Target/AArch64/AArch64.h =================================================================== --- lib/Target/AArch64/AArch64.h +++ lib/Target/AArch64/AArch64.h @@ -50,6 +50,11 @@ FunctionPass *createAArch64CleanupLocalDynamicTLSPass(); FunctionPass *createAArch64CollectLOHPass(); +FunctionPass *createSVEAddressingModesPass(); +FunctionPass *createSVEPostVectorizePass(); +FunctionPass *createSVEExpandLibCallPass(bool Optimize); +FunctionPass *createSVEIntrinsicOptsPass(); +FunctionPass *createSVEConditionalEarlyClobberPass(); InstructionSelector * createAArch64InstructionSelector(const AArch64TargetMachine &, AArch64Subtarget &, AArch64RegisterBankInfo &); @@ -71,6 +76,7 @@ void initializeFalkorHWPFFixPass(PassRegistry&); void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&); void initializeLDTLSCleanupPass(PassRegistry&); +void initializeSVEIntrinsicOptsPass(PassRegistry&); } // end namespace llvm #endif Index: lib/Target/AArch64/AArch64.td =================================================================== --- lib/Target/AArch64/AArch64.td +++ lib/Target/AArch64/AArch64.td @@ -75,7 +75,22 @@ "Enable Statistical Profiling extension">; def FeatureSVE : SubtargetFeature<"sve", "HasSVE", "true", - "Enable Scalable Vector Extension (SVE) instructions">; + "Enable Scalable Vector Extension (SVE) instructions", [FeatureFullFP16]>; + +def FeatureSVE2 : SubtargetFeature<"sve2", "HasSVE2", "true", + "Enable Scalable Vector Extension 2 (SVE2) instructions", [FeatureSVE]>; + +def FeatureSVE2AES : SubtargetFeature<"sve2-aes", "HasSVE2AES", "true", + "Enable AES SVE2 instructions", [FeatureSVE2, FeatureAES]>; + +def FeatureSVE2SM4 : SubtargetFeature<"sve2-sm4", "HasSVE2SM4", "true", + "Enable SM4 SVE2 instructions", [FeatureSVE2, FeatureSM4]>; + +def FeatureSVE2SHA3 : SubtargetFeature<"sve2-sha3", "HasSVE2SHA3", "true", + "Enable SHA3 SVE2 instructions", [FeatureSVE2, FeatureSHA3]>; + +def FeatureSVE2BitPerm : SubtargetFeature<"sve2-bitperm", "HasSVE2BitPerm", "true", + "Enable bit permutation SVE2 instructions", [FeatureSVE2]>; /// Cyclone has register move instructions which are "free". def FeatureZCRegMove : SubtargetFeature<"zcm", "HasZeroCycleRegMove", "true", @@ -179,6 +194,10 @@ "dotprod", "HasDotProd", "true", "Enable dot product support">; +def FeatureIterativeReciprocal : SubtargetFeature< + "use-iterative-reciprocal", "UseIterativeReciprocal", "true", + "Use the iterative reciprocal approximation">; + def FeatureNoNegativeImmediates : SubtargetFeature<"no-neg-immediates", "NegativeImmediates", "false", "Convert immediates and instructions " @@ -238,6 +257,18 @@ //===----------------------------------------------------------------------===// // AArch64 Processors supported. // + +//===----------------------------------------------------------------------===// +// Unsupported features to disable for scheduling models +//===----------------------------------------------------------------------===// + +class AArch64Unsupported { list F; } + +def SVEUnsupported : AArch64Unsupported { + let F = [HasSVE, HasSVE2, HasSVE2AES, HasSVE2SM4, HasSVE2SHA3, + HasSVE2BitPerm]; +} + include "AArch64SchedA53.td" include "AArch64SchedA57.td" include "AArch64SchedCyclone.td" Index: lib/Target/AArch64/AArch64AsmPrinter.cpp =================================================================== --- lib/Target/AArch64/AArch64AsmPrinter.cpp +++ lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -352,6 +352,7 @@ case 's': // Print S register. case 'd': // Print D register. case 'q': // Print Q register. + case 'z': // Print Z register. if (MO.isReg()) { const TargetRegisterClass *RC; switch (ExtraCode[0]) { @@ -370,6 +371,9 @@ case 'q': RC = &AArch64::FPR128RegClass; break; + case 'z': + RC = &AArch64::ZPRRegClass; + break; default: return true; } @@ -390,9 +394,21 @@ AArch64::GPR64allRegClass.contains(Reg)) return printAsmMRegister(MO, 'x', O); + bool hasAltName; + const TargetRegisterClass *RegClass; + if (AArch64::ZPRRegClass.contains(Reg)) { + RegClass = &AArch64::ZPRRegClass; + hasAltName = false; + } else if (AArch64::PPRRegClass.contains(Reg)) { + RegClass = &AArch64::PPRRegClass; + hasAltName = false; + } else { + RegClass = &AArch64::FPR128RegClass; + hasAltName = true; + } + // If this is a b, h, s, d, or q register, print it as a v register. - return printAsmRegInClass(MO, &AArch64::FPR128RegClass, true /* vector */, - O); + return printAsmRegInClass(MO, RegClass, hasAltName /* vector */, O); } printOperand(MI, OpNum, O); Index: lib/Target/AArch64/AArch64CallLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64CallLowering.cpp +++ lib/Target/AArch64/AArch64CallLowering.cpp @@ -197,8 +197,8 @@ return; SmallVector SplitVTs; - SmallVector Offsets; - ComputeValueVTs(TLI, DL, OrigArg.Ty, SplitVTs, &Offsets, 0); + SmallVector Offsets; + ComputeValueVTs(TLI, DL, OrigArg.Ty, SplitVTs, &Offsets, {0, 0}); if (SplitVTs.size() == 1) { // No splitting to do, but we want to replace the original type (e.g. [1 x @@ -223,7 +223,7 @@ SplitArgs.back().Flags.setInConsecutiveRegsLast(); for (unsigned i = 0; i < Offsets.size(); ++i) - PerformArgSplit(SplitArgs[FirstRegIdx + i].Reg, Offsets[i] * 8); + PerformArgSplit(SplitArgs[FirstRegIdx + i].Reg, Offsets[i].UnscaledBytes * 8); } bool AArch64CallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, Index: lib/Target/AArch64/AArch64CallingConvention.td =================================================================== --- lib/Target/AArch64/AArch64CallingConvention.td +++ lib/Target/AArch64/AArch64CallingConvention.td @@ -26,6 +26,7 @@ CCIfType<[iPTR], CCBitConvertToType>, CCIfType<[v2f32], CCBitConvertToType>, CCIfType<[v2f64, v4f32], CCBitConvertToType>, + // TODO: Does SVE need a convert like the above? // Big endian vectors must be passed as if they were 1-element vectors so that // their lanes are in a consistent order. @@ -54,6 +55,19 @@ CCIfConsecutiveRegs>, + // TODO: Do we need to have Qn shadow Zn? + CCIfType<[nxv16i8,nxv8i16,nxv4i32,nxv2i64,nxv2f16,nxv4f16,nxv8f16, + nxv2f32,nxv4f32,nxv2f64], + CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>, + CCIfType<[nxv16i8,nxv8i16,nxv4i32,nxv2i64,nxv2f16,nxv4f16,nxv8f16, + nxv2f32,nxv4f32,nxv2f64], + CCPassIndirect>, + + CCIfType<[nxv2i1,nxv4i1,nxv8i1,nxv16i1], + CCAssignToReg<[P0, P1, P2, P3]>>, + CCIfType<[nxv2i1,nxv4i1,nxv8i1,nxv16i1], + CCPassIndirect>, + // Handle i1, i8, i16, i32, i64, f32, f64 and v2f64 by passing in registers, // up to eight each of GPR and FPR. CCIfType<[i1, i8, i16], CCPromoteToType>, @@ -93,6 +107,7 @@ CCIfType<[iPTR], CCBitConvertToType>, CCIfType<[v2f32], CCBitConvertToType>, CCIfType<[v2f64, v4f32], CCBitConvertToType>, + // TODO: Does SVE need a convert like the above? CCIfSwiftError>>, @@ -118,7 +133,15 @@ CCAssignToRegWithShadow<[D0, D1, D2, D3, D4, D5, D6, D7], [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>, CCIfType<[f128, v2i64, v4i32, v8i16, v16i8, v4f32, v2f64, v8f16], - CCAssignToReg<[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>> + CCAssignToReg<[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>, + + // TODO: Do we need to have Qn shadow Zn? + CCIfType<[nxv16i8,nxv8i16,nxv4i32,nxv2i64,nxv2f16,nxv4f16,nxv8f16, + nxv2f32,nxv4f32,nxv2f64], + CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>, + + CCIfType<[nxv2i1,nxv4i1,nxv8i1,nxv16i1], + CCAssignToReg<[P0, P1, P2, P3]>> ]>; // Vararg functions on windows pass floats in integer registers @@ -288,6 +311,19 @@ D8, D9, D10, D11, D12, D13, D14, D15)>; +// Functions taking SVE arguments or returning a SVE type +// must (additionally) preserve full Z8-Z31 and predicate registers P4-P15 +def CSR_AArch64_SVE_AAPCS : CalleeSavedRegs<(add LR, FP, X19, X20, X21, X22, + X23, X24, X25, X26, X27, X28, + (sequence "Z%u", 8, 23), + (sequence "P%u", 4, 15))>; + +// AArch64 PCS for vector functions (VPCS) +// must (additionally) preserve full Q8-Q23 registers +def CSR_AArch64_AAVPCS : CalleeSavedRegs<(add LR, FP, X19, X20, X21, X22, + X23, X24, X25, X26, X27, X28, + (sequence "Q%u", 8, 23))>; + // Constructors and destructors return 'this' in the iOS 64-bit C++ ABI; since // 'this' and the pointer return value are both passed in X0 in these cases, // this can be partially modelled by treating X0 as a callee-saved register; @@ -362,5 +398,7 @@ : CalleeSavedRegs<(add CSR_AArch64_AAPCS_SwiftError, X18)>; def CSR_AArch64_RT_MostRegs_SCS : CalleeSavedRegs<(add CSR_AArch64_RT_MostRegs, X18)>; +def CSR_AArch64_AAVPCS_SCS + : CalleeSavedRegs<(add CSR_AArch64_AAVPCS, X18)>; def CSR_AArch64_AAPCS_SCS : CalleeSavedRegs<(add CSR_AArch64_AAPCS, X18)>; Index: lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp =================================================================== --- lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp +++ lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp @@ -1,4 +1,4 @@ -//===- AArch64ExpandPseudoInsts.cpp - Expand pseudo instructions ----------===// +//==-- AArch64ExpandPseudoInsts.cpp - Expand pseudo instructions --*- C++ -*-=// // // The LLVM Compiler Infrastructure // @@ -26,6 +26,8 @@ #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineInstrBundle.h" +#include "llvm/CodeGen/RegisterScavenging.h" #include "llvm/CodeGen/MachineOperand.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/DebugLoc.h" @@ -56,14 +58,19 @@ initializeAArch64ExpandPseudoPass(*PassRegistry::getPassRegistry()); } + RegScavenger *RS; + bool runOnMachineFunction(MachineFunction &Fn) override; StringRef getPassName() const override { return AARCH64_EXPAND_PSEUDO_NAME; } private: bool expandMBB(MachineBasicBlock &MBB); + bool foldUnary(MachineInstr &MI,MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI); bool expandMI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, MachineBasicBlock::iterator &NextMBBI); + bool expandMI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI); bool expandMOVImm(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, unsigned BitSize); bool expandMOVImmSimple(MachineBasicBlock &MBB, @@ -72,6 +79,12 @@ unsigned OneChunks, unsigned ZeroChunks); + bool expand_DestructiveOp(MachineInstr &MI, MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI); + bool expandSVE_SelZero(MachineInstr &MI, MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI); + bool expandSVE_FMLA(MachineInstr &MI, MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI); bool expandCMP_SWAP(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, unsigned LdarOp, unsigned StlrOp, unsigned CmpOp, unsigned ExtendImm, unsigned ZeroReg, @@ -586,6 +599,39 @@ return true; } +bool AArch64ExpandPseudo::expandSVE_SelZero(MachineInstr &MI, + MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI) { + unsigned Opcode; + switch(MI.getOpcode()) { + default: + llvm_unreachable("unsupported opcode"); + case AArch64::SELZERO_B: + Opcode = AArch64::SEL_ZPZZ_B; + break; + case AArch64::SELZERO_H: + Opcode = AArch64::SEL_ZPZZ_H; + break; + case AArch64::SELZERO_S: + Opcode = AArch64::SEL_ZPZZ_S; + break; + case AArch64::SELZERO_D: + Opcode = AArch64::SEL_ZPZZ_D; + break; + } + + MachineInstrBuilder MIB; + MIB = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(Opcode)) + .add(MI.getOperand(0)) + .add(MI.getOperand(1)) + .add(MI.getOperand(2)) + .add(MI.getOperand(3)); + + transferImpOps(MI, MIB, MIB); + MI.eraseFromParent(); + return true; +} + bool AArch64ExpandPseudo::expandCMP_SWAP( MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, unsigned LdarOp, unsigned StlrOp, unsigned CmpOp, unsigned ExtendImm, unsigned ZeroReg, @@ -666,6 +712,216 @@ return true; } +/// \brief Expand Pseudos to Instructions with destructive operands. +/// +/// This mechanism uses MOVPRFX instructions for zeroing the false lanes +/// or for fixing relaxed register allocation conditions to comply with +/// the instructions register constraints. The latter case may be cheaper +/// than setting the register constraints in the register allocator, +/// since that will insert regular MOV instructions rather than MOVPRFX. +/// +/// Example (after register allocation): +/// +/// FSUB_ZPZZ_ZERO_B Z0, Pg, Z1, Z0 +/// +/// * The Pseudo FSUB_ZPZZ_ZERO_B maps to FSUB_ZPmZ_B. +/// * We cannot map directly to FSUB_ZPmZ_B because the register +/// constraints of the instruction are not met. +/// * Also the _ZERO specifies the false lanes need to be zeroed. +/// +/// We first try to see if the destructive operand == result operand, +/// if not, we try to swap the operands, e.g. +/// +/// FSUB_ZPmZ_B Z0, Pg/m, Z0, Z1 +/// +/// But because FSUB_ZPmZ is not commutative, this is semantically +/// different, so we need a reverse instruction: +/// +/// FSUBR_ZPmZ_B Z0, Pg/m, Z0, Z1 +/// +/// Then we implement the zeroing of the false lanes of Z0 by adding +/// a zeroing MOVPRFX instruction: +/// +/// MOVPRFX_ZPzZ_B Z0, Pg/z, Z0 +/// FSUBR_ZPmZ_B Z0, Pg/m, Z0, Z1 +/// +/// Note that this can only be done for _ZERO or _UNDEF variants where +/// we can guarantee the false lanes to be zeroed (by implementing this) +/// or that they are undef (don't care / not used), otherwise the +/// swapping of operands is illegal because the operation is not +/// (or cannot be emulated to be) fully commutative. +bool AArch64ExpandPseudo::expand_DestructiveOp( + MachineInstr &MI, + MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI) { + unsigned Opcode = AArch64::getSVEPseudoMap(MI.getOpcode()); + uint64_t DType = TII->get(Opcode).TSFlags & AArch64::DestructiveInstTypeMask; + uint64_t FalseLanes = MI.getDesc().TSFlags & AArch64::FalseLanesMask; + bool FalseZero = FalseLanes == AArch64::FalseLanesZero; + + unsigned DstReg = MI.getOperand(0).getReg(); + bool DstIsDead = MI.getOperand(0).isDead(); + + if (DType == AArch64::DestructiveBinary) + assert(DstReg != MI.getOperand(3).getReg()); + + bool UseRev = false; + unsigned PredIdx, DOPIdx, SrcIdx, Src2Idx; + switch (DType) { + case AArch64::DestructiveBinaryComm: + case AArch64::DestructiveBinaryCommWithRev: + if (DstReg == MI.getOperand(3).getReg()) { + // FSUB Zd, Pg, Zs1, Zd ==> FSUBR Zd, Pg/m, Zd, Zs1 + std::tie(PredIdx, DOPIdx, SrcIdx) = std::make_tuple(1, 3, 2); + UseRev = true; + break; + } + // fallthrough + case AArch64::DestructiveBinary: + case AArch64::DestructiveBinaryImm: + std::tie(PredIdx, DOPIdx, SrcIdx) = std::make_tuple(1, 2, 3); + break; + case AArch64::DestructiveBinaryShImmUnpred: + std::tie(DOPIdx, SrcIdx, Src2Idx) = std::make_tuple(1, 2, 3); + break; + case AArch64::DestructiveTernaryCommWithRev: + std::tie(PredIdx, DOPIdx, SrcIdx, Src2Idx) = std::make_tuple(1, 2, 3, 4); + if (DstReg == MI.getOperand(3).getReg()) { + // FMLA Zd, Pg, Za, Zd, Zm ==> FMAD Zdn, Pg, Zm, Za + std::tie(PredIdx, DOPIdx, SrcIdx, Src2Idx) = std::make_tuple(1, 3, 4, 2); + UseRev = true; + } else if (DstReg == MI.getOperand(4).getReg()) { + // FMLA Zd, Pg, Za, Zm, Zd ==> FMAD Zdn, Pg, Zm, Za + std::tie(PredIdx, DOPIdx, SrcIdx, Src2Idx) = std::make_tuple(1, 4, 3, 2); + UseRev = true; + } + break; + default: + llvm_unreachable("Unsupported Destructive Operand type"); + } + +#ifndef NDEBUG + // MOVPRFX can only be used if the destination operand + // is the destructive operand, not as any other operand, + // so the Destructive Operand must be unique. + bool DOPRegIsUnique = false; + switch (DType) { + case AArch64::DestructiveBinaryComm: + case AArch64::DestructiveBinaryCommWithRev: + case AArch64::DestructiveBinary: + DOPRegIsUnique = + DstReg != MI.getOperand(DOPIdx).getReg() || + MI.getOperand(DOPIdx).getReg() != MI.getOperand(SrcIdx).getReg(); + break; + case AArch64::DestructiveBinaryImm: + case AArch64::DestructiveBinaryShImmUnpred: + DOPRegIsUnique = true; + break; + case AArch64::DestructiveTernaryCommWithRev: + DOPRegIsUnique = + DstReg != MI.getOperand(DOPIdx).getReg() || + (MI.getOperand(DOPIdx).getReg() != MI.getOperand(SrcIdx).getReg() && + MI.getOperand(DOPIdx).getReg() != MI.getOperand(Src2Idx).getReg()); + break; + } + + assert (DOPRegIsUnique && "The destructive operand should be unique"); +#endif + + // Resolve the reverse opcode + if (UseRev) { + if (AArch64::getSVERevInstr(Opcode) != -1) + Opcode = AArch64::getSVERevInstr(Opcode); + else if (AArch64::getSVEOrigInstr(Opcode) != -1) + Opcode = AArch64::getSVEOrigInstr(Opcode); + } + + // Get the right MOVPRFX + uint64_t ElementSize = TII->getElementSizeForOpcode(Opcode); + unsigned MovPrfx, MovPrfxZero; + switch (ElementSize) { + case AArch64::ElementSizeNone: + case AArch64::ElementSizeB: + MovPrfx = AArch64::MOVPRFX_ZZ; + MovPrfxZero = AArch64::MOVPRFX_ZPzZ_B; + break; + case AArch64::ElementSizeH: + MovPrfx = AArch64::MOVPRFX_ZZ; + MovPrfxZero = AArch64::MOVPRFX_ZPzZ_H; + break; + case AArch64::ElementSizeS: + MovPrfx = AArch64::MOVPRFX_ZZ; + MovPrfxZero = AArch64::MOVPRFX_ZPzZ_S; + break; + case AArch64::ElementSizeD: + MovPrfx = AArch64::MOVPRFX_ZZ; + MovPrfxZero = AArch64::MOVPRFX_ZPzZ_D; + break; + default: + llvm_unreachable("Unsupported ElementSize"); + } + + // + // Create the destructive operation (if required) + // + MachineInstrBuilder PRFX, DOP; + if (FalseZero) { + assert(ElementSize != AArch64::ElementSizeNone && + "This instruction is unpredicated"); + + // Merge source operand into destination register + PRFX = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(MovPrfxZero)) + .addReg(DstReg, RegState::Define) + .addReg(MI.getOperand(PredIdx).getReg()) + .addReg(MI.getOperand(DOPIdx).getReg()); + + // After the movprfx, the destructive operand is same as Dst + DOPIdx = 0; + } else if (DstReg != MI.getOperand(DOPIdx).getReg()) { + PRFX = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(MovPrfx)) + .addReg(DstReg, RegState::Define) + .addReg(MI.getOperand(DOPIdx).getReg()); + DOPIdx = 0; + } + + // + // Create the destructive operation + // + DOP = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(Opcode)) + .addReg(DstReg, RegState::Define | getDeadRegState(DstIsDead)); + + switch (DType) { + case AArch64::DestructiveBinary: + case AArch64::DestructiveBinaryImm: + case AArch64::DestructiveBinaryComm: + case AArch64::DestructiveBinaryCommWithRev: + DOP.add(MI.getOperand(PredIdx)) + .addReg(MI.getOperand(DOPIdx).getReg(), RegState::Kill) + .add(MI.getOperand(SrcIdx)); + break; + case AArch64::DestructiveBinaryShImmUnpred: + DOP.addReg(MI.getOperand(DOPIdx).getReg(), RegState::Kill) + .add(MI.getOperand(SrcIdx)) + .add(MI.getOperand(Src2Idx)); + break; + case AArch64::DestructiveTernaryCommWithRev: + DOP.add(MI.getOperand(PredIdx)) + .addReg(MI.getOperand(DOPIdx).getReg(), RegState::Kill) + .add(MI.getOperand(SrcIdx)) + .add(MI.getOperand(Src2Idx)); + break; + } + + if (PRFX) { + finalizeBundle(MBB, PRFX->getIterator(), MBBI->getIterator()); + transferImpOps(MI, PRFX, DOP); + } else + transferImpOps(MI, DOP, DOP); + + MI.eraseFromParent(); + return true; +} + bool AArch64ExpandPseudo::expandCMP_SWAP_128( MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, MachineBasicBlock::iterator &NextMBBI) { @@ -759,13 +1015,135 @@ return true; } -/// If MBBI references a pseudo instruction that should be expanded here, +/// \brief Returns the defining instruction for MachineOperand MO, +/// which should be a register. If we are searching for the defining +/// instruction for the purpose of removing it, it only returns +/// the defining instruction if it is not read between Def..*MBBI. +static bool getRegisterDefInstr(MachineBasicBlock::iterator MBBI, + const MachineOperand &MO, + MachineBasicBlock::iterator &Def, + bool ForRemoving=false) { + assert(MO.isReg() && "Operand must be a register"); + unsigned Reg = MO.getReg(); + + if (MBBI == MBBI->getParent()->begin()) + return false; + + MachineBasicBlock::iterator RI = MBBI; + for (--RI; RI != MBBI->getParent()->begin(); --RI) { + // If we want to remove the Def, it cannot be *used* anywhere + // else in between the Def and MBBI + if (ForRemoving && !RI->definesRegister(Reg) && RI->readsRegister(Reg)) + return false; + else if (RI->definesRegister(Reg)) + break; + } + + if (!RI->definesRegister(Reg)) + return false; + + Def = RI; + return true; +} + +/// \brief Replace instructions where the destructive operand is +/// a vector of zeros with a bundled MOVPRFX instruction, e.g. +/// Transform: +/// %X0 = DUP_ZI_S 0, 0 +/// %X0 = FNEG_ZPmZ_S X0, P0, X2 +/// into: +/// X0 = MOVPRFX P0/z, X0 +/// X0 = FNEG_ZPmZ_S X0, P0, X2 +bool AArch64ExpandPseudo::foldUnary(MachineInstr &MI, + MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI) { + // Zd != Zn + if (MI.getOperand(0).getReg() == MI.getOperand(3).getReg()) + return false; + + // Zsd must be a DUP_ZI_(B|H|S|D) 0, 0 + MachineBasicBlock::iterator Def; + if (!getRegisterDefInstr(MBBI, MI.getOperand(1), Def, true)) + return false; + + switch (Def->getOpcode()) { + case AArch64::DUP_ZI_B: + case AArch64::DUP_ZI_H: + case AArch64::DUP_ZI_S: + case AArch64::DUP_ZI_D: + break; + default: + return false; + } + + if (!Def->getOperand(1).isImm() || Def->getOperand(1).getImm() != 0) + return false; + + unsigned MovPrfx; + switch (TII->getElementSizeForOpcode(MI.getOpcode())) { + case AArch64::ElementSizeNone: + case AArch64::ElementSizeB: + MovPrfx = AArch64::MOVPRFX_ZPzZ_B; + break; + case AArch64::ElementSizeH: + MovPrfx = AArch64::MOVPRFX_ZPzZ_H; + break; + case AArch64::ElementSizeS: + MovPrfx = AArch64::MOVPRFX_ZPzZ_S; + break; + case AArch64::ElementSizeD: + MovPrfx = AArch64::MOVPRFX_ZPzZ_D; + break; + default: + llvm_unreachable("Unsupported ElementSize"); + } + + // Create a Zeroing MOVPRFX + MachineInstrBuilder PRFX, NewMI; + PRFX = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(MovPrfx)) + .addReg(MI.getOperand(0).getReg(), RegState::Define) + .addReg(MI.getOperand(2).getReg()) + .addReg(MI.getOperand(1).getReg(), RegState::Undef); + + NewMI = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(MI.getOpcode())) + .add(MI.getOperand(0)) + .addReg(MI.getOperand(1).getReg(), RegState::Kill) + .add(MI.getOperand(2)) + .add(MI.getOperand(3)); + + finalizeBundle(MBB, PRFX->getIterator(), MBBI->getIterator()); + transferImpOps(MI, PRFX, NewMI); + + Def->eraseFromParent(); + MBBI->eraseFromParent(); + + return true; +} + +/// \brief If MBBI references a pseudo instruction that should be expanded here, /// do the expansion and return true. Otherwise return false. bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, MachineBasicBlock::iterator &NextMBBI) { + const TargetRegisterInfo *TRI = + MBB.getParent()->getSubtarget().getRegisterInfo(); MachineInstr &MI = *MBBI; unsigned Opcode = MI.getOpcode(); + + // Check if we can expand the destructive op + int OrigInstr = AArch64::getSVEPseudoMap(MI.getOpcode()); + if (OrigInstr != -1) { + auto &Orig = TII->get(OrigInstr); + if ((Orig.TSFlags & AArch64::DestructiveInstTypeMask) + != AArch64::NotDestructive) { + return expand_DestructiveOp(MI, MBB, MBBI); + } + } + else if ((MI.getDesc().TSFlags & AArch64::DestructiveInstTypeMask) + == AArch64::DestructiveUnary) { + return foldUnary(MI, MBB, MBBI); + } + switch (Opcode) { default: break; @@ -932,6 +1310,13 @@ MI.eraseFromParent(); return true; } + + case AArch64::SELZERO_B: + case AArch64::SELZERO_H: + case AArch64::SELZERO_S: + case AArch64::SELZERO_D: + return expandSVE_SelZero(MI, MBB, MBBI); + case AArch64::CMP_SWAP_8: return expandCMP_SWAP(MBB, MBBI, AArch64::LDAXRB, AArch64::STLXRB, AArch64::SUBSWrx, @@ -955,6 +1340,82 @@ case AArch64::CMP_SWAP_128: return expandCMP_SWAP_128(MBB, MBBI, NextMBBI); +#define EXPAND_SVE_ADR(X) \ + case X: {\ + unsigned Opcode;\ +\ + assert(MI.getOperand(3).isImm() && "Expected immediate operand");\ + switch (MI.getOperand(3).getImm()) {\ + default: llvm_unreachable("Unexpected immediate");\ + case 0: Opcode = X##_0; break;\ + case 1: Opcode = X##_1; break;\ + case 2: Opcode = X##_2; break;\ + case 3: Opcode = X##_3; break;\ + }\ +\ + MachineInstrBuilder MIB =\ + BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(Opcode))\ + .add(MI.getOperand(0))\ + .add(MI.getOperand(1))\ + .add(MI.getOperand(2));\ + transferImpOps(MI, MIB, MIB);\ + MI.eraseFromParent();\ + return true;\ + } + EXPAND_SVE_ADR(AArch64::ADR_LSL_ZZZ_S) + EXPAND_SVE_ADR(AArch64::ADR_LSL_ZZZ_D) + EXPAND_SVE_ADR(AArch64::ADR_SXTW_ZZZ_D) + EXPAND_SVE_ADR(AArch64::ADR_UXTW_ZZZ_D) + +#define EXPAND_SVE_SPILLFILL(OPC, MI, OFFSET, KILL) \ + { \ + int ImmOffset = MI.getOperand(2).getImm() + OFFSET; \ + assert(ImmOffset >= -256 && ImmOffset < 256 && \ + "Immediate spill offset out of range"); \ + BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(OPC)) \ + .addReg(TRI->getSubReg(MI.getOperand(0).getReg(), \ + AArch64::zsub0 + OFFSET), \ + OPC == AArch64::LDR_ZXI ? RegState::Define : 0) \ + .addReg(MI.getOperand(1).getReg(), getKillRegState(KILL)) \ + .addImm(ImmOffset); \ + } + case AArch64::STR_ZZZZXI: + EXPAND_SVE_SPILLFILL(AArch64::STR_ZXI, MI, 0, false); + EXPAND_SVE_SPILLFILL(AArch64::STR_ZXI, MI, 1, false); + EXPAND_SVE_SPILLFILL(AArch64::STR_ZXI, MI, 2, false); + EXPAND_SVE_SPILLFILL(AArch64::STR_ZXI, MI, 3, MI.getOperand(1).isKill()); + MI.eraseFromParent(); + return true; + case AArch64::STR_ZZZXI: + EXPAND_SVE_SPILLFILL(AArch64::STR_ZXI, MI, 0, false); + EXPAND_SVE_SPILLFILL(AArch64::STR_ZXI, MI, 1, false); + EXPAND_SVE_SPILLFILL(AArch64::STR_ZXI, MI, 2, MI.getOperand(1).isKill()); + MI.eraseFromParent(); + return true; + case AArch64::STR_ZZXI: + EXPAND_SVE_SPILLFILL(AArch64::STR_ZXI, MI, 0, false); + EXPAND_SVE_SPILLFILL(AArch64::STR_ZXI, MI, 1, MI.getOperand(1).isKill()); + MI.eraseFromParent(); + return true; + case AArch64::LDR_ZZZZXI: + EXPAND_SVE_SPILLFILL(AArch64::LDR_ZXI, MI, 0, false); + EXPAND_SVE_SPILLFILL(AArch64::LDR_ZXI, MI, 1, false); + EXPAND_SVE_SPILLFILL(AArch64::LDR_ZXI, MI, 2, false); + EXPAND_SVE_SPILLFILL(AArch64::LDR_ZXI, MI, 3, MI.getOperand(1).isKill()); + MI.eraseFromParent(); + return true; + case AArch64::LDR_ZZZXI: + EXPAND_SVE_SPILLFILL(AArch64::LDR_ZXI, MI, 0, false); + EXPAND_SVE_SPILLFILL(AArch64::LDR_ZXI, MI, 1, false); + EXPAND_SVE_SPILLFILL(AArch64::LDR_ZXI, MI, 2, MI.getOperand(1).isKill()); + MI.eraseFromParent(); + return true; + case AArch64::LDR_ZZXI: + EXPAND_SVE_SPILLFILL(AArch64::LDR_ZXI, MI, 0, false); + EXPAND_SVE_SPILLFILL(AArch64::LDR_ZXI, MI, 1, MI.getOperand(1).isKill()); + MI.eraseFromParent(); + return true; +#undef EXPAND_SVE_SPILLFILL case AArch64::AESMCrrTied: case AArch64::AESIMCrrTied: { MachineInstrBuilder MIB = @@ -988,10 +1449,18 @@ bool AArch64ExpandPseudo::runOnMachineFunction(MachineFunction &MF) { TII = static_cast(MF.getSubtarget().getInstrInfo()); + auto UniqueRS = make_unique(); + RS = UniqueRS.get(); bool Modified = false; for (auto &MBB : MF) Modified |= expandMBB(MBB); +#ifndef NDEBUG + // Verify the MachineFunction as we may be missing def/use/kill flags + // on some of the instructions/operands we added. + if (Modified) + MF.verify(nullptr, "In AArch64ExpandPseudo"); +#endif return Modified; } Index: lib/Target/AArch64/AArch64FastISel.cpp =================================================================== --- lib/Target/AArch64/AArch64FastISel.cpp +++ lib/Target/AArch64/AArch64FastISel.cpp @@ -408,10 +408,9 @@ bool Is64Bit = (VT == MVT::f64); // This checks to see if we can use FMOV instructions to materialize // a constant, otherwise we have to materialize via the constant pool. - if (TLI.isFPImmLegal(Val, VT)) { - int Imm = - Is64Bit ? AArch64_AM::getFP64Imm(Val) : AArch64_AM::getFP32Imm(Val); - assert((Imm != -1) && "Cannot encode floating-point constant."); + int Imm = + Is64Bit ? AArch64_AM::getFP64Imm(Val) : AArch64_AM::getFP32Imm(Val); + if (Imm != -1) { unsigned Opc = Is64Bit ? AArch64::FMOVDi : AArch64::FMOVSi; return fastEmitInst_i(Opc, TLI.getRegClassFor(VT), Imm); } Index: lib/Target/AArch64/AArch64FrameLowering.h =================================================================== --- lib/Target/AArch64/AArch64FrameLowering.h +++ lib/Target/AArch64/AArch64FrameLowering.h @@ -14,10 +14,18 @@ #ifndef LLVM_LIB_TARGET_AARCH64_AARCH64FRAMELOWERING_H #define LLVM_LIB_TARGET_AARCH64_AARCH64FRAMELOWERING_H +#include "AArch64StackOffset.h" +#include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/TargetFrameLowering.h" namespace llvm { +bool isSVEScaledImmInstruction(unsigned Opcode); + +namespace AArch64 { +enum FrameRegions { FR_Default = 0, FR_SVE = 1 }; +} // end namespace AArch64 + class AArch64FrameLowering : public TargetFrameLowering { public: explicit AArch64FrameLowering() @@ -36,13 +44,15 @@ void emitPrologue(MachineFunction &MF, MachineBasicBlock &MBB) const override; void emitEpilogue(MachineFunction &MF, MachineBasicBlock &MBB) const override; + int64_t getFurthestStackArgOffset(MachineFunction &MF) const; + bool canUseAsPrologue(const MachineBasicBlock &MBB) const override; int getFrameIndexReference(const MachineFunction &MF, int FI, unsigned &FrameReg) const override; - int resolveFrameIndexReference(const MachineFunction &MF, int FI, - unsigned &FrameReg, - bool PreferFP = false) const; + StackOffset resolveFrameIndexReference(const MachineFunction &MF, int FI, + unsigned &FrameReg, + bool PreferFP = false) const; bool spillCalleeSavedRegisters(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, const std::vector &CSI, @@ -53,6 +63,9 @@ std::vector &CSI, const TargetRegisterInfo *TRI) const override; + void processFunctionBeforeFrameFinalized(MachineFunction &MF, + RegScavenger *RS) const override; + /// Can this function use the red zone for local allocations. bool canUseRedZone(const MachineFunction &MF) const; @@ -67,11 +80,28 @@ return true; } + void fixupScalableDebugOffsets(MachineFunction &MF) const; + MCCFIInstruction emitDefCFA(const MCRegisterInfo *MRI, unsigned Basereg, + StackOffset Offset) const; + MCCFIInstruction emitCFI(const MCRegisterInfo *MRI, unsigned Reg, + StackOffset Offset) const; + bool enableStackSlotScavenging(const MachineFunction &MF) const override; + unsigned getStackIDForType(const Type *T) const override; + + // Adds StackOffset to the expression in Expr. Scalable parts of the offset + // are scaled using the pseudo Vector Granule (VG) Dwarf Register. + static void addVGScaledOffset(const MCRegisterInfo *MRI, StackOffset Offset, + SmallVectorImpl &Expr); private: bool shouldCombineCSRLocalStackBump(MachineFunction &MF, unsigned StackBumpBytes) const; + + int64_t estimateSVEStackObjectOffsets(MachineFrameInfo &MF, + unsigned &MaxAlign) const; + int64_t assignSVEStackObjectOffsets(MachineFrameInfo &MF, unsigned &MaxAlign, + int64_t &SVECalleeSaveSize) const; }; } // End llvm namespace Index: lib/Target/AArch64/AArch64FrameLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64FrameLowering.cpp +++ lib/Target/AArch64/AArch64FrameLowering.cpp @@ -45,7 +45,19 @@ // | | // |-----------------------------------| // | | -// | prev_fp, prev_lr | +// | SVE callee-saved registers | +// | | +// |-----------------------------------| +// | | +// | SVE stack objects | +// | | +// |-----------------------------------| +// |...................................| +// |........... padding ...............| +// |...................................| +// |- - - - - - - - - - - - - - - - - -| <- 16-byte aligned +// | | (Overlays with Frame Record +// | prev_fp, prev_lr | when present) // | (a.k.a. "frame record") | // |-----------------------------------| <- fp(=x29) // | | @@ -95,6 +107,7 @@ #include "AArch64InstrInfo.h" #include "AArch64MachineFunctionInfo.h" #include "AArch64RegisterInfo.h" +#include "AArch64StackOffset.h" #include "AArch64Subtarget.h" #include "AArch64TargetMachine.h" #include "MCTargetDesc/AArch64AddressingModes.h" @@ -117,8 +130,10 @@ #include "llvm/IR/Attributes.h" #include "llvm/IR/CallingConv.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/Function.h" +#include "llvm/IR/Metadata.h" #include "llvm/MC/MCDwarf.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -171,7 +186,7 @@ if (!MO.isFI()) continue; - int Offset = 0; + StackOffset Offset; if (isAArch64FrameOffsetLegal(MI, Offset, nullptr, nullptr, nullptr) == AArch64FrameOffsetCannotUpdate) return 0; @@ -181,6 +196,37 @@ return DefaultSafeSPDisplacement; } +unsigned AArch64FrameLowering::getStackIDForType(const Type *T) const { + // TODO: Recent discussion upstream rejected the idea of having array/struct + // types that consist of scalable types, so be wary not to upstream this for + // the time being. + if (auto *ST = dyn_cast(T)) + if (ST->getNumElements()) { + unsigned FirstID = getStackIDForType(*ST->element_begin()); + if (std::all_of(ST->element_begin(), ST->element_end(), + [this, FirstID](const Type *SubT) { + return getStackIDForType(SubT) == FirstID; + })) + return FirstID; + llvm_unreachable("Mixed Scalable and non-Scalable " + "struct members are not supported"); + } + + if (isa(T)) + return getStackIDForType(cast(T)->getElementType()); + + if (const VectorType *VT = dyn_cast(T)) + if (VT->isScalable()) + return AArch64::FR_SVE; + + return AArch64::FR_Default; +} + +static uint64_t getSVEStackSize(const MachineFunction &MF) { + const AArch64FunctionInfo *AFI = MF.getInfo(); + return AFI->getStackSizeSVE(); +} + bool AArch64FrameLowering::canUseRedZone(const MachineFunction &MF) const { if (!EnableRedZone) return false; @@ -193,7 +239,8 @@ const AArch64FunctionInfo *AFI = MF.getInfo(); unsigned NumBytes = AFI->getLocalStackSize(); - return !(MFI.hasCalls() || hasFP(MF) || NumBytes > 128); + return !(MFI.hasCalls() || hasFP(MF) || NumBytes > 128 || + getSVEStackSize(MF)); } /// hasFP - Return true if the specified function should have a dedicated frame @@ -267,18 +314,66 @@ // Most call frames will be allocated at the start of a function so // this is OK, but it is a limitation that needs dealing with. assert(Amount > -0xffffff && Amount < 0xffffff && "call frame too large"); - emitFrameOffset(MBB, I, DL, AArch64::SP, AArch64::SP, Amount, TII); + emitFrameOffset(MBB, I, DL, AArch64::SP, AArch64::SP, {Amount, MVT::i8}, + TII); } } else if (CalleePopAmount != 0) { // If the calling convention demands that the callee pops arguments from the // stack, we want to add it back if we have a reserved call frame. assert(CalleePopAmount < 0xffffff && "call frame too large"); - emitFrameOffset(MBB, I, DL, AArch64::SP, AArch64::SP, -CalleePopAmount, - TII); + emitFrameOffset(MBB, I, DL, AArch64::SP, AArch64::SP, + {-(int64_t)CalleePopAmount, MVT::i8}, TII); } return MBB.erase(I); } +void AArch64FrameLowering::addVGScaledOffset(const MCRegisterInfo *MRI, + StackOffset Offset, + SmallVectorImpl &Expr) { + // Build scaled expression using VG (Vector Granule) + unsigned VG = MRI->getDwarfRegNum(AArch64::VG, true); + + int64_t Bytes, VGSized; + Offset.getForDwarfOffset(Bytes, VGSized); + + if (VGSized) { + assert(VGSized >= 0 && "VGSized offsets should always be positive"); + Expr.append({dwarf::DW_OP_constu, (uint64_t)VGSized}); + Expr.append({dwarf::DW_OP_bregx, VG, 0}); + Expr.append({dwarf::DW_OP_mul, dwarf::DW_OP_plus}); + } + + if (Bytes) { + assert(Bytes >= 0 && "Dwarf byte offsets should always be positive"); + Expr.append({dwarf::DW_OP_constu, (uint64_t)Bytes}); + Expr.append({dwarf::DW_OP_plus}); + } +} + +void AArch64FrameLowering::fixupScalableDebugOffsets(MachineFunction &MF) const { + // The MF.VariableDbgInfo cache of debug info for each variable needs + // to be updated separately from all DBG_VALUE instructions in the IR. + // Here we can add the expression to get it from the VL-scaled region. + // This may not be the best place to do this, but we need to do it + // somehwere. + for (auto &VI : MF.getVariableDbgInfo()) { + if (!VI.Var) + continue; + unsigned FrameReg; + StackOffset Offset = + resolveFrameIndexReference(MF, VI.Slot, FrameReg, true); + + // Handle only the scalable part, as the non-scalable part is handled + // by generic code. + if (int64_t Scalable = Offset.getScalableBytes()) { + SmallVector Buffer; + addVGScaledOffset(MF.getSubtarget().getRegisterInfo(), + StackOffset(Scalable, MVT::nxv1i8), Buffer); + VI.Expr = DIExpression::get(MF.getFunction().getContext(), Buffer); + } + } +} + void AArch64FrameLowering::emitCalleeSavedFrameMoves( MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) const { MachineFunction &MF = *MBB.getParent(); @@ -295,11 +390,24 @@ for (const auto &Info : CSI) { unsigned Reg = Info.getReg(); - int64_t Offset = - MFI.getObjectOffset(Info.getFrameIdx()) - getOffsetOfLocalArea(); - unsigned DwarfReg = MRI->getDwarfRegNum(Reg, true); - unsigned CFIIndex = MF.addFrameInst( - MCCFIInstruction::createOffset(nullptr, DwarfReg, Offset)); + unsigned StackID = MFI.getStackID(Info.getFrameIdx()); + uint64_t SVEStackSize = getSVEStackSize(MF); + + unsigned CFIIndex; + StackOffset Offset; + if (StackID == AArch64::FR_SVE) + Offset = + StackOffset(MFI.getObjectOffset(Info.getFrameIdx()), MVT::nxv1i8); + else { + int64_t ByteOffset = + MFI.getObjectOffset(Info.getFrameIdx()) - getOffsetOfLocalArea(); + Offset = StackOffset(ByteOffset, MVT::i8) - + StackOffset(SVEStackSize, MVT::nxv1i8); + if (SVEStackSize && hasFP(MF)) + Offset += StackOffset(16, MVT::i8); + } + + CFIIndex = MF.addFrameInst(emitCFI(MRI, Reg, Offset)); BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION)) .addCFIIndex(CFIIndex) .setMIFlags(MachineInstr::FrameSetup); @@ -405,6 +513,9 @@ if (canUseRedZone(MF)) return false; + if (getSVEStackSize(MF)) + return false; + return true; } @@ -423,37 +534,51 @@ } unsigned NewOpc; - bool NewIsUnscaled = false; + int Scale = 1; switch (MBBI->getOpcode()) { default: llvm_unreachable("Unexpected callee-save save/restore opcode!"); case AArch64::STPXi: NewOpc = AArch64::STPXpre; + Scale = 8; break; case AArch64::STPDi: NewOpc = AArch64::STPDpre; + Scale = 8; + break; + case AArch64::STPQi: + NewOpc = AArch64::STPQpre; + Scale = 16; break; case AArch64::STRXui: NewOpc = AArch64::STRXpre; - NewIsUnscaled = true; break; case AArch64::STRDui: NewOpc = AArch64::STRDpre; - NewIsUnscaled = true; + break; + case AArch64::STRQui: + NewOpc = AArch64::STRQpre; break; case AArch64::LDPXi: NewOpc = AArch64::LDPXpost; + Scale = 8; break; case AArch64::LDPDi: NewOpc = AArch64::LDPDpost; + Scale = 8; + break; + case AArch64::LDPQi: + NewOpc = AArch64::LDPQpost; + Scale = 16; break; case AArch64::LDRXui: NewOpc = AArch64::LDRXpost; - NewIsUnscaled = true; break; case AArch64::LDRDui: NewOpc = AArch64::LDRDpost; - NewIsUnscaled = true; + break; + case AArch64::LDRQui: + NewOpc = AArch64::LDRQpost; break; } @@ -472,11 +597,8 @@ assert(MBBI->getOperand(OpndIdx - 1).getReg() == AArch64::SP && "Unexpected base register in callee-save save/restore instruction!"); // Last operand is immediate offset that needs fixing. - assert(CSStackSizeInc % 8 == 0); - int64_t CSStackSizeIncImm = CSStackSizeInc; - if (!NewIsUnscaled) - CSStackSizeIncImm /= 8; - MIB.addImm(CSStackSizeIncImm); + assert(CSStackSizeInc % Scale == 0); + MIB.addImm(CSStackSizeInc / Scale); MIB.setMIFlags(MBBI->getFlags()); MIB.setMemRefs(MBBI->memoperands_begin(), MBBI->memoperands_end()); @@ -497,12 +619,35 @@ return; } - (void)Opc; - assert((Opc == AArch64::STPXi || Opc == AArch64::STPDi || - Opc == AArch64::STRXui || Opc == AArch64::STRDui || - Opc == AArch64::LDPXi || Opc == AArch64::LDPDi || - Opc == AArch64::LDRXui || Opc == AArch64::LDRDui) && - "Unexpected callee-save save/restore opcode!"); + unsigned Scale; + switch (Opc) { + case AArch64::STPXi: + case AArch64::STRXui: + case AArch64::STPDi: + case AArch64::STRDui: + case AArch64::LDPXi: + case AArch64::LDRXui: + case AArch64::LDPDi: + case AArch64::LDRDui: + Scale = 8; + break; + case AArch64::STPQi: + case AArch64::STRQui: + case AArch64::LDPQi: + case AArch64::LDRQui: + Scale = 16; + break; + case AArch64::STR_PXI: + case AArch64::LDR_PXI: + Scale = 2; + break; + case AArch64::STR_ZXI: + case AArch64::LDR_ZXI: + Scale = 16; + break; + default: + llvm_unreachable("Unexpected callee-save save/restore opcode!"); + } unsigned OffsetIdx = MI.getNumExplicitOperands() - 1; assert(MI.getOperand(OffsetIdx - 1).getReg() == AArch64::SP && @@ -510,8 +655,8 @@ // Last operand is immediate offset that needs fixing. MachineOperand &OffsetOpnd = MI.getOperand(OffsetIdx); // All generated opcodes have scaled offsets. - assert(LocalStackSize % 8 == 0); - OffsetOpnd.setImm(OffsetOpnd.getImm() + LocalStackSize / 8); + assert(LocalStackSize % Scale == 0); + OffsetOpnd.setImm(OffsetOpnd.getImm() + LocalStackSize / Scale); } static void adaptForLdStOpt(MachineBasicBlock &MBB, @@ -546,6 +691,112 @@ // } +void realignFrame(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, + unsigned Dst, unsigned Src, unsigned Alignment) { + const unsigned NrBitsToZero = countTrailingZeros(Alignment); + assert(NrBitsToZero > 1); + + // SUB X9, SP, NumBytes + // -- X9 is temporary register, so shouldn't contain any live data here, + // -- free to use. This is already produced by emitFrameOffset above. + // AND SP, X9, 0b11111...0000 + // The logical immediates have a non-trivial encoding. The following + // formula computes the encoded immediate with all ones but + // NrBitsToZero zero bits as least significant bits. + uint32_t andMaskEncoded = (1 << 12) // = N + | ((64 - NrBitsToZero) << 6) // immr + | ((64 - NrBitsToZero - 1) << 0); // imms + + const AArch64Subtarget &Subtarget = + MBB.getParent()->getSubtarget(); + const TargetInstrInfo *TII = Subtarget.getInstrInfo(); + BuildMI(MBB, MBBI, DebugLoc(), TII->get(AArch64::ANDXri), Dst) + .addReg(Src, getKillRegState(Src != AArch64::SP)) + .addImm(andMaskEncoded); +} + +// Returns true if all the callee saves can be accessed from the SP +// after allocating the entire SVE stack area in one go. +static bool canCombineSPBumpSVE(MachineBasicBlock::iterator CalleeSavesBegin, + MachineBasicBlock::iterator CalleeSavesEnd, + int64_t SVECalleeSavedStackSize, + int64_t SVEStackObjectsSize) { + if (!SVECalleeSavedStackSize || !SVEStackObjectsSize) + return false; + + for (auto I = CalleeSavesBegin; I != CalleeSavesEnd; ++I) { + StackOffset S(SVEStackObjectsSize, MVT::nxv1i8); + switch (I->getOpcode()) { + case AArch64::STR_PXI: + case AArch64::LDR_PXI: + case AArch64::STR_ZXI: + case AArch64::LDR_ZXI: + if (isAArch64FrameOffsetLegal(*I, S, nullptr, nullptr, nullptr) & + AArch64FrameOffsetIsLegal) + break; + return false; + default: + llvm_unreachable("Unsupported operation"); + } + } + + return true; +} + +static void AllocateSVE(MachineBasicBlock &MBB, DebugLoc DL, + const TargetInstrInfo *TII, + MachineBasicBlock::iterator CalleeSavesBegin, + MachineBasicBlock::iterator CalleeSavesEnd, + int64_t SVECalleeSavedStackSize, + int64_t SVEStackSize) { + int64_t SVEStackObjectsSize = SVEStackSize - SVECalleeSavedStackSize; + if (canCombineSPBumpSVE(CalleeSavesBegin, CalleeSavesEnd, + SVECalleeSavedStackSize, SVEStackObjectsSize)) { + for (auto I = CalleeSavesBegin; I != CalleeSavesEnd; ++I) + fixupCalleeSaveRestoreStackOffset(*I, SVEStackObjectsSize); + + SVECalleeSavedStackSize = SVEStackSize; + SVEStackObjectsSize = 0; + } + + // Allocate space for the callee saves. + emitFrameOffset(MBB, CalleeSavesBegin, DL, AArch64::SP, AArch64::SP, + {-SVECalleeSavedStackSize, MVT::nxv1i8}, TII, + MachineInstr::FrameSetup); + + // Allocate remaining SVE stack space. + emitFrameOffset(MBB, CalleeSavesEnd, DL, AArch64::SP, AArch64::SP, + {-SVEStackObjectsSize, MVT::nxv1i8}, TII, + MachineInstr::FrameSetup); +} + +static void DeallocateSVE(MachineBasicBlock &MBB, DebugLoc DL, + const TargetInstrInfo *TII, + MachineBasicBlock::iterator CalleeSavesBegin, + MachineBasicBlock::iterator CalleeSavesEnd, + int64_t SVECalleeSavedStackSize, + int64_t SVEStackSize) { + int64_t SVEStackObjectsSize = SVEStackSize - SVECalleeSavedStackSize; + if (canCombineSPBumpSVE(CalleeSavesBegin, CalleeSavesEnd, + SVECalleeSavedStackSize, SVEStackObjectsSize)) { + for (auto I = CalleeSavesBegin; I != CalleeSavesEnd; ++I) + fixupCalleeSaveRestoreStackOffset(*I, SVEStackObjectsSize); + + SVECalleeSavedStackSize = SVEStackSize; + SVEStackObjectsSize = 0; + } + + // Deallocate SVE objects space. + emitFrameOffset(MBB, CalleeSavesBegin, DL, AArch64::SP, AArch64::SP, + {SVEStackObjectsSize, MVT::nxv1i8}, TII, + MachineInstr::FrameDestroy); + + // Deallocate space for the callee saves. + emitFrameOffset(MBB, CalleeSavesEnd, DL, AArch64::SP, AArch64::SP, + {SVECalleeSavedStackSize, MVT::nxv1i8}, TII, + MachineInstr::FrameDestroy); +} + void AArch64FrameLowering::emitPrologue(MachineFunction &MF, MachineBasicBlock &MBB) const { MachineBasicBlock::iterator MBBI = MBB.begin(); @@ -573,10 +824,27 @@ if (MF.getFunction().getCallingConv() == CallingConv::GHC) return; + bool NeedsRealignment = RegInfo->needsStackRealignment(MF); + + // Skip past the SVE callee saves. + int64_t SVECalleeSavedStackSize = AFI->getSVECalleeSavedStackSize(); + if (SVECalleeSavedStackSize) + while ((MBBI->getOpcode() == AArch64::STR_ZXI || + MBBI->getOpcode() == AArch64::STR_PXI) && + !MBBI->getOperand(1).isFI() && MBBI != MBB.getFirstTerminator()) + ++MBBI; + + // Allocate the SVE area. + int64_t SVEStackSize = getSVEStackSize(MF); + assert((!SVEStackSize || AFI->getMaxAlignSVE() <= getStackAlignment()) && + "Alignment of SVE objects expected <= 16 for now"); + AllocateSVE(MBB, DL, TII, MBB.begin(), MBBI, SVECalleeSavedStackSize, + SVEStackSize); + int NumBytes = (int)MFI.getStackSize(); if (!AFI->hasStackFrame() && !windowsRequiresStackProbe(MF, NumBytes)) { assert(!HasFP && "unexpected function without stack frame but with FP"); - + assert(!SVEStackSize && "Must have stack frame with SVE"); // All of the stack allocation is for locals. AFI->setLocalStackSize(NumBytes); @@ -588,8 +856,8 @@ AFI->setHasRedZone(true); ++NumRedZoneFunctions; } else { - emitFrameOffset(MBB, MBBI, DL, AArch64::SP, AArch64::SP, -NumBytes, TII, - MachineInstr::FrameSetup); + emitFrameOffset(MBB, MBBI, DL, AArch64::SP, AArch64::SP, + {-NumBytes, MVT::i8}, TII, MachineInstr::FrameSetup); // Label used to tie together the PROLOG_LABEL and the MachineMoves. MCSymbol *FrameLabel = MMI.getContext().createTempSymbol(); @@ -612,13 +880,17 @@ AFI->setLocalStackSize(NumBytes - PrologueSaveSize); bool CombineSPBump = shouldCombineCSRLocalStackBump(MF, NumBytes); + + // Adjust offset for using the Frame Record with 1 SVE register. + int OverlapOffset = SVEStackSize && HasFP ? 16 : 0; if (CombineSPBump) { - emitFrameOffset(MBB, MBBI, DL, AArch64::SP, AArch64::SP, -NumBytes, TII, + emitFrameOffset(MBB, MBBI, DL, AArch64::SP, AArch64::SP, + {-NumBytes + OverlapOffset, MVT::i8}, TII, MachineInstr::FrameSetup); NumBytes = 0; } else if (PrologueSaveSize != 0) { MBBI = convertCalleeSaveRestoreToSPPrePostIncDec(MBB, MBBI, DL, TII, - -PrologueSaveSize); + -PrologueSaveSize + OverlapOffset); NumBytes -= PrologueSaveSize; } assert(NumBytes >= 0 && "Negative stack allocation size!?"); @@ -643,8 +915,8 @@ // mov fp,sp when FPOffset is zero. // Note: All stores of callee-saved registers are marked as "FrameSetup". // This code marks the instruction(s) that set the FP also. - emitFrameOffset(MBB, MBBI, DL, AArch64::FP, AArch64::SP, FPOffset, TII, - MachineInstr::FrameSetup); + emitFrameOffset(MBB, MBBI, DL, AArch64::FP, AArch64::SP, + {FPOffset, MVT::i8}, TII, MachineInstr::FrameSetup); } if (windowsRequiresStackProbe(MF, NumBytes)) { @@ -687,7 +959,6 @@ // Allocate space for the rest of the frame. if (NumBytes) { - const bool NeedsRealignment = RegInfo->needsStackRealignment(MF); unsigned scratchSPReg = AArch64::SP; if (NeedsRealignment) { @@ -700,29 +971,11 @@ // FIXME: in the case of dynamic re-alignment, NumBytes doesn't have // the correct value here, as NumBytes also includes padding bytes, // which shouldn't be counted here. - emitFrameOffset(MBB, MBBI, DL, scratchSPReg, AArch64::SP, -NumBytes, TII, - MachineInstr::FrameSetup); + emitFrameOffset(MBB, MBBI, DL, scratchSPReg, AArch64::SP, + {-NumBytes, MVT::i8}, TII, MachineInstr::FrameSetup); if (NeedsRealignment) { - const unsigned Alignment = MFI.getMaxAlignment(); - const unsigned NrBitsToZero = countTrailingZeros(Alignment); - assert(NrBitsToZero > 1); - assert(scratchSPReg != AArch64::SP); - - // SUB X9, SP, NumBytes - // -- X9 is temporary register, so shouldn't contain any live data here, - // -- free to use. This is already produced by emitFrameOffset above. - // AND SP, X9, 0b11111...0000 - // The logical immediates have a non-trivial encoding. The following - // formula computes the encoded immediate with all ones but - // NrBitsToZero zero bits as least significant bits. - uint32_t andMaskEncoded = (1 << 12) // = N - | ((64 - NrBitsToZero) << 6) // immr - | ((64 - NrBitsToZero - 1) << 0); // imms - - BuildMI(MBB, MBBI, DL, TII->get(AArch64::ANDXri), AArch64::SP) - .addReg(scratchSPReg, RegState::Kill) - .addImm(andMaskEncoded); + realignFrame(MBB, MBBI, AArch64::SP, scratchSPReg, MFI.getMaxAlignment()); AFI->setStackRealigned(true); } } @@ -809,18 +1062,26 @@ // Ltmp5: // .cfi_offset w28, -32 + const MCRegisterInfo *MRI = MF.getSubtarget().getRegisterInfo(); if (HasFP) { // Define the current CFA rule to use the provided FP. unsigned Reg = RegInfo->getDwarfRegNum(FramePtr, true); - unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::createDefCfa( - nullptr, Reg, 2 * StackGrowth - FixedObject)); + unsigned CFIIndex; + if (SVEStackSize) + CFIIndex = MF.addFrameInst( + emitDefCFA(MRI, FramePtr, StackOffset(SVEStackSize, MVT::nxv1i8))); + else + CFIIndex = MF.addFrameInst(MCCFIInstruction::createDefCfa( + nullptr, Reg, 2 * StackGrowth - FixedObject)); + BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION)) .addCFIIndex(CFIIndex) .setMIFlags(MachineInstr::FrameSetup); } else { // Encode the stack size of the leaf function. - unsigned CFIIndex = MF.addFrameInst( - MCCFIInstruction::createDefCfaOffset(nullptr, -MFI.getStackSize())); + unsigned CFIIndex = MF.addFrameInst(emitDefCFA( + MRI, AArch64::SP, StackOffset(SVEStackSize, MVT::nxv1i8) + + StackOffset(MFI.getStackSize(), MVT::i8))); BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION)) .addCFIIndex(CFIIndex) .setMIFlags(MachineInstr::FrameSetup); @@ -829,6 +1090,7 @@ // Now emit the moves for whatever callee saved regs we have (including FP, // LR if those are saved). emitCalleeSavedFrameMoves(MBB, MBBI); + fixupScalableDebugOffsets(MF); } } @@ -849,6 +1111,15 @@ int NumBytes = MFI.getStackSize(); const AArch64FunctionInfo *AFI = MF.getInfo(); + // Set the FirstTerminator to be the first SVE reload. + MachineBasicBlock::iterator FirstTerminator = MBB.getFirstTerminator(); + if (AFI->getSVECalleeSavedStackSize()) + while (FirstTerminator != MBB.begin() && + (std::prev(FirstTerminator)->getOpcode() == AArch64::LDR_ZXI || + std::prev(FirstTerminator)->getOpcode() == AArch64::LDR_PXI) && + !std::prev(FirstTerminator)->getOperand(1).isFI()) + --FirstTerminator; + // All calls are tail calls in GHC calling conv, and functions have no // prologue/epilogue. if (MF.getFunction().getCallingConv() == CallingConv::GHC) @@ -906,30 +1177,32 @@ uint64_t AfterCSRPopSize = ArgumentPopSize; auto PrologueSaveSize = AFI->getCalleeSavedStackSize() + FixedObject; bool CombineSPBump = shouldCombineCSRLocalStackBump(MF, NumBytes); - // Assume we can't combine the last pop with the sp restore. + int64_t SVEStackSize = getSVEStackSize(MF); + unsigned OverlapOffset = SVEStackSize && hasFP(MF) ? 16 : 0; + // Assume we can't combine the last pop with the sp restore. if (!CombineSPBump && PrologueSaveSize != 0) { - MachineBasicBlock::iterator Pop = std::prev(MBB.getFirstTerminator()); + MachineBasicBlock::iterator Pop = std::prev(FirstTerminator); // Converting the last ldp to a post-index ldp is valid only if the last // ldp's offset is 0. const MachineOperand &OffsetOp = Pop->getOperand(Pop->getNumOperands() - 1); // If the offset is 0, convert it to a post-index ldp. if (OffsetOp.getImm() == 0) { convertCalleeSaveRestoreToSPPrePostIncDec(MBB, Pop, DL, TII, - PrologueSaveSize); + PrologueSaveSize - OverlapOffset); } else { // If not, make sure to emit an add after the last ldp. // We're doing this by transfering the size to be restored from the // adjustment *before* the CSR pops to the adjustment *after* the CSR // pops. - AfterCSRPopSize += PrologueSaveSize; + AfterCSRPopSize += PrologueSaveSize - OverlapOffset; } } // Move past the restores of the callee-saved registers. // If we plan on combining the sp bump of the local stack size and the callee // save stack size, we might need to adjust the CSR save and restore offsets. - MachineBasicBlock::iterator LastPopI = MBB.getFirstTerminator(); + MachineBasicBlock::iterator LastPopI = FirstTerminator; MachineBasicBlock::iterator Begin = MBB.begin(); while (LastPopI != Begin) { --LastPopI; @@ -942,9 +1215,12 @@ // If there is a single SP update, insert it before the ret and we're done. if (CombineSPBump) { - emitFrameOffset(MBB, MBB.getFirstTerminator(), DL, AArch64::SP, AArch64::SP, - NumBytes + AfterCSRPopSize, TII, - MachineInstr::FrameDestroy); + emitFrameOffset( + MBB, FirstTerminator, DL, AArch64::SP, AArch64::SP, + {NumBytes + (int64_t)AfterCSRPopSize - OverlapOffset, MVT::i8}, TII, + MachineInstr::FrameDestroy); + DeallocateSVE(MBB, DL, TII, FirstTerminator, MBB.getFirstTerminator(), + AFI->getSVECalleeSavedStackSize(), SVEStackSize); return; } @@ -969,11 +1245,12 @@ // If we're done after this, make sure to help the load store optimizer. if (Done) - adaptForLdStOpt(MBB, MBB.getFirstTerminator(), LastPopI); + adaptForLdStOpt(MBB, FirstTerminator, LastPopI); emitFrameOffset(MBB, LastPopI, DL, AArch64::SP, AArch64::SP, - StackRestoreBytes, TII, MachineInstr::FrameDestroy); - if (Done) + {StackRestoreBytes, MVT::i8}, TII, + MachineInstr::FrameDestroy); + if (!SVEStackSize && Done) return; NumBytes = 0; @@ -985,11 +1262,11 @@ // be able to save any instructions. if (MFI.hasVarSizedObjects() || AFI->isStackRealigned()) emitFrameOffset(MBB, LastPopI, DL, AArch64::SP, AArch64::FP, - -AFI->getCalleeSavedStackSize() + 16, TII, - MachineInstr::FrameDestroy); + {-(int64_t)AFI->getCalleeSavedStackSize() + 16, MVT::i8}, + TII, MachineInstr::FrameDestroy); else if (NumBytes) - emitFrameOffset(MBB, LastPopI, DL, AArch64::SP, AArch64::SP, NumBytes, TII, - MachineInstr::FrameDestroy); + emitFrameOffset(MBB, LastPopI, DL, AArch64::SP, AArch64::SP, + {NumBytes, MVT::i8}, TII, MachineInstr::FrameDestroy); // This must be placed after the callee-save restore code because that code // assumes the SP is at the same location as it was after the callee-save save @@ -998,7 +1275,7 @@ // Find an insertion point for the first ldp so that it goes before the // shadow call stack epilog instruction. This ensures that the restore of // lr from x18 is placed after the restore from sp. - auto FirstSPPopI = MBB.getFirstTerminator(); + auto FirstSPPopI = FirstTerminator; while (FirstSPPopI != Begin) { auto Prev = std::prev(FirstSPPopI); if (Prev->getOpcode() != AArch64::LDRXpre || @@ -1010,8 +1287,12 @@ adaptForLdStOpt(MBB, FirstSPPopI, LastPopI); emitFrameOffset(MBB, FirstSPPopI, DL, AArch64::SP, AArch64::SP, - AfterCSRPopSize, TII, MachineInstr::FrameDestroy); + {(int64_t)AfterCSRPopSize, MVT::i8}, TII, + MachineInstr::FrameDestroy); } + + DeallocateSVE(MBB, DL, TII, FirstTerminator, MBB.getFirstTerminator(), + AFI->getSVECalleeSavedStackSize(), SVEStackSize); } /// getFrameIndexReference - Provide a base+offset reference to an FI slot for @@ -1021,12 +1302,13 @@ int AArch64FrameLowering::getFrameIndexReference(const MachineFunction &MF, int FI, unsigned &FrameReg) const { - return resolveFrameIndexReference(MF, FI, FrameReg); + return resolveFrameIndexReference(MF, FI, FrameReg).getBytes(); } -int AArch64FrameLowering::resolveFrameIndexReference(const MachineFunction &MF, - int FI, unsigned &FrameReg, - bool PreferFP) const { +StackOffset +AArch64FrameLowering::resolveFrameIndexReference(const MachineFunction &MF, + int FI, unsigned &FrameReg, + bool PreferFP) const { const MachineFrameInfo &MFI = MF.getFrameInfo(); const AArch64RegisterInfo *RegInfo = static_cast( MF.getSubtarget().getRegisterInfo()); @@ -1059,6 +1341,20 @@ // the CSR area. assert(hasFP(MF) && "Re-aligned stack must have frame pointer"); UseFP = true; + } else if (hasFP(MF) && !RegInfo->hasBasePointer(MF) && + !RegInfo->needsStackRealignment(MF)) { + // Use SP or FP, whichever gives us the best chance of the offset + // being in range for direct access. If the FPOffset is positive, + // that'll always be best, as the SP will be even further away. + // If the FPOffset is negative, we have to keep in mind that the + // available offset range for negative offsets is smaller than for + // positive ones. If we have variable sized objects, we're stuck with + // using the FP regardless, though, as the SP offset is unknown + // and we don't have a base pointer available. If an offset is + // available via the FP and the SP, use whichever is closest. + if (PreferFP || MFI.hasVarSizedObjects() || + FPOffset >= 0 || (FPOffset >= -256 && Offset > -FPOffset)) + UseFP = true; } else if (hasFP(MF) && !RegInfo->needsStackRealignment(MF)) { // If the FPOffset is negative, we have to keep in mind that the // available offset range for negative offsets is smaller than for @@ -1096,9 +1392,38 @@ "In the presence of dynamic stack pointer realignment, " "non-argument/CSR objects cannot be accessed through the frame pointer"); + uint64_t SVEStackSize = getSVEStackSize(MF); + + // Handle objects in an SVE region by returning: + // FP + ScalableOffset + // or + // SP + ScalableOffset + sizeof(NonScalableStack) + auto StackID = MFI.getStackID(FI); + if (StackID == AArch64::FR_SVE) { + int64_t SVEOffset = MFI.getObjectOffset(FI) + SVEStackSize; + if (hasFP(MF)) { + FrameReg = AArch64::FP; + return StackOffset(SVEOffset, MVT::nxv1i8); + } + FrameReg = AArch64::SP; + return StackOffset(SVEOffset, MVT::nxv1i8) + + StackOffset(MFI.getStackSize(), MVT::i8); + } + + assert(StackID == AArch64::FR_Default && "Unsupported Stack ID"); + + // Handle stack arguments above the SVE region by adding + // the size of the Scalable area (minus the overlap of the + // frame-record). + StackOffset Scalable; + if (SVEStackSize && isFixed) { + Scalable = StackOffset(SVEStackSize, MVT::nxv1i8); + FPOffset -= UseFP ? 16 : 0; + } + if (UseFP) { FrameReg = RegInfo->getFrameRegister(MF); - return FPOffset; + return StackOffset(FPOffset, MVT::i8) + Scalable; } // Use the base pointer if we have one. @@ -1115,7 +1440,7 @@ Offset -= AFI->getLocalStackSize(); } - return Offset; + return StackOffset(Offset, MVT::i8) + Scalable; } static unsigned getPrologueDeath(MachineFunction &MF, unsigned Reg) { @@ -1143,11 +1468,28 @@ unsigned Reg2 = AArch64::NoRegister; int FrameIdx; int Offset; - bool IsGPR; + enum RegType { GPR, FPR64, FPR128, PPR, ZPR } Type; RegPairInfo() = default; bool isPaired() const { return Reg2 != AArch64::NoRegister; } + + unsigned getScale() const { + switch (Type) { + case PPR: + return 2; + case GPR: + case FPR64: + return 8; + case ZPR: + case FPR128: + return 16; + default: + llvm_unreachable("Unsupported type"); + } + } + + bool isScalable() const { return Type == PPR || Type == ZPR; } }; } // end anonymous namespace @@ -1171,22 +1513,46 @@ CC == CallingConv::PreserveMost || (Count & 1) == 0) && "Odd number of callee-saved regs to spill!"); - int Offset = AFI->getCalleeSavedStackSize(); + int ByteOffset = AFI->getCalleeSavedStackSize(); + int ScalableByteOffset = AFI->getSVECalleeSavedStackSize(); for (unsigned i = 0; i < Count; ++i) { RegPairInfo RPI; RPI.Reg1 = CSI[i].getReg(); - assert(AArch64::GPR64RegClass.contains(RPI.Reg1) || - AArch64::FPR64RegClass.contains(RPI.Reg1)); - RPI.IsGPR = AArch64::GPR64RegClass.contains(RPI.Reg1); + if (AArch64::GPR64RegClass.contains(RPI.Reg1)) + RPI.Type = RegPairInfo::GPR; + else if (AArch64::FPR64RegClass.contains(RPI.Reg1)) + RPI.Type = RegPairInfo::FPR64; + else if (AArch64::FPR128RegClass.contains(RPI.Reg1)) + RPI.Type = RegPairInfo::FPR128; + else if (AArch64::ZPRRegClass.contains(RPI.Reg1)) + RPI.Type = RegPairInfo::ZPR; + else if (AArch64::PPRRegClass.contains(RPI.Reg1)) + RPI.Type = RegPairInfo::PPR; + else + llvm_unreachable("Unsupported register class."); // Add the next reg to the pair if it is in the same register class. if (i + 1 < Count) { unsigned NextReg = CSI[i + 1].getReg(); - if ((RPI.IsGPR && AArch64::GPR64RegClass.contains(NextReg)) || - (!RPI.IsGPR && AArch64::FPR64RegClass.contains(NextReg))) - RPI.Reg2 = NextReg; + switch (RPI.Type) { + case RegPairInfo::GPR: + if (AArch64::GPR64RegClass.contains(NextReg)) + RPI.Reg2 = NextReg; + break; + case RegPairInfo::FPR64: + if (AArch64::FPR64RegClass.contains(NextReg)) + RPI.Reg2 = NextReg; + break; + case RegPairInfo::FPR128: + if (AArch64::FPR128RegClass.contains(NextReg)) + RPI.Reg2 = NextReg; + break; + case RegPairInfo::PPR: + case RegPairInfo::ZPR: + break; + } } // If either of the registers to be saved is the lr register, it means that @@ -1219,18 +1585,31 @@ RPI.FrameIdx = CSI[i].getFrameIdx(); - if (Count * 8 != AFI->getCalleeSavedStackSize() && !RPI.isPaired()) { - // Round up size of non-pair to pair size if we need to pad the - // callee-save area to ensure 16-byte alignment. - Offset -= 16; + int Scale = RPI.getScale(); + if (RPI.isScalable()) + ScalableByteOffset -= Scale; + else + ByteOffset -= RPI.isPaired() ? 2 * Scale : Scale; + + assert(!(RPI.isScalable() && RPI.isPaired()) && + "Paired spill/fill instructions don't exist for SVE vectors"); + + // Round up size of non-pair to pair size if we need to pad the + // callee-save area to ensure 16-byte alignment. + if (AFI->hasCalleeSaveStackFreeSpace() && RPI.Type != RegPairInfo::FPR128 && + !RPI.isScalable() && !RPI.isPaired()) { + ByteOffset -= 8; + assert(ByteOffset % 16 == 0); assert(MFI.getObjectAlignment(RPI.FrameIdx) <= 16); MFI.setObjectAlignment(RPI.FrameIdx, 16); - AFI->setCalleeSaveStackHasFreeSpace(true); - } else - Offset -= RPI.isPaired() ? 16 : 8; - assert(Offset % 8 == 0); - RPI.Offset = Offset / 8; - assert((RPI.Offset >= -64 && RPI.Offset <= 63) && + } + + int Offset = RPI.isScalable() ? ScalableByteOffset : ByteOffset; + assert(Offset % Scale == 0); + RPI.Offset = Offset / Scale; + + assert(((!RPI.isScalable() && RPI.Offset >= -64 && RPI.Offset <= 63) || + (RPI.isScalable() && RPI.Offset >= -256 && RPI.Offset <= 255)) && "Offset out of bounds for LDP/STP immediate"); RegPairs.push_back(RPI); @@ -1246,8 +1625,8 @@ MachineFunction &MF = *MBB.getParent(); const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo(); DebugLoc DL; - SmallVector RegPairs; + SmallVector RegPairs; bool NeedShadowCallStackProlog = false; computeCalleeSaveRegisterPairs(MF, CSI, TRI, RegPairs, NeedShadowCallStackProlog); @@ -1283,10 +1662,34 @@ // Rationale: This sequence saves uop updates compared to a sequence of // pre-increment spills like stp xi,xj,[sp,#-16]! // Note: Similar rationale and sequence for restores in epilog. - if (RPI.IsGPR) - StrOpc = RPI.isPaired() ? AArch64::STPXi : AArch64::STRXui; - else - StrOpc = RPI.isPaired() ? AArch64::STPDi : AArch64::STRDui; + unsigned Size, Align; + switch (RPI.Type) { + case RegPairInfo::GPR: + StrOpc = RPI.isPaired() ? AArch64::STPXi : AArch64::STRXui; + Size = 8; + Align = 8; + break; + case RegPairInfo::FPR64: + StrOpc = RPI.isPaired() ? AArch64::STPDi : AArch64::STRDui; + Size = 8; + Align = 8; + break; + case RegPairInfo::FPR128: + StrOpc = RPI.isPaired() ? AArch64::STPQi : AArch64::STRQui; + Size = 16; + Align = 16; + break; + case RegPairInfo::ZPR: + StrOpc = AArch64::STR_ZXI; + Size = 16; + Align = 16; + break; + case RegPairInfo::PPR: + StrOpc = AArch64::STR_PXI; + Size = 2; + Align = 2; + break; + } LLVM_DEBUG(dbgs() << "CSR spill: (" << printReg(Reg1, TRI); if (RPI.isPaired()) dbgs() << ", " << printReg(Reg2, TRI); dbgs() << ") -> fi#(" << RPI.FrameIdx; @@ -1302,23 +1705,29 @@ MIB.addReg(Reg2, getPrologueDeath(MF, Reg2)); MIB.addMemOperand(MF.getMachineMemOperand( MachinePointerInfo::getFixedStack(MF, RPI.FrameIdx + 1), - MachineMemOperand::MOStore, 8, 8)); + MachineMemOperand::MOStore, Size, Align)); } MIB.addReg(Reg1, getPrologueDeath(MF, Reg1)) .addReg(AArch64::SP) - .addImm(RPI.Offset) // [sp, #offset*8], where factor*8 is implicit + .addImm(RPI.Offset) // [sp, #offset*scale], + // where factor*scale is implicit .setMIFlag(MachineInstr::FrameSetup); MIB.addMemOperand(MF.getMachineMemOperand( MachinePointerInfo::getFixedStack(MF, RPI.FrameIdx), - MachineMemOperand::MOStore, 8, 8)); + MachineMemOperand::MOStore, Size, Align)); + + // Update the StackIDs of the SVE stack slots. + MachineFrameInfo &MFI = MF.getFrameInfo(); + if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) + MFI.setStackID(RPI.FrameIdx, AArch64::FR_SVE); + } return true; } bool AArch64FrameLowering::restoreCalleeSavedRegisters( MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, - std::vector &CSI, - const TargetRegisterInfo *TRI) const { + std::vector &CSI, const TargetRegisterInfo *TRI) const { MachineFunction &MF = *MBB.getParent(); const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo(); DebugLoc DL; @@ -1344,10 +1753,34 @@ // ldp x22, x21, [sp, #0] // addImm(+0) // Note: see comment in spillCalleeSavedRegisters() unsigned LdrOpc; - if (RPI.IsGPR) - LdrOpc = RPI.isPaired() ? AArch64::LDPXi : AArch64::LDRXui; - else - LdrOpc = RPI.isPaired() ? AArch64::LDPDi : AArch64::LDRDui; + unsigned Size, Align; + switch (RPI.Type) { + case RegPairInfo::GPR: + LdrOpc = RPI.isPaired() ? AArch64::LDPXi : AArch64::LDRXui; + Size = 8; + Align = 8; + break; + case RegPairInfo::FPR64: + LdrOpc = RPI.isPaired() ? AArch64::LDPDi : AArch64::LDRDui; + Size = 8; + Align = 8; + break; + case RegPairInfo::FPR128: + LdrOpc = RPI.isPaired() ? AArch64::LDPQi : AArch64::LDRQui; + Size = 16; + Align = 16; + break; + case RegPairInfo::ZPR: + LdrOpc = AArch64::LDR_ZXI; + Size = 16; + Align = 16; + break; + case RegPairInfo::PPR: + LdrOpc = AArch64::LDR_PXI; + Size = 2; + Align = 2; + break; + } LLVM_DEBUG(dbgs() << "CSR restore: (" << printReg(Reg1, TRI); if (RPI.isPaired()) dbgs() << ", " << printReg(Reg2, TRI); dbgs() << ") -> fi#(" << RPI.FrameIdx; @@ -1359,22 +1792,30 @@ MIB.addReg(Reg2, getDefRegState(true)); MIB.addMemOperand(MF.getMachineMemOperand( MachinePointerInfo::getFixedStack(MF, RPI.FrameIdx + 1), - MachineMemOperand::MOLoad, 8, 8)); + MachineMemOperand::MOLoad, Size, Align)); } MIB.addReg(Reg1, getDefRegState(true)) .addReg(AArch64::SP) - .addImm(RPI.Offset) // [sp, #offset*8] where the factor*8 is implicit + .addImm(RPI.Offset) // [sp, #offset*scale] + // where factor*scale is implicit .setMIFlag(MachineInstr::FrameDestroy); MIB.addMemOperand(MF.getMachineMemOperand( MachinePointerInfo::getFixedStack(MF, RPI.FrameIdx), - MachineMemOperand::MOLoad, 8, 8)); + MachineMemOperand::MOLoad, Size, Align)); }; - if (ReverseCSRRestoreSeq) + if (ReverseCSRRestoreSeq) { for (const RegPairInfo &RPI : reverse(RegPairs)) - EmitMI(RPI); - else + if (!RPI.isScalable()) + EmitMI(RPI); + } else for (const RegPairInfo &RPI : RegPairs) + if (!RPI.isScalable()) + EmitMI(RPI); + + // SVE objects are always restored in reverse order. + for (const RegPairInfo &RPI : reverse(RegPairs)) + if (RPI.isScalable()) EmitMI(RPI); if (NeedShadowCallStackProlog) { @@ -1390,6 +1831,24 @@ return true; } +int64_t +AArch64FrameLowering::getFurthestStackArgOffset(MachineFunction &MF) const { + MachineFrameInfo &MFI = MF.getFrameInfo(); + + int64_t Offset = 0; + + for (int i = MFI.getObjectIndexBegin(); i != 0; ++i) { + if(MFI.getStackID(i) != 0) { + assert(MFI.getObjectOffset(i) < 0 && + "Does not support SVE stack arguments passed by value"); + continue; + } + Offset = std::max(Offset, MFI.getObjectOffset(i)); + } + + return Offset; +} + void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, BitVector &SavedRegs, RegScavenger *RS) const { @@ -1412,35 +1871,21 @@ ? RegInfo->getBaseRegister() : (unsigned)AArch64::NoRegister; - unsigned SpillEstimate = SavedRegs.count(); - for (unsigned i = 0; CSRegs[i]; ++i) { - unsigned Reg = CSRegs[i]; - unsigned PairedReg = CSRegs[i ^ 1]; - if (Reg == BasePointerReg) - SpillEstimate++; - if (produceCompactUnwindFrame(MF) && !SavedRegs.test(PairedReg)) - SpillEstimate++; - } - SpillEstimate += 2; // Conservatively include FP+LR in the estimate - unsigned StackEstimate = MFI.estimateStackSize(MF) + 8 * SpillEstimate; - - // The frame record needs to be created by saving the appropriate registers - if (hasFP(MF) || windowsRequiresStackProbe(MF, StackEstimate)) { - SavedRegs.set(AArch64::FP); - SavedRegs.set(AArch64::LR); - } - unsigned ExtraCSSpill = 0; // Figure out which callee-saved registers to save/restore. for (unsigned i = 0; CSRegs[i]; ++i) { const unsigned Reg = CSRegs[i]; - // Add the base pointer register to SavedRegs if it is callee-save. if (Reg == BasePointerReg) SavedRegs.set(Reg); bool RegUsed = SavedRegs.test(Reg); - unsigned PairedReg = CSRegs[i ^ 1]; + unsigned PairedReg = AArch64::NoRegister; + if (AArch64::GPR64RegClass.contains(Reg) || + AArch64::FPR64RegClass.contains(Reg) || + AArch64::FPR128RegClass.contains(Reg)) + PairedReg = CSRegs[i ^ 1]; + if (!RegUsed) { if (AArch64::GPR64RegClass.contains(Reg) && !RegInfo->isReservedReg(MF, Reg)) { @@ -1453,7 +1898,8 @@ // MachO's compact unwind format relies on all registers being stored in // pairs. // FIXME: the usual format is actually better if unwinding isn't needed. - if (produceCompactUnwindFrame(MF) && !SavedRegs.test(PairedReg)) { + if (produceCompactUnwindFrame(MF) && PairedReg != AArch64::NoRegister && + !SavedRegs.test(PairedReg)) { SavedRegs.set(PairedReg); if (AArch64::GPR64RegClass.contains(PairedReg) && !RegInfo->isReservedReg(MF, PairedReg)) @@ -1461,22 +1907,62 @@ } } + // Calculates the callee saved stack size. + unsigned CSStackSize = 0; + unsigned SVECSStackSize = 0; + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + const MachineRegisterInfo &MRI = MF.getRegInfo(); + for (unsigned Reg : SavedRegs.set_bits()) { + auto RegSize = TRI->getRegSizeInBits(Reg, MRI) / 8; + if (AArch64::PPRRegClass.contains(Reg) || + AArch64::ZPRRegClass.contains(Reg)) + SVECSStackSize += RegSize; + else + CSStackSize += RegSize; + } + + // Save number of saved regs, so we can easily update CSStackSize later. + unsigned NumSavedRegs = SavedRegs.count(); + + // The frame record needs to be created by saving the appropriate registers + unsigned EstimatedStackSize = MFI.estimateStackSize(MF); + if (hasFP(MF) || + windowsRequiresStackProbe(MF, EstimatedStackSize + CSStackSize + 16)) { + SavedRegs.set(AArch64::FP); + SavedRegs.set(AArch64::LR); + } + LLVM_DEBUG(dbgs() << "*** determineCalleeSaves\nUsed CSRs:"; for (unsigned Reg : SavedRegs.set_bits()) dbgs() << ' ' << printReg(Reg, RegInfo); dbgs() << "\n";); - // If any callee-saved registers are used, the frame cannot be eliminated. - unsigned NumRegsSpilled = SavedRegs.count(); - bool CanEliminateFrame = NumRegsSpilled == 0; + unsigned MaxAlignSVE = 16; + uint64_t StackRegionsSize = + alignTo(SVECSStackSize + estimateSVEStackObjectOffsets(MFI, MaxAlignSVE), + MaxAlignSVE); - // The CSR spill slots have not been allocated yet, so estimateStackSize - // won't include them. - unsigned CFSize = MFI.estimateStackSize(MF) + 8 * NumRegsSpilled; - LLVM_DEBUG(dbgs() << "Estimated stack frame size: " << CFSize << " bytes.\n"); + // If any callee-saved registers are used, the frame cannot be eliminated. + bool CanEliminateFrame = + (SavedRegs.count() == 0) && !(StackRegionsSize || CSStackSize); + + // In case of SVE regions, we know a scratch register is needed + // to calculate the address of a SVE vector on the stack when: + // - There is no framepointer and accessing a SVE vector requires + // first calculating the base of the SVE region from SP. + // - There *is* a framepointer and we need to step over the + // 16 byte framerecord (x29,x30) just above FP. + // - The address of the furthest SVE vector would not fit the + // immediate field. + unsigned FurthestStackObj = getFurthestStackArgOffset(MF); unsigned EstimatedStackSizeLimit = estimateRSStackSizeLimit(MF); - bool BigStack = (CFSize > EstimatedStackSizeLimit); + bool BigStack = + ((EstimatedStackSize + CSStackSize + FurthestStackObj) > + EstimatedStackSizeLimit) || + ((StackRegionsSize / 16) > 31) || + (StackRegionsSize > 0 && !hasFP(MF) && + (EstimatedStackSize + CSStackSize) > 0); if (BigStack || !CanEliminateFrame || RegInfo->cannotEliminateFrame(MF)) AFI->setHasStackFrame(true); @@ -1497,7 +1983,6 @@ if (produceCompactUnwindFrame(MF)) SavedRegs.set(UnspilledCSGPRPaired); ExtraCSSpill = UnspilledCSGPRPaired; - NumRegsSpilled = SavedRegs.count(); } // If we didn't find an extra callee-saved register to spill, create @@ -1514,13 +1999,234 @@ } } + // Adding the size of additional 64bit GPR saves. + CSStackSize += 8 * (SavedRegs.count() - NumSavedRegs); + unsigned AlignedCSStackSize = alignTo(CSStackSize, 16); + LLVM_DEBUG(dbgs() << "Estimated stack frame size: " + << EstimatedStackSize + AlignedCSStackSize + << " bytes.\n"); + // Round up to register pair alignment to avoid additional SP adjustment // instructions. - AFI->setCalleeSavedStackSize(alignTo(8 * NumRegsSpilled, 16)); + AFI->setCalleeSavedStackSize(AlignedCSStackSize); + AFI->setCalleeSaveStackHasFreeSpace(AlignedCSStackSize != CSStackSize); + + unsigned AlignedSVECSStackSize = alignTo(SVECSStackSize, 16); + AFI->setSVECalleeSavedStackSize(AlignedSVECSStackSize); + AFI->setSVECalleeSaveStackHasFreeSpace(AlignedSVECSStackSize != + SVECSStackSize); } +// When stack realignment is required the size of the stack frame becomes +// runtime variable. This manifests itself as a runtime variable gap between the +// local and callee-save regions. If FP is used to build the callee-save region +// then FP must be used to access any locals placed within it. This does not +// happen today causing corruption to callee saves by SP relative stores whose +// offset is calculated assuming FP-SP is a compile time constant. +// +// An option is to force all callee-save region accesses to be relative to the +// same base (SP or FP) but given the alignment is likely to reduce the benefit +// of slot scavenging it's simpler to disable the optimisation. bool AArch64FrameLowering::enableStackSlotScavenging( const MachineFunction &MF) const { const AArch64FunctionInfo *AFI = MF.getInfo(); return AFI->hasCalleeSaveStackFreeSpace(); } + +/// returns true if there are any SVE callee saves. +static bool getSVECalleeSaveSlotRange(const MachineFrameInfo &MFI, + int &Min, int &Max) { + if (!MFI.isCalleeSavedInfoValid()) + return false; + + Min = std::numeric_limits::max(); + Max = std::numeric_limits::min(); + const std::vector &CSI = MFI.getCalleeSavedInfo(); + for (auto &CS : CSI) { + if (AArch64::ZPRRegClass.contains(CS.getReg()) || + AArch64::PPRRegClass.contains(CS.getReg())) { + assert((Max == std::numeric_limits::min() || + Max + 1 == CS.getFrameIdx()) && + "SVE CalleeSaves are not consecutive"); + + Min = std::min(Min, CS.getFrameIdx()); + Max = CS.getFrameIdx(); + } + } + return Min != std::numeric_limits::max(); +} + +// Process all the SVE stack objects and determine offsets for each +// object. If AssignOffsets is true, the offsets get assigned. +// Returns the size of the stack. +static int64_t determineSVEStackObjectOffsets(MachineFrameInfo &MFI, + unsigned &MaxAlign, + int64_t &SVECalleeSaveSize, + bool AssignOffsets) { + // First process all fixed stack objects. + int64_t Offset = 0; + for (int I = MFI.getObjectIndexBegin(); I != 0; ++I) { + unsigned StackID = MFI.getStackID(I); + if (StackID == AArch64::FR_SVE) { + int64_t FixedOffset = -MFI.getObjectOffset(I); + if (FixedOffset > Offset) Offset = FixedOffset; + } + } + + // Allocation function + auto Assign = [&MFI](int FI, int64_t Offset) { + LLVM_DEBUG(dbgs() << "alloc FI(" << FI << ") at SP[" << Offset << "]\n"); + MFI.setObjectOffset(FI, Offset); + }; + + // Then process all callee saved slots. + int MinCSFrameIndex = -1, MaxCSFrameIndex = -1; + if (getSVECalleeSaveSlotRange(MFI, MinCSFrameIndex, MaxCSFrameIndex)) { + // Align the last callee save slot. + MFI.setObjectAlignment( + MaxCSFrameIndex, + std::max(MFI.getObjectAlignment(MaxCSFrameIndex), 16U)); + + // Assign offsets to the callee save slots. + for (int I = MinCSFrameIndex; I <= MaxCSFrameIndex; ++I) { + Offset += MFI.getObjectSize(I); + Offset = alignTo(Offset, MFI.getObjectAlignment(I)); + if (AssignOffsets) + Assign(I, -Offset); + } + + // When there are allocatable fixed stack objects, the (SVE) + // callee saves are merged and we need to record their combined + // size in order to correctly allocate the space. + SVECalleeSaveSize = Offset; + } + + // Create a buffer of SVE objects to allocate and sort it. + SmallVector ObjectsToAllocate; + for (int I = 0, E = MFI.getObjectIndexEnd(); I != E; ++I) { + unsigned StackID = MFI.getStackID(I); + if (StackID != AArch64::FR_SVE) + continue; + if (MaxCSFrameIndex >= I && I >= MinCSFrameIndex) + continue; + if (MFI.isDeadObjectIndex(I)) + continue; + + ObjectsToAllocate.push_back(I); + } + llvm::sort( + ObjectsToAllocate.begin(), ObjectsToAllocate.end(), [&MFI](int A, int B) { + if (MFI.isSpillSlotObjectIndex(A) != MFI.isSpillSlotObjectIndex(B)) + return !MFI.isSpillSlotObjectIndex(A); + if (MFI.getObjectSize(A) != MFI.getObjectSize(B)) + return MFI.getObjectSize(A) > MFI.getObjectSize(B); + return A > B; + }); + + // Dynamically allocate all SVE objects + for (unsigned FI : ObjectsToAllocate) { + // The callee-save area might not be 16-byte aligned, so align first object. + unsigned Align = MFI.getObjectAlignment(FI); + Offset = alignTo(Offset + MFI.getObjectSize(FI), Align); + MaxAlign = std::max(Align, MaxAlign); + if (AssignOffsets) + Assign(FI, -Offset); + } + + return Offset; +} + +int64_t +AArch64FrameLowering::estimateSVEStackObjectOffsets(MachineFrameInfo &MFI, + unsigned &MaxAlign) const { + int64_t SVECalleeSaveSize; + return determineSVEStackObjectOffsets(MFI, MaxAlign, SVECalleeSaveSize, + false); +} + +int64_t AArch64FrameLowering::assignSVEStackObjectOffsets( + MachineFrameInfo &MFI, unsigned &MaxAlign, + int64_t &SVECalleeSaveSize) const { + return determineSVEStackObjectOffsets(MFI, MaxAlign, SVECalleeSaveSize, true); +} + +void AArch64FrameLowering::processFunctionBeforeFrameFinalized( + MachineFunction &MF, RegScavenger *RS) const { + auto &MFI = MF.getFrameInfo(); + unsigned MaxAlign = getStackAlignment(); + int64_t SVECalleeSaveSize = 0; + int64_t SVEStackSize = + assignSVEStackObjectOffsets(MFI, MaxAlign, SVECalleeSaveSize); + + // EXTRASTACKREGION: We allocate an extra empty VL sized spill slot which is + // overlaid by the frame-record (x29, x30) to prevent having to materialise + // the base of the SVE region for each spill/fill by calculating (FP + 16). + // From this, we can just access any SVE object directly from: + // + // FP + (#offset + 1) * VL + // + // Note that this favours performance over stack size, which may change in + // the future. The offset starts at (n x) 16 bytes and requires the size of + // a full SVE vector, since the FP needs to be 16 byte aligned. + if (hasFP(MF) && SVEStackSize) + SVEStackSize += 16; + + AArch64FunctionInfo *AFI = MF.getInfo(); + AFI->setStackSizeSVE((uint64_t) alignTo(SVEStackSize, MaxAlign)); + AFI->setSVECalleeSavedStackSize(SVECalleeSaveSize); + AFI->setMaxAlignSVE(MaxAlign); +} + +static std::string getDwarfCommentForOffset(StackOffset O) { + int64_t Bytes, VGSized; + O.getForDwarfOffset(Bytes, VGSized); + + std::string Comment = ""; + if (Bytes) + Comment += " + " + std::to_string(Bytes); + if (VGSized) + Comment += " + " + std::to_string(VGSized) + " * VG"; + return Comment; +} + +MCCFIInstruction AArch64FrameLowering::emitCFI(const MCRegisterInfo *MRI, + unsigned Reg, + StackOffset Offset) const { + int64_t Bytes, VGSized; + Offset.getForDwarfOffset(Bytes, VGSized); + + unsigned DwarfReg = MRI->getDwarfRegNum(Reg, true); + if (!VGSized) + return MCCFIInstruction::createOffset(nullptr, DwarfReg, Bytes); + + // If the offset is (partially) scalable, generate a comment and + // complex Dwarf expression. + std::string Comment = "cfi("; + Comment += MRI->getName(Reg); + Comment += ") = cfa" + getDwarfCommentForOffset(Offset); + unsigned VG = MRI->getDwarfRegNum(AArch64::VG, true); + return MCCFIInstruction::createScaledCfaOffset(nullptr, DwarfReg, Bytes, VG, + VGSized, Comment); +} + +MCCFIInstruction AArch64FrameLowering::emitDefCFA(const MCRegisterInfo *MRI, + unsigned Basereg, + StackOffset Offset) const { + assert((Basereg == AArch64::SP || Basereg == AArch64::FP) && + "Unsupported base register to define CFA"); + + int64_t Bytes, VGSized; + Offset.getForDwarfOffset(Bytes, VGSized); + + if (!VGSized) + return MCCFIInstruction::createDefCfaOffset(nullptr, -Bytes); + + // Return the CFI instruction + std::string Comment = "cfa = "; + Comment += (Basereg == AArch64::SP ? "sp" : "fp"); + Comment += getDwarfCommentForOffset(Offset); + unsigned VG = MRI->getDwarfRegNum(AArch64::VG, true); + unsigned DwarfBase = MRI->getDwarfRegNum(Basereg, true); + return MCCFIInstruction::createScaledDefCfa(nullptr, DwarfBase, Bytes, VG, + VGSized, Comment); +} Index: lib/Target/AArch64/AArch64ISelDAGToDAG.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -122,6 +122,24 @@ bool SelectAddrModeUnscaled128(SDValue N, SDValue &Base, SDValue &OffImm) { return SelectAddrModeUnscaled(N, 16, Base, OffImm); } + template + bool SelectAddrModeIndexedUImm(SDValue N, SDValue &Base, SDValue &OffImm) { + bool Found = SelectAddrModeIndexed(N, Size, Base, OffImm); + if (Found) { + if (auto CI = dyn_cast(OffImm)) { + int64_t C = CI->getSExtValue(); + if (C <= Max) + return true; + } + } + + return false; + } + bool SelectAddrModePred(SDValue N, SDValue &Base, SDValue &OffImm); + template + bool SelectAddrModeIndexedSVE(SDNode *Root, SDValue N, SDValue &Base, + SDValue &OffImm); + int64_t generateSVEScaledOffset(SDValue Node, bool &BaseFound, bool &ECFound); template bool SelectAddrModeWRO(SDValue N, SDValue &Base, SDValue &Offset, @@ -135,6 +153,111 @@ return SelectAddrModeXRO(N, Width / 8, Base, Offset, SignExtend, DoShift); } + bool SelectDupZeroOrUndef(SDValue N) { + switch(N->getOpcode()) { + case ISD::UNDEF: + return true; + case AArch64ISD::DUP: + case ISD::SPLAT_VECTOR: { + auto Opnd0 = N->getOperand(0); + if (auto CN = dyn_cast(Opnd0)) + if (CN->isNullValue()) + return true; + if (auto CN = dyn_cast(Opnd0)) + if (CN->isZero()) + return true; + } + default: + break; + } + + return false; + } + + bool SelectDupZero(SDValue N) { + switch(N->getOpcode()) { + case AArch64ISD::DUP: + case ISD::SPLAT_VECTOR: { + auto Opnd0 = N->getOperand(0); + if (auto CN = dyn_cast(Opnd0)) + if (CN->isNullValue()) + return true; + if (auto CN = dyn_cast(Opnd0)) + if (CN->isZero()) + return true; + } + default: + break; + } + + return false; + } + + /// SVE Reg+Reg address mode + template + bool SelectSVERegRegAddrMode(SDValue N, SDValue &Base, SDValue &Offset) { + return SelectSVERegRegAddrMode(N, Scale, Base, Offset); + } + + template + bool SelectVectorLslImm(SDValue N, SDValue &UnscaledOp) { + return SelectVectorLslImm(N, Scale, UnscaledOp); + } + + template + bool SelectVectorUxtwLslImm(SDValue N, SDValue &Offsets) { + return SelectVectorUxtwLslImm(N, Scale, Offsets); + } + + template + bool SelectSVEUIntArithImm(SDValue N, SDValue &Imm, SDValue &Shift) { + return SelectSVEUIntArithImm(N, VT, Imm, Shift); + } + + template + bool SelectSVELogicalImm(SDValue N, SDValue &Imm) { + return SelectSVELogicalImm(N, VT, Imm); + } + + template + bool SelectSVEShiftImm64(SDValue N, SDValue &Imm) { + return SelectSVEShiftImm64(N, Low, High, Imm); + } + + // Returns a suitable CNT/INC/DEC/RDVL multiplier to calculate VSCALE*N. + template + bool SelectRDVLImm(SDValue N, SDValue &Imm) { + if (!isa(N)) + return false; + + int64_t MulImm = cast(N)->getSExtValue(); + if ((MulImm % std::abs(Scale)) == 0) { + int64_t RDVLImm = MulImm / Scale; + if ((RDVLImm >= Low) && (RDVLImm <= High)) { + Imm = CurDAG->getTargetConstant(RDVLImm, SDLoc(N), MVT::i32); + return true; + } + } + + return false; + } + + template + bool SelectShiftImm(SDValue N, SDValue &Imm) { + if (!isa(N)) + return false; + + int64_t MulImm = (1 << cast(N)->getZExtValue()); + if ((MulImm % std::abs(Scale)) == 0) { + int64_t ShiftImm = MulImm / Scale; + if ((ShiftImm >= Low) && (ShiftImm <= High)) { + Imm = CurDAG->getTargetConstant(ShiftImm, SDLoc(N), MVT::i32); + return true; + } + } + + return false; + } /// Form sequences of consecutive 64/128-bit registers for use in NEON /// instructions making use of a vector-list (e.g. ldN, tbl). Vecs must have @@ -142,14 +265,20 @@ /// unchanged; otherwise a REG_SEQUENCE value is returned. SDValue createDTuple(ArrayRef Vecs); SDValue createQTuple(ArrayRef Vecs); + // Same thing for SVE instructions making use of lists of Z registers + SDValue createZTuple(ArrayRef Vecs); - /// Generic helper for the createDTuple/createQTuple + /// Generic helper for the createDTuple/createQTuple/createZTuple /// functions. Those should almost always be called instead. SDValue createTuple(ArrayRef Vecs, const unsigned RegClassIDs[], const unsigned SubRegs[]); void SelectTable(SDNode *N, unsigned NumVecs, unsigned Opc, bool isExt); + void SelectTableSVE2(SDNode *N, unsigned NumVecs, unsigned Opc); + + void SelectBitwiseSVE2(SDNode *N, unsigned Opc); + bool tryIndexedLoad(SDNode *N); void SelectLoad(SDNode *N, unsigned NumVecs, unsigned Opc, @@ -163,6 +292,11 @@ void SelectPostStore(SDNode *N, unsigned NumVecs, unsigned Opc); void SelectStoreLane(SDNode *N, unsigned NumVecs, unsigned Opc); void SelectPostStoreLane(SDNode *N, unsigned NumVecs, unsigned Opc); + void SelectPredicatedLoad(SDNode *N, unsigned NumVecs, + const unsigned Opc_rr, + const unsigned Opc_ri, unsigned SubRegIdx); + void SelectPredicatedStore(SDNode *N, unsigned NumVecs, + const unsigned Opc_rr, const unsigned Opc_ri); bool tryBitfieldExtractOp(SDNode *N); bool tryBitfieldExtractOpFromSExt(SDNode *N); @@ -191,7 +325,7 @@ bool SelectAddrModeXRO(SDValue N, unsigned Size, SDValue &Base, SDValue &Offset, SDValue &SignExtend, SDValue &DoShift); - bool isWorthFolding(SDValue V) const; + bool isWorthFolding(SDValue V, unsigned MaxUses = 1) const; bool SelectExtendedSHL(SDValue N, unsigned Size, bool WantExtend, SDValue &Offset, SDValue &SignExtend); @@ -201,9 +335,16 @@ } bool SelectCVTFixedPosOperand(SDValue N, SDValue &FixedPos, unsigned Width); + bool SelectSVERegRegAddrMode(SDValue N, unsigned Scale, SDValue &Base, SDValue &Offset); + bool Select8BitLslImm(SDValue N, SDValue &Imm, SDValue &Shift); + bool SelectVectorLslImm(SDValue N, unsigned Scale, SDValue &UnscaledOp); + bool SelectVectorUxtwLslImm(SDValue N, unsigned Scale, SDValue &Offsets); + bool SelectSVEUIntArithImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift); + bool SelectSVELogicalImm(SDValue N, MVT VT, SDValue &Imm); + bool SelectSVEShiftImm64(SDValue N, uint64_t Low, uint64_t High, + SDValue &Imm); bool SelectCMP_SWAP(SDNode *N); - }; } // end anonymous namespace @@ -361,8 +502,8 @@ return true; } -/// Determine whether it is worth to fold V into an extended register. -bool AArch64DAGToDAGISel::isWorthFolding(SDValue V) const { +/// \brief Determine whether it is worth to fold V into an extended register. +bool AArch64DAGToDAGISel::isWorthFolding(SDValue V, unsigned MaxUses) const { // Trivial if we are optimizing for code size or if there is only // one use of the value. if (ForCodeSize || V.hasOneUse()) @@ -381,6 +522,18 @@ return true; } + // If it has more than one use, check they're all loads/stores + // from/to the same memory type (e.g. if you can fold for one + // addressing mode, you can fold for the others as well). + EVT VT; + for (auto *Use : V.getNode()->uses()) + if (auto *MemNode = dyn_cast(Use)) + if (MemNode->getMemoryVT() != VT && VT != EVT()) + return false; + + if (V.getNode()->use_size() <= MaxUses) + return true; + // It hurts otherwise, since the value will be reused. return false; } @@ -786,6 +939,123 @@ return true; } +bool AArch64DAGToDAGISel::SelectAddrModePred(SDValue N, SDValue &Base, + SDValue &OffImm) { + SDLoc dl(N); + + // If this is not a frame index, load directly from this address + if (N->getOpcode() != ISD::FrameIndex) { + Base = N; + OffImm = CurDAG->getTargetConstant(0, dl, MVT::i64); + return true; + } + + // Otherwise, match it for the frame address + const DataLayout &DL = CurDAG->getDataLayout(); + const TargetLowering *TLI = getTargetLowering(); + int FI = cast(N)->getIndex(); + Base = CurDAG->getTargetFrameIndex(FI, TLI->getPointerTy(DL)); + OffImm = CurDAG->getTargetConstant(0, dl, MVT::i64); + return true; +} + +/// SelectAddrModeIndexedSVE - Attempt selection of the addressing mode: +/// Base + OffImm * sizeof(MemVT) // for Min >= OffImm <= Max +/// where Root is the memory access using N for its address. +template +bool AArch64DAGToDAGISel::SelectAddrModeIndexedSVE(SDNode *Root, SDValue N, + SDValue &Base, + SDValue &OffImm) { + DEBUG_WITH_TYPE("isel_sve_addr_mode", + dbgs() << "*** SelectAddrModeIndexedSVE ***\n" + << "Machine function: " + << CurDAG->getMachineFunction().getName() << "\n" + << "Attempting better address generation for:\n"; + Root->dumpr();); + + if (N->getOpcode() == ISD::FrameIndex) { + SDLoc dl(N); + const DataLayout &DL = CurDAG->getDataLayout(); + const TargetLowering *TLI = getTargetLowering(); + + int FI = cast(N)->getIndex(); + Base = CurDAG->getTargetFrameIndex(FI, TLI->getPointerTy(DL)); + OffImm = CurDAG->getTargetConstant(0, dl, MVT::i64); + return true; + } + + // The offset's range is proportional to the extent of the load/store. + EVT MemVT; + bool Invalid = false; + if (isa(Root)) + MemVT = cast(Root)->getMemoryVT(); + else { + switch (Root->getOpcode()) { + case AArch64ISD::LDNF1: + case AArch64ISD::LDNF1S: + MemVT = cast(Root->getOperand(3))->getVT(); + break; + case AArch64ISD::LDNT1: + case AArch64ISD::STNT1: + MemVT = Root->getOperand(3)->getValueType(0); + break; + case ISD::INTRINSIC_VOID: { + unsigned IntNo = + cast(Root->getOperand(1))->getZExtValue(); + if (IntNo == Intrinsic::aarch64_sve_prf) { + // Type must be inferred from the width of the predicate + EVT PredVT = Root->getOperand(2)->getValueType(0); + if (PredVT == MVT::nxv16i1) + MemVT = MVT::nxv16i8; + else if (PredVT == MVT::nxv8i1) + MemVT = MVT::nxv8i16; + else if (PredVT == MVT::nxv4i1) + MemVT = MVT::nxv4i32; + else if (PredVT == MVT::nxv2i1) + MemVT = MVT::nxv2i64; + else + Invalid = true; + } + break; + } + default: + Invalid = true; + } + } + + if (Invalid) { + DEBUG_WITH_TYPE("isel_sve_addr_mode", + dbgs() << "Unexpected root node - "; Root->dumpr()); + return false; + } + + if (N.getOpcode() != ISD::ADD) + return false; + + SDValue VS = N.getOperand(1); + if (VS.getOpcode() != ISD::VSCALE) + return false; + + unsigned MemByteWidth = MemVT.getSizeInBits() / 8; + int64_t MulImm = cast(VS.getOperand(0))->getSExtValue(); + if ((MulImm % MemByteWidth) == 0) { + signed Offset = MulImm / MemByteWidth; + + if ((Offset >= Min) && (Offset <= Max)) { + Base = N.getOperand(0); + OffImm = CurDAG->getTargetConstant(Offset, SDLoc(N), MVT::i64); + + DEBUG_WITH_TYPE("isel_sve_addr_mode", + dbgs() << "Match found, rewritting as:\n" + << "BASE:\n"; Base.dumpr(); + dbgs() << "OFFSET:\n"; OffImm.dumpr()); + return true; + } + } + + return false; +} + /// SelectAddrModeUnscaled - Select a "register plus unscaled signed 9-bit /// immediate" address. This should only match when there is an offset that /// is not valid for a scaled immediate addressing mode. The "Size" argument @@ -1017,6 +1287,11 @@ return true; } + if (!Subtarget->hasFreeBasePlusRegAddrMode() && + Node->use_size() > 1) { + return false; + } + // Match any non-shifted, non-extend, non-immediate add expression. Base = LHS; Offset = RHS; @@ -1026,6 +1301,389 @@ return true; } +bool AArch64DAGToDAGISel::SelectSVERegRegAddrMode(SDValue N, unsigned Scale, + SDValue &Base, + SDValue &Offset) { + const unsigned Opcode = N.getOpcode(); + const SDLoc dl(N); + + DEBUG_WITH_TYPE("isel_sve_addr_mode", + dbgs() << "*** SelectSVERegRegAddrMode\n" + << "Machine function: " + << CurDAG->getMachineFunction().getName() << "\n" + << "Attempting better address generation for:\n"; + N.dumpr();); + + if (Opcode != ISD::ADD) + return false; + + // Process an ADD node + const SDValue LHS = N.getOperand(0); + const SDValue RHS = N.getOperand(1); + + if (auto C = dyn_cast(RHS)) { + int64_t ImmOff = (int64_t)C->getZExtValue(); + unsigned Size = 1<> Scale) > 0xffffu)) + return false; + + // Convert: + // MOV x0, Offset + // ADD x1, BaseReg, x0 + // LD1 z0.?, [x1, 0] + // To: + // MOV x0, Offset>>Scale + // LD1 z0.?, [BaseReg, x0 lsl Scale] + + Base = LHS; + Offset = CurDAG->getTargetConstant(ImmOff >> Scale, dl, MVT::i64); + SDValue Ops[] = { Offset }; + SDNode *MI = CurDAG->getMachineNode(AArch64::MOVi64imm, dl, MVT::i64, Ops); + Offset = SDValue(MI, 0); + return true; + } + + // We don't match addition to constants + if (isa(RHS) || isa(LHS)) + return false; + + // Check if this particular node is reused in any non-memory related + // operation. If yes, do not try to fold this node into the address + // computation, since the computation will be kept. + const SDNode *Node = N.getNode(); + for (SDNode *UI : Node->uses()) { + switch (UI->getOpcode()) { + default: + if (!isa(*UI)) + return false; + break; + + case AArch64ISD::LDFF1: + case AArch64ISD::LDFF1S: + case AArch64ISD::LDNT1: + break; + + case ISD::INTRINSIC_W_CHAIN: + case ISD::INTRINSIC_VOID: + switch (cast(UI->getOperand(1))->getZExtValue()) { + default: + return false; + + case Intrinsic::aarch64_sve_ld2: + case Intrinsic::aarch64_sve_ld3: + case Intrinsic::aarch64_sve_ld4: + case Intrinsic::aarch64_sve_ld2_legacy: + case Intrinsic::aarch64_sve_ld3_legacy: + case Intrinsic::aarch64_sve_ld4_legacy: + case Intrinsic::aarch64_sve_prf: + case Intrinsic::aarch64_sve_st2: + case Intrinsic::aarch64_sve_st3: + case Intrinsic::aarch64_sve_st4: + break; + } + } + } + + // Remember if it is worth folding N when it produces extended register. + // We pass (unsigned) '-1' max uses to indicate that there is no maximum, + // for SVE we prefer to fold all add/multiplies into the addressing mode + // to benefit our current cost model. + bool IsRegisterWorthFolding = isWorthFolding(N, UINT_MAX); + if (!IsRegisterWorthFolding) { + DEBUG_WITH_TYPE("isel_sve_addr_mode", dbgs() << "not worth folding\n"); + return false; + } + + // 8 bit data don't have the SHL node, so we treat it separately + if (Scale == 0) { + Base = LHS; + Offset = RHS; + return true; + } + + // Check if the RHS is a shift node with a constant. + if (RHS.getOpcode() == ISD::SHL) { + const SDValue SRHS = RHS.getOperand(1); + if (auto C = dyn_cast(SRHS)) { + const uint64_t Shift = C->getZExtValue(); + if (Shift == Scale) { + Base = LHS; + Offset = RHS.getOperand(0); + DEBUG_WITH_TYPE("isel_sve_addr_mode", + dbgs() << "Match found, rewritting as:\n" + << "BASE:\n"; Base.dumpr(); + dbgs() << "OFFSET:\n"; Offset.dumpr()); + return true; + } + + // Decompose a (LSL Y, #n) into Inner/Outer shifts as follows: + // (LSL (LSL Y #(n-Scale)) #Scale) + if ((Shift > Scale)) { + auto val = Shift - Scale; + unsigned Immr = (-val % 64); + unsigned Imms = 63-val; + assert((Imms + 1 == Immr) && "Invalid values for UBMX LSL alias."); + SDValue ImmrInner = CurDAG->getTargetConstant(Immr, dl, MVT::i64); + SDValue ImmsInner = CurDAG->getTargetConstant(Imms, dl, MVT::i64); + SDNode *Inner = CurDAG->getMachineNode(AArch64::UBFMXri, dl, + MVT::i64, RHS.getOperand(0), + ImmrInner, ImmsInner); + Base = LHS; + Offset = SDValue(Inner,0); + DEBUG_WITH_TYPE("isel_sve_addr_mode", + dbgs() << "Match found, rewritting as:\n" + << "BASE:\n"; Base.dumpr(); + dbgs() << "OFFSET:\n"; Offset.dumpr()); + return true; + } + } + } + + return false; +} + +bool AArch64DAGToDAGISel::Select8BitLslImm(SDValue N, SDValue &Base, + SDValue &Offset) { + auto C = dyn_cast(N); + if (!C) + return false; + + auto Ty = N->getValueType(0); + + int64_t Imm = C->getSExtValue(); + SDLoc DL(N); + + if ((Imm >= -128) && (Imm <= 127)) { + Base = CurDAG->getTargetConstant(Imm, DL, Ty); + Offset = CurDAG->getTargetConstant(0, DL, Ty); + return true; + } + + if (((Imm % 256) == 0) && (Imm >= -32768) && (Imm <= 32512)) { + Base = CurDAG->getTargetConstant(Imm/256, DL, Ty); + Offset = CurDAG->getTargetConstant(8, DL, Ty); + return true; + } + + return false; +} + +bool AArch64DAGToDAGISel::SelectSVEUIntArithImm(SDValue N, MVT VT, SDValue &Imm, + SDValue &Shift) { + if (auto CNode = dyn_cast(N)) { + const int64_t ImmVal = CNode->getSExtValue(); + SDLoc DL(N); + + switch (VT.SimpleTy) { + case MVT::i8: + // Can always select i8s, no shift, mask the immediate value to + // deal with sign-extended value from lowering. + Shift = CurDAG->getTargetConstant(0, DL, MVT::i32); + Imm = CurDAG->getTargetConstant(ImmVal & 0xFF, DL, MVT::i32); + return true; + case MVT::i16: + // i16 values get sign-extended during lowering, so need to check for + // "negative" values when shifting. + if ((ImmVal & 0xFF) == ImmVal) { + Shift = CurDAG->getTargetConstant(0, DL, MVT::i32); + Imm = CurDAG->getTargetConstant(ImmVal, DL, MVT::i32); + return true; + } else if (((ImmVal & 0xFF) == 0) && + (ImmVal >= -32768) && + (ImmVal <= 32512)) { + Shift = CurDAG->getTargetConstant(8, DL, MVT::i32); + Imm = CurDAG->getTargetConstant((ImmVal/256) & 0xFF, DL, MVT::i32); + return true; + } + break; + case MVT::i32: + case MVT::i64: + // Range of immediate won't trigger signedness problems for 32/64b. + if ((ImmVal & 0xFF) == ImmVal) { + Shift = CurDAG->getTargetConstant(0, DL, MVT::i32); + Imm = CurDAG->getTargetConstant(ImmVal, DL, MVT::i32); + return true; + } else if ((ImmVal & 0xFF00) == ImmVal) { + Shift = CurDAG->getTargetConstant(8, DL, MVT::i32); + Imm = CurDAG->getTargetConstant(ImmVal >> 8, DL, MVT::i32); + return true; + } + default: + break; + } + } + + return false; +} + +bool AArch64DAGToDAGISel::SelectSVELogicalImm(SDValue N, MVT VT, SDValue &Imm) { + if (auto *CNode = dyn_cast(N)) { + uint64_t ImmVal = CNode->getZExtValue(); + SDLoc DL(N); + + // If smaller than i64, replicate until it is. + // Fall through to avoid duplicated code. + switch(VT.SimpleTy) { + default: + break; + case MVT::i8: + ImmVal &= 0xFF; + ImmVal |= (ImmVal << 8); + LLVM_FALLTHROUGH; + case MVT::i16: + ImmVal &= 0xFFFF; + ImmVal |= (ImmVal << 16); + LLVM_FALLTHROUGH; + case MVT::i32: + ImmVal &= 0xFFFFFFFF; + ImmVal |= (ImmVal << 32); + break; + } + + uint64_t encoding; + // Check and see if we now have a valid logical immediate + if (AArch64_AM::processLogicalImmediate(ImmVal, 64, encoding)) { + Imm = CurDAG->getTargetConstant(encoding, DL, MVT::i64); + return true; + } + } + + return false; +} + +// This method is only needed to "cast" i64s into i32s when the value +// is a valid shift which has been splatted into a vector with i64 elements. +// Every other type is fine in tablegen. +bool AArch64DAGToDAGISel::SelectSVEShiftImm64(SDValue N, uint64_t Low, + uint64_t High, SDValue &Imm) { + if (auto *CN = dyn_cast(N)) { + uint64_t ImmVal = CN->getZExtValue(); + SDLoc DL(N); + + if (ImmVal >= Low && ImmVal <= High) { + Imm = CurDAG->getTargetConstant(ImmVal, DL, MVT::i32); + return true; + } + } + + return false; +} + +// Mainly used to match against scaled offsets for gather/scatter. +bool AArch64DAGToDAGISel::SelectVectorLslImm(SDValue N, unsigned Scale, + SDValue &UnscaledOp) { + if (N.getOpcode() == ISD::SERIES_VECTOR && + N.getOperand(0).getOpcode() == ISD::SHL) { + SDValue Start = N.getOperand(0); + ConstantSDNode *Step = dyn_cast(N.getOperand(1)); + ConstantSDNode *StartShift = dyn_cast(Start.getOperand(1)); + if (StartShift && Step && + StartShift->getSExtValue() == Scale && + Step->getSExtValue() == (1 << Scale)) { + SDValue StartBase = Start.getOperand(0); + SDLoc DL(N); + EVT VT = N->getValueType(0); + EVT ET = VT.getVectorElementType(); + + bool StartBaseIsConstant = isa(StartBase); + unsigned Opc = 0; + switch (ET.getSizeInBits()) { + case 8: + Opc = StartBaseIsConstant ? AArch64::INDEX_II_B : AArch64::INDEX_RI_B; + break; + case 16: + Opc = StartBaseIsConstant ? AArch64::INDEX_II_H : AArch64::INDEX_RI_H; + break; + case 32: + Opc = StartBaseIsConstant ? AArch64::INDEX_II_S : AArch64::INDEX_RI_S; + break; + case 64: + Opc = StartBaseIsConstant ? AArch64::INDEX_II_D : AArch64::INDEX_RI_D; + break; + default: + llvm_unreachable("Unexpected element size"); + } + + SmallVector Ops; + Ops.push_back(StartBase); + Ops.push_back(CurDAG->getTargetConstant(1, DL, ET)); + + SDNode *NewN = CurDAG->getMachineNode(Opc, DL, VT, Ops); + UnscaledOp = SDValue(NewN, 0); + return true; + } + } + + if (N.getOpcode() == ISD::MUL) { + SDValue Mod = N.getOperand(1); + if (Mod.getOpcode() == ISD::SPLAT_VECTOR || + Mod.getOpcode() == AArch64ISD::DUP) { + auto *SplatVal = dyn_cast(Mod.getOperand(0)); + + if (SplatVal && (SplatVal->getSExtValue() == (1 << Scale))) { + UnscaledOp = N.getOperand(0); + return true; + } + } + } + + if (N.getOpcode() == ISD::SHL) { + SDValue Mod = N.getOperand(1); + if (Mod.getOpcode() == ISD::SPLAT_VECTOR || + Mod.getOpcode() == AArch64ISD::DUP) { + auto *SplatVal = dyn_cast(Mod.getOperand(0)); + + if (SplatVal && (SplatVal->getSExtValue() == Scale)) { + UnscaledOp = N.getOperand(0); + return true; + } + } + } + + return false; +} + +// Mainly used to match against scaled 32bit offsets for gather/scatter. +bool AArch64DAGToDAGISel::SelectVectorUxtwLslImm(SDValue N, unsigned Scale, + SDValue &Offsets) { + if (N.getOpcode() == ISD::MUL) { + SDValue Mod = N.getOperand(1); + if (Mod.getOpcode() != ISD::SPLAT_VECTOR) + return false; + + // we only care about constant splats + ConstantSDNode *SplatVal = dyn_cast(Mod.getOperand(0)); + if (!SplatVal) + return false; + + if (SplatVal->getSExtValue() != (1 << Scale)) + return false; + + if (N.getOperand(0).getOpcode() != ISD::AND) + return false; + + SDValue Mask = N.getOperand(0).getOperand(1); + if (Mask.getOpcode() != ISD::SPLAT_VECTOR) + return false; + + // we only care about constant splats + SplatVal = dyn_cast(Mask.getOperand(0)); + if (!SplatVal) + return false; + + if (SplatVal->getZExtValue() == 0xffffffff) { + Offsets = N.getOperand(0).getOperand(0); + return true; + } + } + + return false; +} + SDValue AArch64DAGToDAGISel::createDTuple(ArrayRef Regs) { static const unsigned RegClassIDs[] = { AArch64::DDRegClassID, AArch64::DDDRegClassID, AArch64::DDDDRegClassID}; @@ -1044,6 +1702,15 @@ return createTuple(Regs, RegClassIDs, SubRegs); } +SDValue AArch64DAGToDAGISel::createZTuple(ArrayRef Regs) { + static const unsigned RegClassIDs[] = { + AArch64::ZPR2RegClassID, AArch64::ZPR3RegClassID, AArch64::ZPR4RegClassID}; + static const unsigned SubRegs[] = {AArch64::zsub0, AArch64::zsub1, + AArch64::zsub2, AArch64::zsub3}; + + return createTuple(Regs, RegClassIDs, SubRegs); +} + SDValue AArch64DAGToDAGISel::createTuple(ArrayRef Regs, const unsigned RegClassIDs[], const unsigned SubRegs[]) { @@ -1094,6 +1761,28 @@ ReplaceNode(N, CurDAG->getMachineNode(Opc, dl, VT, Ops)); } +void AArch64DAGToDAGISel::SelectTableSVE2(SDNode *N, unsigned NumVecs, + unsigned Opc) { + SDLoc dl(N); + EVT VT = N->getValueType(0); + + // Form a REG_SEQUENCE to force register allocation. + SmallVector Regs(N->op_begin() + 1, N->op_begin() + 1 + NumVecs); + SDValue RegSeq = createZTuple(Regs); + + SmallVector Ops; + Ops.push_back(RegSeq); + Ops.push_back(N->getOperand(NumVecs + 1)); + ReplaceNode(N, CurDAG->getMachineNode(Opc, dl, VT, Ops)); +} + +void AArch64DAGToDAGISel::SelectBitwiseSVE2(SDNode *N, unsigned Opc) { + SDLoc dl(N); + EVT VT = N->getValueType(0); + SmallVector Ops(N->op_begin() + 1, N->op_begin() + 1 + 3); + ReplaceNode(N, CurDAG->getMachineNode(Opc, dl, VT, Ops)); +} + bool AArch64DAGToDAGISel::tryIndexedLoad(SDNode *N) { LoadSDNode *LD = cast(N); if (LD->isUnindexed()) @@ -1247,6 +1936,57 @@ CurDAG->RemoveDeadNode(N); } +void AArch64DAGToDAGISel::SelectPredicatedLoad(SDNode *N, unsigned NumVecs, + const unsigned Opc_rr, + const unsigned Opc_ri, + unsigned SubRegIdx) { + SDLoc dl(N); + EVT VT = N->getValueType(0); + SDValue Chain = N->getOperand(0); + + EVT ElTy = VT.getVectorElementType(); + unsigned Bits = ElTy.getSimpleVT().getSizeInBits(); + + // Assume we use reg+imm with zero shift. + SDValue Base = N->getOperand(3); + SDValue Offset = CurDAG->getTargetConstant(0, dl, MVT::i64); + + // Detect a possible reg+reg addressing mode. + unsigned Scale = APInt(32, Bits).exactLogBase2() - 3; + const bool IsRegReg = SelectSVERegRegAddrMode(Base, Scale, Base, Offset); + + // Select the instruction. + const unsigned Opc = (IsRegReg) ? Opc_rr : Opc_ri; + + // TODO: if (!IsRegReg), optimize reg+imm addressing model. The method is + // already there, all it need is to add a default additional template + // parameter that checks the immediate value to be a multiple of the one + // allowed by the ISA. The correct AArch64 opcode has been selected already + // by teh previous switch case. + // + // if (!IsRegReg) + // template + // bool AArch64DAGToDAGISel::SelectAddrModeIndexedSVE(SDValue N, + // SDValue &Base, + // SDValue &OffImm) { + + SDValue Ops[] = {N->getOperand(2), // Predicate + Base, // Memory operand + Offset, + Chain}; + + const EVT ResTys[] = {MVT::Untyped, MVT::Other}; + + SDNode *Ld = CurDAG->getMachineNode(Opc, dl, ResTys, Ops); + SDValue SuperReg = SDValue(Ld, 0); + for (unsigned i = 0; i < NumVecs; ++i) + ReplaceUses(SDValue(N, i), + CurDAG->getTargetExtractSubreg(SubRegIdx + i, dl, VT, SuperReg)); + + ReplaceUses(SDValue(N, NumVecs), SDValue(Ld, 1)); + CurDAG->RemoveDeadNode(N); +} + void AArch64DAGToDAGISel::SelectStore(SDNode *N, unsigned NumVecs, unsigned Opc) { SDLoc dl(N); @@ -1268,6 +2008,39 @@ ReplaceNode(N, St); } +void AArch64DAGToDAGISel::SelectPredicatedStore(SDNode *N, unsigned NumVecs, + const unsigned Opc_rr, + const unsigned Opc_ri) { + SDLoc dl(N); + EVT VT = N->getOperand(2).getValueType(); + EVT ElTy = VT.getVectorElementType(); + unsigned Bits = ElTy.getSimpleVT().getSizeInBits(); + + // Form a REG_SEQUENCE to force register allocation. + SmallVector Regs(N->op_begin() + 2, N->op_begin() + 2 + NumVecs); + SDValue RegSeq = createZTuple(Regs); + + // Assume we use reg+imm with zero shift. + SDValue Base = N->getOperand(NumVecs + 3); + SDValue Offset = CurDAG->getTargetConstant(0, dl, MVT::i64); + + // Detect a possible reg+reg addressing mode. + unsigned Scale = APInt(32, Bits).exactLogBase2() - 3; + const bool IsRegReg = SelectSVERegRegAddrMode(Base, Scale, Base, Offset); + + // Select the instruction. + const unsigned Opc = (IsRegReg) ? Opc_rr : Opc_ri; + + SDValue Ops[] = {RegSeq, N->getOperand(NumVecs + 2), // predicate + Base, // address + Offset, // offset + N->getOperand(0)}; // chain + SDNode *St = CurDAG->getMachineNode(Opc, dl, N->getValueType(0), Ops); + + + ReplaceNode(N, St); +} + void AArch64DAGToDAGISel::SelectPostStore(SDNode *N, unsigned NumVecs, unsigned Opc) { SDLoc dl(N); @@ -1300,9 +2073,9 @@ SDValue operator()(SDValue V64Reg) { EVT VT = V64Reg.getValueType(); - unsigned NarrowSize = VT.getVectorNumElements(); + auto NarrowSize = VT.getVectorElementCount(); MVT EltTy = VT.getVectorElementType().getSimpleVT(); - MVT WideTy = MVT::getVectorVT(EltTy, 2 * NarrowSize); + MVT WideTy = MVT::getVectorVT(EltTy, NarrowSize * 2); SDLoc DL(V64Reg); SDValue Undef = @@ -1316,7 +2089,7 @@ /// equivalent value in the V64 register class. static SDValue NarrowVector(SDValue V128Reg, SelectionDAG &DAG) { EVT VT = V128Reg.getValueType(); - unsigned WideSize = VT.getVectorNumElements(); + auto WideSize = VT.getVectorElementCount(); MVT EltTy = VT.getVectorElementType().getSimpleVT(); MVT NarrowTy = MVT::getVectorVT(EltTy, WideSize / 2); @@ -2836,10 +3609,12 @@ // the rest of the compiler, especially the register allocator and copyi // propagation, to reason about, so is preferred when it's possible to // use it. - ConstantSDNode *LaneNode = cast(Node->getOperand(1)); - // Bail and use the default Select() for non-zero lanes. - if (LaneNode->getZExtValue() != 0) + + // Bail and use the default Select() if the index is not constant zero + auto *LaneNode = dyn_cast(Node->getOperand(1)); + if (!LaneNode || (LaneNode->getZExtValue() != 0)) break; + // If the element type is not the same as the result type, likewise // bail and use the default Select(), as there's more to do than just // a cross-class COPY. This catches extracts of i8 and i16 elements @@ -3059,6 +3834,66 @@ return; } break; + case Intrinsic::aarch64_sve_ld2: + case Intrinsic::aarch64_sve_ld2_legacy: + if (VT == MVT::nxv16i8) { + SelectPredicatedLoad(Node, 2, AArch64::LD2B, AArch64::LD2B_IMM, + AArch64::zsub0); + return; + } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16) { + SelectPredicatedLoad(Node, 2, AArch64::LD2H, AArch64::LD2H_IMM, + AArch64::zsub0); + return; + } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { + SelectPredicatedLoad(Node, 2, AArch64::LD2W, AArch64::LD2W_IMM, + AArch64::zsub0); + return; + } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { + SelectPredicatedLoad(Node, 2, AArch64::LD2D, AArch64::LD2D_IMM, + AArch64::zsub0); + return; + } + break; + case Intrinsic::aarch64_sve_ld3: + case Intrinsic::aarch64_sve_ld3_legacy: + if (VT == MVT::nxv16i8) { + SelectPredicatedLoad(Node, 3, AArch64::LD3B, AArch64::LD3B_IMM, + AArch64::zsub0); + return; + } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16) { + SelectPredicatedLoad(Node, 3, AArch64::LD3H, AArch64::LD3H_IMM, + AArch64::zsub0); + return; + } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { + SelectPredicatedLoad(Node, 3, AArch64::LD3W, AArch64::LD3W_IMM, + AArch64::zsub0); + return; + } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { + SelectPredicatedLoad(Node, 3, AArch64::LD3D, AArch64::LD3D_IMM, + AArch64::zsub0); + return; + } + break; + case Intrinsic::aarch64_sve_ld4: + case Intrinsic::aarch64_sve_ld4_legacy: + if (VT == MVT::nxv16i8) { + SelectPredicatedLoad(Node, 4, AArch64::LD4B, AArch64::LD4B_IMM, + AArch64::zsub0); + return; + } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16) { + SelectPredicatedLoad(Node, 4, AArch64::LD4H, AArch64::LD4H_IMM, + AArch64::zsub0); + return; + } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { + SelectPredicatedLoad(Node, 4, AArch64::LD4W, AArch64::LD4W_IMM, + AArch64::zsub0); + return; + } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { + SelectPredicatedLoad(Node, 4, AArch64::LD4D, AArch64::LD4D_IMM, + AArch64::zsub0); + return; + } + break; case Intrinsic::aarch64_neon_ld3: if (VT == MVT::v8i8) { SelectLoad(Node, 3, AArch64::LD3Threev8b, AArch64::dsub0); @@ -3290,6 +4125,65 @@ if (tryMULLV64LaneV128(IntNo, Node)) return; break; + case Intrinsic::aarch64_sve_bcax: + if (VT == MVT::nxv16i8 || VT == MVT::nxv8i16 || + VT == MVT::nxv4i32 || VT == MVT::nxv2i64) { + SelectBitwiseSVE2(Node, AArch64::BCAX_ZZZZ_D); + return; + } + break; + case Intrinsic::aarch64_sve_bsl: + if (VT == MVT::nxv16i8 || VT == MVT::nxv8i16 || + VT == MVT::nxv4i32 || VT == MVT::nxv2i64) { + SelectBitwiseSVE2(Node, AArch64::BSL_ZZZZ_D); + return; + } + break; + case Intrinsic::aarch64_sve_bsl1n: + if (VT == MVT::nxv16i8 || VT == MVT::nxv8i16 || + VT == MVT::nxv4i32 || VT == MVT::nxv2i64) { + SelectBitwiseSVE2(Node, AArch64::BSL1N_ZZZZ_D); + return; + } + break; + case Intrinsic::aarch64_sve_bsl2n: + if (VT == MVT::nxv16i8 || VT == MVT::nxv8i16 || + VT == MVT::nxv4i32 || VT == MVT::nxv2i64) { + SelectBitwiseSVE2(Node, AArch64::BSL2N_ZZZZ_D); + return; + } + break; + case Intrinsic::aarch64_sve_nbsl: + if (VT == MVT::nxv16i8 || VT == MVT::nxv8i16 || + VT == MVT::nxv4i32 || VT == MVT::nxv2i64) { + SelectBitwiseSVE2(Node, AArch64::NBSL_ZZZZ_D); + return; + } + break; + case Intrinsic::aarch64_sve_eor3: + if (VT == MVT::nxv16i8 || VT == MVT::nxv8i16 || + VT == MVT::nxv4i32 || VT == MVT::nxv2i64) { + SelectBitwiseSVE2(Node, AArch64::EOR3_ZZZZ_D); + return; + } + break; + case Intrinsic::aarch64_sve_tbl2: + if (VT == MVT::nxv16i8) { + SelectTableSVE2(Node, 2, AArch64::TBL_ZZZZ_B); + return; + } + if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16) { + SelectTableSVE2(Node, 2, AArch64::TBL_ZZZZ_H); + return; + } + if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { + SelectTableSVE2(Node, 2, AArch64::TBL_ZZZZ_S); + return; + } + if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { + SelectTableSVE2(Node, 2, AArch64::TBL_ZZZZ_D); + return; + } } break; } @@ -3525,6 +4419,54 @@ } break; } + case Intrinsic::aarch64_sve_st2: { + if (VT == MVT::nxv16i8) { + SelectPredicatedStore(Node, 2, AArch64::ST2B, AArch64::ST2B_IMM); + return; + } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16) { + SelectPredicatedStore(Node, 2, AArch64::ST2H, AArch64::ST2H_IMM); + return; + } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { + SelectPredicatedStore(Node, 2, AArch64::ST2W, AArch64::ST2W_IMM); + return; + } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { + SelectPredicatedStore(Node, 2, AArch64::ST2D, AArch64::ST2D_IMM); + return; + } + break; + } + case Intrinsic::aarch64_sve_st3: { + if (VT == MVT::nxv16i8) { + SelectPredicatedStore(Node, 3, AArch64::ST3B, AArch64::ST3B_IMM); + return; + } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16) { + SelectPredicatedStore(Node, 3, AArch64::ST3H, AArch64::ST3H_IMM); + return; + } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { + SelectPredicatedStore(Node, 3, AArch64::ST3W, AArch64::ST3W_IMM); + return; + } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { + SelectPredicatedStore(Node, 3, AArch64::ST3D, AArch64::ST3D_IMM); + return; + } + break; + } + case Intrinsic::aarch64_sve_st4: { + if (VT == MVT::nxv16i8) { + SelectPredicatedStore(Node, 4, AArch64::ST4B, AArch64::ST4B_IMM); + return; + } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16) { + SelectPredicatedStore(Node, 4, AArch64::ST4H, AArch64::ST4H_IMM); + return; + } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { + SelectPredicatedStore(Node, 4, AArch64::ST4W, AArch64::ST4W_IMM); + return; + } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { + SelectPredicatedStore(Node, 4, AArch64::ST4D, AArch64::ST4D_IMM); + return; + } + break; + } } break; } Index: lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.h +++ lib/Target/AArch64/AArch64ISelLowering.h @@ -21,6 +21,7 @@ #include "llvm/CodeGen/TargetLowering.h" #include "llvm/IR/CallingConv.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" namespace llvm { @@ -104,6 +105,7 @@ UZP2, TRN1, TRN2, + REV, REV16, REV32, REV64, @@ -191,6 +193,92 @@ FRECPE, FRECPS, FRSQRTE, FRSQRTS, + SUNPKHI, + SUNPKLO, + UUNPKHI, + UUNPKLO, + + TBL, + + // SVE specific operations. + ANDV_PRED, + BRKA, + CLASTA_N, + CLASTB_N, + DUP_PRED, + EORV_PRED, + FADDA_PRED, + FADDV_PRED, + FMAXV_PRED, + FMAXNMV_PRED, + FMINV_PRED, + FMINNMV_PRED, + INSR, + LASTA, + LASTB, + LD1RQ, + LDNT1, + ORV_PRED, + PTEST, + PTRUE, + + // Unsigned first faulting gather loads. + LDFF1, + LDNF1, + GLDFF1, + GLDFF1_SCALED, + GLDFF1_SXTW, + GLDFF1_SXTW_SCALED, + GLDFF1_UXTW, + GLDFF1_UXTW_SCALED, + + // Signed first faulting gather loads. + LDFF1S, + LDNF1S, + GLDFF1S, + GLDFF1S_SCALED, + GLDFF1S_SXTW, + GLDFF1S_SXTW_SCALED, + GLDFF1S_UXTW, + GLDFF1S_UXTW_SCALED, + + // Unsigned non temporal gather loads. + GLDNT1, + GLDNT1_UXTW, + + // Signed non temporal gather loads. + GLDNT1S, + GLDNT1S_UXTW, + + // SVE gather prefetches + GPRF_S_IMM, + GPRF_D_IMM, + GPRF_D_SCALED, + GPRF_S_SXTW_SCALED, + GPRF_S_UXTW_SCALED, + GPRF_D_SXTW_SCALED, + GPRF_D_UXTW_SCALED, + + RDFFR, + RDFFR_PRED, + SETFFR, + WRFFR, + + REINTERPRET_CAST, + SADDV_PRED, + SMAXV_PRED, + SMINV_PRED, + STNT1, + SSTNT1, + SSTNT1_UXTW, + UADDV_PRED, + UMAXV_PRED, + UMINV_PRED, + FMIN_PRED, + FMINNM_PRED, + FMAX_PRED, + FMAXNM_PRED, + // NEON Load/Store with post-increment base updates LD2post = ISD::FIRST_TARGET_MEMORY_OPCODE, LD3post, @@ -268,6 +356,8 @@ /// Provide custom lowering hooks for some operations. SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override; + void LowerOperationWrapper(SDNode *N, SmallVectorImpl &Results, + SelectionDAG &DAG) const override; const char *getTargetNodeName(unsigned Opcode) const override; @@ -323,7 +413,7 @@ bool hasPairedLoad(EVT LoadedType, unsigned &RequiredAligment) const override; - unsigned getMaxSupportedInterleaveFactor() const override { return 4; } + unsigned getMaxSupportedInterleaveFactor() const override { return 6; } bool lowerInterleavedLoad(LoadInst *LI, ArrayRef Shuffles, @@ -331,6 +421,16 @@ unsigned Factor) const override; bool lowerInterleavedStore(StoreInst *SI, ShuffleVectorInst *SVI, unsigned Factor) const override; + bool lowerGathersToInterleavedLoad(ArrayRef Gathers, + IntrinsicInst *FirstGather, + int OffsetFirstGather, + unsigned Factor, + TargetTransformInfo *TTI) const override; + bool lowerScattersToInterleavedStore(ArrayRef ValuesToStore, + Value *FirstScatterAddress, + IntrinsicInst *ReplaceNode, + unsigned Factor, + TargetTransformInfo *TTI) const override; bool isLegalAddImmediate(int64_t) const override; bool isLegalICmpImmediate(int64_t) const override; @@ -617,8 +717,15 @@ SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const; SDValue LowerSCALAR_TO_VECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerCONCAT_VECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerVECTOR_SHUFFLE_VAR(SDValue Op, SelectionDAG &DAG, + unsigned Factor, EVT NewVT) const; + SDValue LowerVECREDUCE_SVE(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerSERIES_VECTOR(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerSPLAT_VECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerINSERT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const; SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const; SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const; @@ -635,6 +742,13 @@ SDValue LowerVectorOR(SDValue Op, SelectionDAG &DAG) const; SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const; SDValue LowerFSINCOS(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerLASTX(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerDUPQLane(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerVectorBITCAST(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerVSCALE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerATOMIC_LOAD_SUB(SDValue Op, SelectionDAG &DAG) const; SDValue LowerATOMIC_LOAD_AND(SDValue Op, SelectionDAG &DAG) const; @@ -681,6 +795,7 @@ return TargetLowering::getInlineAsmMemConstraint(ConstraintCode); } + bool isVectorLoadExtDesirable(SDValue ExtVal) const override; bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override; bool mayBeEmittedAsTailCall(const CallInst *CI) const override; bool getIndexedAddressParts(SDNode *Op, SDValue &Base, SDValue &Offset, @@ -693,6 +808,33 @@ SDValue &Offset, ISD::MemIndexedMode &AM, SelectionDAG &DAG) const override; + void ReplaceExtensionResults(SDNode *N, SmallVectorImpl &Results, + SelectionDAG &DAG, unsigned HiOpcode, + unsigned LoOpcode) const; + void ReplaceExtractSubVectorResults(SDNode *N, + SmallVectorImpl &Results, + SelectionDAG &DAG) const; + void ReplaceInsertSubVectorResults(SDNode *N, + SmallVectorImpl &Results, + SelectionDAG &DAG) const; + void ReplaceInsertVectorElementResults(SDNode *N, + SmallVectorImpl &Results, + SelectionDAG &DAG) const; + void ReplaceFP_EXTENDResults(SDNode *N, SmallVectorImpl &Results, + SelectionDAG &DAG) const; + void ReplaceMaskedSpecLoadResults(SDNode *, SmallVectorImpl &, + SelectionDAG &) const; + void ReplaceVectorShuffleVarResults(SDNode *, SmallVectorImpl &, + SelectionDAG &) const; + void ReplaceSplatVectorResults(SDNode *, SmallVectorImpl &, + SelectionDAG &) const; + + void ReplaceMergeVecCpyResults(SDNode *N, SmallVectorImpl &Results, + SelectionDAG &DAG) const; + void ReplaceBITCASTResults(SDNode *, SmallVectorImpl &, + SelectionDAG &) const; + void ReplaceVectorBITCASTResults(SDNode *, SmallVectorImpl &, + SelectionDAG &) const; void ReplaceNodeResults(SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const override; Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -29,6 +29,7 @@ #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/Triple.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineBasicBlock.h" @@ -87,6 +88,7 @@ #include using namespace llvm; +using namespace AArch64SVEPredPattern; #define DEBUG_TYPE "aarch64-lower" @@ -108,6 +110,13 @@ cl::init(false)); static cl::opt +EnableSGToContiguousXForm("aarch64-sve-sg-to-contiguous-xform", cl::Hidden, + cl::desc("Allows AArch64 SVE to transform gathers " + "or scatters with a constant stride of two " + "into contiguous loads with shuffles and " + "masking"), + cl::init(true)); +static cl::opt EnableOptimizeLogicalImm("aarch64-enable-logical-imm", cl::Hidden, cl::desc("Enable AArch64 logical imm instruction " "optimization"), @@ -116,6 +125,22 @@ /// Value type used for condition codes. static const MVT MVT_CC = MVT::i32; +static unsigned getIntrinsicID(const SDNode *N) { + unsigned Opcode = N->getOpcode(); + switch (Opcode) { + default: + return Intrinsic::not_intrinsic; + case ISD::INTRINSIC_WO_CHAIN: { + unsigned IID = cast(N->getOperand(0))->getZExtValue(); + if (IID < Intrinsic::num_intrinsics) + return IID; + return Intrinsic::not_intrinsic; + } + } +} + +#include "SVEISelLowering.inc.h" + AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, const AArch64Subtarget &STI) : TargetLowering(TM), Subtarget(&STI) { @@ -158,6 +183,26 @@ addQRTypeForNEON(MVT::v8f16); } + // SVE Vector Registers. + if (Subtarget->hasSVE()) { + addRegisterClass(MVT::nxv16i8, &AArch64::ZPRRegClass); + addRegisterClass(MVT::nxv8i16, &AArch64::ZPRRegClass); + addRegisterClass(MVT::nxv4i32, &AArch64::ZPRRegClass); + addRegisterClass(MVT::nxv2i64, &AArch64::ZPRRegClass); + + addRegisterClass(MVT::nxv16i1, &AArch64::PPRRegClass); + addRegisterClass(MVT::nxv8i1, &AArch64::PPRRegClass); + addRegisterClass(MVT::nxv4i1, &AArch64::PPRRegClass); + addRegisterClass(MVT::nxv2i1, &AArch64::PPRRegClass); + + addRegisterClass(MVT::nxv2f16, &AArch64::ZPRRegClass); + addRegisterClass(MVT::nxv4f16, &AArch64::ZPRRegClass); + addRegisterClass(MVT::nxv8f16, &AArch64::ZPRRegClass); + addRegisterClass(MVT::nxv2f32, &AArch64::ZPRRegClass); + addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass); + addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass); + } + // Compute derived properties from the register classes computeRegisterProperties(Subtarget->getRegisterInfo()); @@ -551,6 +596,7 @@ setTargetDAGCombine(ISD::ADD); setTargetDAGCombine(ISD::SUB); setTargetDAGCombine(ISD::SRL); + setTargetDAGCombine(ISD::AND); setTargetDAGCombine(ISD::XOR); setTargetDAGCombine(ISD::SINT_TO_FP); setTargetDAGCombine(ISD::UINT_TO_FP); @@ -564,6 +610,7 @@ setTargetDAGCombine(ISD::ANY_EXTEND); setTargetDAGCombine(ISD::ZERO_EXTEND); setTargetDAGCombine(ISD::SIGN_EXTEND); + setTargetDAGCombine(ISD::SIGN_EXTEND_INREG); setTargetDAGCombine(ISD::BITCAST); setTargetDAGCombine(ISD::CONCAT_VECTORS); setTargetDAGCombine(ISD::STORE); @@ -571,6 +618,7 @@ setTargetDAGCombine(ISD::LOAD); setTargetDAGCombine(ISD::MUL); + setTargetDAGCombine(ISD::SHL); setTargetDAGCombine(ISD::SELECT); setTargetDAGCombine(ISD::VSELECT); @@ -578,6 +626,18 @@ setTargetDAGCombine(ISD::INTRINSIC_VOID); setTargetDAGCombine(ISD::INTRINSIC_W_CHAIN); setTargetDAGCombine(ISD::INSERT_VECTOR_ELT); + setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); + + setTargetDAGCombine(ISD::MGATHER); + setTargetDAGCombine(ISD::MSCATTER); + setTargetDAGCombine(ISD::EXTRACT_SUBVECTOR); + setTargetDAGCombine(ISD::VECTOR_SHUFFLE_VAR); + + if (Subtarget->hasSVE()) { + setTargetDAGCombine(ISD::SETCC); + setTargetDAGCombine(ISD::TRUNCATE); + setTargetDAGCombine(ISD::VSCALE); + } setTargetDAGCombine(ISD::GlobalAddress); @@ -614,6 +674,8 @@ setHasExtractBitsInsn(true); setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom); + setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); + setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom); if (Subtarget->hasNEON()) { // FIXME: v1f64 shouldn't be legal if we can avoid it, because it leads to @@ -694,14 +756,16 @@ setOperationAction(ISD::MUL, MVT::v2i64, Custom); // Vector reductions - for (MVT VT : MVT::integer_valuetypes()) { + for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32, + MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) { setOperationAction(ISD::VECREDUCE_ADD, VT, Custom); setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); } - for (MVT VT : MVT::fp_valuetypes()) { + for (MVT VT : { MVT::v4f16, MVT::v2f32, + MVT::v8f16, MVT::v4f32, MVT::v2f64 }) { setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); } @@ -747,6 +811,199 @@ } PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive(); + + setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom); + + if (Subtarget->hasSVE()) { + // This is needed when extracting fixed width vectors from scalable + // vectors, to avoid the promotion of the illegal output types. + setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2i16, Custom); + setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v2i8, Custom); + setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v4i8, Custom); + + setOperationAction(ISD::ConstantFP, MVT::f16, Legal); + setOperationAction(ISD::VSCALE, MVT::i32, Custom); + + // Vector operation legalization checks the result type of + // SIGN_EXTEND_INREG, overall legalization checks the inner type. + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::nxv2i64, Legal); + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::nxv2i32, Legal); + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::nxv2i16, Legal); + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::nxv2i8, Legal); + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::nxv4i32, Legal); + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::nxv4i16, Legal); + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::nxv4i8, Legal); + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::nxv8i16, Legal); + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::nxv8i8, Legal); + + setOperationAction(ISD::SDIV, MVT::nxv16i8, Custom); + setOperationAction(ISD::SDIV, MVT::nxv8i16, Custom); + setOperationAction(ISD::UDIV, MVT::nxv16i8, Custom); + setOperationAction(ISD::UDIV, MVT::nxv8i16, Custom); + + setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::nxv2i1, Promote); + AddPromotedToType(ISD::INSERT_VECTOR_ELT, MVT::nxv2i1, MVT::nxv2i64); + setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::nxv4i1, Promote); + AddPromotedToType(ISD::INSERT_VECTOR_ELT, MVT::nxv4i1, MVT::nxv4i32); + setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::nxv8i1, Promote); + AddPromotedToType(ISD::INSERT_VECTOR_ELT, MVT::nxv8i1, MVT::nxv8i16); + setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::nxv16i1, Promote); + AddPromotedToType(ISD::INSERT_VECTOR_ELT, MVT::nxv16i1, MVT::nxv16i8); + + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::nxv2i1, Promote); + AddPromotedToType(ISD::EXTRACT_VECTOR_ELT, MVT::nxv2i1, MVT::nxv2i64); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::nxv4i1, Promote); + AddPromotedToType(ISD::EXTRACT_VECTOR_ELT, MVT::nxv4i1, MVT::nxv4i32); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::nxv8i1, Promote); + AddPromotedToType(ISD::EXTRACT_VECTOR_ELT, MVT::nxv8i1, MVT::nxv8i16); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::nxv16i1, Promote); + AddPromotedToType(ISD::EXTRACT_VECTOR_ELT, MVT::nxv16i1, MVT::nxv16i8); + + setOperationAction(ISD::SINT_TO_FP, MVT::nxv2i1, Promote); + AddPromotedToType(ISD::SINT_TO_FP, MVT::nxv2i1, MVT::nxv2i64); + setOperationAction(ISD::SINT_TO_FP, MVT::nxv4i1, Promote); + AddPromotedToType(ISD::SINT_TO_FP, MVT::nxv4i1, MVT::nxv4i32); + + setOperationAction(ISD::UINT_TO_FP, MVT::nxv2i1, Promote); + AddPromotedToType(ISD::UINT_TO_FP, MVT::nxv2i1, MVT::nxv2i64); + setOperationAction(ISD::UINT_TO_FP, MVT::nxv4i1, Promote); + AddPromotedToType(ISD::UINT_TO_FP, MVT::nxv4i1, MVT::nxv4i32); + + // Use SVE to implement fixed-width masked loads & stores. + for (auto VT : { MVT::v2i32, MVT::v2i64, MVT::v2f32, MVT::v2f64, + MVT::v4i16, MVT::v4i32, MVT::v4f32, + MVT::v8i8, MVT::v8i16, + MVT::v16i8 }) { + setOperationAction(ISD::MLOAD, VT, Custom); + setOperationAction(ISD::MSTORE, VT, Custom); + setOperationAction(ISD::MGATHER, VT, Custom); + setOperationAction(ISD::MSCATTER, VT, Custom); + } + } + + // Handle SVE operations. + for (MVT VT : MVT::integer_scalable_vector_valuetypes()) { + setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom); + setOperationAction(ISD::SELECT_CC, VT, Expand); + + if (isTypeLegal(VT)) { + if (VT.getVectorElementType() != MVT::i1) { + setOperationAction(ISD::VECREDUCE_AND, VT, Custom); + setOperationAction(ISD::VECREDUCE_OR, VT, Custom); + setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); + setOperationAction(ISD::VECREDUCE_ADD, VT, Custom); + setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); + setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); + setOperationAction(ISD::AND, VT, Legal); + setOperationAction(ISD::OR, VT, Legal); + setOperationAction(ISD::XOR, VT, Legal); + setOperationAction(ISD::UMIN, VT, Legal); + setOperationAction(ISD::UMAX, VT, Legal); + setOperationAction(ISD::SMIN, VT, Legal); + setOperationAction(ISD::SMAX, VT, Legal); + setOperationAction(ISD::BSWAP, VT, Legal); + setOperationAction(ISD::SERIES_VECTOR, VT, Custom); + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + setOperationAction(ISD::TRUNCATE, VT, Custom); + setOperationAction(ISD::VECTOR_SHUFFLE_VAR, VT, Custom); + + // No remainder instructions, need to expand + setOperationAction(ISD::SREM, VT, Custom); + setOperationAction(ISD::UREM, VT, Custom); + setOperationAction(ISD::SDIVREM, VT, Expand); + setOperationAction(ISD::UDIVREM, VT, Expand); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); + setOperationAction(ISD::SETCC, VT, Custom); + + for (MVT InnerVT : MVT::integer_scalable_vector_valuetypes()) { + if (InnerVT.getVectorNumElements() == VT.getVectorNumElements() && + InnerVT.getVectorElementType() != MVT::i1) { + setTruncStoreAction(VT, InnerVT, Legal); + setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Legal); + setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Legal); + setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Legal); + } + } + } else { + setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); + setOperationAction(ISD::SERIES_VECTOR, VT, Custom); + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + } + } else { + if (VT.getVectorElementType() != MVT::i1) { + // Use UNPK{LO,HI} sequences to lower extensions from legal SVE + // types to wider-than-legal types. + setOperationAction(ISD::SIGN_EXTEND, VT, Custom); + setOperationAction(ISD::ZERO_EXTEND, VT, Custom); + setOperationAction(ISD::ANY_EXTEND, VT, Custom); + + setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); + setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); + + // We want to custom-lower this for both legal and too-wide types, + // since we can do a better job than the target-independent code + // by exploiting the fact that TBL produces zero for out-of-range + // indices. + setOperationAction(ISD::VECTOR_SHUFFLE_VAR, VT, Custom); + } + } + } + + for (MVT VT : MVT::fp_scalable_vector_valuetypes()) { + setOperationAction(ISD::INTRINSIC_VOID, VT, Custom); + setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom); + setOperationAction(ISD::SELECT_CC, VT, Expand); + setOperationAction(ISD::SERIES_VECTOR, VT, Custom); + + // Marking these intrinsics as expand should trigger the "we cannot unroll a + // scalable vector" assert. + setOperationAction(ISD::FSIN, VT, Expand); + setOperationAction(ISD::FCOS, VT, Expand); + setOperationAction(ISD::FPOWI, VT, Expand); + setOperationAction(ISD::FPOW, VT, Expand); + setOperationAction(ISD::FLOG, VT, Expand); + setOperationAction(ISD::FLOG2, VT, Expand); + setOperationAction(ISD::FLOG10, VT, Expand); + setOperationAction(ISD::FEXP, VT, Expand); + setOperationAction(ISD::FEXP2, VT, Expand); + + if (isTypeLegal(VT)) { + setOperationAction(ISD::FFLOOR, VT, Legal); + setOperationAction(ISD::FNEARBYINT, VT, Legal); + setOperationAction(ISD::FCEIL, VT, Legal); + setOperationAction(ISD::FRINT, VT, Legal); + setOperationAction(ISD::FTRUNC, VT, Legal); + setOperationAction(ISD::FROUND, VT, Legal); + setOperationAction(ISD::FMINNAN, VT, Legal); + setOperationAction(ISD::FMINNUM, VT, Legal); + setOperationAction(ISD::FMAXNAN, VT, Legal); + setOperationAction(ISD::FMAXNUM, VT, Legal); + + setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); + setOperationAction(ISD::SETCC, VT, Custom); + setOperationAction(ISD::VECTOR_SHUFFLE_VAR, VT, Custom); + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + setOperationAction(ISD::VECREDUCE_FADD, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); + + // int<->fp bitcasts of unpacked fp datatypes require special handling + MVT IntVT = EVT(VT).changeVectorElementTypeToInteger().getSimpleVT(); + if (!isTypeLegal(IntVT)) + setOperationAction(ISD::BITCAST, IntVT, Custom); + } else { + setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); + setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + setOperationAction(ISD::FP_EXTEND, VT, Custom); + setOperationAction(ISD::INTRINSIC_WO_CHAIN, VT, Custom); + } + } } void AArch64TargetLowering::addTypeForNEON(MVT VT, MVT PromotedBitwiseVT) { @@ -843,6 +1100,11 @@ EVT VT) const { if (!VT.isVector()) return MVT::i32; + + // SVE has predicate regs + if (VT.isScalableVector()) + return MVT::getVectorVT(MVT::i1, VT.getVectorElementCount()); + return VT.changeVectorElementTypeToInteger(); } @@ -1039,8 +1301,40 @@ Known.Zero |= Mask; } break; - } break; } + } + break; + } + case ISD::VSCALE: { + assert(isPowerOf2_64(AArch64::SVEBitsPerBlock) && + isPowerOf2_64(AArch64::SVEMaxBitsPerVector) && + "Incorrect RDVL range calculation."); + + unsigned BitWidth = Op.getValueType().getScalarSizeInBits(); + uint64_t HiVal = AArch64::SVEMaxBitsPerVector / AArch64::SVEBitsPerBlock; + Known.Zero = APInt(BitWidth, ~((HiVal * 2) - 1)); + + KnownBits Known2; + DAG.computeKnownBits(Op->getOperand(0), Known2, Depth + 1); + + // NOTE: Taken from SelectionDAG::computeKnownBits case ISD::MUL + // + // If low bits are zero in either operand, output low known-0 bits. + // Also compute a conservative estimate for high known-0 bits. + // More trickiness is possible, but this is sufficient for the + // interesting case of alignment computation. + Known.One.clearAllBits(); + unsigned TrailZ = Known.Zero.countTrailingOnes() + + Known2.Zero.countTrailingOnes(); + unsigned LeadZ = std::max(Known.Zero.countLeadingOnes() + + Known2.Zero.countLeadingOnes(), + BitWidth) - BitWidth; + + TrailZ = std::min(TrailZ, BitWidth); + LeadZ = std::min(LeadZ, BitWidth); + Known.Zero = APInt::getLowBitsSet(BitWidth, TrailZ) | + APInt::getHighBitsSet(BitWidth, LeadZ); + break; } } } @@ -1078,7 +1372,9 @@ FastISel * AArch64TargetLowering::createFastISel(FunctionLoweringInfo &funcInfo, const TargetLibraryInfo *libInfo) const { - return AArch64::createFastISel(funcInfo, libInfo); + bool EnableSVE = getTargetMachine().getTargetFeatureString().count("sve") > 0; + bool UseFastISel = !EnableSVE; + return UseFastISel ? AArch64::createFastISel(funcInfo, libInfo) : nullptr; } const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { @@ -1131,6 +1427,7 @@ case AArch64ISD::UZP2: return "AArch64ISD::UZP2"; case AArch64ISD::TRN1: return "AArch64ISD::TRN1"; case AArch64ISD::TRN2: return "AArch64ISD::TRN2"; + case AArch64ISD::REV: return "AArch64ISD::REV"; case AArch64ISD::REV16: return "AArch64ISD::REV16"; case AArch64ISD::REV32: return "AArch64ISD::REV32"; case AArch64ISD::REV64: return "AArch64ISD::REV64"; @@ -1157,11 +1454,20 @@ case AArch64ISD::FCMLEz: return "AArch64ISD::FCMLEz"; case AArch64ISD::FCMLTz: return "AArch64ISD::FCMLTz"; case AArch64ISD::SADDV: return "AArch64ISD::SADDV"; + case AArch64ISD::SADDV_PRED: return "AArch64ISD::SADDV_PRED"; case AArch64ISD::UADDV: return "AArch64ISD::UADDV"; + case AArch64ISD::UADDV_PRED: return "AArch64ISD::UADDV_PRED"; case AArch64ISD::SMINV: return "AArch64ISD::SMINV"; + case AArch64ISD::SMINV_PRED: return "AArch64ISD::SMINV_PRED"; case AArch64ISD::UMINV: return "AArch64ISD::UMINV"; + case AArch64ISD::UMINV_PRED: return "AArch64ISD::UMINV_PRED"; case AArch64ISD::SMAXV: return "AArch64ISD::SMAXV"; + case AArch64ISD::SMAXV_PRED: return "AArch64ISD::SMAXV_PRED"; case AArch64ISD::UMAXV: return "AArch64ISD::UMAXV"; + case AArch64ISD::UMAXV_PRED: return "AArch64ISD::UMAXV_PRED"; + case AArch64ISD::ANDV_PRED: return "AArch64ISD::ANDV_PRED"; + case AArch64ISD::EORV_PRED: return "AArch64ISD::EORV_PRED"; + case AArch64ISD::ORV_PRED: return "AArch64ISD::ORV_PRED"; case AArch64ISD::NOT: return "AArch64ISD::NOT"; case AArch64ISD::BIT: return "AArch64ISD::BIT"; case AArch64ISD::CBZ: return "AArch64ISD::CBZ"; @@ -1179,6 +1485,67 @@ case AArch64ISD::URSHR_I: return "AArch64ISD::URSHR_I"; case AArch64ISD::SQSHLU_I: return "AArch64ISD::SQSHLU_I"; case AArch64ISD::WrapperLarge: return "AArch64ISD::WrapperLarge"; + case AArch64ISD::PTEST: return "AArch64ISD::PTEST"; + case AArch64ISD::PTRUE: return "AArch64ISD::PTRUE"; + + case AArch64ISD::GLDNT1: return "AArch64ISD::GLDNT1"; + case AArch64ISD::GLDNT1_UXTW: return "AArch64ISD::GLDNT1_UXTW"; + case AArch64ISD::GLDFF1: return "AArch64ISD::GLDFF1"; + case AArch64ISD::GLDFF1_SCALED: return "AArch64ISD::GLDFF1_SCALED"; + case AArch64ISD::GLDFF1_SXTW: return "AArch64ISD::GLDFF1_SXTW"; + case AArch64ISD::GLDFF1_SXTW_SCALED:return "AArch64ISD::GLDFF1_SXTW_SCALED"; + case AArch64ISD::GLDFF1_UXTW: return "AArch64ISD::GLDFF1_UXTW"; + case AArch64ISD::GLDFF1_UXTW_SCALED:return "AArch64ISD::GLDFF1_UXTW_SCALED"; + + case AArch64ISD::GLDNT1S: return "AArch64ISD::GLDNT1S"; + case AArch64ISD::GLDNT1S_UXTW: return "AArch64ISD::GLDNT1S_UXTW"; + case AArch64ISD::GLDFF1S: return "AArch64ISD::GLDFF1S"; + case AArch64ISD::GLDFF1S_SCALED: return "AArch64ISD::GLDFF1S_SCALED"; + case AArch64ISD::GLDFF1S_SXTW: return "AArch64ISD::GLDFF1S_SXTW"; + case AArch64ISD::GLDFF1S_SXTW_SCALED:return "AArch64ISD::GLDFF1S_SXTW_SCALED"; + case AArch64ISD::GLDFF1S_UXTW: return "AArch64ISD::GLDFF1S_UXTW"; + case AArch64ISD::GLDFF1S_UXTW_SCALED:return "AArch64ISD::GLDFF1S_UXTW_SCALED"; + + case AArch64ISD::GPRF_S_IMM: return "AArch64ISD::GPRF_S_IMM"; + case AArch64ISD::GPRF_D_IMM: return "AArch64ISD::GPRF_D_IMM"; + case AArch64ISD::GPRF_D_SCALED: return "AArch64ISD::GPRF_D_SCALED"; + case AArch64ISD::GPRF_S_SXTW_SCALED: return "AArch64ISD::GPRF_S_SXTW_SCALED"; + case AArch64ISD::GPRF_S_UXTW_SCALED: return "AArch64ISD::GPRF_S_UXTW_SCALED"; + case AArch64ISD::GPRF_D_SXTW_SCALED: return "AArch64ISD::GPRF_D_SXTW_SCALED"; + case AArch64ISD::GPRF_D_UXTW_SCALED: return "AArch64ISD::GPRF_D_UXTW_SCALED"; + + case AArch64ISD::LD1RQ: return "AArch64ISD::LD1RQ"; + case AArch64ISD::LDFF1: return "AArch64ISD::LDFF1"; + case AArch64ISD::LDFF1S: return "AArch64ISD::LDFF1S"; + case AArch64ISD::LDNF1: return "AArch64ISD::LDNF1"; + case AArch64ISD::LDNF1S: return "AArch64ISD::LDNF1S"; + case AArch64ISD::RDFFR: return "AArch64ISD::RDFFR"; + case AArch64ISD::RDFFR_PRED: return "AArch64ISD::RDFFR_PRED"; + case AArch64ISD::SETFFR: return "AArch64ISD::SETFFR"; + case AArch64ISD::WRFFR: return "AArch64ISD::WRFFR"; + case AArch64ISD::REINTERPRET_CAST: return "AArch64ISD::REINTERPRET_CAST"; + case AArch64ISD::TBL: return "AArch64ISD::TBL"; + case AArch64ISD::BRKA: return "AArch64ISD::BRKA"; + case AArch64ISD::CLASTA_N: return "AArch64ISD::CLASTA_N"; + case AArch64ISD::CLASTB_N: return "AArch64ISD::CLASTB_N"; + case AArch64ISD::DUP_PRED: return "AArch64ISD::DUP_PRED"; + case AArch64ISD::FADDA_PRED: return "AArch64ISD::FADDA_PRED"; + case AArch64ISD::FADDV_PRED: return "AArch64ISD::FADDV_PRED"; + case AArch64ISD::FMAXV_PRED: return "AArch64ISD::FMAXV_PRED"; + case AArch64ISD::FMAXNMV_PRED: return "AArch64ISD::FMAXNMV_PRED"; + case AArch64ISD::FMINV_PRED: return "AArch64ISD::FMINV_PRED"; + case AArch64ISD::FMINNMV_PRED: return "AArch64ISD::FMINNMV_PRED"; + case AArch64ISD::FMIN_PRED: return "AArch64ISD::FMIN_PRED"; + case AArch64ISD::FMINNM_PRED: return "AArch64ISD::FMINNM_PRED"; + case AArch64ISD::FMAX_PRED: return "AArch64ISD::FMAX_PRED"; + case AArch64ISD::FMAXNM_PRED: return "AArch64ISD::FMAXNM_PRED"; + case AArch64ISD::INSR: return "AArch64ISD::INSR"; + case AArch64ISD::LASTA: return "AArch64ISD::LASTA"; + case AArch64ISD::LASTB: return "AArch64ISD::LASTB"; + case AArch64ISD::SUNPKHI: return "AArch64ISD::SUNPKHI"; + case AArch64ISD::SUNPKLO: return "AArch64ISD::SUNPKLO"; + case AArch64ISD::UUNPKHI: return "AArch64ISD::UUNPKHI"; + case AArch64ISD::UUNPKLO: return "AArch64ISD::UUNPKLO"; case AArch64ISD::LD2post: return "AArch64ISD::LD2post"; case AArch64ISD::LD3post: return "AArch64ISD::LD3post"; case AArch64ISD::LD4post: return "AArch64ISD::LD4post"; @@ -1208,6 +1575,10 @@ case AArch64ISD::FRECPS: return "AArch64ISD::FRECPS"; case AArch64ISD::FRSQRTE: return "AArch64ISD::FRSQRTE"; case AArch64ISD::FRSQRTS: return "AArch64ISD::FRSQRTS"; + case AArch64ISD::LDNT1: return "AArch64ISD::LDNT1"; + case AArch64ISD::STNT1: return "AArch64ISD::STNT1"; + case AArch64ISD::SSTNT1: return "AArch64ISD::SSTNT1"; + case AArch64ISD::SSTNT1_UXTW: return "AArch64ISD::SSTNT1_UXTW"; } return nullptr; } @@ -1567,6 +1938,7 @@ AArch64CC::CondCode Predicate, AArch64CC::CondCode OutCC, const SDLoc &DL, SelectionDAG &DAG) { + assert(LHS.getValueType() != MVT::f128 && "Cannot emit 128-bit floats"); unsigned Opcode = 0; const bool FullFP16 = static_cast(DAG.getSubtarget()).hasFullFP16(); @@ -2243,11 +2615,10 @@ // in the cost tables. EVT InVT = Op.getOperand(0).getValueType(); EVT VT = Op.getValueType(); - unsigned NumElts = InVT.getVectorNumElements(); // f16 vectors are promoted to f32 before a conversion. if (InVT.getVectorElementType() == MVT::f16) { - MVT NewVT = MVT::getVectorVT(MVT::f32, NumElts); + MVT NewVT = MVT::getVectorVT(MVT::f32, InVT.getVectorElementCount()); SDLoc dl(Op); return DAG.getNode( Op.getOpcode(), dl, Op.getValueType(), @@ -2266,7 +2637,7 @@ SDLoc dl(Op); MVT ExtVT = MVT::getVectorVT(MVT::getFloatingPointVT(VT.getScalarSizeInBits()), - VT.getVectorNumElements()); + VT.getVectorElementCount()); SDValue Ext = DAG.getNode(ISD::FP_EXTEND, dl, ExtVT, Op.getOperand(0)); return DAG.getNode(Op.getOpcode(), dl, VT, Ext); } @@ -2316,7 +2687,7 @@ if (VT.getSizeInBits() < InVT.getSizeInBits()) { MVT CastVT = MVT::getVectorVT(MVT::getFloatingPointVT(InVT.getScalarSizeInBits()), - InVT.getVectorNumElements()); + InVT.getVectorElementCount()); In = DAG.getNode(Op.getOpcode(), dl, CastVT, In); return DAG.getNode(ISD::FP_ROUND, dl, VT, In, DAG.getIntPtrConstant(0, dl)); } @@ -2399,7 +2770,11 @@ return CallResult.first; } -static SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) { +SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op, + SelectionDAG &DAG) const { + if (Op.getValueType().isScalableVector()) + return LowerVectorBITCAST(Op, DAG); + if (Op.getValueType() != MVT::f16) return SDValue(); @@ -2704,7 +3079,120 @@ case Intrinsic::aarch64_neon_umin: return DAG.getNode(ISD::UMIN, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); + + case Intrinsic::aarch64_sve_fadda: + return DAG.getNode(AArch64ISD::FADDA_PRED, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); + case Intrinsic::aarch64_sve_faddv: + return DAG.getNode(AArch64ISD::FADDV_PRED, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_clasta_n: + return DAG.getNode(AArch64ISD::CLASTA_N, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); + case Intrinsic::aarch64_sve_clastb_n: + return DAG.getNode(AArch64ISD::CLASTB_N, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); + case Intrinsic::aarch64_sve_dupq_lane: + return LowerDUPQLane(Op, DAG); + case Intrinsic::aarch64_sve_insr: { + SDValue Scalar = Op.getOperand(2); + EVT ScalarTy = Scalar.getValueType(); + if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16)) + Scalar = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Scalar); + + return DAG.getNode(AArch64ISD::INSR, dl, Op.getValueType(), + Op.getOperand(1), Scalar); } + case Intrinsic::aarch64_sve_fmaxv: + return DAG.getNode(AArch64ISD::FMAXV_PRED, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_fmaxnmv: + return DAG.getNode(AArch64ISD::FMAXNMV_PRED, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_fminv: + return DAG.getNode(AArch64ISD::FMINV_PRED, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_fminnmv: + return DAG.getNode(AArch64ISD::FMINNMV_PRED, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_fmin: + return DAG.getNode(AArch64ISD::FMIN_PRED, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); + case Intrinsic::aarch64_sve_fminnm: + return DAG.getNode(AArch64ISD::FMINNM_PRED, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); + case Intrinsic::aarch64_sve_fmax: + return DAG.getNode(AArch64ISD::FMAX_PRED, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); + case Intrinsic::aarch64_sve_fmaxnm: + return DAG.getNode(AArch64ISD::FMAXNM_PRED, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); + case Intrinsic::aarch64_sve_lasta: + return DAG.getNode(AArch64ISD::LASTA, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_lastb: + return DAG.getNode(AArch64ISD::LASTB, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_ptrue: + return DAG.getNode(AArch64ISD::PTRUE, dl, Op.getValueType(), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_sunpkhi: + return DAG.getNode(AArch64ISD::SUNPKHI, dl, Op.getValueType(), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_sunpklo: + return DAG.getNode(AArch64ISD::SUNPKLO, dl, Op.getValueType(), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_reinterpret_bool_b: + return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_reinterpret_bool_h: + return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_reinterpret_bool_w: + return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_reinterpret_bool_d: + return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_rev: + return DAG.getNode(AArch64ISD::REV, dl, Op.getValueType(), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_tbl: + return DAG.getNode(AArch64ISD::TBL, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_trn1: + return DAG.getNode(AArch64ISD::TRN1, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_trn2: + return DAG.getNode(AArch64ISD::TRN2, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_uunpkhi: + return DAG.getNode(AArch64ISD::UUNPKHI, dl, Op.getValueType(), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_uunpklo: + return DAG.getNode(AArch64ISD::UUNPKLO, dl, Op.getValueType(), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_uzp1: + return DAG.getNode(AArch64ISD::UZP1, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_uzp2: + return DAG.getNode(AArch64ISD::UZP2, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_zip1: + return DAG.getNode(AArch64ISD::ZIP1, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_zip2: + return DAG.getNode(AArch64ISD::ZIP2, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + } +} + +void AArch64TargetLowering::LowerOperationWrapper(SDNode *N, + SmallVectorImpl &Results, SelectionDAG &DAG) const { + SDValue Res = LowerOperation(SDValue(N, 0), DAG); + if (Res.getNode()) + for (unsigned I = 0, E = Res->getNumValues(); I != E; ++I) + Results.push_back(Res.getValue(I)); } // Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16. @@ -2778,6 +3266,15 @@ default: llvm_unreachable("unimplemented operand"); return SDValue(); + case ISD::ANY_EXTEND: + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: + // Needed because we have selected custom lowering for illegal SVE types. + // The cases we actually want to handle are where the operand is legal + // and the result isn't, which go through ReplaceNodeResults instead. + // This code only sees cases where the result is legal and the operand + // isn't. + return SDValue(); case ISD::BITCAST: return LowerBITCAST(Op, DAG); case ISD::GlobalAddress: @@ -2838,10 +3335,20 @@ return LowerEXTRACT_VECTOR_ELT(Op, DAG); case ISD::BUILD_VECTOR: return LowerBUILD_VECTOR(Op, DAG); + case ISD::CONCAT_VECTORS: + return LowerCONCAT_VECTORS(Op, DAG); case ISD::VECTOR_SHUFFLE: return LowerVECTOR_SHUFFLE(Op, DAG); + case ISD::VECTOR_SHUFFLE_VAR: + return LowerVECTOR_SHUFFLE_VAR(Op, DAG, 1, Op->getValueType(0)); + case ISD::SERIES_VECTOR: + return LowerSERIES_VECTOR(Op, DAG); + case ISD::SPLAT_VECTOR: + return LowerSPLAT_VECTOR(Op, DAG); case ISD::EXTRACT_SUBVECTOR: return LowerEXTRACT_SUBVECTOR(Op, DAG); + case ISD::INSERT_SUBVECTOR: + return LowerINSERT_SUBVECTOR(Op, DAG); case ISD::SRA: case ISD::SRL: case ISD::SHL: @@ -2875,18 +3382,45 @@ return LowerFLT_ROUNDS_(Op, DAG); case ISD::MUL: return LowerMUL(Op, DAG); + case ISD::SDIV: + case ISD::UDIV: + return LowerDIV(Op, DAG); + case ISD::SREM: + case ISD::UREM: + return LowerREM(Op, DAG); + case ISD::INTRINSIC_W_CHAIN: + return LowerINTRINSIC_W_CHAIN(Op, DAG); case ISD::MULHS: case ISD::MULHU: return LowerMULH(Op, DAG); case ISD::INTRINSIC_WO_CHAIN: return LowerINTRINSIC_WO_CHAIN(Op, DAG); + case ISD::TRUNCATE: + return LowerTRUNCATE(Op, DAG); + case AArch64ISD::LASTA: + case AArch64ISD::LASTB: + return LowerLASTX(Op, DAG); + case ISD::MLOAD: + return LowerMLOAD(Op, DAG); + case ISD::MSTORE: + return LowerMSTORE(Op, DAG); + case ISD::MGATHER: + return LowerMGATHER(Op, DAG); + case ISD::MSCATTER: + return LowerMSCATTER(Op, DAG); + case ISD::VSCALE: + return LowerVSCALE(Op, DAG); case ISD::STORE: return LowerSTORE(Op, DAG); + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: case ISD::VECREDUCE_ADD: case ISD::VECREDUCE_SMAX: case ISD::VECREDUCE_SMIN: case ISD::VECREDUCE_UMAX: case ISD::VECREDUCE_UMIN: + case ISD::VECREDUCE_FADD: case ISD::VECREDUCE_FMAX: case ISD::VECREDUCE_FMIN: return LowerVECREDUCE(Op, DAG); @@ -2927,6 +3461,8 @@ return IsVarArg ? CC_AArch64_DarwinPCS_VarArg : CC_AArch64_DarwinPCS; case CallingConv::Win64: return IsVarArg ? CC_AArch64_Win64_VarArg : CC_AArch64_AAPCS; + case CallingConv::AArch64_VectorCall: + return CC_AArch64_AAPCS; } } @@ -3002,11 +3538,11 @@ continue; } + SDValue ArgValue; if (VA.isRegLoc()) { // Arguments stored in registers. EVT RegVT = VA.getLocVT(); - SDValue ArgValue; const TargetRegisterClass *RC; if (RegVT == MVT::i32) @@ -3021,6 +3557,11 @@ RC = &AArch64::FPR64RegClass; else if (RegVT == MVT::f128 || RegVT.is128BitVector()) RC = &AArch64::FPR128RegClass; + else if (RegVT.isScalableVector() && + RegVT.getVectorElementType() == MVT::i1) + RC = &AArch64::PPRRegClass; + else if (RegVT.isScalableVector()) + RC = &AArch64::ZPRRegClass; else llvm_unreachable("RegVT not supported by FORMAL_ARGUMENTS Lowering"); @@ -3035,6 +3576,7 @@ default: llvm_unreachable("Unknown loc info!"); case CCValAssign::Full: + case CCValAssign::Indirect: break; case CCValAssign::BCvt: ArgValue = DAG.getNode(ISD::BITCAST, DL, VA.getValVT(), ArgValue); @@ -3047,13 +3589,12 @@ assert(RegVT == Ins[i].VT && "incorrect register location selected"); break; } - - InVals.push_back(ArgValue); - - } else { // VA.isRegLoc() - assert(VA.isMemLoc() && "CCValAssign is neither reg nor mem"); + } else { + assert(VA.isMemLoc() && "CCValAssign is neither reg or mem"); unsigned ArgOffset = VA.getLocMemOffset(); unsigned ArgSize = VA.getValVT().getSizeInBits() / 8; + if (VA.getLocInfo() == CCValAssign::Indirect) + ArgSize = VA.getLocVT().getSizeInBits() / 8; uint32_t BEAlign = 0; if (!Subtarget->isLittleEndian() && ArgSize < 8 && @@ -3064,7 +3605,6 @@ // Create load nodes to retrieve arguments from the stack. SDValue FIN = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout())); - SDValue ArgValue; // For NON_EXTLOAD, generic code in getLoad assert(ValVT == MemVT) ISD::LoadExtType ExtType = ISD::NON_EXTLOAD; @@ -3074,6 +3614,7 @@ default: break; case CCValAssign::BCvt: + case CCValAssign::Indirect: MemVT = VA.getLocVT(); break; case CCValAssign::SExt: @@ -3091,9 +3632,17 @@ ExtType, DL, VA.getLocVT(), Chain, FIN, MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI), MemVT); - - InVals.push_back(ArgValue); } + + if (VA.getLocInfo() == CCValAssign::Indirect) { + assert(VA.getValVT().isScalableVector() && + "Only scalable vectors can be passed indirectly"); + // If value is passed via pointer - do a load. + ArgValue = + DAG.getLoad(VA.getValVT(), DL, Chain, ArgValue, MachinePointerInfo()); + } + + InVals.push_back(ArgValue); } // varargs @@ -3599,6 +4148,18 @@ case CCValAssign::FPExt: Arg = DAG.getNode(ISD::FP_EXTEND, DL, VA.getLocVT(), Arg); break; + case CCValAssign::Indirect: { + assert(VA.getValVT().isScalableVector() && + "Only scalable vectors can be passed indirectly"); + SDValue SpillSlot = DAG.CreateStackTemporary(VA.getValVT()); + int FI = cast(SpillSlot)->getIndex(); + MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo(); + MFI.setStackID(FI, AArch64::FR_SVE); + Chain = DAG.getStore( + Chain, DL, Arg, SpillSlot, + MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI)); + Arg = SpillSlot; + } } if (VA.isRegLoc()) { @@ -3745,6 +4306,20 @@ Ops.push_back(DAG.getRegister(RegToPass.first, RegToPass.second.getValueType())); + // SVE mask.... + if (CallConv == CallingConv::C) { + // Check callee args/returns + bool CalleeOutSVE = any_of(Outs, [](ISD::OutputArg &Out){ + return Out.VT.isScalableVector(); + }); + bool CalleeInSVE = any_of(Ins, [](ISD::InputArg &In){ + return In.VT.isScalableVector(); + }); + + if (CalleeInSVE || CalleeOutSVE) + CallConv = CallingConv::AArch64_SVE_VectorCall; + } + // Add a register mask operand representing the call-preserved registers. const uint32_t *Mask; const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); @@ -5204,45 +5779,30 @@ } bool AArch64TargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT) const { - // We can materialize #0.0 as fmov $Rd, XZR for 64-bit and 32-bit cases. - // FIXME: We should be able to handle f128 as well with a clever lowering. - if (Imm.isPosZero() && (VT == MVT::f64 || VT == MVT::f32 || - (VT == MVT::f16 && Subtarget->hasFullFP16()))) { - LLVM_DEBUG( - dbgs() << "Legal fp imm: materialize 0 using the zero register\n"); - return true; - } - - StringRef FPType; bool IsLegal = false; - SmallString<128> ImmStrVal; - Imm.toString(ImmStrVal); + // We can materialize #0.0 as fmov $Rd, XZR for 64-bit, 32-bit cases, and + // 16-bit case when target has full fp16 support. + // FIXME: We should be able to handle f128 as well with a clever lowering. + const APInt ImmInt = Imm.bitcastToAPInt(); + if (VT == MVT::f64) + IsLegal = AArch64_AM::getFP64Imm(ImmInt) != -1 || Imm.isPosZero(); + else if (VT == MVT::f32) + IsLegal = AArch64_AM::getFP32Imm(ImmInt) != -1 || Imm.isPosZero(); + else if (VT == MVT::f16 && Subtarget->hasFullFP16()) + IsLegal = AArch64_AM::getFP16Imm(ImmInt) != -1 || Imm.isPosZero(); + // TODO: fmov h0, w0 is also legal, however on't have an isel pattern to + // generate that fmov. - if (VT == MVT::f64) { - FPType = "f64"; - IsLegal = AArch64_AM::getFP64Imm(Imm) != -1; - } else if (VT == MVT::f32) { - FPType = "f32"; - IsLegal = AArch64_AM::getFP32Imm(Imm) != -1; - } else if (VT == MVT::f16 && Subtarget->hasFullFP16()) { - FPType = "f16"; - IsLegal = AArch64_AM::getFP16Imm(Imm) != -1; - } + // If we can not materialize in immediate field for fmov, check if the + // value can be encoded as the immediate operand of a logical instruction. + // The immediate value will be created with either MOVZ, MOVN, or ORR. + if (!IsLegal && (VT == MVT::f64 || VT == MVT::f32)) + IsLegal = AArch64_AM::isAnyMOVWMovAlias(ImmInt.getZExtValue(), + VT.getSizeInBits()); - if (IsLegal) { - LLVM_DEBUG(dbgs() << "Legal " << FPType << " imm value: " << ImmStrVal - << "\n"); - return true; - } - - if (!FPType.empty()) - LLVM_DEBUG(dbgs() << "Illegal " << FPType << " imm value: " << ImmStrVal - << "\n"); - else - LLVM_DEBUG(dbgs() << "Illegal fp imm " << ImmStrVal - << ": unsupported fp type\n"); - - return false; + LLVM_DEBUG(dbgs() << (IsLegal ? "Legal " : "Illegal ") << VT.getEVTString() + << " imm value: "; Imm.dump();); + return IsLegal; } //===----------------------------------------------------------------------===// @@ -5253,10 +5813,11 @@ SDValue Operand, SelectionDAG &DAG, int &ExtraSteps) { EVT VT = Operand.getValueType(); - if (ST->hasNEON() && - (VT == MVT::f64 || VT == MVT::v1f64 || VT == MVT::v2f64 || - VT == MVT::f32 || VT == MVT::v1f32 || - VT == MVT::v2f32 || VT == MVT::v4f32)) { + if ((ST->hasNEON() && + (VT == MVT::f64 || VT == MVT::v1f64 || VT == MVT::v2f64 || + VT == MVT::f32 || VT == MVT::v1f32 || + VT == MVT::v2f32 || VT == MVT::v4f32)) || + (ST->hasSVE() && (VT == MVT::nxv4f32 || VT == MVT::nxv2f64))) { if (ExtraSteps == TargetLoweringBase::ReciprocalEstimate::Unspecified) // For the reciprocal estimates, convergence is quadratic, so the number // of digits is doubled after each iteration. In ARMv8, the accuracy of @@ -5316,7 +5877,9 @@ SDValue AArch64TargetLowering::getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps) const { - if (Enabled == ReciprocalEstimate::Enabled) + if (Enabled == ReciprocalEstimate::Enabled || + (Enabled == ReciprocalEstimate::Unspecified && + Subtarget->useIterativeReciprocal())) if (SDValue Estimate = getEstimate(Subtarget, AArch64ISD::FRECPE, Operand, DAG, ExtraSteps)) { SDLoc DL(Operand); @@ -5401,6 +5964,7 @@ return C_Other; case 'x': case 'w': + case 'y': return C_RegisterClass; // An address with a single base register. Due to the way we // currently handle addresses it is the same as 'r'. @@ -5409,6 +5973,10 @@ case 'S': // A symbolic address return C_Other; } + } else if (Constraint.size() == 3 && Constraint[0] == 'U' + && Constraint[1] == 'p' + && (Constraint[2] == 'l' || Constraint[2] == 'a')) { + return C_RegisterClass; } return TargetLowering::getConstraintType(Constraint); } @@ -5433,12 +6001,17 @@ break; case 'x': case 'w': + case 'y': if (type->isFloatingPointTy() || type->isVectorTy()) weight = CW_Register; break; case 'z': weight = CW_Constant; break; + case 'U': + if (constraint[1] == 'p' && (constraint[2] == 'l' || constraint[2] == 'a')) + weight = CW_Register; + break; } return weight; } @@ -5449,10 +6022,14 @@ if (Constraint.size() == 1) { switch (Constraint[0]) { case 'r': + if (VT.isScalableVector()) + break; if (VT.getSizeInBits() == 64) return std::make_pair(0U, &AArch64::GPR64commonRegClass); return std::make_pair(0U, &AArch64::GPR32commonRegClass); case 'w': + if (VT.isScalableVector()) + return std::make_pair(0U, &AArch64::ZPRRegClass); if (VT.getSizeInBits() == 16) return std::make_pair(0U, &AArch64::FPR16RegClass); if (VT.getSizeInBits() == 32) @@ -5465,9 +6042,23 @@ // The instructions that this constraint is designed for can // only take 128-bit registers so just use that regclass. case 'x': + if (VT.isScalableVector()) + return std::make_pair(0U, &AArch64::ZPR_4bRegClass); if (VT.getSizeInBits() == 128) return std::make_pair(0U, &AArch64::FPR128_loRegClass); break; + case 'y': + if (VT.isScalableVector()) + return std::make_pair(0U, &AArch64::ZPR_3bRegClass); + break; + } + } else if (Constraint.size() == 3 && Constraint[0] == 'U' + && Constraint[1] == 'p' + && (Constraint[2] == 'l' || Constraint[2] == 'a')) { + if (VT.isScalableVector() && (VT.getVectorElementType() == MVT::i1)) { + bool restricted = (Constraint[2] == 'l'); + return restricted ? std::make_pair(0U, &AArch64::PPR_3bRegClass) + : std::make_pair(0U, &AArch64::PPRRegClass); } } if (StringRef("{cc}").equals_lower(Constraint)) @@ -5733,12 +6324,13 @@ if (V.isUndef()) continue; else if (V.getOpcode() != ISD::EXTRACT_VECTOR_ELT || - !isa(V.getOperand(1))) { + !isa(V.getOperand(1)) || + V.getOperand(0).getValueType().isScalableVector()) { LLVM_DEBUG( dbgs() << "Reshuffle failed: " "a shuffle can only come from building a vector from " "various elements of other vectors, provided their " - "indices are constant\n"); + "indices are constant, and vector is fixed width\n"); return SDValue(); } @@ -6168,7 +6760,7 @@ return SDValue(); EVT CastVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(), - VT.getVectorNumElements() / 2); + VT.getVectorElementCount() / 2); if (SplitV0) { V0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, CastVT, V0, DAG.getConstant(0, DL, MVT::i64)); @@ -6785,20 +7377,6 @@ return true; } -static unsigned getIntrinsicID(const SDNode *N) { - unsigned Opcode = N->getOpcode(); - switch (Opcode) { - default: - return Intrinsic::not_intrinsic; - case ISD::INTRINSIC_WO_CHAIN: { - unsigned IID = cast(N->getOperand(0))->getZExtValue(); - if (IID < Intrinsic::num_intrinsics) - return IID; - return Intrinsic::not_intrinsic; - } - } -} - // Attempt to form a vector S[LR]I from (or (and X, BvecC1), (lsl Y, C2)), // to (SLI X, Y, C2), where X and Y have matching vector types, BvecC1 is a // BUILD_VECTORs with constant element C1, C2 is a constant, and C1 == ~C2. @@ -7289,13 +7867,53 @@ AArch64TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const { assert(Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT && "Unknown opcode!"); + SDLoc DL(Op); - // Check for non-constant or out of range lane. + // Check for out of range lane. EVT VT = Op.getOperand(0).getValueType(); - ConstantSDNode *CI = dyn_cast(Op.getOperand(1)); - if (!CI || CI->getZExtValue() >= VT.getVectorNumElements()) + auto *CI = dyn_cast(Op.getOperand(1)); + + // Lower extract on an SVE vector to LASTA or LASTB + if (VT.isScalableVector()) { + if (CI) + return isTypeLegal(VT) ? Op : SDValue(); + + SDValue InVec = Op.getOperand(0); + SDValue InIdx = Op.getOperand(1); + + unsigned NumElts = VT.getVectorNumElements(); + unsigned EltBits = VT.getVectorElementType().getSizeInBits(); + if (NumElts * EltBits < AArch64::SVEBitsPerBlock) return SDValue(); + // Legal are handled directly + if (isTypeLegal(VT)) + return Op; + + // Otherwise, create a LastB + EVT IntVT = VT.changeVectorElementTypeToInteger(); + + // Splat the index + SDValue SplatIdx = DAG.getNode(ISD::SPLAT_VECTOR, DL, IntVT, + DAG.getZExtOrTrunc(InIdx, DL, MVT::i32)); + + // Create a 0..EC sequence vector + SDValue Zero = DAG.getConstant(0, DL, MVT::i32); + SDValue One = DAG.getConstant(1, DL, MVT::i32); + SDValue Seq = DAG.getNode(ISD::SERIES_VECTOR, DL, IntVT, Zero, One); + + // Compare idx and sequence + EVT BoolVT = EVT::getIntegerVT(*DAG.getContext(), 1); + EVT PredVT = Seq.getValueType().changeVectorElementType(BoolVT); + SDValue Pred = DAG.getNode(ISD::SETCC, DL, PredVT, Seq, SplatIdx, + DAG.getCondCode(ISD::SETEQ)); + + return DAG.getNode(AArch64ISD::LASTB, DL, Op.getValueType(), Pred, InVec); + } + + // Only SVE supports non constant extracts + if (!CI || (CI && CI->getZExtValue() >= VT.getVectorNumElements())) + return SDValue(); // Insertion/extraction are legal for V128 types. if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32 || @@ -7307,9 +7925,11 @@ VT != MVT::v1i64 && VT != MVT::v2f32 && VT != MVT::v4f16) return SDValue(); + assert(VT.getVectorElementType() != MVT::i1 && + "Predicate extract should not be custom lowered"); + // For V64 types, we perform extraction by expanding the value // to a V128 type and perform the extraction on that. - SDLoc DL(Op); SDValue WideVec = WidenVector(Op.getOperand(0), DAG); EVT WideTy = WideVec.getValueType(); @@ -7324,6 +7944,10 @@ SDValue AArch64TargetLowering::LowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const { + // If there is anything to do it will be handled by ReplaceNodeResult. + if (Op.getValueType().isScalableVector()) + return SDValue(); + EVT VT = Op.getOperand(0).getValueType(); SDLoc dl(Op); // Just in case... @@ -7337,6 +7961,34 @@ unsigned Size = Op.getValueSizeInBits(); + // Handle the shuffle of an unpacked scalable vector into a + // fixed-width vector by promoting the operands to a wider type. + const unsigned M = VT.getVectorNumElements(); + const unsigned SizeTy = Op.getValueType().getScalarSizeInBits(); + if (!Op.getValueType().isScalableVector() && VT.isScalableVector() && + (M * SizeTy < 128) && Val == 0) { + assert(Op.getValueType().isInteger() && + "EXTRACT_SUBREG from scalable vectors must \ + return vectors with integer elements"); + assert(Op.getValueType().getVectorNumElements() == M && + "Invalid vector size."); + SDValue Step = Op.getOperand(0); + assert(128 % M == 0 && "Invalid number of lanes."); + const unsigned SizeWideTy = 128 / M; + const EVT nxMxWideTy = EVT::getVectorVT( + *DAG.getContext(), MVT::getIntegerVT(SizeWideTy), M, true); + + Step = DAG.getNode(ISD::ANY_EXTEND, dl, nxMxWideTy, Step); + + const EVT MxWideTy = EVT::getVectorVT( + *DAG.getContext(), MVT::getIntegerVT(SizeWideTy), M, false); + Step = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MxWideTy, Step, + DAG.getConstant(0, dl, MVT::i64)); + const EVT MxTy = EVT::getVectorVT(*DAG.getContext(), + MVT::getIntegerVT(SizeTy), M, false); + return DAG.getNode(ISD::TRUNCATE, dl, MxTy, Step); + } + // This will get lowered to an appropriate EXTRACT_SUBREG in ISel. if (Val == 0) return Op; @@ -7482,7 +8134,8 @@ AArch64CC::CondCode CC, bool NoNans, EVT VT, const SDLoc &dl, SelectionDAG &DAG) { EVT SrcVT = LHS.getValueType(); - assert(VT.getSizeInBits() == SrcVT.getSizeInBits() && + assert((VT.getSizeInBits() == SrcVT.getSizeInBits() || + VT.isScalableVector()) && "function only supposed to emit natural comparisons"); BuildVectorSDNode *BVN = dyn_cast(RHS.getNode()); @@ -7515,6 +8168,11 @@ if (IsZero) return DAG.getNode(AArch64ISD::FCMGTz, dl, VT, LHS); return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS); + case AArch64CC::LE: + if (!NoNans) + return SDValue(); + // If we ignore NaNs we can use to the LS implementation. + // Fallthrough. case AArch64CC::LS: if (IsZero) return DAG.getNode(AArch64ISD::FCMLEz, dl, VT, LHS); @@ -7579,9 +8237,25 @@ SDValue LHS = Op.getOperand(0); SDValue RHS = Op.getOperand(1); EVT CmpVT = LHS.getValueType().changeVectorElementTypeToInteger(); + EVT VT = Op.getValueType(); SDLoc dl(Op); + bool IsInt = LHS.getValueType().getVectorElementType().isInteger(); - if (LHS.getValueType().getVectorElementType().isInteger()) { + // Override to i1s for SVE + if (CmpVT.isScalableVector()) { + if (VT.getScalarSizeInBits() != 1 && IsInt) { + EVT ExtVT = VT; + VT = MVT::getVectorVT(MVT::i1, VT.getVectorElementCount()); + SDValue NewCmp = + DAG.getNode(ISD::SETCC, dl, VT, LHS, RHS, Op.getOperand(2)); + return DAG.getZExtOrTrunc(NewCmp, dl, ExtVT); + } + if (IsInt) + return Op; + CmpVT = MVT::getVectorVT(MVT::i1, CmpVT.getVectorElementCount()); + } + + if (IsInt) { assert(LHS.getValueType() == RHS.getValueType()); AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC); SDValue Cmp = @@ -7594,18 +8268,21 @@ // Make v4f16 (only) fcmp operations utilise vector instructions // v8f16 support will be a litle more complicated - if (LHS.getValueType().getVectorElementType() == MVT::f16) { - if (!FullFP16 && LHS.getValueType().getVectorNumElements() == 4) { - LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, LHS); - RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, RHS); - SDValue NewSetcc = DAG.getSetCC(dl, MVT::v4i16, LHS, RHS, CC); - DAG.ReplaceAllUsesWith(Op, NewSetcc); - CmpVT = MVT::v4i32; - } else - return SDValue(); + if (!LHS.getValueType().isScalableVector()) { + if (LHS.getValueType().getVectorElementType() == MVT::f16) { + if (!FullFP16 && LHS.getValueType().getVectorNumElements() == 4) { + LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, LHS); + RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, RHS); + SDValue NewSetcc = DAG.getSetCC(dl, MVT::v4i16, LHS, RHS, CC); + DAG.ReplaceAllUsesWith(Op, NewSetcc); + CmpVT = MVT::v4i32; + } else + return SDValue(); + } } - assert(LHS.getValueType().getVectorElementType() == MVT::f32 || + assert(LHS.getValueType().getVectorElementType() == MVT::f16 || + LHS.getValueType().getVectorElementType() == MVT::f32 || LHS.getValueType().getVectorElementType() == MVT::f64); // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally @@ -7647,10 +8324,16 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const { + + if (Op.getOperand(0).getValueType().isScalableVector()) + return LowerVECREDUCE_SVE(Op, DAG); + SDLoc dl(Op); switch (Op.getOpcode()) { case ISD::VECREDUCE_ADD: return getReductionSDNode(AArch64ISD::UADDV, dl, Op, DAG); + case ISD::VECREDUCE_FADD: + llvm_unreachable("Custom lowering of fadd reduction expects SVE"); case ISD::VECREDUCE_SMAX: return getReductionSDNode(AArch64ISD::SMAXV, dl, Op, DAG); case ISD::VECREDUCE_SMIN: @@ -8110,6 +8793,10 @@ assert(Shuffles.size() == Indices.size() && "Unmatched number of shufflevectors and indices"); + // Only gather/scatter supports the complete range of factors. + if (Factor > 4) + return false; + const DataLayout &DL = LI->getModule()->getDataLayout(); VectorType *VecTy = Shuffles[0]->getType(); @@ -8238,6 +8925,10 @@ assert(VecTy->getVectorNumElements() % Factor == 0 && "Invalid interleaved store"); + // Only gather/scatter supports the complete range of factors. + if (Factor > 4) + return false; + unsigned LaneLen = VecTy->getVectorNumElements() / Factor; Type *EltTy = VecTy->getVectorElementType(); VectorType *SubVecTy = VectorType::get(EltTy, LaneLen); @@ -8260,7 +8951,8 @@ // vectors to integer vectors. if (EltTy->isPointerTy()) { Type *IntTy = DL.getIntPtrType(EltTy); - unsigned NumOpElts = Op0->getType()->getVectorNumElements(); + unsigned NumOpElts = + dyn_cast(Op0->getType())->getVectorNumElements(); // Convert to the corresponding integer vector. Type *IntVecTy = VectorType::get(IntTy, NumOpElts); @@ -8287,7 +8979,8 @@ SI->getPointerAddressSpace())); } - auto Mask = SVI->getShuffleMask(); + SmallVector Mask; + SVI->getShuffleMask(Mask); Type *PtrTy = SubVecTy->getPointerTo(SI->getPointerAddressSpace()); Type *Tys[2] = {SubVecTy, PtrTy}; @@ -8474,6 +9167,8 @@ return false; switch (VT.getSimpleVT().SimpleTy) { + case MVT::f16: + return Subtarget->hasFullFP16(); case MVT::f32: case MVT::f64: return true; @@ -8652,6 +9347,17 @@ return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), SRA); } +static bool IsSVECntIntrinsic(SDValue S) { + switch(getIntrinsicID(S.getNode())) { + case Intrinsic::aarch64_sve_cntb: + case Intrinsic::aarch64_sve_cnth: + case Intrinsic::aarch64_sve_cntw: + case Intrinsic::aarch64_sve_cntd: + return true; + } + return false; +} + static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -8662,9 +9368,17 @@ if (!isa(N->getOperand(1))) return SDValue(); + SDValue N0 = N->getOperand(0); ConstantSDNode *C = cast(N->getOperand(1)); const APInt &ConstValue = C->getAPIntValue(); + // There's nothing to do if the immediate can be packed into an instruction. + if (IsSVECntIntrinsic(N0) || + (N0->getOpcode() == ISD::TRUNCATE && + (IsSVECntIntrinsic(N0->getOperand(0))))) + if (ConstValue.sge(1) && ConstValue.sle(16)) + return SDValue(); + // Multiplication of a power of two plus/minus one can be done more // cheaply as as shift+add/sub. For now, this is true unilaterally. If // future CPUs have a cheaper MADD instruction, this may need to be @@ -8675,7 +9389,7 @@ // e.g. 6=3*2=(2+1)*2. // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 // which equals to (1+2)*16-(1+2). - SDValue N0 = N->getOperand(0); + // TrailingZeroes is used to test if the mul can be lowered to // shift+add+shift. unsigned TrailingZeroes = ConstValue.countTrailingZeros(); @@ -9104,6 +9818,9 @@ if (SDValue Res = tryCombineToBSL(N, DCI)) return Res; + if (SDValue Res = tryCombineVecOrNot(N, DCI)) + return Res; + return SDValue(); } @@ -9182,6 +9899,7 @@ // type, we know this is an extract of the high or low half of the vector. EVT SVT = Source->getValueType(0); if (!SVT.isVector() || + SVT.getVectorElementType() != VT.getVectorElementType() || SVT.getVectorNumElements() != VT.getVectorNumElements() * 2) return SDValue(); @@ -9535,6 +10253,9 @@ if (DCI.isBeforeLegalizeOps()) return SDValue(); + if (SDValue NewAdd = performSVEIndexedAddressingCombine(N, DCI, DAG)) + return NewAdd; + MVT VT = N->getSimpleValueType(0); if (!VT.is128BitVector()) { if (N->getOpcode() == ISD::ADD) @@ -9745,6 +10466,70 @@ case Intrinsic::aarch64_crc32h: case Intrinsic::aarch64_crc32ch: return tryCombineCRC32(0xffff, N, DAG); + + case Intrinsic::aarch64_sve_andv: + return LowerSVEIntReduction(N, AArch64ISD::ANDV_PRED, DAG); + case Intrinsic::aarch64_sve_dup: + return LowerSVEIntrinsicDUP(N, DAG); + case Intrinsic::aarch64_sve_eorv: + return LowerSVEIntReduction(N, AArch64ISD::EORV_PRED, DAG); + case Intrinsic::aarch64_sve_ext: + return LowerSVEIntrinsicEXT(N, DAG); + case Intrinsic::aarch64_sve_orv: + return LowerSVEIntReduction(N, AArch64ISD::ORV_PRED, DAG); + case Intrinsic::aarch64_sve_saddv: + // There is no i64 version of SADDV because the sign is irrelevant. + if (N->getOperand(2)->getValueType(0).getVectorElementType() == MVT::i64) + return LowerSVEIntReduction(N, AArch64ISD::UADDV_PRED, DAG); + else + return LowerSVEIntReduction(N, AArch64ISD::SADDV_PRED, DAG); + case Intrinsic::aarch64_sve_smaxv: + return LowerSVEIntReduction(N, AArch64ISD::SMAXV_PRED, DAG); + case Intrinsic::aarch64_sve_sminv: + return LowerSVEIntReduction(N, AArch64ISD::SMINV_PRED, DAG); + case Intrinsic::aarch64_sve_uaddv: + return LowerSVEIntReduction(N, AArch64ISD::UADDV_PRED, DAG); + case Intrinsic::aarch64_sve_umaxv: + return LowerSVEIntReduction(N, AArch64ISD::UMAXV_PRED, DAG); + case Intrinsic::aarch64_sve_uminv: + return LowerSVEIntReduction(N, AArch64ISD::UMINV_PRED, DAG); + case Intrinsic::aarch64_sve_cmpeq_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmpeq, + false, DAG); + case Intrinsic::aarch64_sve_cmpne_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmpne, + false, DAG); + case Intrinsic::aarch64_sve_cmpge_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmpge, + false, DAG); + case Intrinsic::aarch64_sve_cmpgt_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmpgt, + false, DAG); + case Intrinsic::aarch64_sve_cmplt_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmpgt, + true, DAG); + case Intrinsic::aarch64_sve_cmple_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmpge, + true, DAG); + case Intrinsic::aarch64_sve_cmphs_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmphs, + false, DAG); + case Intrinsic::aarch64_sve_cmphi_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmphi, + false, DAG); + case Intrinsic::aarch64_sve_cmplo_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmphi, true, DAG); + case Intrinsic::aarch64_sve_cmpls_wide: + return tryConvertSVEWideCompare(N, Intrinsic::aarch64_sve_cmphs, true, DAG); + case Intrinsic::aarch64_sve_ptest_any: + return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2), + AArch64CC::ANY_ACTIVE); + case Intrinsic::aarch64_sve_ptest_first: + return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2), + AArch64CC::FIRST_ACTIVE); + case Intrinsic::aarch64_sve_ptest_last: + return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2), + AArch64CC::LAST_ACTIVE); } return SDValue(); } @@ -9818,9 +10603,7 @@ if (SrcVT.getSizeInBits() != 64) return SDValue(); - unsigned SrcEltSize = SrcVT.getScalarSizeInBits(); - unsigned ElementCount = SrcVT.getVectorNumElements(); - SrcVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize * 2), ElementCount); + SrcVT = SrcVT.widenIntegerVectorElementType(*DAG.getContext()); SDLoc DL(N); Src = DAG.getNode(N->getOpcode(), DL, SrcVT, Src); @@ -9828,13 +10611,13 @@ // bit source. EVT LoVT, HiVT; SDValue Lo, Hi; - unsigned NumElements = ResVT.getVectorNumElements(); - assert(!(NumElements & 1) && "Splitting vector, but not in half!"); + auto ResEltCnt = ResVT.getVectorElementCount(); + assert(!(ResEltCnt.Min & 1) && "Splitting vector, but not in half!"); LoVT = HiVT = EVT::getVectorVT(*DAG.getContext(), - ResVT.getVectorElementType(), NumElements / 2); + ResVT.getVectorElementType(), ResEltCnt/2); EVT InNVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getVectorElementType(), - LoVT.getVectorNumElements()); + LoVT.getVectorElementCount()); Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InNVT, Src, DAG.getConstant(0, DL, MVT::i64)); Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InNVT, Src, @@ -10088,6 +10871,10 @@ SelectionDAG &DAG = DCI.DAG; EVT VT = N->getValueType(0); + // If scalable, skip this + if (VT.isScalableVector()) + return SDValue(); + unsigned LoadIdx = IsLaneOp ? 1 : 0; SDNode *LD = N->getOperand(LoadIdx).getNode(); // If it is not LOAD, can not do such combine. @@ -10660,6 +11447,36 @@ return SDValue(); } +static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG) { + assert(N->getOpcode() == ISD::SETCC && "Unexpected opcode!"); + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + ISD::CondCode Cond = cast(N->getOperand(2))->get(); + + // setcc (csel 0, 1, cond, X), 1, ne ==> csel 0, 1, !cond, X + if (Cond == ISD::SETNE && isOneConstant(RHS) && + LHS->getOpcode() == AArch64ISD::CSEL && + isNullConstant(LHS->getOperand(0)) && + isOneConstant(LHS->getOperand(1)) && + LHS->hasOneUse()) { + SDLoc DL(N); + + // Invert CSEL's condition. + auto OpCC = cast(LHS.getOperand(2)); + auto OldCond = static_cast(OpCC->getZExtValue()); + auto NewCond = getInvertedCondCode(OldCond); + + // csel 0, 1, !cond, X + SDValue CSEL = DAG.getNode(AArch64ISD::CSEL, DL, LHS.getValueType(), + LHS.getOperand(0), LHS.getOperand(1), + DAG.getConstant(NewCond, DL, MVT::i32), + LHS.getOperand(3)); + return DAG.getZExtOrTrunc(CSEL, DL, N->getValueType(0)); + } + + return SDValue(); +} + // Optimize some simple tbz/tbnz cases. Returns the new operand and bit to test // as well as whether the test should be inverted. This code is required to // catch these cases (as opposed to standard dag combines) because @@ -10763,6 +11580,10 @@ // condition. If it can legalize "VSELECT v1i1" correctly, no need to combine // such VSELECT. static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) { + SDValue MergeRdx; + if ((MergeRdx = performVSelectMinMaxRdxCombine(N, DAG))) + return MergeRdx; + SDValue N0 = N->getOperand(0); EVT CCVT = N0.getValueType(); @@ -10800,6 +11621,9 @@ if (N0.getOpcode() != ISD::SETCC) return SDValue(); + if (ResVT.isScalableVector()) + return SDValue(); + // Make sure the SETCC result is either i1 (initial DAG), or i32, the lowered // scalar SetCCResultType. We also don't expect vectors, because we assume // that selects fed by vector SETCCs are canonicalized to VSELECT. @@ -10851,7 +11675,7 @@ return DAG.getSelect(DL, ResVT, Mask, N->getOperand(1), N->getOperand(2)); } -/// Get rid of unnecessary NVCASTs (that don't change the type). +/// Get rid of unnecessary NVCAST/REINTERPRET_CAST (that don't change the type). static SDValue performNVCASTCombine(SDNode *N) { if (N->getValueType(0) == N->getOperand(0).getValueType()) return N->getOperand(0); @@ -10922,6 +11746,8 @@ case ISD::ADD: case ISD::SUB: return performAddSubLongCombine(N, DCI, DAG); + case ISD::AND: + return performAndCombine(N, DCI, Subtarget); case ISD::XOR: return performXorCombine(N, DAG, DCI, Subtarget); case ISD::MUL: @@ -10944,6 +11770,10 @@ case ISD::ZERO_EXTEND: case ISD::SIGN_EXTEND: return performExtendCombine(N, DCI, DAG); + case ISD::SIGN_EXTEND_INREG: + return performExtendInRegCombine(N, DCI, DAG); + case ISD::TRUNCATE: + return performTRUNCATECombine(N, DCI, DAG); case ISD::BITCAST: return performBitcastCombine(N, DCI, DAG); case ISD::CONCAT_VECTORS: @@ -10965,12 +11795,21 @@ return performTBZCombine(N, DCI, DAG); case AArch64ISD::CSEL: return performCONDCombine(N, DCI, DAG, 2, 3); + case ISD::SETCC: + return performSETCCCombine(N, DAG); case AArch64ISD::DUP: return performPostLD1Combine(N, DCI, false); case AArch64ISD::NVCAST: + case AArch64ISD::REINTERPRET_CAST: return performNVCASTCombine(N); case ISD::INSERT_VECTOR_ELT: return performPostLD1Combine(N, DCI, true); + case ISD::EXTRACT_VECTOR_ELT: { + SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget); + if (Res == SDValue()) + Res = performLastTrueTestVectorCombine(N, DCI, Subtarget); + return Res; + } case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: switch (cast(N->getOperand(1))->getZExtValue()) { @@ -10996,9 +11835,96 @@ case Intrinsic::aarch64_neon_st3lane: case Intrinsic::aarch64_neon_st4lane: return performNEONPostLDSTCombine(N, DCI, DAG); + case Intrinsic::aarch64_sve_ld1rq: + return performLD1RQCombine(N, DAG); + case Intrinsic::aarch64_sve_ldnt1: + return performLDNT1Combine(N, DAG); + case Intrinsic::aarch64_sve_ldnt1_gather: + return performLDNTGatherCombine(N, DAG); + case Intrinsic::aarch64_sve_stnt1: + return performSTNT1Combine(N, DAG); + case Intrinsic::aarch64_sve_stnt1_scatter: + return performSTNT1ScatterCombine(N, DAG); + case Intrinsic::aarch64_sve_rdffr: { + SDVTList Tys = DAG.getVTList(N->getValueType(0), MVT::Other); + return DAG.getNode(AArch64ISD::RDFFR, SDLoc(N), Tys, N->getOperand(0)); + } + case Intrinsic::aarch64_sve_rdffr_z: { + SDVTList Tys = DAG.getVTList(N->getValueType(0), MVT::Other); + return DAG.getNode(AArch64ISD::RDFFR_PRED, SDLoc(N), Tys, + N->getOperand(0), N->getOperand(2)); + } + case Intrinsic::aarch64_sve_setffr: { + SDVTList Tys = DAG.getVTList(MVT::Other, MVT::Glue); + return DAG.getNode(AArch64ISD::SETFFR, SDLoc(N), Tys, N->getOperand(0)); + } + case Intrinsic::aarch64_sve_wrffr: { + SDVTList Tys = DAG.getVTList(MVT::Other, MVT::Glue); + return DAG.getNode(AArch64ISD::WRFFR, SDLoc(N), Tys, N->getOperand(0), + N->getOperand(2)); + } + case Intrinsic::aarch64_sve_ldff1: + return performLDFF1Combine(N, DAG, false /* first-faulting */); + case Intrinsic::aarch64_sve_ldnf1: + return performLDFF1Combine(N, DAG, true /* non-faulting */); + case Intrinsic::aarch64_sve_ldff1_gather: + return performLDFF1GatherCombine(N, DAG); + case Intrinsic::aarch64_sve_prfb_gather: + return performPrefetchGatherCombine(N, DAG, MVT::nxv16i8); + case Intrinsic::aarch64_sve_prfh_gather: + return performPrefetchGatherCombine(N, DAG, MVT::nxv8i16); + case Intrinsic::aarch64_sve_prfw_gather: + return performPrefetchGatherCombine(N, DAG, MVT::nxv4i32); + case Intrinsic::aarch64_sve_prfd_gather: + return performPrefetchGatherCombine(N, DAG, MVT::nxv2i64); default: break; } + break; + case ISD::MGATHER: + return performMGATHERCombine(cast(N), DCI); + case ISD::MSCATTER: + return performMSCATTERCombine(cast(N), DCI); + case ISD::VSCALE: + return performVScaleCombine(N, DCI, DAG); + case ISD::EXTRACT_SUBVECTOR: { + EVT InVT = N->getOperand(0).getValueType(); + if (InVT.isScalableVector() && InVT.isFloatingPoint() && + DCI.isBeforeLegalize()) { + SDLoc DL(N); + // Bitcast the input + SDValue VecOp = N->getOperand(0); + VecOp = DAG.getNode(ISD::BITCAST, DL, InVT.changeTypeToInteger(), + VecOp); + // Perform extract in integer type + EVT OutVT = N->getValueType(0); + SDValue Extract = + DAG.getNode(N->getOpcode(), DL, OutVT.changeTypeToInteger(), VecOp, + N->getOperand(1)); + // Bitcast back to fp type + return DAG.getNode(ISD::BITCAST, DL, OutVT, Extract); + } + return SDValue(); + } + case ISD::VECTOR_SHUFFLE_VAR: { + auto VT = N->getValueType(0); + if (VT.isScalableVector() && VT.isFloatingPoint() && + DCI.isBeforeLegalize()) { + SDLoc DL(N); + SDValue Op1 = N->getOperand(0); + SDValue Op2 = N->getOperand(1); + SDValue Sel = N->getOperand(2); + EVT OutVT = VT.changeTypeToInteger(); + EVT InVT = Op1.getValueType().changeTypeToInteger(); + + Op1 = DAG.getNode(ISD::BITCAST, DL, InVT, Op1); + Op2 = DAG.getNode(ISD::BITCAST, DL, InVT, Op2); + SDValue Shuffle = DAG.getNode(ISD::VECTOR_SHUFFLE_VAR, DL, + OutVT, Op1, Op2, Sel); + return DAG.getNode(ISD::BITCAST, DL, VT, Shuffle); + } + return SDValue(); + } case ISD::GlobalAddress: return performGlobalAddressCombine(N, DAG, Subtarget, getTargetMachine()); } @@ -11121,11 +12047,17 @@ return true; } -static void ReplaceBITCASTResults(SDNode *N, SmallVectorImpl &Results, - SelectionDAG &DAG) { +void AArch64TargetLowering::ReplaceBITCASTResults(SDNode *N, + SmallVectorImpl &Results, + SelectionDAG &DAG) const { SDLoc DL(N); SDValue Op = N->getOperand(0); + if (N->getValueType(0).isScalableVector()) { + ReplaceVectorBITCASTResults(N, Results, DAG); + return; + } + if (N->getValueType(0) != MVT::i16 || Op.getValueType() != MVT::f16) return; @@ -11257,17 +12189,15 @@ switch (N->getOpcode()) { default: llvm_unreachable("Don't know how to custom expand this"); + case ISD::FP_EXTEND: + ReplaceFP_EXTENDResults(N, Results, DAG); + return; case ISD::BITCAST: ReplaceBITCASTResults(N, Results, DAG); return; - case ISD::VECREDUCE_ADD: - case ISD::VECREDUCE_SMAX: - case ISD::VECREDUCE_SMIN: - case ISD::VECREDUCE_UMAX: - case ISD::VECREDUCE_UMIN: - Results.push_back(LowerVECREDUCE(SDValue(N, 0), DAG)); + case ISD::VECTOR_SHUFFLE_VAR: + ReplaceVectorShuffleVarResults(N, Results, DAG); return; - case AArch64ISD::SADDV: ReplaceReductionResults(N, Results, DAG, ISD::ADD, AArch64ISD::SADDV); return; @@ -11291,9 +12221,85 @@ assert(N->getValueType(0) == MVT::i128 && "unexpected illegal conversion"); // Let normal code take care of it by not adding anything to Results. return; + case ISD::SIGN_EXTEND: + ReplaceExtensionResults(N, Results, DAG, + AArch64ISD::SUNPKHI, AArch64ISD::SUNPKLO); + return; + case ISD::ZERO_EXTEND: + case ISD::ANY_EXTEND: + ReplaceExtensionResults(N, Results, DAG, + AArch64ISD::UUNPKHI, AArch64ISD::UUNPKLO); + return; + case ISD::EXTRACT_SUBVECTOR: + ReplaceExtractSubVectorResults(N, Results, DAG); + return; + case ISD::INSERT_SUBVECTOR: + ReplaceInsertSubVectorResults(N, Results, DAG); + return; + case ISD::INSERT_VECTOR_ELT: + ReplaceInsertVectorElementResults(N, Results, DAG); + return; + case ISD::CONCAT_VECTORS: + // The real work is done by LowerCONCAT_VECTORS. + return; + case ISD::INTRINSIC_WO_CHAIN: { + ConstantSDNode *CN = cast(N->getOperand(0)); + Intrinsic::ID IntID = static_cast(CN->getZExtValue()); + switch (IntID) { + default: + return; + + case Intrinsic::aarch64_sve_clasta_n: { + SDLoc DL(N); + auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2)); + auto V = DAG.getNode(AArch64ISD::CLASTA_N, DL, MVT::i32, + N->getOperand(1), Op2, N->getOperand(3)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), V)); + return; + } + case Intrinsic::aarch64_sve_clastb_n: { + SDLoc DL(N); + auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2)); + auto V = DAG.getNode(AArch64ISD::CLASTB_N, DL, MVT::i32, + N->getOperand(1), Op2, N->getOperand(3)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), V)); + return; + } + + case Intrinsic::aarch64_sve_lasta: { + SDLoc DL(N); + auto V = DAG.getNode(AArch64ISD::LASTA, DL, MVT::i32, + N->getOperand(1), N->getOperand(2)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), V)); + return; + } + case Intrinsic::aarch64_sve_lastb: { + SDLoc DL(N); + auto V = DAG.getNode(AArch64ISD::LASTB, DL, MVT::i32, + N->getOperand(1), N->getOperand(2)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), V)); + return; + } + } + } + case ISD::INTRINSIC_W_CHAIN: { + ConstantSDNode *CN = cast(N->getOperand(1)); + Intrinsic::ID IntID = static_cast(CN->getZExtValue()); + switch (IntID) { + default: + return; + + case Intrinsic::masked_spec_load: + ReplaceMaskedSpecLoadResults(N, Results, DAG); + return; + } + } case ISD::ATOMIC_CMP_SWAP: ReplaceCMP_SWAP_128Results(N, Results, DAG, Subtarget); return; + case AArch64ISD::DUP_PRED: + ReplaceMergeVecCpyResults(N, Results, DAG); + return; } } @@ -11318,6 +12324,9 @@ || SVT == MVT::v1f32) return TypeWidenVector; + if (VT.isScalableVector() && (VT.getVectorNumElements() == 1)) + return TypeWidenVector; + return TargetLoweringBase::getPreferredVectorAction(VT); } @@ -11493,7 +12502,7 @@ ConstantInt* Mask = dyn_cast(AndI.getOperand(1)); if (!Mask) return false; - return Mask->getValue().isPowerOf2(); + return Mask->getUniqueInteger().isPowerOf2(); } void AArch64TargetLowering::initializeSplitCSR(MachineBasicBlock *Entry) const { Index: lib/Target/AArch64/AArch64InstrFormats.td =================================================================== --- lib/Target/AArch64/AArch64InstrFormats.td +++ lib/Target/AArch64/AArch64InstrFormats.td @@ -21,6 +21,28 @@ def PseudoFrm : Format<0>; def NormalFrm : Format<1>; // Do we need any others? +// Enum describing whether an instruction is +// destructive in its first source operand. +class DestructiveInstTypeEnum val> { + bits<4> Value = val; +} +def NotDestructive : DestructiveInstTypeEnum<0>; +def DestructiveOther : DestructiveInstTypeEnum<1>; +def DestructiveUnary : DestructiveInstTypeEnum<2>; +def DestructiveBinaryImm : DestructiveInstTypeEnum<3>; +def DestructiveBinaryShImmUnpred : DestructiveInstTypeEnum<4>; +def DestructiveBinary : DestructiveInstTypeEnum<5>; +def DestructiveBinaryComm : DestructiveInstTypeEnum<6>; +def DestructiveBinaryCommWithRev : DestructiveInstTypeEnum<7>; +def DestructiveTernaryCommWithRev : DestructiveInstTypeEnum<8>; + +class FalseLanesEnum val> { + bits<2> Value = val; +} +def FalseLanesNone : FalseLanesEnum<0>; +def FalseLanesZero : FalseLanesEnum<1>; +def FalseLanesUndef : FalseLanesEnum<2>; + // AArch64 Instruction Format class AArch64Inst : Instruction { field bits<32> Inst; // Instruction encoding. @@ -35,6 +57,20 @@ let Namespace = "AArch64"; Format F = f; bits<2> Form = F.Value; + + // Defaults + bit isWhile = 0; + bit isPTestLike = 0; + FalseLanesEnum FalseLanes = FalseLanesNone; + DestructiveInstTypeEnum DestructiveInstType = NotDestructive; + ElementSizeEnum ElementSize = ElementSizeNone; + + let TSFlags{10} = isPTestLike; + let TSFlags{9} = isWhile; + let TSFlags{8-7} = FalseLanes.Value; + let TSFlags{6-3} = DestructiveInstType.Value; + let TSFlags{2-0} = ElementSize.Value; + let Pattern = []; let Constraints = cstr; } @@ -49,6 +85,7 @@ dag InOperandList = iops; let Pattern = pattern; let isCodeGenOnly = 1; + let isPseudo = 1; } // Real instructions (have encoding information) @@ -57,14 +94,6 @@ let Size = 4; } -// Enum describing whether an instruction is -// destructive in its first source operand. -class DestructiveInstTypeEnum val> { - bits<1> Value = val; -} -def NotDestructive : DestructiveInstTypeEnum<0>; -def Destructive : DestructiveInstTypeEnum<1>; - // Normal instructions class I pattern> @@ -72,13 +101,6 @@ dag OutOperandList = oops; dag InOperandList = iops; let AsmString = !strconcat(asm, operands); - - // Destructive operations (SVE) - DestructiveInstTypeEnum DestructiveInstType = NotDestructive; - ElementSizeEnum ElementSize = ElementSizeB; - - let TSFlags{3} = DestructiveInstType.Value; - let TSFlags{2-0} = ElementSize.Value; } class TriOpFrag : PatFrag<(ops node:$LHS, node:$MHS, node:$RHS), res>; @@ -288,7 +310,7 @@ } def SImm8Operand : SImmOperand<8>; -def simm8 : Operand, ImmLeaf= -128 && Imm < 127; }]> { +def simm8 : Operand, ImmLeaf= -128 && Imm < 128; }]> { let ParserMatchClass = SImm8Operand; let DecoderMethod = "DecodeSImm<8>"; } @@ -338,6 +360,17 @@ def am_indexed7s64 : ComplexPattern; def am_indexed7s128 : ComplexPattern; +def UImmS2XForm : SDNodeXFormgetTargetConstant(N->getZExtValue() / 2, SDLoc(N), MVT::i64); +}]>; +def UImmS4XForm : SDNodeXFormgetTargetConstant(N->getZExtValue() / 4, SDLoc(N), MVT::i64); +}]>; +def UImmS8XForm : SDNodeXFormgetTargetConstant(N->getZExtValue() / 8, SDLoc(N), MVT::i64); +}]>; + + // uimm5sN predicate - True if the immediate is a multiple of N in the range // [0 * N, 32 * N]. def UImm5s2Operand : UImmScaledMemoryIndexed<5, 2>; @@ -345,17 +378,20 @@ def UImm5s8Operand : UImmScaledMemoryIndexed<5, 8>; def uimm5s2 : Operand, ImmLeaf= 0 && Imm < (32*2) && ((Imm % 2) == 0); }]> { + [{ return Imm >= 0 && Imm < (32*2) && ((Imm % 2) == 0); }], + UImmS2XForm> { let ParserMatchClass = UImm5s2Operand; let PrintMethod = "printImmScale<2>"; } def uimm5s4 : Operand, ImmLeaf= 0 && Imm < (32*4) && ((Imm % 4) == 0); }]> { + [{ return Imm >= 0 && Imm < (32*4) && ((Imm % 4) == 0); }], + UImmS4XForm> { let ParserMatchClass = UImm5s4Operand; let PrintMethod = "printImmScale<4>"; } def uimm5s8 : Operand, ImmLeaf= 0 && Imm < (32*8) && ((Imm % 8) == 0); }]> { + [{ return Imm >= 0 && Imm < (32*8) && ((Imm % 8) == 0); }], + UImmS8XForm> { let ParserMatchClass = UImm5s8Operand; let PrintMethod = "printImmScale<8>"; } @@ -371,17 +407,17 @@ let ParserMatchClass = UImm6s1Operand; } def uimm6s2 : Operand, ImmLeaf= 0 && Imm < (64*2) && ((Imm % 2) == 0); }]> { +[{ return Imm >= 0 && Imm < (64*2) && ((Imm % 2) == 0); }], UImmS2XForm> { let PrintMethod = "printImmScale<2>"; let ParserMatchClass = UImm6s2Operand; } def uimm6s4 : Operand, ImmLeaf= 0 && Imm < (64*4) && ((Imm % 4) == 0); }]> { +[{ return Imm >= 0 && Imm < (64*4) && ((Imm % 4) == 0); }], UImmS4XForm> { let PrintMethod = "printImmScale<4>"; let ParserMatchClass = UImm6s4Operand; } def uimm6s8 : Operand, ImmLeaf= 0 && Imm < (64*8) && ((Imm % 8) == 0); }]> { +[{ return Imm >= 0 && Imm < (64*8) && ((Imm % 8) == 0); }], UImmS8XForm> { let PrintMethod = "printImmScale<8>"; let ParserMatchClass = UImm6s8Operand; } @@ -394,6 +430,19 @@ let DecoderMethod = "DecodeSImm<6>"; } +def SImmS2XForm : SDNodeXFormgetTargetConstant(N->getSExtValue() / 2, SDLoc(N), MVT::i64); +}]>; +def SImmS3XForm : SDNodeXFormgetTargetConstant(N->getSExtValue() / 3, SDLoc(N), MVT::i64); +}]>; +def SImmS4XForm : SDNodeXFormgetTargetConstant(N->getSExtValue() / 4, SDLoc(N), MVT::i64); +}]>; +def SImmS16XForm : SDNodeXFormgetTargetConstant(N->getSExtValue() / 16, SDLoc(N), MVT::i64); +}]>; + // simm4sN predicate - True if the immediate is a multiple of N in the range // [ -8* N, 7 * N]. def SImm4s1Operand : SImmScaledMemoryIndexed<4, 1>; @@ -409,27 +458,27 @@ } def simm4s2 : Operand, ImmLeaf=-16 && Imm <= 14 && (Imm % 2) == 0x0; }]> { +[{ return Imm >=-16 && Imm <= 14 && (Imm % 2) == 0x0; }], SImmS2XForm> { let PrintMethod = "printImmScale<2>"; let ParserMatchClass = SImm4s2Operand; let DecoderMethod = "DecodeSImm<4>"; } def simm4s3 : Operand, ImmLeaf=-24 && Imm <= 21 && (Imm % 3) == 0x0; }]> { +[{ return Imm >=-24 && Imm <= 21 && (Imm % 3) == 0x0; }], SImmS3XForm> { let PrintMethod = "printImmScale<3>"; let ParserMatchClass = SImm4s3Operand; let DecoderMethod = "DecodeSImm<4>"; } def simm4s4 : Operand, ImmLeaf=-32 && Imm <= 28 && (Imm % 4) == 0x0; }]> { +[{ return Imm >=-32 && Imm <= 28 && (Imm % 4) == 0x0; }], SImmS4XForm> { let PrintMethod = "printImmScale<4>"; let ParserMatchClass = SImm4s4Operand; let DecoderMethod = "DecodeSImm<4>"; } def simm4s16 : Operand, ImmLeaf=-128 && Imm <= 112 && (Imm % 16) == 0x0; }]> { +[{ return Imm >=-128 && Imm <= 112 && (Imm % 16) == 0x0; }], SImmS16XForm> { let PrintMethod = "printImmScale<16>"; let ParserMatchClass = SImm4s16Operand; let DecoderMethod = "DecodeSImm<4>"; @@ -608,6 +657,7 @@ } def Imm0_1Operand : AsmImmRange<0, 1>; +def Imm0_3Operand : AsmImmRange<0, 3>; def Imm0_7Operand : AsmImmRange<0, 7>; def Imm0_15Operand : AsmImmRange<0, 15>; def Imm0_31Operand : AsmImmRange<0, 31>; @@ -642,7 +692,6 @@ let ParserMatchClass = Imm0_63Operand; } - // Crazy immediate formats used by 32-bit and 64-bit logical immediate // instructions for splatting repeating bit patterns across the immediate. def logical_imm32_XFORM : SDNodeXForm, ImmLeaf { + let ParserMatchClass = Imm0_127Operand; + let PrintMethod = "printImm"; +} // NOTE: These imm0_N operands have to be of type i64 because i64 is the size // for all shift-amounts. @@ -768,6 +823,20 @@ let ParserMatchClass = Imm0_7Operand; } +// imm0_3 predicate - True if the immediate is in the range [0,3] +def imm0_3 : Operand, ImmLeaf { + let ParserMatchClass = Imm0_3Operand; +} + +// imm32_0_7 predicate - True if the 32-bit immediate is in the range [0,7] +def imm32_0_7 : Operand, ImmLeaf { + let ParserMatchClass = Imm0_7Operand; +} + // imm32_0_15 predicate - True if the 32-bit immediate is in the range [0,15] def imm32_0_15 : Operand, ImmLeaf; +def fpimm_half : FPImmLeaf; + +def fpimm_one : FPImmLeaf; + +def fpimm_two : FPImmLeaf; + // Vector lane operands class AsmVectorIndex : AsmOperandClass { let Name = NamePrefix # "IndexRange" # Min # "_" # Max; @@ -982,8 +1063,8 @@ let RenderMethod = "addVectorIndexOperands"; } -class AsmVectorIndexOpnd - : Operand, ImmLeaf { +class AsmVectorIndexOpnd + : Operand, ImmLeaf { let ParserMatchClass = mc; let PrintMethod = "printVectorIndex"; } @@ -994,11 +1075,17 @@ def VectorIndexSOperand : AsmVectorIndex<0, 3>; def VectorIndexDOperand : AsmVectorIndex<0, 1>; -def VectorIndex1 : AsmVectorIndexOpnd; -def VectorIndexB : AsmVectorIndexOpnd; -def VectorIndexH : AsmVectorIndexOpnd; -def VectorIndexS : AsmVectorIndexOpnd; -def VectorIndexD : AsmVectorIndexOpnd; +def VectorIndex1 : AsmVectorIndexOpnd; +def VectorIndexB : AsmVectorIndexOpnd; +def VectorIndexH : AsmVectorIndexOpnd; +def VectorIndexS : AsmVectorIndexOpnd; +def VectorIndexD : AsmVectorIndexOpnd; + +def VectorIndex132b : AsmVectorIndexOpnd; +def VectorIndexB32b : AsmVectorIndexOpnd; +def VectorIndexH32b : AsmVectorIndexOpnd; +def VectorIndexS32b : AsmVectorIndexOpnd; +def VectorIndexD32b : AsmVectorIndexOpnd; def SVEVectorIndexExtDupBOperand : AsmVectorIndex<0, 63, "SVE">; def SVEVectorIndexExtDupHOperand : AsmVectorIndex<0, 31, "SVE">; @@ -1007,15 +1094,15 @@ def SVEVectorIndexExtDupQOperand : AsmVectorIndex<0, 3, "SVE">; def sve_elm_idx_extdup_b - : AsmVectorIndexOpnd; + : AsmVectorIndexOpnd; def sve_elm_idx_extdup_h - : AsmVectorIndexOpnd; + : AsmVectorIndexOpnd; def sve_elm_idx_extdup_s - : AsmVectorIndexOpnd; + : AsmVectorIndexOpnd; def sve_elm_idx_extdup_d - : AsmVectorIndexOpnd; + : AsmVectorIndexOpnd; def sve_elm_idx_extdup_q - : AsmVectorIndexOpnd; + : AsmVectorIndexOpnd; // 8-bit immediate for AdvSIMD where 64-bit values of the form: // aaaaaaaa bbbbbbbb cccccccc dddddddd eeeeeeee ffffffff gggggggg hhhhhhhh @@ -1035,7 +1122,6 @@ let PrintMethod = "printSIMDType10Operand"; } - //--- // System management //--- @@ -2719,6 +2805,13 @@ def am_indexed64 : ComplexPattern; def am_indexed128 : ComplexPattern; +// (unsigned immediate) +// Indexed for 8-bit registers. offset is in range [0,63]. +def am_indexed8_6b : ComplexPattern", []>; +def am_indexed16_6b : ComplexPattern", []>; +def am_indexed32_6b : ComplexPattern", []>; +def am_indexed64_6b : ComplexPattern", []>; + def gi_am_indexed8 : GIComplexOperandMatcher">, GIComplexPatternEquiv; @@ -3383,6 +3476,11 @@ def am_unscaled64 : ComplexPattern; def am_unscaled128 :ComplexPattern; +def am_sve_pred : ComplexPattern; +def am_sve_indexed_s9 :ComplexPattern", [], [SDNPWantRoot]>; +def am_sve_indexed_s6 :ComplexPattern", [], [SDNPWantRoot]>; +def am_sve_indexed_s4 :ComplexPattern", [], [SDNPWantRoot]>; + def gi_am_unscaled8 : GIComplexOperandMatcher, GIComplexPatternEquiv; @@ -7696,8 +7794,8 @@ asm#"2", ".4s", ".4s", ".8h", ".h", [(set (v4i32 V128:$Rd), (OpNode (extract_high_v8i16 V128:$Rn), - (extract_high_v8i16 (AArch64duplane16 (v8i16 V128_lo:$Rm), - VectorIndexH:$idx))))]> { + (extract_high_v8i16 (v8i16 (AArch64duplane16 (v8i16 V128_lo:$Rm), + VectorIndexH:$idx)))))]> { bits<3> idx; let Inst{11} = idx{2}; @@ -7723,8 +7821,8 @@ asm#"2", ".2d", ".2d", ".4s", ".s", [(set (v2i64 V128:$Rd), (OpNode (extract_high_v4i32 V128:$Rn), - (extract_high_v4i32 (AArch64duplane32 (v4i32 V128:$Rm), - VectorIndexS:$idx))))]> { + (extract_high_v4i32 (v4i32 (AArch64duplane32 (v4i32 V128:$Rm), + VectorIndexS:$idx)))))]> { bits<2> idx; let Inst{11} = idx{1}; let Inst{21} = idx{0}; @@ -7789,8 +7887,8 @@ (v4i32 (int_aarch64_neon_sqdmull (extract_high_v8i16 V128:$Rn), (extract_high_v8i16 - (AArch64duplane16 (v8i16 V128_lo:$Rm), - VectorIndexH:$idx))))))]> { + (v8i16 (AArch64duplane16 (v8i16 V128_lo:$Rm), + VectorIndexH:$idx)))))))]> { bits<3> idx; let Inst{11} = idx{2}; let Inst{21} = idx{1}; @@ -7821,8 +7919,8 @@ (v2i64 (int_aarch64_neon_sqdmull (extract_high_v4i32 V128:$Rn), (extract_high_v4i32 - (AArch64duplane32 (v4i32 V128:$Rm), - VectorIndexS:$idx))))))]> { + (v4i32 (AArch64duplane32 (v4i32 V128:$Rm), + VectorIndexS:$idx)))))))]> { bits<2> idx; let Inst{11} = idx{1}; let Inst{21} = idx{0}; @@ -7876,8 +7974,8 @@ asm#"2", ".4s", ".4s", ".8h", ".h", [(set (v4i32 V128:$Rd), (OpNode (extract_high_v8i16 V128:$Rn), - (extract_high_v8i16 (AArch64duplane16 (v8i16 V128_lo:$Rm), - VectorIndexH:$idx))))]> { + (extract_high_v8i16 (v8i16 (AArch64duplane16 (v8i16 V128_lo:$Rm), + VectorIndexH:$idx)))))]> { bits<3> idx; let Inst{11} = idx{2}; @@ -7903,8 +8001,8 @@ asm#"2", ".2d", ".2d", ".4s", ".s", [(set (v2i64 V128:$Rd), (OpNode (extract_high_v4i32 V128:$Rn), - (extract_high_v4i32 (AArch64duplane32 (v4i32 V128:$Rm), - VectorIndexS:$idx))))]> { + (extract_high_v4i32 (v4i32 (AArch64duplane32 (v4i32 V128:$Rm), + VectorIndexS:$idx)))))]> { bits<2> idx; let Inst{11} = idx{1}; let Inst{21} = idx{0}; @@ -7935,8 +8033,8 @@ [(set (v4i32 V128:$dst), (OpNode (v4i32 V128:$Rd), (extract_high_v8i16 V128:$Rn), - (extract_high_v8i16 (AArch64duplane16 (v8i16 V128_lo:$Rm), - VectorIndexH:$idx))))]> { + (extract_high_v8i16 (v8i16 (AArch64duplane16 (v8i16 V128_lo:$Rm), + VectorIndexH:$idx)))))]> { bits<3> idx; let Inst{11} = idx{2}; let Inst{21} = idx{1}; @@ -7962,8 +8060,8 @@ [(set (v2i64 V128:$dst), (OpNode (v2i64 V128:$Rd), (extract_high_v4i32 V128:$Rn), - (extract_high_v4i32 (AArch64duplane32 (v4i32 V128:$Rm), - VectorIndexS:$idx))))]> { + (extract_high_v4i32 (v4i32 (AArch64duplane32 (v4i32 V128:$Rm), + VectorIndexS:$idx)))))]> { bits<2> idx; let Inst{11} = idx{1}; let Inst{21} = idx{0}; @@ -9711,11 +9809,17 @@ let DiagnosticType = "InvalidComplexRotation" # Type; let Name = "ComplexRotation" # Type; } -def complexrotateop : Operand { +def complexrotateop : Operand, ImmLeaf= 0 && Imm <= 270; }], + SDNodeXFormgetTargetConstant((N->getSExtValue() / 90), SDLoc(N), MVT::i64); +}]>> { let ParserMatchClass = ComplexRotationOperand<90, 0, "Even">; let PrintMethod = "printComplexRotationOp<90, 0>"; } -def complexrotateopodd : Operand { +def complexrotateopodd : Operand, ImmLeaf= 0 && Imm <= 270; }], + SDNodeXFormgetTargetConstant(((N->getSExtValue() - 90) / 180), SDLoc(N), MVT::i64); +}]>> { let ParserMatchClass = ComplexRotationOperand<180, 90, "Odd">; let PrintMethod = "printComplexRotationOp<180, 90>"; } Index: lib/Target/AArch64/AArch64InstrInfo.h =================================================================== --- lib/Target/AArch64/AArch64InstrInfo.h +++ lib/Target/AArch64/AArch64InstrInfo.h @@ -16,6 +16,7 @@ #include "AArch64.h" #include "AArch64RegisterInfo.h" +#include "AArch64StackOffset.h" #include "llvm/CodeGen/MachineCombinerPattern.h" #include "llvm/CodeGen/TargetInstrInfo.h" @@ -259,8 +260,18 @@ /// Returns true if the instruction has a shift by immediate that can be /// executed in one cycle less. bool isFalkorShiftExtFast(const MachineInstr &MI) const; + /// Returns the vector element size (B, H, S or D) of an SVE opcode. + uint64_t getElementSizeForOpcode(unsigned Opc) const; + /// Returns true if the opcode is for an SVE instruction that sets the + /// condition codes as if it's results had been fed to a PTEST instruction + /// along with the same general predicate. + bool isPTestLikeOpcode(unsigned Opc) const; + /// Returns true if the opcode is for an SVE WHILE## instruction. + bool isWhileOpcode(unsigned Opc) const; private: + unsigned getInstBundleLength(const MachineInstr &MI) const; + /// Sets the offsets on outlined instructions in \p MBB which use SP /// so that they will be valid post-outlining. /// @@ -276,6 +287,12 @@ /// Returns an unused general-purpose register which can be used for /// constructing an outlined call if one exists. Returns 0 otherwise. unsigned findRegisterToSaveLRTo(const outliner::Candidate &C) const; + + /// Remove a ptest of a predicate-generating operation that already sets, or + /// can be made to set, the condition codes in an identical manner + bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg, + unsigned PredReg, + const MachineRegisterInfo *MRI) const; }; /// emitFrameOffset - Emit instructions as needed to set DestReg to SrcReg @@ -284,7 +301,7 @@ /// if necessary, to be replaced by the scavenger at the end of PEI. void emitFrameOffset(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, const DebugLoc &DL, unsigned DestReg, unsigned SrcReg, - int Offset, const TargetInstrInfo *TII, + StackOffset, const TargetInstrInfo *TII, MachineInstr::MIFlag = MachineInstr::NoFlags, bool SetNZCV = false); @@ -292,7 +309,7 @@ /// FP. Return false if the offset could not be handled directly in MI, and /// return the left-over portion by reference. bool rewriteAArch64FrameIndex(MachineInstr &MI, unsigned FrameRegIdx, - unsigned FrameReg, int &Offset, + unsigned FrameReg, StackOffset &Offset, const AArch64InstrInfo *TII); /// Use to report the frame offset status in isAArch64FrameOffsetLegal. @@ -316,7 +333,7 @@ /// If set, @p EmittableOffset contains the amount that can be set in @p MI /// (possibly with @p OutUnscaledOp if OutUseUnscaledOp is true) and that /// is a legal offset. -int isAArch64FrameOffsetLegal(const MachineInstr &MI, int &Offset, +int isAArch64FrameOffsetLegal(const MachineInstr &MI, StackOffset &Offset, bool *OutUseUnscaledOp = nullptr, unsigned *OutUnscaledOp = nullptr, int *EmittableOffset = nullptr); @@ -344,9 +361,23 @@ return Opc == AArch64::BR; } +static inline bool isPTrueOpcode(unsigned Opc) { + switch (Opc) { + case AArch64::PTRUE_B: + case AArch64::PTRUE_H: + case AArch64::PTRUE_S: + case AArch64::PTRUE_D: + return true; + default: + return false; + } +} + // struct TSFlags { #define TSFLAG_ELEMENT_SIZE_TYPE(X) (X) // 3-bits -#define TSFLAG_DESTRUCTIVE_INST_TYPE(X) ((X) << 3) // 1-bit +#define TSFLAG_DESTRUCTIVE_INST_TYPE(X) ((X) << 3) // 4-bits +#define TSFLAG_FALSE_LANE_TYPE(X) ((X) << 7) // 2-bits +#define TSFLAG_INSTR_FLAGS(X) ((X) << 9) // 2-bits // } namespace AArch64 { @@ -361,13 +392,36 @@ }; enum DestructiveInstType { - DestructiveInstTypeMask = TSFLAG_DESTRUCTIVE_INST_TYPE(0x1), - NotDestructive = TSFLAG_DESTRUCTIVE_INST_TYPE(0x0), - Destructive = TSFLAG_DESTRUCTIVE_INST_TYPE(0x1), + DestructiveInstTypeMask = TSFLAG_DESTRUCTIVE_INST_TYPE(0xf), + NotDestructive = TSFLAG_DESTRUCTIVE_INST_TYPE(0x0), + DestructiveOther = TSFLAG_DESTRUCTIVE_INST_TYPE(0x1), + DestructiveUnary = TSFLAG_DESTRUCTIVE_INST_TYPE(0x2), + DestructiveBinaryImm = TSFLAG_DESTRUCTIVE_INST_TYPE(0x3), + DestructiveBinaryShImmUnpred = TSFLAG_DESTRUCTIVE_INST_TYPE(0x4), + DestructiveBinary = TSFLAG_DESTRUCTIVE_INST_TYPE(0x5), + DestructiveBinaryComm = TSFLAG_DESTRUCTIVE_INST_TYPE(0x6), + DestructiveBinaryCommWithRev = TSFLAG_DESTRUCTIVE_INST_TYPE(0x7), + DestructiveTernaryCommWithRev = TSFLAG_DESTRUCTIVE_INST_TYPE(0x8), }; +enum FalseLaneType { + FalseLanesMask = TSFLAG_FALSE_LANE_TYPE(0x3), + FalseLanesZero = TSFLAG_FALSE_LANE_TYPE(0x1), + FalseLanesUndef = TSFLAG_FALSE_LANE_TYPE(0x2), +}; + +// NOTE: This is a bit field. +static const uint64_t InstrFlagIsWhile = TSFLAG_INSTR_FLAGS(0x1); +static const uint64_t InstrFlagIsPTestLike = TSFLAG_INSTR_FLAGS(0x2); + #undef TSFLAG_ELEMENT_SIZE_TYPE #undef TSFLAG_DESTRUCTIVE_INST_TYPE +#undef TSFLAG_FALSE_LANE_TYPE +#undef TSFLAG_INSTR_FLAGS + +int getSVEPseudoMap(uint16_t Opcode); +int getSVERevInstr(uint16_t Opcode); +int getSVEOrigInstr(uint16_t Opcode); } } // end namespace llvm Index: lib/Target/AArch64/AArch64InstrInfo.cpp =================================================================== --- lib/Target/AArch64/AArch64InstrInfo.cpp +++ lib/Target/AArch64/AArch64InstrInfo.cpp @@ -46,11 +46,13 @@ #include #include #include +#include #include using namespace llvm; #define GET_INSTRINFO_CTOR_DTOR +#define GET_INSTRMAP_INFO #include "AArch64GenInstrInfo.inc" static cl::opt TBZDisplacementBits( @@ -79,6 +81,10 @@ if (MI.getOpcode() == AArch64::INLINEASM) return getInlineAsmLength(MI.getOperand(0).getSymbolName(), *MAI); + // Meta-instructions emit no code. + if (MI.isMetaInstruction()) + return 0; + // FIXME: We currently only handle pseudoinstructions that don't get expanded // before the assembly printer. unsigned NumBytes = 0; @@ -88,12 +94,6 @@ // Anything not explicitly designated otherwise is a normal 4-byte insn. NumBytes = 4; break; - case TargetOpcode::DBG_VALUE: - case TargetOpcode::EH_LABEL: - case TargetOpcode::IMPLICIT_DEF: - case TargetOpcode::KILL: - NumBytes = 0; - break; case TargetOpcode::STACKMAP: // The upper bound for a stackmap intrinsic is the full length of its shadow NumBytes = StackMapOpers(&MI).getNumPatchBytes(); @@ -108,11 +108,25 @@ // This gets lowered to an instruction sequence which takes 16 bytes NumBytes = 16; break; + case TargetOpcode::BUNDLE: + NumBytes = getInstBundleLength(MI); + break; } return NumBytes; } +unsigned AArch64InstrInfo::getInstBundleLength(const MachineInstr &MI) const { + unsigned Size = 0; + MachineBasicBlock::const_instr_iterator I = MI.getIterator(); + MachineBasicBlock::const_instr_iterator E = MI.getParent()->instr_end(); + while (++I != E && I->isInsideBundle()) { + assert(!I->isBundle() && "No nested bundle!"); + Size += getInstSizeInBytes(*I); + } + return Size; +} + static void parseCondBranch(MachineInstr *LastInst, MachineBasicBlock *&Target, SmallVectorImpl &Cond) { // Block ends with fall-through condbranch. @@ -717,6 +731,24 @@ case AArch64::ORRXrr: return true; + // logical ops on register with shift + case AArch64::ANDWrs: + case AArch64::ANDXrs: + case AArch64::BICWrs: + case AArch64::BICXrs: + case AArch64::EONWrs: + case AArch64::EONXrs: + case AArch64::EORWrs: + case AArch64::EORXrs: + case AArch64::ORNWrs: + case AArch64::ORNXrs: + case AArch64::ORRWrs: + case AArch64::ORRXrs: { + unsigned Imm = MI.getOperand(3).getImm(); + return (Subtarget.getProcFamily() == AArch64Subtarget::ExynosM1 && + AArch64_AM::getShiftValue(Imm) < 4 && + AArch64_AM::getShiftType(Imm) == AArch64_AM::LSL); + } // If MOVi32imm or MOVi64imm can be expanded into ORRWri or // ORRXri, it is as cheap as MOV case AArch64::MOVi32imm: @@ -1122,6 +1154,13 @@ switch (MI.getOpcode()) { default: break; + case AArch64::PTEST_PP: + SrcReg = MI.getOperand(0).getReg(); + SrcReg2 = MI.getOperand(1).getReg(); + // Not sure about the mask and value for now... + CmpMask = ~0; + CmpValue = 0; + return true; case AArch64::SUBSWrr: case AArch64::SUBSWrs: case AArch64::SUBSWrx: @@ -1294,6 +1333,125 @@ return false; } +/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating +/// operation which could set the flags in an identical manner +/// +bool AArch64InstrInfo::optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg, + unsigned PredReg, const MachineRegisterInfo *MRI) const { + auto *Mask = MRI->getUniqueVRegDef(MaskReg); + auto *Pred = MRI->getUniqueVRegDef(PredReg); + auto NewOp = Pred->getOpcode(); + bool OpChanged = false; + + unsigned MaskOpcode = Mask->getOpcode(); + unsigned PredOpcode = Pred->getOpcode(); + bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode); + bool PredIsWhileLike = isWhileOpcode(PredOpcode); + + if (isPTrueOpcode(MaskOpcode) && (PredIsPTestLike || PredIsWhileLike)) { + // For PTEST(PTRUE, OTHER_INST), PTEST is redundant when PTRUE doesn't + // deactivate any lanes OTHER_INST might set. + uint64_t MaskElementSize = getElementSizeForOpcode(MaskOpcode); + uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode); + + // Must be an all active predicate of matching element size. + if ((PredElementSize != MaskElementSize) || + (Mask->getOperand(1).getImm() != 31)) + return false; + + // Fallthough to simply remove the PTEST. + } else if ((Mask == Pred) && (PredIsPTestLike || PredIsWhileLike)) { + // For PTEST(PG, PG), PTEST is redundant when PG is the result of an + // instruction that sets the flags as PTEST would. + + // Fallthough to simply remove the PTEST. + } else if (PredIsPTestLike) { + // For PTEST(PG_1, PTEST_LIKE(PG2, ...)), PTEST is redundant when both + // instructions use the same predicate. + auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg()); + if (Mask != PTestLikeMask) + return false; + + // Fallthough to simply remove the PTEST. + } else { + switch (Pred->getOpcode()) { + case AArch64::BRKB_PPzP: + case AArch64::BRKPB_PPzPP: { + // Op 0 is chain, 1 is the mask, 2 the previous predicate to + // propagate, 3 the new predicate. + + // Check to see if our mask is the same as the brkpb's. If + // not the resulting flag bits may be different and we + // can't remove the ptest. + auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg()); + if (Mask != PredMask) + return false; + + // Switch to the new opcode + NewOp = Pred->getOpcode() == AArch64::BRKB_PPzP ? + AArch64::BRKBS_PPzP : AArch64::BRKPBS_PPzPP; + OpChanged = true; + break; + } + case AArch64::BRKN_PPzP: { + auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg()); + if (Mask != PredMask) + return false; + + NewOp = AArch64::BRKNS_PPzP; + OpChanged = true; + break; + } + default: + // Bail out if we don't recognize the input + return false; + } + } + + const TargetRegisterInfo *TRI = &getRegisterInfo(); + + // If the predicate is in a different block (possibly because its been + // hoisted out), then assume the flags are set in between statements. + if (Pred->getParent() != PTest->getParent()) + return false; + + // If another instruction between the propagation and test sets the + // flags, don't remove the ptest. + MachineBasicBlock::iterator I = Pred, E = PTest; + ++I; // Skip past the predicate op itself. + for (; I != E; ++I) { + const MachineInstr &Inst = *I; + + // TODO: If the ptest flags are unused, we could still remove it. + if (Inst.modifiesRegister(AArch64::NZCV, TRI)) + return false; + } + + // If we've gotten past all the checks, it's safe to remove the ptest + // and use the flag-setting form of brkpb. + Pred->setDesc(get(NewOp)); + PTest->eraseFromParent(); + if (OpChanged) { + bool succeeded = UpdateOperandRegClass(*Pred); + (void)succeeded; + assert(succeeded && "Operands have incompatible register classes!"); + Pred->addRegisterDefined(AArch64::NZCV, TRI); + } + + // Ensure that the flags def is live. + if (Pred->registerDefIsDead(AArch64::NZCV, TRI)) { + unsigned i = 0, e = Pred->getNumOperands(); + for (; i != e; ++i) { + MachineOperand &MO = Pred->getOperand(i); + if (MO.isReg() && MO.isDef() && MO.getReg() == AArch64::NZCV) { + MO.setIsDead(false); + break; + } + } + } + return true; +} + /// Try to optimize a compare instruction. A compare instruction is an /// instruction which produces AArch64::NZCV. It can be truly compare /// instruction @@ -1332,6 +1490,9 @@ return true; } + if (CmpInstr.getOpcode() == AArch64::PTEST_PP) + return optimizePTestInstr(&CmpInstr, SrcReg, SrcReg2, MRI); + // Continue only if we have a "ri" where immediate is zero. // FIXME:CmpValue has already been converted to 0 or 1 in analyzeCompare // function. @@ -1593,11 +1754,46 @@ } bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const { - if (MI.getOpcode() != TargetOpcode::LOAD_STACK_GUARD) - return false; - + const TargetRegisterInfo *TRI = &getRegisterInfo(); MachineBasicBlock &MBB = *MI.getParent(); DebugLoc DL = MI.getDebugLoc(); + unsigned Opcode = MI.getOpcode(); + + if ((Opcode == AArch64::DUP_ZV_H) || + (Opcode == AArch64::DUP_ZV_S) || + (Opcode == AArch64::DUP_ZV_D)) { + auto RC = &AArch64::ZPRRegClass; + unsigned Dst = MI.getOperand(0).getReg(); + unsigned Src = MI.getOperand(1).getReg(); + bool SrcIsKill = MI.getOperand(1).isKill(); + + unsigned NewOpcode; + unsigned NewSrc; + switch (Opcode) { + case AArch64::DUP_ZV_H: + NewOpcode = AArch64::DUP_ZZI_H; + NewSrc = TRI->getMatchingSuperReg(Src, AArch64::hsub, RC); + break; + case AArch64::DUP_ZV_S: + NewOpcode = AArch64::DUP_ZZI_S; + NewSrc = TRI->getMatchingSuperReg(Src, AArch64::ssub, RC); + break; + case AArch64::DUP_ZV_D: + NewOpcode = AArch64::DUP_ZZI_D; + NewSrc = TRI->getMatchingSuperReg(Src, AArch64::dsub, RC); + break; + } + + BuildMI(MBB, MI, DL, get(NewOpcode), Dst) + .addReg(NewSrc, getKillRegState(SrcIsKill)) + .addImm(0); + MBB.erase(MI); + return true; + } + + if (Opcode != TargetOpcode::LOAD_STACK_GUARD) + return false; + unsigned Reg = MI.getOperand(0).getReg(); const GlobalValue *GV = cast((*MI.memoperands_begin())->getValue()); @@ -1801,6 +1997,11 @@ case AArch64::LDRSui: case AArch64::LDRDui: case AArch64::LDRQui: + case AArch64::LDR_PXI: + case AArch64::LDR_ZXI: + case AArch64::LDR_ZZXI: + case AArch64::LDR_ZZZXI: + case AArch64::LDR_ZZZZXI: if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() && MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) { FrameIndex = MI.getOperand(1).getIndex(); @@ -1824,6 +2025,11 @@ case AArch64::STRSui: case AArch64::STRDui: case AArch64::STRQui: + case AArch64::STR_PXI: + case AArch64::STR_ZXI: + case AArch64::STR_ZZXI: + case AArch64::STR_ZZZXI: + case AArch64::STR_ZZZZXI: if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() && MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) { FrameIndex = MI.getOperand(1).getIndex(); @@ -2507,6 +2713,27 @@ return; } + // Copy a Predicate register by ORRing with itself. + if (AArch64::PPRRegClass.contains(DestReg) && + AArch64::PPRRegClass.contains(SrcReg)) { + assert(Subtarget.hasSVE() && "Unexpected SVE register."); + BuildMI(MBB, I, DL, get(AArch64::ORR_PPzPP), DestReg) + .addReg(SrcReg) // Pg + .addReg(SrcReg) + .addReg(SrcReg, getKillRegState(KillSrc)); + return; + } + + // Copy a Z register by ORRing with itself. + if (AArch64::ZPRRegClass.contains(DestReg) && + AArch64::ZPRRegClass.contains(SrcReg)) { + assert(Subtarget.hasSVE() && "Unexpected SVE register."); + BuildMI(MBB, I, DL, get(AArch64::ORR_ZZZ), DestReg) + .addReg(SrcReg) + .addReg(SrcReg, getKillRegState(KillSrc)); + return; + } + if (AArch64::GPR64spRegClass.contains(DestReg) && (AArch64::GPR64spRegClass.contains(SrcReg) || SrcReg == AArch64::XZR)) { if (DestReg == AArch64::SP || SrcReg == AArch64::SP) { @@ -2586,6 +2813,35 @@ return; } + // Copy a ZPR2 register pair by copying the individual sub-registers. + if (AArch64::ZPR2RegClass.contains(DestReg) && + AArch64::ZPR2RegClass.contains(SrcReg)) { + static const unsigned Indices[] = { AArch64::zsub0, AArch64::zsub1 }; + copyPhysRegTuple(MBB, I, DL, DestReg, SrcReg, KillSrc, AArch64::ORR_ZZZ, + Indices); + return; + } + + // Copy a ZPR3 register pair by copying the individual sub-registers. + if (AArch64::ZPR3RegClass.contains(DestReg) && + AArch64::ZPR3RegClass.contains(SrcReg)) { + static const unsigned Indices[] = { AArch64::zsub0, AArch64::zsub1, + AArch64::zsub2 }; + copyPhysRegTuple(MBB, I, DL, DestReg, SrcReg, KillSrc, AArch64::ORR_ZZZ, + Indices); + return; + } + + // Copy a ZPR4 register pair by copying the individual sub-registers. + if (AArch64::ZPR4RegClass.contains(DestReg) && + AArch64::ZPR4RegClass.contains(SrcReg)) { + static const unsigned Indices[] = { AArch64::zsub0, AArch64::zsub1, + AArch64::zsub2, AArch64::zsub3 }; + copyPhysRegTuple(MBB, I, DL, DestReg, SrcReg, KillSrc, AArch64::ORR_ZZZ, + Indices); + return; + } + if (AArch64::FPR128RegClass.contains(DestReg) && AArch64::FPR128RegClass.contains(SrcReg)) { if (Subtarget.hasNEON()) { @@ -2827,7 +3083,34 @@ } break; } + unsigned StackID = AArch64::FR_Default; + if (AArch64::PPRRegClass.hasSubClassEq(RC)) { + assert(Subtarget.hasSVE() && "Unexpected register store without SVE"); + Opc = AArch64::STR_PXI; + StackID = AArch64::FR_SVE; + } else if (AArch64::ZPRRegClass.hasSubClassEq(RC)) { + assert(Subtarget.hasSVE() && "Unexpected register store without SVE"); + Opc = AArch64::STR_ZXI; + StackID = AArch64::FR_SVE; + } else if (AArch64::ZPR2RegClass.hasSubClassEq(RC)) { + // This function relies on spills constructing a single MI (see + // InlineSpiller), but ST2,ST3,ST4 require a predicate input, which we don't + // currently have. STR_ZZXI and friends are pseudo instructions which get + // expanded to a PTRUE + ST2 later + assert(Subtarget.hasSVE() && "Unexpected register store without SVE"); + Opc = AArch64::STR_ZZXI; + StackID = AArch64::FR_SVE; + } else if (AArch64::ZPR3RegClass.hasSubClassEq(RC)) { + assert(Subtarget.hasSVE() && "Unexpected register store without SVE"); + Opc = AArch64::STR_ZZZXI; + StackID = AArch64::FR_SVE; + } else if (AArch64::ZPR4RegClass.hasSubClassEq(RC)) { + assert(Subtarget.hasSVE() && "Unexpected register store without SVE"); + Opc = AArch64::STR_ZZZZXI; + StackID = AArch64::FR_SVE; + } assert(Opc && "Unknown register class"); + MFI.setStackID(FI, StackID); const MachineInstrBuilder MI = BuildMI(MBB, MBBI, DL, get(Opc)) .addReg(SrcReg, getKillRegState(isKill)) @@ -2935,8 +3218,39 @@ } break; } + + unsigned StackID = AArch64::FR_Default; + if (AArch64::PPRRegClass.hasSubClassEq(RC)) { + assert(Subtarget.hasSVE() && "Unexpected register load without SVE"); + Opc = AArch64::LDR_PXI; + StackID = AArch64::FR_SVE; + } else if (AArch64::ZPRRegClass.hasSubClassEq(RC)) { + assert(Subtarget.hasSVE() && "Unexpected register load without SVE"); + Opc = AArch64::LDR_ZXI; + StackID = AArch64::FR_SVE; + } else if (AArch64::ZPR2RegClass.hasSubClassEq(RC)) { + assert(Subtarget.hasSVE() && "Unexpected register load without SVE"); + // This function relies on fills constructing a single MI (see + // InlineSpiller), but LD2,LD3,LD4 require a predicate input, which we + // don't + // currently have. LDR_ZZXI and friends are pseudo instructions which get + // expanded to a PTRUE + LD2 later + Opc = AArch64::LDR_ZZXI; + StackID = AArch64::FR_SVE; + } else if (AArch64::ZPR3RegClass.hasSubClassEq(RC)) { + assert(Subtarget.hasSVE() && "Unexpected register load without SVE"); + Opc = AArch64::LDR_ZZZXI; + StackID = AArch64::FR_SVE; + } else if (AArch64::ZPR4RegClass.hasSubClassEq(RC)) { + assert(Subtarget.hasSVE() && "Unexpected register load without SVE"); + Opc = AArch64::LDR_ZZZZXI; + StackID = AArch64::FR_SVE; + } assert(Opc && "Unknown register class"); + // SVE registers are allocated to their own stack regions. + MFI.setStackID(FI, StackID); + const MachineInstrBuilder MI = BuildMI(MBB, MBBI, DL, get(Opc)) .addReg(DestReg, getDefRegState(true)) .addFrameIndex(FI); @@ -2945,20 +3259,44 @@ MI.addMemOperand(MMO); } -void llvm::emitFrameOffset(MachineBasicBlock &MBB, - MachineBasicBlock::iterator MBBI, const DebugLoc &DL, - unsigned DestReg, unsigned SrcReg, int Offset, - const TargetInstrInfo *TII, - MachineInstr::MIFlag Flag, bool SetNZCV) { +// Helper function to emit a frame offset adjustment from a given +// pointer (SrcReg), stored into DestReg. This function is explicit +// in that it requires the opcode. +static void emitFrameOffsetAdj(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, + const DebugLoc &DL, unsigned DestReg, + unsigned SrcReg, int64_t Offset, unsigned Opc, + const TargetInstrInfo *TII, + MachineInstr::MIFlag Flag) { + // If we write to the same register as SrcReg, we always kill it + bool KillSrcReg = (DestReg == SrcReg); + if (DestReg == SrcReg && Offset == 0) return; - assert((DestReg != AArch64::SP || Offset % 16 == 0) && - "SP increment/decrement not 16-byte aligned"); - - bool isSub = Offset < 0; - if (isSub) - Offset = -Offset; + int Sign = 1; + unsigned MaxEncoding, ShiftSize; + switch (Opc) { + case AArch64::ADDXri: + case AArch64::ADDSXri: + case AArch64::SUBXri: + case AArch64::SUBSXri: + MaxEncoding = 0xfff; + ShiftSize = 12; + break; + case AArch64::ADDVL_XXI: + case AArch64::ADDPL_XXI: + MaxEncoding = 31; + ShiftSize = 0; + if (Offset < 0) { + MaxEncoding = 32; + Sign = -1; + Offset = -Offset; + } + break; + default: + llvm_unreachable("Unsupported opcode"); + } // FIXME: If the offset won't fit in 24-bits, compute the offset into a // scratch register. If DestReg is a virtual register, use it as the @@ -2971,39 +3309,66 @@ // of code. // assert(Offset < (1 << 24) && "unimplemented reg plus immediate"); - unsigned Opc; - if (SetNZCV) - Opc = isSub ? AArch64::SUBSXri : AArch64::ADDSXri; - else - Opc = isSub ? AArch64::SUBXri : AArch64::ADDXri; - const unsigned MaxEncoding = 0xfff; - const unsigned ShiftSize = 12; const unsigned MaxEncodableValue = MaxEncoding << ShiftSize; - while (((unsigned)Offset) >= (1 << ShiftSize)) { - unsigned ThisVal; - if (((unsigned)Offset) > MaxEncodableValue) { - ThisVal = MaxEncodableValue; - } else { - ThisVal = Offset & MaxEncodableValue; + do { + unsigned ThisVal = std::min(Offset, MaxEncodableValue); + unsigned LocalShiftSize = 0; + if (ThisVal > MaxEncoding) { + ThisVal = ThisVal >> ShiftSize; + LocalShiftSize = ShiftSize; } assert((ThisVal >> ShiftSize) <= MaxEncoding && "Encoding cannot handle value that big"); - BuildMI(MBB, MBBI, DL, TII->get(Opc), DestReg) - .addReg(SrcReg) - .addImm(ThisVal >> ShiftSize) - .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, ShiftSize)) - .setMIFlag(Flag); + auto MBI = BuildMI(MBB, MBBI, DL, TII->get(Opc), DestReg) + .addReg(SrcReg, KillSrcReg ? RegState::Kill : 0) + .addImm(Sign * (int)ThisVal); + if (ShiftSize) + MBI = MBI.addImm( + AArch64_AM::getShifterImm(AArch64_AM::LSL, LocalShiftSize)); + MBI = MBI.setMIFlag(Flag); SrcReg = DestReg; - Offset -= ThisVal; - if (Offset == 0) - return; + Offset -= ThisVal << LocalShiftSize; + KillSrcReg = true; + } while (Offset); +} + +void llvm::emitFrameOffset(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MBBI, const DebugLoc &DL, + unsigned DestReg, unsigned SrcReg, + StackOffset Offset, const TargetInstrInfo *TII, + MachineInstr::MIFlag Flag, bool SetNZCV) { + int64_t Bytes, PLSized, VLSized; + Offset.getForFrameOffset(Bytes, PLSized, VLSized); + + // First emit non-scalable frame offsets, or a simple 'mov'. + if (Bytes || (Offset.isZero() && SrcReg != DestReg)) { + assert((DestReg != AArch64::SP || Bytes % 16 == 0) && + "SP increment/decrement not 16-byte aligned"); + unsigned Opc = SetNZCV ? AArch64::ADDSXri : AArch64::ADDXri; + if (Bytes < 0) { + Bytes = -Bytes; + Opc = SetNZCV ? AArch64::SUBSXri : AArch64::SUBXri; + } + emitFrameOffsetAdj(MBB, MBBI, DL, DestReg, SrcReg, Bytes, Opc, TII, + Flag); + SrcReg = DestReg; } - BuildMI(MBB, MBBI, DL, TII->get(Opc), DestReg) - .addReg(SrcReg) - .addImm(Offset) - .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0)) - .setMIFlag(Flag); + + assert(!(SetNZCV && (PLSized || VLSized)) && + "SetNZCV not supported with SVE vectors"); + + if (VLSized) { + emitFrameOffsetAdj(MBB, MBBI, DL, DestReg, SrcReg, VLSized, + AArch64::ADDVL_XXI, TII, Flag); + SrcReg = DestReg; + } + + if (PLSized) { + assert(DestReg != AArch64::SP && "Unaligned access to SP"); + emitFrameOffsetAdj(MBB, MBBI, DL, DestReg, SrcReg, PLSized, + AArch64::ADDPL_XXI, TII, Flag); + } } MachineInstr *AArch64InstrInfo::foldMemoryOperandImpl( @@ -3084,7 +3449,7 @@ if (DstMO.getSubReg() == 0 && SrcMO.getSubReg() == 0) { assert(TRI.getRegSizeInBits(*getRegClass(DstReg)) == - TRI.getRegSizeInBits(*getRegClass(SrcReg)) && + TRI.getRegSizeInBits(*getRegClass(SrcReg)) && "Mismatched register size in non subreg COPY"); if (IsSpill) storeRegToStackSlot(MBB, InsertPt, SrcReg, SrcMO.isKill(), FrameIndex, @@ -3120,14 +3485,18 @@ SpillRC = &AArch64::GPR64RegClass; SpillSubreg = AArch64::sub_32; } else if (AArch64::FPR32RegClass.contains(SrcReg)) { - SpillRC = &AArch64::FPR64RegClass; + SpillRC = getRegClass(DstReg) == &AArch64::ZPRRegClass + ? &AArch64::ZPRRegClass + : &AArch64::FPR128RegClass; SpillSubreg = AArch64::ssub; } else SpillRC = nullptr; break; case AArch64::dsub: if (AArch64::FPR64RegClass.contains(SrcReg)) { - SpillRC = &AArch64::FPR128RegClass; + SpillRC = getRegClass(DstReg) == &AArch64::ZPRRegClass + ? &AArch64::ZPRRegClass + : &AArch64::FPR128RegClass; SpillSubreg = AArch64::dsub; } else SpillRC = nullptr; @@ -3188,12 +3557,69 @@ return nullptr; } -int llvm::isAArch64FrameOffsetLegal(const MachineInstr &MI, int &Offset, +bool llvm::isSVEScaledImmInstruction(unsigned Opcode) { + switch (Opcode) { + case AArch64::LD1B_IMM: + case AArch64::LD1B_H_IMM: + case AArch64::LD1B_S_IMM: + case AArch64::LD1B_D_IMM: + case AArch64::LD1SB_H_IMM: + case AArch64::LD1SB_S_IMM: + case AArch64::LD1SB_D_IMM: + case AArch64::LD1H_IMM: + case AArch64::LD1H_S_IMM: + case AArch64::LD1H_D_IMM: + case AArch64::LD1SH_S_IMM: + case AArch64::LD1SH_D_IMM: + case AArch64::LD1W_IMM: + case AArch64::LD1W_D_IMM: + case AArch64::LD1SW_D_IMM: + case AArch64::LD1D_IMM: + case AArch64::LDNT1B_ZRI: + case AArch64::LDNT1H_ZRI: + case AArch64::LDNT1W_ZRI: + case AArch64::LDNT1D_ZRI: + case AArch64::STNT1B_ZRI: + case AArch64::STNT1H_ZRI: + case AArch64::STNT1W_ZRI: + case AArch64::STNT1D_ZRI: + case AArch64::ST1B_IMM: + case AArch64::ST1B_H_IMM: + case AArch64::ST1B_S_IMM: + case AArch64::ST1B_D_IMM: + case AArch64::ST1H_IMM: + case AArch64::ST1H_S_IMM: + case AArch64::ST1H_D_IMM: + case AArch64::ST1W_IMM: + case AArch64::ST1W_D_IMM: + case AArch64::ST1D_IMM: + case AArch64::LDR_ZXI: + case AArch64::LDR_ZZXI: + case AArch64::LDR_ZZZXI: + case AArch64::LDR_ZZZZXI: + case AArch64::STR_ZXI: + case AArch64::STR_ZZXI: + case AArch64::STR_ZZZXI: + case AArch64::STR_ZZZZXI: + case AArch64::LDR_PXI: + case AArch64::STR_PXI: + return true; + default: + return false; + } + return false; +} + +int llvm::isAArch64FrameOffsetLegal(const MachineInstr &MI, + StackOffset &SOffset, bool *OutUseUnscaledOp, unsigned *OutUnscaledOp, int *EmittableOffset) { + unsigned MaskBits; int Scale = 1; + bool IsSVE = false; bool IsSigned = false; + // The ImmIdx should be changed case by case if it is not 2. unsigned ImmIdx = 2; unsigned UnscaledOp = 0; @@ -3377,37 +3803,162 @@ case AArch64::STURHHi: Scale = 1; break; + case AArch64::LD1B_IMM: + case AArch64::LD1B_H_IMM: + case AArch64::LD1B_S_IMM: + case AArch64::LD1B_D_IMM: + case AArch64::LD1SB_H_IMM: + case AArch64::LD1SB_S_IMM: + case AArch64::LD1SB_D_IMM: + case AArch64::LD1H_IMM: + case AArch64::LD1H_S_IMM: + case AArch64::LD1H_D_IMM: + case AArch64::LD1SH_S_IMM: + case AArch64::LD1SH_D_IMM: + case AArch64::LD1W_IMM: + case AArch64::LD1W_D_IMM: + case AArch64::LD1SW_D_IMM: + case AArch64::LD1D_IMM: + case AArch64::LDNT1B_ZRI: + case AArch64::LDNT1H_ZRI: + case AArch64::LDNT1W_ZRI: + case AArch64::LDNT1D_ZRI: + case AArch64::ST1B_IMM: + case AArch64::ST1B_H_IMM: + case AArch64::ST1B_S_IMM: + case AArch64::ST1B_D_IMM: + case AArch64::ST1H_IMM: + case AArch64::ST1H_S_IMM: + case AArch64::ST1H_D_IMM: + case AArch64::ST1W_IMM: + case AArch64::ST1W_D_IMM: + case AArch64::ST1D_IMM: + case AArch64::STNT1B_ZRI: + case AArch64::STNT1H_ZRI: + case AArch64::STNT1W_ZRI: + case AArch64::STNT1D_ZRI: + IsSVE = true; + IsSigned = true; + MaskBits = 4; + Scale = 16; + ImmIdx = 3; + break; + case AArch64::LDR_PXI: + case AArch64::STR_PXI: + IsSVE = true; + IsSigned = true; + MaskBits = 9; + Scale = 2; + break; + case AArch64::STR_ZZZZXI: + case AArch64::LDR_ZZZZXI: + case AArch64::STR_ZZZXI: + case AArch64::LDR_ZZZXI: + case AArch64::STR_ZZXI: + case AArch64::LDR_ZZXI: + case AArch64::LDR_ZXI: + case AArch64::STR_ZXI: + IsSVE = true; + IsSigned = true; + MaskBits = 9; + Scale = 16; + break; + case AArch64::LD1RB_IMM: + case AArch64::LD1RB_H_IMM: + case AArch64::LD1RB_S_IMM: + case AArch64::LD1RB_D_IMM: + case AArch64::LD1RSB_H_IMM: + case AArch64::LD1RSB_S_IMM: + case AArch64::LD1RSB_D_IMM: + IsSVE = true; + IsSigned = false; + MaskBits = 6; + Scale = 1; + ImmIdx = 3; + break; + case AArch64::LD1RH_IMM: + case AArch64::LD1RH_S_IMM: + case AArch64::LD1RH_D_IMM: + case AArch64::LD1RSH_S_IMM: + case AArch64::LD1RSH_D_IMM: + IsSVE = true; + IsSigned = false; + MaskBits = 6; + Scale = 2; + ImmIdx = 3; + break; + case AArch64::LD1RW_IMM: + case AArch64::LD1RW_D_IMM: + case AArch64::LD1RSW_IMM: + IsSVE = true; + IsSigned = false; + MaskBits = 6; + Scale = 4; + ImmIdx = 3; + break; + case AArch64::LD1RD_IMM: + IsSVE = true; + IsSigned = false; + MaskBits = 6; + Scale = 8; + ImmIdx = 3; + break; } + bool IsMulVL = isSVEScaledImmInstruction(MI.getOpcode()); + int64_t Offset = + IsMulVL ? (SOffset.getScalableBytes()) : (SOffset.getBytes()); Offset += MI.getOperand(ImmIdx).getImm() * Scale; + int64_t Remainder = Offset % Scale; - bool useUnscaledOp = false; // If the offset doesn't match the scale, we rewrite the instruction to // use the unscaled instruction instead. Likewise, if we have a negative // offset (and have an unscaled op to use). - if ((Offset & (Scale - 1)) != 0 || (Offset < 0 && UnscaledOp != 0)) - useUnscaledOp = true; + bool useUnscaledOp = !IsSVE && (Remainder || (Offset < 0 && UnscaledOp != 0)); // Use an unscaled addressing mode if the instruction has a negative offset // (or if the instruction is already using an unscaled addressing mode). - unsigned MaskBits; - if (IsSigned) { - // ldp/stp instructions. - MaskBits = 7; - Offset /= Scale; - } else if (UnscaledOp == 0 || useUnscaledOp) { - MaskBits = 9; - IsSigned = true; - Scale = 1; - } else { - MaskBits = 12; - IsSigned = false; - Offset /= Scale; + if (!IsSVE) { + if (IsSigned) { + // ldp/stp instructions. + MaskBits = 7; + } else if (UnscaledOp == 0 || useUnscaledOp) { + MaskBits = 9; + IsSigned = true; + Scale = 1; + Remainder = 0; + } else { + MaskBits = 12; + IsSigned = false; + } } + Offset /= Scale; // Attempt to fold address computation. int MaxOff = (1 << (MaskBits - IsSigned)) - 1; int MinOff = (IsSigned ? (-MaxOff - 1) : 0); + + // The multi-vector spills/fills are expanded into a series of single + // STR_ZXI/LDR_ZXI instructions, so we need to verify that all the frame + // offsets are legal for each individual STR/LDR instruction. + switch (MI.getOpcode()) { + case AArch64::STR_ZZZZXI: + case AArch64::LDR_ZZZZXI: + MaxOff -= 3; + break; + case AArch64::STR_ZZZXI: + case AArch64::LDR_ZZZXI: + MaxOff -= 2; + break; + case AArch64::STR_ZZXI: + case AArch64::LDR_ZZXI: + MaxOff -= 1; + break; + default: + break; + } + assert(MinOff < MaxOff && "Unexpected Min/Max offsets"); + if (Offset >= MinOff && Offset <= MaxOff) { if (EmittableOffset) *EmittableOffset = Offset; @@ -3418,27 +3969,38 @@ *EmittableOffset = NewOff; Offset = (Offset - NewOff) * Scale; } + assert(!(Remainder && useUnscaledOp) && + "Cannot have remainder when using unscaled op"); + Offset += Remainder; if (OutUseUnscaledOp) *OutUseUnscaledOp = useUnscaledOp; if (OutUnscaledOp) *OutUnscaledOp = UnscaledOp; + + if (IsMulVL) + SOffset = StackOffset(Offset, MVT::nxv1i8) + + StackOffset(SOffset.getBytes(), MVT::i8); + else + SOffset = StackOffset(Offset, MVT::i8) + + StackOffset(SOffset.getScalableBytes(), MVT::nxv1i8); + return AArch64FrameOffsetCanUpdate | - (Offset == 0 ? AArch64FrameOffsetIsLegal : 0); + (SOffset.isZero() ? AArch64FrameOffsetIsLegal : 0); } bool llvm::rewriteAArch64FrameIndex(MachineInstr &MI, unsigned FrameRegIdx, - unsigned FrameReg, int &Offset, + unsigned FrameReg, StackOffset &Offset, const AArch64InstrInfo *TII) { unsigned Opcode = MI.getOpcode(); unsigned ImmIdx = FrameRegIdx + 1; if (Opcode == AArch64::ADDSXri || Opcode == AArch64::ADDXri) { - Offset += MI.getOperand(ImmIdx).getImm(); + Offset += StackOffset(MI.getOperand(ImmIdx).getImm(), MVT::i8); emitFrameOffset(*MI.getParent(), MI, MI.getDebugLoc(), MI.getOperand(0).getReg(), FrameReg, Offset, TII, MachineInstr::NoFlags, (Opcode == AArch64::ADDSXri)); MI.eraseFromParent(); - Offset = 0; + Offset = StackOffset(); return true; } @@ -3450,12 +4012,12 @@ if (Status & AArch64FrameOffsetCanUpdate) { if (Status & AArch64FrameOffsetIsLegal) // Replace the FrameIndex with FrameReg. - MI.getOperand(FrameRegIdx).ChangeToRegister(FrameReg, false); + MI.getOperand(FrameRegIdx).ChangeToRegister(FrameReg, false, false); if (UseUnscaledOp) MI.setDesc(TII->get(UnscaledOp)); MI.getOperand(ImmIdx).ChangeToImmediate(NewOffset); - return Offset == 0; + return Offset.isZero(); } return false; @@ -3557,8 +4119,8 @@ // Utility routine that checks if \param MO is defined by an // \param CombineOpc instruction in the basic block \param MBB static bool canCombine(MachineBasicBlock &MBB, MachineOperand &MO, - unsigned CombineOpc, unsigned ZeroReg = 0, - bool CheckZeroReg = false) { + unsigned CombineOpc, bool AggressiveCombine, + unsigned ZeroReg = 0, bool CheckZeroReg = false) { MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); MachineInstr *MI = nullptr; @@ -3567,8 +4129,9 @@ // And it needs to be in the trace (otherwise, it won't have a depth). if (!MI || MI->getParent() != &MBB || (unsigned)MI->getOpcode() != CombineOpc) return false; - // Must only used by the user we combine with. - if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg())) + // Must only used by the user we combine with, unless we are aggressively + // combining with any available operand + if (!AggressiveCombine && !MRI.hasOneNonDBGUse(MI->getOperand(0).getReg())) return false; if (CheckZeroReg) { @@ -3587,14 +4150,16 @@ // Is \param MO defined by an integer multiply and can be combined? static bool canCombineWithMUL(MachineBasicBlock &MBB, MachineOperand &MO, unsigned MulOpc, unsigned ZeroReg) { - return canCombine(MBB, MO, MulOpc, ZeroReg, true); + return canCombine(MBB, MO, MulOpc, false, ZeroReg, true); } // // Is \param MO defined by a floating-point multiply and can be combined? static bool canCombineWithFMUL(MachineBasicBlock &MBB, MachineOperand &MO, unsigned MulOpc) { - return canCombine(MBB, MO, MulOpc); + MachineFunction &MF = *MBB.getParent(); + bool AggressiveFMA = MF.getSubtarget().hasAggressiveFMA(); + return canCombine(MBB, MO, MulOpc, AggressiveFMA); } // TODO: There are many more machine instruction opcodes to match: @@ -3733,6 +4298,52 @@ } /// Floating-Point Support +static int calcOperandDist(const MachineInstr &Root, int OpdIdx) { + const MachineBasicBlock &MBB = *Root.getParent(); + const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); + + const auto *OpdInstr = MRI.getUniqueVRegDef(Root.getOperand(OpdIdx).getReg()); + if (!OpdInstr || OpdInstr->getParent() != &MBB) + return INT_MAX; + + // We have found a fmul instruction. A first order approximation would + // involve calculating the distance from the root to the fmul, however + // this doesn't take into account the effect of scheduling. If the fmul + // operands themselves are defined far away from the root then that + // matters more as the fmul is likely to be scheduled further away. + const auto *FmulOp1Instr = + MRI.getUniqueVRegDef(OpdInstr->getOperand(1).getReg()); + const auto *FmulOp2Instr = + MRI.getUniqueVRegDef(OpdInstr->getOperand(2).getReg()); + + int Distance1, Distance2; + if (!FmulOp1Instr || FmulOp1Instr->getParent() != &MBB) + Distance1 = INT_MAX; + else + Distance1 = std::distance(FmulOp1Instr->getIterator(), Root.getIterator()); + + if (!FmulOp2Instr || FmulOp2Instr->getParent() != &MBB) + Distance2 = INT_MAX; + else + Distance2 = std::distance(FmulOp2Instr->getIterator(), Root.getIterator()); + + return std::min(Distance1, Distance2); +} + +struct DistPatternPair { + int Distance; + MachineCombinerPattern Pattern; + + DistPatternPair(int Distance, MachineCombinerPattern Pattern) + : Distance(Distance), Pattern(Pattern) {} + + bool operator<(const DistPatternPair &Other) const noexcept { + // Note comparison is intentionally reversed, since we want a min-heap + // rather than the priority_queue default of a max-heap. + return Distance > Other.Distance; + } +}; + /// Find instructions that can be turned into madd. static bool getFMAPatterns(MachineInstr &Root, SmallVectorImpl &Patterns) { @@ -3741,7 +4352,13 @@ return false; MachineBasicBlock &MBB = *Root.getParent(); - bool Found = false; + + std::priority_queue DistPatterns; + + auto AddOpdPattern = [&](int Operand, MachineCombinerPattern Pattern) { + int Distance = calcOperandDist(Root, Operand); + DistPatterns.emplace(Distance, Pattern); + }; switch (Root.getOpcode()) { default: @@ -3751,198 +4368,170 @@ assert(Root.getOperand(1).isReg() && Root.getOperand(2).isReg() && "FADDWrr does not have register operands"); if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULSrr)) { - Patterns.push_back(MachineCombinerPattern::FMULADDS_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMULADDS_OP1); } else if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv1i32_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLAv1i32_indexed_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLAv1i32_indexed_OP1); } if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULSrr)) { - Patterns.push_back(MachineCombinerPattern::FMULADDS_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMULADDS_OP2); } else if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULv1i32_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLAv1i32_indexed_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMLAv1i32_indexed_OP2); } break; case AArch64::FADDDrr: if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULDrr)) { - Patterns.push_back(MachineCombinerPattern::FMULADDD_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMULADDD_OP1); } else if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv1i64_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLAv1i64_indexed_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLAv1i64_indexed_OP1); } if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULDrr)) { - Patterns.push_back(MachineCombinerPattern::FMULADDD_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMULADDD_OP2); } else if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULv1i64_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLAv1i64_indexed_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMLAv1i64_indexed_OP2); } break; case AArch64::FADDv2f32: if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv2i32_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLAv2i32_indexed_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLAv2i32_indexed_OP1); } else if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv2f32)) { - Patterns.push_back(MachineCombinerPattern::FMLAv2f32_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLAv2f32_OP1); } if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULv2i32_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLAv2i32_indexed_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMLAv2i32_indexed_OP2); } else if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULv2f32)) { - Patterns.push_back(MachineCombinerPattern::FMLAv2f32_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMLAv2f32_OP2); } break; case AArch64::FADDv2f64: if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv2i64_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLAv2i64_indexed_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLAv2i64_indexed_OP1); } else if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv2f64)) { - Patterns.push_back(MachineCombinerPattern::FMLAv2f64_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLAv2f64_OP1); } if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULv2i64_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLAv2i64_indexed_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMLAv2i64_indexed_OP2); } else if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULv2f64)) { - Patterns.push_back(MachineCombinerPattern::FMLAv2f64_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMLAv2f64_OP2); } break; case AArch64::FADDv4f32: if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv4i32_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLAv4i32_indexed_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLAv4i32_indexed_OP1); } else if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv4f32)) { - Patterns.push_back(MachineCombinerPattern::FMLAv4f32_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLAv4f32_OP1); } if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULv4i32_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLAv4i32_indexed_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMLAv4i32_indexed_OP2); } else if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULv4f32)) { - Patterns.push_back(MachineCombinerPattern::FMLAv4f32_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMLAv4f32_OP2); } break; case AArch64::FSUBSrr: if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULSrr)) { - Patterns.push_back(MachineCombinerPattern::FMULSUBS_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMULSUBS_OP1); } if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULSrr)) { - Patterns.push_back(MachineCombinerPattern::FMULSUBS_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMULSUBS_OP2); } else if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULv1i32_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLSv1i32_indexed_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMLSv1i32_indexed_OP2); } if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FNMULSrr)) { - Patterns.push_back(MachineCombinerPattern::FNMULSUBS_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FNMULSUBS_OP1); } break; case AArch64::FSUBDrr: if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULDrr)) { - Patterns.push_back(MachineCombinerPattern::FMULSUBD_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMULSUBD_OP1); } if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULDrr)) { - Patterns.push_back(MachineCombinerPattern::FMULSUBD_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMULSUBD_OP2); } else if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULv1i64_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLSv1i64_indexed_OP2); - Found = true; + AddOpdPattern(2, MachineCombinerPattern::FMLSv1i64_indexed_OP2); } if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FNMULDrr)) { - Patterns.push_back(MachineCombinerPattern::FNMULSUBD_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FNMULSUBD_OP1); } break; case AArch64::FSUBv2f32: - if (canCombineWithFMUL(MBB, Root.getOperand(2), - AArch64::FMULv2i32_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLSv2i32_indexed_OP2); - Found = true; - } else if (canCombineWithFMUL(MBB, Root.getOperand(2), - AArch64::FMULv2f32)) { - Patterns.push_back(MachineCombinerPattern::FMLSv2f32_OP2); - Found = true; - } if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv2i32_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLSv2i32_indexed_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLSv2i32_indexed_OP1); } else if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv2f32)) { - Patterns.push_back(MachineCombinerPattern::FMLSv2f32_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLSv2f32_OP1); + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv2i32_indexed)) { + AddOpdPattern(2, MachineCombinerPattern::FMLSv2i32_indexed_OP2); + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv2f32)) { + AddOpdPattern(2, MachineCombinerPattern::FMLSv2f32_OP2); } break; case AArch64::FSUBv2f64: - if (canCombineWithFMUL(MBB, Root.getOperand(2), - AArch64::FMULv2i64_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLSv2i64_indexed_OP2); - Found = true; - } else if (canCombineWithFMUL(MBB, Root.getOperand(2), - AArch64::FMULv2f64)) { - Patterns.push_back(MachineCombinerPattern::FMLSv2f64_OP2); - Found = true; - } if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv2i64_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLSv2i64_indexed_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLSv2i64_indexed_OP1); } else if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv2f64)) { - Patterns.push_back(MachineCombinerPattern::FMLSv2f64_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLSv2f64_OP1); + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv2i64_indexed)) { + AddOpdPattern(2, MachineCombinerPattern::FMLSv2i64_indexed_OP2); + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv2f64)) { + AddOpdPattern(2, MachineCombinerPattern::FMLSv2f64_OP2); } break; case AArch64::FSUBv4f32: - if (canCombineWithFMUL(MBB, Root.getOperand(2), - AArch64::FMULv4i32_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLSv4i32_indexed_OP2); - Found = true; - } else if (canCombineWithFMUL(MBB, Root.getOperand(2), - AArch64::FMULv4f32)) { - Patterns.push_back(MachineCombinerPattern::FMLSv4f32_OP2); - Found = true; - } if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv4i32_indexed)) { - Patterns.push_back(MachineCombinerPattern::FMLSv4i32_indexed_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLSv4i32_indexed_OP1); } else if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv4f32)) { - Patterns.push_back(MachineCombinerPattern::FMLSv4f32_OP1); - Found = true; + AddOpdPattern(1, MachineCombinerPattern::FMLSv4f32_OP1); + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv4i32_indexed)) { + AddOpdPattern(2, MachineCombinerPattern::FMLSv4i32_indexed_OP2); + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv4f32)) { + AddOpdPattern(2, MachineCombinerPattern::FMLSv4f32_OP2); } break; } + + bool Found = false; + + // Add patterns in order from closest to furthest to ensure that operands + // near the root are suggested first, rather than attempting to combine with + // something that would have already finished executing. + while (!DistPatterns.empty()) { + Patterns.push_back(DistPatterns.top().Pattern); + DistPatterns.pop(); + Found = true; + } + return Found; } @@ -4228,15 +4817,16 @@ } uint64_t UImm = SignExtend64(Imm, BitSize); uint64_t Encoding; - if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) { - MachineInstrBuilder MIB1 = - BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc), NewVR) - .addReg(ZeroReg) - .addImm(Encoding); - InsInstrs.push_back(MIB1); - InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); - MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC); - } + if (!AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) + return; + + MachineInstrBuilder MIB1 = + BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc), NewVR) + .addReg(ZeroReg) + .addImm(Encoding); + InsInstrs.push_back(MIB1); + InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); + MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC); break; } case MachineCombinerPattern::MULSUBW_OP1: @@ -4319,15 +4909,16 @@ } uint64_t UImm = SignExtend64(-Imm, BitSize); uint64_t Encoding; - if (AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) { - MachineInstrBuilder MIB1 = - BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc), NewVR) - .addReg(ZeroReg) - .addImm(Encoding); - InsInstrs.push_back(MIB1); - InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); - MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC); - } + if (!AArch64_AM::processLogicalImmediate(UImm, BitSize, Encoding)) + return; + + MachineInstrBuilder MIB1 = + BuildMI(MF, Root.getDebugLoc(), TII->get(OrrOpc), NewVR) + .addReg(ZeroReg) + .addImm(Encoding); + InsInstrs.push_back(MIB1); + InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); + MUL = genMaddR(MF, MRI, TII, Root, InsInstrs, 1, Opc, NewVR, RC); break; } // Floating Point Support @@ -4638,7 +5229,8 @@ } } // end switch (Pattern) // Record MUL and ADD/SUB for deletion - DelInstrs.push_back(MUL); + if (MRI.hasOneNonDBGUse(MUL->getOperand(0).getReg())) + DelInstrs.push_back(MUL); DelInstrs.push_back(&Root); } @@ -5562,3 +6154,15 @@ MachineFunction &MF) const { return MF.getFunction().optForMinSize(); } + +uint64_t AArch64InstrInfo::getElementSizeForOpcode(unsigned Opc) const { + return get(Opc).TSFlags & AArch64::ElementSizeMask; +} + +bool AArch64InstrInfo::isPTestLikeOpcode(unsigned Opc) const { + return get(Opc).TSFlags & AArch64::InstrFlagIsPTestLike; +} + +bool AArch64InstrInfo::isWhileOpcode(unsigned Opc) const { + return get(Opc).TSFlags & AArch64::InstrFlagIsWhile; +} Index: lib/Target/AArch64/AArch64InstrInfo.td =================================================================== --- lib/Target/AArch64/AArch64InstrInfo.td +++ lib/Target/AArch64/AArch64InstrInfo.td @@ -56,6 +56,16 @@ "fuse-aes">; def HasSVE : Predicate<"Subtarget->hasSVE()">, AssemblerPredicate<"FeatureSVE", "sve">; +def HasSVE2 : Predicate<"Subtarget->hasSVE2()">, + AssemblerPredicate<"FeatureSVE2", "sve2">; +def HasSVE2AES : Predicate<"Subtarget->hasSVE2AES()">, + AssemblerPredicate<"FeatureSVE2AES", "sve2-aes">; +def HasSVE2SM4 : Predicate<"Subtarget->hasSVE2SM4()">, + AssemblerPredicate<"FeatureSVE2SM4", "sve2-sm4">; +def HasSVE2SHA3 : Predicate<"Subtarget->hasSVE2SHA3()">, + AssemblerPredicate<"FeatureSVE2SHA3", "sve2-sha3">; +def HasSVE2BitPerm : Predicate<"Subtarget->hasSVE2BitPerm()">, + AssemblerPredicate<"FeatureSVE2BitPerm", "sve2-bitperm">; def HasRCPC : Predicate<"Subtarget->hasRCPC()">, AssemblerPredicate<"FeatureRCPC", "rcpc">; @@ -171,6 +181,247 @@ SDTCisSameAs<1, 2>, SDTCisSameAs<1, 3>, SDTCisSameAs<1, 4>]>; +def SDT_AArch64TBL : SDTypeProfile<1, 2, [ + SDTCisVec<0>, SDTCisSameAs<0, 1>, SDTCisInt<2> +]>; + +// TODO: these are patterns that everybody could benefit from but sadly they +// cannot live in TargetSelectionDAG.td because compilation fails when targets +// without vector registers are enabled. + +// non-extending masked load fragment. +def nonext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getExtensionType() == ISD::NON_EXTLOAD; +}]>; + +// sign extending masked load fragments. +def sext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_load node:$ptr, node:$pred, node:$def),[{ + return cast(N)->getExtensionType() == ISD::SEXTLOAD; +}]>; +def sext_masked_load_i8 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (sext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def sext_masked_load_i16 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (sext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def sext_masked_load_i32 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (sext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i32; +}]>; + +// zero extending masked load fragments. +def zext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getExtensionType() == ISD::ZEXTLOAD; +}]>; +def zext_masked_load_i8 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def zext_masked_load_i16 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def zext_masked_load_i32 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i32; +}]>; + +// non-truncating masked store fragment. +def nontrunc_masked_store : + PatFrag<(ops node:$ptr, node:$pred, node:$val), + (masked_store node:$ptr, node:$pred, node:$val), [{ + return !cast(N)->isTruncatingStore(); +}]>; + +// truncating masked store fragments. +def trunc_masked_store : + PatFrag<(ops node:$ptr, node:$pred, node:$val), + (masked_store node:$ptr, node:$pred, node:$val), [{ + return cast(N)->isTruncatingStore(); +}]>; +def trunc_masked_store_i8 : + PatFrag<(ops node:$ptr, node:$pred, node:$val), + (trunc_masked_store node:$ptr, node:$pred, node:$val), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def trunc_masked_store_i16 : + PatFrag<(ops node:$ptr, node:$pred, node:$val), + (trunc_masked_store node:$ptr, node:$pred, node:$val), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def trunc_masked_store_i32 : + PatFrag<(ops node:$ptr, node:$pred, node:$val), + (trunc_masked_store node:$ptr, node:$pred, node:$val), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i32; +}]>; + +// NOTE: This is a clone of SDTMaskedGather as the "common" version looks rather +// target specific. +def SDT_SVEMaskedGather : SDTypeProfile<1, 4, [ // masked gather + SDTCisVec<0>, SDTCisSameAs<0, 1>, SDTCisVec<2>, SDTCisPtrTy<3>, SDTCisVec<4>, + SDTCVecEltisVT<2, i1>, SDTCisSameNumEltsAs<0, 2>, SDTCisSameNumEltsAs<0, 4> +]>; +def sve_masked_gather : SDNode<"ISD::MGATHER", SDT_SVEMaskedGather, + [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; + +// masked gather (signed scaled offsets). +def masked_gather_signed_scaled : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (sve_masked_gather node:$def, node:$pred, node:$ptr, node:$idx),[{ + return cast(N)->getIndexType() == ISD::SIGNED_SCALED; +}]>; +// masked gather (signed unscaled offsets). +def masked_gather_signed_unscaled : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (sve_masked_gather node:$def, node:$pred, node:$ptr, node:$idx),[{ + return cast(N)->getIndexType() == ISD::SIGNED_UNSCALED; +}]>; +// masked gather (unsigned scaled offsets). +def masked_gather_unsigned_scaled : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (sve_masked_gather node:$def, node:$pred, node:$ptr, node:$idx),[{ + return cast(N)->getIndexType() == ISD::UNSIGNED_SCALED; +}]>; +// masked gather (unsigned unscaled offsets). +def masked_gather_unsigned_unscaled : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (sve_masked_gather node:$def, node:$pred, node:$ptr, node:$idx),[{ + return cast(N)->getIndexType() == ISD::UNSIGNED_UNSCALED; +}]>; + +multiclass masked_gather { + def nonext_#NAME : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->getExtensionType() == ISD::NON_EXTLOAD; + }]>; + + // Sign extending masked gather fragments. + def sext_#NAME#_i8 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->getExtensionType() == ISD::SEXTLOAD && + cast(N)->getMemoryVT().getScalarType() == MVT::i8; + }]>; + def sext_#NAME#_i16 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->getExtensionType() == ISD::SEXTLOAD && + cast(N)->getMemoryVT().getScalarType() == MVT::i16; + }]>; + def sext_#NAME#_i32 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->getExtensionType() == ISD::SEXTLOAD && + cast(N)->getMemoryVT().getScalarType() == MVT::i32; + }]>; + + // Zero extending masked gather fragments. + def zext_#NAME#_i8 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->getExtensionType() == ISD::ZEXTLOAD && + cast(N)->getMemoryVT().getScalarType() == MVT::i8; + }]>; + def zext_#NAME#_i16 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->getExtensionType() == ISD::ZEXTLOAD && + cast(N)->getMemoryVT().getScalarType() == MVT::i16; + }]>; + def zext_#NAME#_i32 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->getExtensionType() == ISD::ZEXTLOAD && + cast(N)->getMemoryVT().getScalarType() == MVT::i32; + }]>; +} + +defm masked_gather_signed_scaled : masked_gather; +defm masked_gather_signed_unscaled : masked_gather; +defm masked_gather_unsigned_scaled : masked_gather; +defm masked_gather_unsigned_unscaled : masked_gather; + +// NOTE: This is a clone of SDTMaskedScatter as the "common" version looks +// rather target specific. +def SDT_SVEMaskedScatter : SDTypeProfile<0, 4, [ // masked scatter + SDTCisVec<0>, SDTCisVec<1>, SDTCisPtrTy<2>, SDTCisVec<3>, + SDTCVecEltisVT<1, i1>, SDTCisSameNumEltsAs<0, 1>, SDTCisSameNumEltsAs<0, 3> +]>; +def sve_masked_scatter : SDNode<"ISD::MSCATTER", SDT_SVEMaskedScatter, + [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; + +// masked scatter fragment (signed scaled offsets). +def masked_scatter_signed_scaled : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (sve_masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->getIndexType() == ISD::SIGNED_SCALED; +}]>; +// masked scatter fragment (signed unscaled offsets). +def masked_scatter_signed_unscaled : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (sve_masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->getIndexType() == ISD::SIGNED_UNSCALED; +}]>; +// masked scatter fragment (unsigned scaled offsets). +def masked_scatter_unsigned_scaled : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (sve_masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->getIndexType() == ISD::UNSIGNED_SCALED; +}]>; +// masked scatter fragment (unsigned unscaled offsets). +def masked_scatter_unsigned_unscaled : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (sve_masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->getIndexType() == ISD::UNSIGNED_UNSCALED; +}]>; + +multiclass masked_scatter { + def nontrunc_#NAME : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (ScatterOp node:$val, node:$pred, node:$ptr, node:$idx), [{ + return !cast(N)->isTruncatingStore(); + }]>; + + // Truncating masked scatter fragments. + def trunc_#NAME#_i8 : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (ScatterOp node:$val, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->isTruncatingStore() && + cast(N)->getMemoryVT().getScalarType() == MVT::i8; + }]>; + def trunc_#NAME#_i16 : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (ScatterOp node:$val, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->isTruncatingStore() && + cast(N)->getMemoryVT().getScalarType() == MVT::i16; + }]>; + def trunc_#NAME#_i32 : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (ScatterOp node:$val, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->isTruncatingStore() && + cast(N)->getMemoryVT().getScalarType() == MVT::i32; + }]>; +} + +defm masked_scatter_signed_scaled : masked_scatter; +defm masked_scatter_signed_unscaled : masked_scatter; +defm masked_scatter_unsigned_scaled : masked_scatter; +defm masked_scatter_unsigned_unscaled : masked_scatter; // Node definitions. def AArch64adrp : SDNode<"AArch64ISD::ADRP", SDTIntUnaryOp, []>; @@ -329,6 +580,16 @@ def AArch64smaxv : SDNode<"AArch64ISD::SMAXV", SDT_AArch64UnaryVec>; def AArch64umaxv : SDNode<"AArch64ISD::UMAXV", SDT_AArch64UnaryVec>; +def AArch64TBL : SDNode<"AArch64ISD::TBL", SDT_AArch64TBL>; + +def SDT_AArch64unpk : SDTypeProfile<1, 1, [ + SDTCisInt<0>, SDTCisInt<1>, SDTCisOpSmallerThanOp<1, 0> +]>; +def AArch64sunpkhi : SDNode<"AArch64ISD::SUNPKHI", SDT_AArch64unpk>; +def AArch64sunpklo : SDNode<"AArch64ISD::SUNPKLO", SDT_AArch64unpk>; +def AArch64uunpkhi : SDNode<"AArch64ISD::UUNPKHI", SDT_AArch64unpk>; +def AArch64uunpklo : SDNode<"AArch64ISD::UUNPKLO", SDT_AArch64unpk>; + //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// @@ -759,6 +1020,8 @@ }]>; +def : Pat<(f16 fpimm:$in), + (EXTRACT_SUBREG (f32 (COPY_TO_REGCLASS (MOVi32imm (bitcast_fpimm_to_i32 f16:$in)), FPR32)), hsub)>; def : Pat<(f32 fpimm:$in), (COPY_TO_REGCLASS (MOVi32imm (bitcast_fpimm_to_i32 f32:$in)), FPR32)>; def : Pat<(f64 fpimm:$in), @@ -1436,7 +1699,20 @@ // FIXME: maybe the scratch register used shouldn't be fixed to X1? // FIXME: can "hasSideEffects be dropped? -let isCall = 1, Defs = [LR, X0, X1], hasSideEffects = 1, +let isCall = 1, Defs = [LR, X0, X1, + P0, P1, P2, P3, + P4, P5, P6, P7, + P8, P9, P10, P11, + P12, P13, P14, P15, + Z0_HI, Z1_HI, Z2_HI, Z3_HI, + Z4_HI, Z5_HI, Z6_HI, Z7_HI, + Z8_HI, Z9_HI, Z10_HI, Z11_HI, + Z12_HI, Z13_HI, Z14_HI, Z15_HI, + Z16_HI, Z17_HI, Z18_HI, Z19_HI, + Z20_HI, Z21_HI, Z22_HI, Z23_HI, + Z24_HI, Z25_HI, Z26_HI, Z27_HI, + Z28_HI, Z29_HI, Z30_HI, Z31_HI], + hasSideEffects = 1, isCodeGenOnly = 1 in def TLSDESC_CALLSEQ : Pseudo<(outs), (ins i64imm:$sym), @@ -4490,8 +4766,8 @@ // If none did, fallback to the explicit patterns, consuming the vector_extract. -def : Pat<(i32 (vector_extract (insert_subvector undef, (v8i8 (opNode V64:$Rn)), - (i32 0)), (i64 0))), +def : Pat<(i32 (vector_extract (v16i8 (insert_subvector undef, (v8i8 (opNode V64:$Rn)), + (i32 0))), (i64 0))), (EXTRACT_SUBREG (INSERT_SUBREG (v8i8 (IMPLICIT_DEF)), (!cast(!strconcat(baseOpc, "v8i8v")) V64:$Rn), bsub), ssub)>; @@ -4499,8 +4775,8 @@ (EXTRACT_SUBREG (INSERT_SUBREG (v16i8 (IMPLICIT_DEF)), (!cast(!strconcat(baseOpc, "v16i8v")) V128:$Rn), bsub), ssub)>; -def : Pat<(i32 (vector_extract (insert_subvector undef, - (v4i16 (opNode V64:$Rn)), (i32 0)), (i64 0))), +def : Pat<(i32 (vector_extract (v8i16 (insert_subvector undef, + (v4i16 (opNode V64:$Rn)), (i32 0))), (i64 0))), (EXTRACT_SUBREG (INSERT_SUBREG (v4i16 (IMPLICIT_DEF)), (!cast(!strconcat(baseOpc, "v4i16v")) V64:$Rn), hsub), ssub)>; @@ -4520,20 +4796,20 @@ : SIMDAcrossLanesIntrinsic { // If there is a sign extension after this intrinsic, consume it as smov already // performed it -def : Pat<(i32 (sext_inreg (i32 (vector_extract (insert_subvector undef, - (opNode (v8i8 V64:$Rn)), (i32 0)), (i64 0))), i8)), - (i32 (SMOVvi8to32 - (INSERT_SUBREG (v16i8 (IMPLICIT_DEF)), - (!cast(!strconcat(baseOpc, "v8i8v")) V64:$Rn), bsub), - (i64 0)))>; +def : Pat<(i32 (sext_inreg (i32 (vector_extract (v16i8 (insert_subvector undef, + (opNode (v8i8 V64:$Rn)), (i32 0))), (i64 0))), i8)), + (i32 (SMOVvi8to32 + (INSERT_SUBREG (v16i8 (IMPLICIT_DEF)), + (!cast(!strconcat(baseOpc, "v8i8v")) V64:$Rn), bsub), + (i64 0)))>; def : Pat<(i32 (sext_inreg (i32 (vector_extract (opNode (v16i8 V128:$Rn)), (i64 0))), i8)), (i32 (SMOVvi8to32 (INSERT_SUBREG (v16i8 (IMPLICIT_DEF)), (!cast(!strconcat(baseOpc, "v16i8v")) V128:$Rn), bsub), (i64 0)))>; -def : Pat<(i32 (sext_inreg (i32 (vector_extract (insert_subvector undef, - (opNode (v4i16 V64:$Rn)), (i32 0)), (i64 0))), i16)), +def : Pat<(i32 (sext_inreg (i32 (vector_extract (v8i16 (insert_subvector undef, + (opNode (v4i16 V64:$Rn)), (i32 0))), (i64 0))), i16)), (i32 (SMOVvi16to32 (INSERT_SUBREG (v16i8 (IMPLICIT_DEF)), (!cast(!strconcat(baseOpc, "v4i16v")) V64:$Rn), hsub), @@ -4551,8 +4827,8 @@ : SIMDAcrossLanesIntrinsic { // If there is a masking operation keeping only what has been actually // generated, consume it. -def : Pat<(i32 (and (i32 (vector_extract (insert_subvector undef, - (opNode (v8i8 V64:$Rn)), (i32 0)), (i64 0))), maski8_or_more)), +def : Pat<(i32 (and (i32 (vector_extract (v16i8 (insert_subvector undef, + (opNode (v8i8 V64:$Rn)), (i32 0))), (i64 0))), maski8_or_more)), (i32 (EXTRACT_SUBREG (INSERT_SUBREG (v16i8 (IMPLICIT_DEF)), (!cast(!strconcat(baseOpc, "v8i8v")) V64:$Rn), bsub), @@ -4563,8 +4839,8 @@ (INSERT_SUBREG (v16i8 (IMPLICIT_DEF)), (!cast(!strconcat(baseOpc, "v16i8v")) V128:$Rn), bsub), ssub))>; -def : Pat<(i32 (and (i32 (vector_extract (insert_subvector undef, - (opNode (v4i16 V64:$Rn)), (i32 0)), (i64 0))), maski16_or_more)), +def : Pat<(i32 (and (i32 (vector_extract (v8i16 (insert_subvector undef, + (opNode (v4i16 V64:$Rn)), (i32 0))), (i64 0))), maski16_or_more)), (i32 (EXTRACT_SUBREG (INSERT_SUBREG (v16i8 (IMPLICIT_DEF)), (!cast(!strconcat(baseOpc, "v4i16v")) V64:$Rn), hsub), @@ -4769,6 +5045,9 @@ def : Pat<(v8i16 immAllOnesV), (MOVIv2d_ns (i32 255))>; def : Pat<(v16i8 immAllOnesV), (MOVIv2d_ns (i32 255))>; +def : Pat<(v2f64 (AArch64dup (f64 fpimm0))), (MOVIv2d_ns (i32 0))>; +def : Pat<(v4f32 (AArch64dup (f32 fpimm0))), (MOVIv2d_ns (i32 0))>; + // EDIT per word & halfword: 2s, 4h, 4s, & 8h let isReMaterializable = 1, isAsCheapAsAMove = 1 in defm MOVI : SIMDModifiedImmVectorShift<0, 0b10, 0b00, "movi">; Index: lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp =================================================================== --- lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp +++ lib/Target/AArch64/AArch64LoadStoreOptimizer.cpp @@ -541,7 +541,7 @@ } } -static const MachineOperand &getLdStRegOp(const MachineInstr &MI, +static MachineOperand &getLdStRegOp(MachineInstr &MI, unsigned PairedRegOp = 0) { assert(PairedRegOp < 2 && "Unexpected register operand idx."); unsigned Idx = isPairedLdSt(MI) ? PairedRegOp : 0; @@ -679,6 +679,16 @@ const MachineOperand &BaseRegOp = MergeForward ? getLdStBaseOp(*MergeMI) : getLdStBaseOp(*I); + if (!MergeForward) { + // Add KILL MI for any registers killed by MergeMI + for (MachineOperand &MO : MergeMI->operands()) + if (MO.isReg() && MO.isKill()) { + BuildMI(*MergeMI->getParent(), MergeMI, MergeMI->getDebugLoc(), + TII->get(TargetOpcode::KILL), MO.getReg()) + .addReg(MO.getReg()); + MO.setIsKill(false); + } + } // Which register is Rt and which is Rt2 depends on the offset order. MachineInstr *RtMI; if (getLdStOffsetOp(*I).getImm() == @@ -886,6 +896,10 @@ unsigned LdRt = getLdStRegOp(*LoadI).getReg(); const MachineOperand &StMO = getLdStRegOp(*StoreI); unsigned StRt = getLdStRegOp(*StoreI).getReg(); + // We're adding an additional use of StRt, so the isKill flag may need to move + // down to this new use. Removal of isKill from the store is delayed so the + // DEBUG() printed below reflects the prior state correctly + bool StRtIsKill = getLdStRegOp(*StoreI).isKill(); bool IsStoreXReg = TRI->getRegClass(AArch64::GPR64RegClassID)->contains(StRt); assert((IsStoreXReg || @@ -908,6 +922,8 @@ LLVM_DEBUG(LoadI->print(dbgs())); LLVM_DEBUG(dbgs() << "\n"); LoadI->eraseFromParent(); + if (StRtIsKill) + getLdStRegOp(*StoreI).setIsKill(false); return NextI; } // Replace the load with a mov if the load and store are in the same size. @@ -984,6 +1000,15 @@ LLVM_DEBUG(StoreI->print(dbgs())); LLVM_DEBUG(dbgs() << " "); LLVM_DEBUG(LoadI->print(dbgs())); + + // Clear isKill for the StoreI Reg we reused above. It may have an implicit + // definition of the x-form that needs isKill clearing as well, so the easiest + // thing is to look over all operands + for (MachineOperand &MO : StoreI->operands()) { + if (MO.isReg() && MO.isKill() && TRI->regsOverlap(MO.getReg(), StRt)) + MO.setIsKill(false); + } + LLVM_DEBUG(dbgs() << " with instructions:\n "); LLVM_DEBUG(StoreI->print(dbgs())); LLVM_DEBUG(dbgs() << " "); Index: lib/Target/AArch64/AArch64MachineFunctionInfo.h =================================================================== --- lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -54,6 +54,7 @@ /// Amount of stack frame size used for saving callee-saved registers. unsigned CalleeSavedStackSize; + unsigned SVECalleeSavedStackSize; /// Number of TLS accesses using the special (combinable) /// _TLS_MODULE_BASE_ symbol. @@ -90,6 +91,16 @@ /// True when the callee-save stack area has unused gaps that may be used for /// other stack allocations. bool CalleeSaveStackHasFreeSpace = false; + bool SVECalleeSaveStackHasFreeSpace = false; + + /// SVE stack size (for predicates and data vectors) are maintained here + /// rather than in FrameInfo, as the placement and Stack IDs are target + /// specific. + uint64_t StackSizeSVE = 0; + + /// The SVE region gets its own alignment, separate from the regular area + /// on the stack. This means we may align the SVE region separately. + unsigned MaxAlignSVE = 16; /// Has a value when it is known whether or not the function uses a /// redzone, and no value otherwise. @@ -117,6 +128,12 @@ ArgumentStackToRestore = bytes; } + void setStackSizeSVE(uint64_t S) { StackSizeSVE = S; } + uint64_t getStackSizeSVE() const { return StackSizeSVE; } + + void setMaxAlignSVE(unsigned A) { MaxAlignSVE = A; } + unsigned getMaxAlignSVE() const { return MaxAlignSVE; } + bool hasStackFrame() const { return HasStackFrame; } void setHasStackFrame(bool s) { HasStackFrame = s; } @@ -129,6 +146,12 @@ void setCalleeSaveStackHasFreeSpace(bool s) { CalleeSaveStackHasFreeSpace = s; } + bool hasSVECalleeSaveStackFreeSpace() const { + return SVECalleeSaveStackHasFreeSpace; + } + void setSVECalleeSaveStackHasFreeSpace(bool s) { + SVECalleeSaveStackHasFreeSpace = s; + } bool isSplitCSR() const { return IsSplitCSR; } void setIsSplitCSR(bool s) { IsSplitCSR = s; } @@ -139,6 +162,13 @@ void setCalleeSavedStackSize(unsigned Size) { CalleeSavedStackSize = Size; } unsigned getCalleeSavedStackSize() const { return CalleeSavedStackSize; } + void setSVECalleeSavedStackSize(unsigned Size) { + SVECalleeSavedStackSize = Size; + } + unsigned getSVECalleeSavedStackSize() const { + return SVECalleeSavedStackSize; + } + void incNumLocalDynamicTLSAccesses() { ++NumLocalDynamicTLSAccesses; } unsigned getNumLocalDynamicTLSAccesses() const { return NumLocalDynamicTLSAccesses; Index: lib/Target/AArch64/AArch64RegisterBankInfo.cpp =================================================================== --- lib/Target/AArch64/AArch64RegisterBankInfo.cpp +++ lib/Target/AArch64/AArch64RegisterBankInfo.cpp @@ -73,6 +73,8 @@ // GR64all + its subclasses. assert(RBFPR.covers(*TRI.getRegClass(AArch64::QQRegClassID)) && "Subclass not added?"); + assert(RBFPR.covers(*TRI.getRegClass(AArch64::ZPR2RegClassID)) && + "Subclass not added?"); assert(RBFPR.covers(*TRI.getRegClass(AArch64::FPR64RegClassID)) && "Subclass not added?"); assert(RBFPR.getSize() == 512 && @@ -238,6 +240,10 @@ case AArch64::QQRegClassID: case AArch64::QQQRegClassID: case AArch64::QQQQRegClassID: + case AArch64::ZPRRegClassID: + case AArch64::ZPR2RegClassID: + case AArch64::ZPR3RegClassID: + case AArch64::ZPR4RegClassID: return getRegBank(AArch64::FPRRegBankID); case AArch64::GPR32commonRegClassID: case AArch64::GPR32RegClassID: Index: lib/Target/AArch64/AArch64RegisterBanks.td =================================================================== --- lib/Target/AArch64/AArch64RegisterBanks.td +++ lib/Target/AArch64/AArch64RegisterBanks.td @@ -13,8 +13,8 @@ /// General Purpose Registers: W, X. def GPRRegBank : RegisterBank<"GPR", [GPR64all]>; -/// Floating Point/Vector Registers: B, H, S, D, Q. -def FPRRegBank : RegisterBank<"FPR", [QQQQ]>; +/// Floating Point/Vector Registers: B, H, S, D, Q, Z. +def FPRRegBank : RegisterBank<"FPR", [QQQQ, ZPR4]>; /// Conditional register: NZCV. def CCRegBank : RegisterBank<"CC", [CCR]>; Index: lib/Target/AArch64/AArch64RegisterInfo.cpp =================================================================== --- lib/Target/AArch64/AArch64RegisterInfo.cpp +++ lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -16,6 +16,7 @@ #include "AArch64FrameLowering.h" #include "AArch64InstrInfo.h" #include "AArch64MachineFunctionInfo.h" +#include "AArch64StackOffset.h" #include "AArch64Subtarget.h" #include "MCTargetDesc/AArch64AddressingModes.h" #include "llvm/ADT/BitVector.h" @@ -24,9 +25,10 @@ #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/RegisterScavenging.h" +#include "llvm/CodeGen/TargetFrameLowering.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Function.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/CodeGen/TargetFrameLowering.h" #include "llvm/Target/TargetOptions.h" using namespace llvm; @@ -39,6 +41,60 @@ AArch64_MC::initLLVMToCVRegMapping(this); } +// Returns if Function has arguments or return value of SVE type. +// +// The (proposed) SVE AAPCS states: +// +// z0-z7 are used to pass SVE vector arguments to a subroutine and to +// return SVE vector results from a function. If a subroutine takes arguments +// in scalable vector or predicate registers, or if it is a function that +// returns results in such registers, it must ensure that the entire contents +// of z8-z31 are preserved across the call. In other cases it need only +// preserve the low 64 bits of z8-z15, as described in §5.1.2. +// +// p0-p3 are used to pass scalable predicate arguments to a subroutine and to +// return scalable predicate results from a function. If a subroutine takes +// arguments in scalable vector or predicate registers, or if it is a +// function that returns results in these registers, it must ensure that +// p4-p15 are preserved across the call. In other cases it need not preserve +// any scalable predicate register contents. +// +// We can determine this by investigating the (C/C++) function and test +// if one of the arguments (or return value) is of SVE vector +// type given some language binding for a SVE vector type. +// +// Note that this is different from checking the TargetLowering of the +// argument and testing if this is a SVE 'nxv..i..' type. + +bool +isSVEFunction(const MachineFunction& MF) { + // Loop all arguments + for (const auto &arg : MF.getFunction().args()) { + Type *Ty = arg.getType(); + // If this is a struct where the first element is a scalable + // vec, then this is considered to be an SVE function. + // This currently only occurs for 'sizeless structs' + if (auto *STy = dyn_cast(Ty)) + if (STy->getNumElements() > 0) + Ty = STy->getTypeAtIndex((unsigned)0); + + // Non-vector types cannot be SVE + if (auto *VT = dyn_cast(Ty)) + // Check if Vector is scalable (ergo, SVE) + if (VT->isScalable()) + return true; + } + + // If not yet conclusive, check return type + Type *Ty = MF.getFunction().getReturnType(); + if (auto *STy = dyn_cast(Ty)) + if (STy->getNumElements() > 0) + Ty = STy->getTypeAtIndex((unsigned)0); + + auto *VT = dyn_cast(Ty); + return VT && VT->isScalable(); +} + const MCPhysReg * AArch64RegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const { assert(MF && "Invalid MachineFunction pointer."); @@ -48,6 +104,10 @@ return CSR_AArch64_NoRegs_SaveList; if (MF->getFunction().getCallingConv() == CallingConv::AnyReg) return CSR_AArch64_AllRegs_SaveList; + else if (isSVEFunction(*MF)) + return CSR_AArch64_SVE_AAPCS_SaveList; + if (MF->getFunction().getCallingConv() == CallingConv::AArch64_VectorCall) + return CSR_AArch64_AAVPCS_SaveList; if (MF->getFunction().getCallingConv() == CallingConv::CXX_FAST_TLS) return MF->getInfo()->isSplitCSR() ? CSR_AArch64_CXX_TLS_Darwin_PE_SaveList : @@ -97,6 +157,10 @@ if (CC == CallingConv::CXX_FAST_TLS) return SCS ? CSR_AArch64_CXX_TLS_Darwin_SCS_RegMask : CSR_AArch64_CXX_TLS_Darwin_RegMask; + if (CC == CallingConv::AArch64_SVE_VectorCall) + return CSR_AArch64_SVE_AAPCS_RegMask; + if (CC == CallingConv::AArch64_VectorCall) + return SCS ? CSR_AArch64_AAVPCS_SCS_RegMask : CSR_AArch64_AAVPCS_RegMask; if (MF.getSubtarget().getTargetLowering() ->supportSwiftError() && MF.getFunction().getAttributes().hasAttrSomewhere(Attribute::SwiftError)) @@ -271,7 +335,9 @@ const MachineFrameInfo &MFI = MF.getFrameInfo(); if (MF.getTarget().Options.DisableFramePointerElim(MF) && MFI.adjustsStack()) return true; - return MFI.hasVarSizedObjects() || MFI.isFrameAddressTaken(); + const AArch64FunctionInfo *AFI = MF.getInfo(); + return MFI.hasVarSizedObjects() || MFI.isFrameAddressTaken() || + AFI->getStackSizeSVE(); } /// needsFrameBaseReg - Returns true if the instruction's frame index @@ -343,7 +409,7 @@ int64_t Offset) const { assert(Offset <= INT_MAX && "Offset too big to fit in int."); assert(MI && "Unable to get the legal offset for nil instruction."); - int SaveOffset = Offset; + StackOffset SaveOffset(Offset, MVT::i8); return isAArch64FrameOffsetLegal(*MI, SaveOffset) & AArch64FrameOffsetIsLegal; } @@ -373,9 +439,10 @@ void AArch64RegisterInfo::resolveFrameIndex(MachineInstr &MI, unsigned BaseReg, int64_t Offset) const { - int Off = Offset; // ARM doesn't need the general 64-bit offsets - unsigned i = 0; + // ARM doesn't need the general 64-bit offsets + StackOffset Off(Offset, MVT::i8); + unsigned i = 0; while (!MI.getOperand(i).isFI()) { ++i; assert(i < MI.getNumOperands() && "Instr doesn't have FrameIndex operand!"); @@ -399,36 +466,61 @@ const AArch64InstrInfo *TII = MF.getSubtarget().getInstrInfo(); const AArch64FrameLowering *TFI = getFrameLowering(MF); - int FrameIndex = MI.getOperand(FIOperandNum).getIndex(); unsigned FrameReg; - int Offset; // Special handling of dbg_value, stackmap and patchpoint instructions. if (MI.isDebugValue() || MI.getOpcode() == TargetOpcode::STACKMAP || MI.getOpcode() == TargetOpcode::PATCHPOINT) { - Offset = TFI->resolveFrameIndexReference(MF, FrameIndex, FrameReg, - /*PreferFP=*/true); - Offset += MI.getOperand(FIOperandNum + 1).getImm(); - MI.getOperand(FIOperandNum).ChangeToRegister(FrameReg, false /*isDef*/); - MI.getOperand(FIOperandNum + 1).ChangeToImmediate(Offset); + StackOffset Offset = + TFI->resolveFrameIndexReference(MF, FrameIndex, FrameReg, + /*PreferFP=*/true); + + Offset += StackOffset(MI.getOperand(FIOperandNum + 1).getImm(), MVT::i8); + MI.getOperand(FIOperandNum).ChangeToRegister(FrameReg, false /*isDef */); + MI.getOperand(FIOperandNum).setIsDebug(); + + if (MI.getOpcode() == TargetOpcode::STACKMAP || + MI.getOpcode() == TargetOpcode::PATCHPOINT) + MI.getOperand(FIOperandNum + 1) + .ChangeToImmediate(Offset.getBytes()); + + if (MI.isDebugValue()) { + const MCRegisterInfo *MRI = MF.getSubtarget().getRegisterInfo(); + + SmallVector Buffer; + AArch64FrameLowering::addVGScaledOffset(MRI, Offset, Buffer); + + auto MDExpr = cast(MI.getOperand(3).getMetadata()); + Buffer.append(MDExpr->elements_begin(), MDExpr->elements_end()); + auto *NewMD = DIExpression::get(MF.getFunction().getContext(), Buffer); + + // Set immediate 0, since the non-scalable offset is folded into the + // attached dwarf expression. + MI.RemoveOperand(3); + MI.addOperand(MachineOperand::CreateMetadata(NewMD)); + MI.getOperand(FIOperandNum + 1).ChangeToImmediate(0); + } return; } // Modify MI as necessary to handle as much of 'Offset' as possible - Offset = TFI->resolveFrameIndexReference(MF, FrameIndex, FrameReg); + StackOffset Offset = + TFI->resolveFrameIndexReference(MF, FrameIndex, FrameReg); + if (rewriteAArch64FrameIndex(MI, FIOperandNum, FrameReg, Offset, TII)) return; assert((!RS || !RS->isScavengingFrameIndex(FrameIndex)) && "Emergency spill slot is out of reach"); - // If we get here, the immediate doesn't fit into the instruction. We folded - // as much as possible above. Handle the rest, providing a register that is - // SP+LargeImm. + // If we get here, the immediate doesn't fit into the instruction. + // We folded as much as possible above. Handle the rest, providing a + // register that is SP+LargeImm. unsigned ScratchReg = MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass); - emitFrameOffset(MBB, II, MI.getDebugLoc(), ScratchReg, FrameReg, Offset, TII); + emitFrameOffset(MBB, II, MI.getDebugLoc(), ScratchReg, FrameReg, Offset, + TII, MachineInstr::NoFlags, false); MI.getOperand(FIOperandNum).ChangeToRegister(ScratchReg, false, false, true); } @@ -459,6 +551,7 @@ case AArch64::FPR32RegClassID: case AArch64::FPR64RegClassID: case AArch64::FPR128RegClassID: + case AArch64::ZPRRegClassID: return 32; case AArch64::DDRegClassID: @@ -467,6 +560,9 @@ case AArch64::QQRegClassID: case AArch64::QQQRegClassID: case AArch64::QQQQRegClassID: + case AArch64::ZPR2RegClassID: + case AArch64::ZPR3RegClassID: + case AArch64::ZPR4RegClassID: return 32; case AArch64::FPR128_loRegClassID: Index: lib/Target/AArch64/AArch64RegisterInfo.td =================================================================== --- lib/Target/AArch64/AArch64RegisterInfo.td +++ lib/Target/AArch64/AArch64RegisterInfo.td @@ -134,6 +134,9 @@ // First fault status register def FFR : AArch64Reg<0, "ffr">, DwarfRegNum<[47]>; +// Purely virtual Vector Granule (VG) Dwarf register +def VG : AArch64Reg<0, "vg">, DwarfRegNum<[46]>; + // GPR register classes with the intersections of GPR32/GPR32sp and // GPR64/GPR64sp for use by the coalescer. def GPR32common : RegisterClass<"AArch64", [i32], 32, (sequence "W%u", 0, 30)> { @@ -228,6 +231,14 @@ let isAllocatable = 0; } +// First Fault regclass +def FFRC : RegisterClass<"AArch64", [i32], 32, (add FFR)> { + let CopyCost = -1; // Don't allow copying of status registers. + + // FFR is not allocatable. + let isAllocatable = 0; +} + //===----------------------------------------------------------------------===// // Floating Point Scalar Registers //===----------------------------------------------------------------------===// @@ -618,7 +629,10 @@ let RenderMethod = "addRegOperands"; } -// Register operand versions of the scalar FP registers. +// Register operand versions of the scalar Int/FP registers. +def GPR32Op : RegisterOperand; +def GPR64Op : RegisterOperand; + def FPR8Op : RegisterOperand { let ParserMatchClass = FPRAsmOperand<"FPR8">; } @@ -842,33 +856,25 @@ //****************************************************************************** -// SVE vector register class -def ZPR : RegisterClass<"AArch64", - [nxv16i8, nxv8i16, nxv4i32, nxv2i64, - nxv2f16, nxv4f16, nxv8f16, - nxv1f32, nxv2f32, nxv4f32, - nxv1f64, nxv2f64], - 128, (sequence "Z%u", 0, 31)> { +// SVE vector register classes +class ZPRClass : RegisterClass<"AArch64", + [nxv16i8, nxv8i16, nxv4i32, nxv2i64, + nxv2f16, nxv4f16, nxv8f16, + nxv2f32, nxv4f32, + nxv2f64], + 128, (sequence "Z%u", 0, lastreg)> { let Size = 128; } -// SVE restricted 4 bit scalable vector register class -def ZPR_4b : RegisterClass<"AArch64", - [nxv16i8, nxv8i16, nxv4i32, nxv2i64, - nxv2f16, nxv4f16, nxv8f16, - nxv1f32, nxv2f32, nxv4f32, - nxv1f64, nxv2f64], - 128, (sequence "Z%u", 0, 15)> { - let Size = 128; -} +def ZPR : ZPRClass<31>; +def ZPR_4b : ZPRClass<15>; // Restricted 4 bit SVE vector register class. +def ZPR_3b : ZPRClass<7>; // Restricted 3 bit SVE vector register class. -// SVE restricted 3 bit scalable vector register class -def ZPR_3b : RegisterClass<"AArch64", - [nxv16i8, nxv8i16, nxv4i32, nxv2i64, - nxv2f16, nxv4f16, nxv8f16, - nxv1f32, nxv2f32, nxv4f32, - nxv1f64, nxv2f64], - 128, (sequence "Z%u", 0, 7)> { +// The part of SVE vector registers that don't overlap Neon registers. +// NOTE: Type needed to build but should ever be used directly. +def ZPR_HI : RegisterClass<"AArch64", + [untyped], + 128, (sequence "Z%u_HI", 0, 31)> { let Size = 128; } Index: lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- lib/Target/AArch64/AArch64SVEInstrInfo.td +++ lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -11,62 +11,278 @@ // //===----------------------------------------------------------------------===// +def SVEAddrModeRegReg8 : ComplexPattern", []>; +def SVEAddrModeRegReg16 : ComplexPattern", []>; +def SVEAddrModeRegReg32 : ComplexPattern", []>; +def SVEAddrModeRegReg64 : ComplexPattern", []>; + +def nxv2i64LslBy1 :ComplexPattern", []>; +def nxv2i64LslBy2 :ComplexPattern", []>; +def nxv2i64LslBy3 :ComplexPattern", []>; + +def SVE8BitLslImm : ComplexPattern; + +def nxv2i64UxtwLslBy3 :ComplexPattern", []>; + +def SVEUIntArithImm8 : ComplexPattern", []>; +def SVEUIntArithImm16 : ComplexPattern", []>; +def SVEUIntArithImm32 : ComplexPattern", []>; +def SVEUIntArithImm64 : ComplexPattern", []>; + +def SVELogicalImm8 : ComplexPattern", []>; +def SVELogicalImm16 : ComplexPattern", []>; +def SVELogicalImm32 : ComplexPattern", []>; +def SVELogicalImm64 : ComplexPattern", []>; + +def SVELShiftImm64 : ComplexPattern", []>; +def SVERShiftImm64 : ComplexPattern", []>; + +// Wide pseudo-immediates for pattern matching to shift-by-immediate +def SVEWideLShiftImm8 : ComplexPattern", []>; +def SVEWideLShiftImm16 : ComplexPattern", []>; +def SVEWideLShiftImm32 : ComplexPattern", []>; + +def SVEWideRShiftImm8 : ComplexPattern", []>; +def SVEWideRShiftImm16 : ComplexPattern", []>; +def SVEWideRShiftImm32 : ComplexPattern", []>; + +// SVE CNT/INC/RDVL +def sve_rdvl_imm : ComplexPattern">; +def sve_cnth_imm : ComplexPattern">; +def sve_cntw_imm : ComplexPattern">; +def sve_cntd_imm : ComplexPattern">; + +// SVE DEC +def sve_cnth_imm_neg : ComplexPattern">; +def sve_cntw_imm_neg : ComplexPattern">; +def sve_cntd_imm_neg : ComplexPattern">; + +let AddedComplexity = 1 in { + class LD1RPat : + Pat<(vt (AArch64dup (index_vt (operator (CP GPR64:$base, immtype:$offset))))), + (load (ptrue 31), GPR64:$base, $offset)>; +} + +multiclass SVETruncStore { + def : Pat<(operator (vt ZPR:$val), (CP GPR64:$base, GPR64:$offset)), + (store ZPR:$val, (ptrue 31), GPR64:$base, GPR64:$offset)>; + def : Pat<(operator (vt ZPR:$val), GPR64:$base), + (store_imm ZPR:$val, (ptrue 31), GPR64:$base, (i64 0))>; +} + +def SDT_AArch64DUP_PRED : SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisSameAs<0,1>, SDTCisVec<2>, SDTCVecEltisVT<2,i1>]>; +def AArch64dup_pred : SDNode<"AArch64ISD::DUP_PRED", SDT_AArch64DUP_PRED>; + +def SDT_AArch64Insr : SDTypeProfile<1, 2, [SDTCisVec<0>]>; +def AArch64insr : SDNode<"AArch64ISD::INSR", SDT_AArch64Insr>; + +def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>; +def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>; + +def SDT_AArch64RDFFR : SDTypeProfile<1, 0, [SDTCisVec<0>, SDTCVecEltisVT<0,i1>]>; +def AArch64rdffr : SDNode<"AArch64ISD::RDFFR", SDT_AArch64RDFFR, [SDNPHasChain]>; + +def SDT_AArch64RDFFR_PRED : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisSameAs<0,1>, SDTCVecEltisVT<0,i1>]>; +def AArch64rdffr_pred : SDNode<"AArch64ISD::RDFFR_PRED", SDT_AArch64RDFFR_PRED, [SDNPHasChain]>; + +def SDT_AArch64WRFFR : SDTypeProfile<0, 1, [SDTCisVT<0, nxv16i1>]>; +def AArch64wrffr : SDNode<"AArch64ISD::WRFFR", SDT_AArch64WRFFR, [SDNPHasChain, SDNPOutGlue]>; + +def SDT_AArch64SETFFR : SDTypeProfile<0, 0, []>; +def AArch64setffr : SDNode<"AArch64ISD::SETFFR", SDT_AArch64SETFFR, [SDNPHasChain, SDNPOutGlue]>; + +def SDT_AArch64Rev : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisSameAs<0,1>]>; +def AArch64rev : SDNode<"AArch64ISD::REV", SDT_AArch64Rev>; + +def SDT_AArch64BRKA : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>, + SDTCisSameAs<0,2>, SDTCVecEltisVT<0,i1>]>; +def AArch64brka : SDNode<"AArch64ISD::BRKA", SDT_AArch64BRKA>; + +def reinterpret_cast : SDNode<"AArch64ISD::REINTERPRET_CAST", SDTUnaryOp>; + +def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>; +def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>; +def AArch64clastb_n : SDNode<"AArch64ISD::CLASTB_N", SDT_AArch64ReduceWithInit>; +def AArch64fadda_pred : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>; + +def SDT_AArch64Reduce : SDTypeProfile<1, 2, [SDTCisVec<1>, SDTCisVec<2>]>; +def AArch64andv_pred : SDNode<"AArch64ISD::ANDV_PRED", SDT_AArch64Reduce>; +def AArch64eorv_pred : SDNode<"AArch64ISD::EORV_PRED", SDT_AArch64Reduce>; +def AArch64faddv_pred : SDNode<"AArch64ISD::FADDV_PRED", SDT_AArch64Reduce>; +def AArch64fmaxv_pred : SDNode<"AArch64ISD::FMAXV_PRED", SDT_AArch64Reduce>; +def AArch64fmaxnmv_pred : SDNode<"AArch64ISD::FMAXNMV_PRED", SDT_AArch64Reduce>; +def AArch64fminv_pred : SDNode<"AArch64ISD::FMINV_PRED", SDT_AArch64Reduce>; +def AArch64fminnmv_pred : SDNode<"AArch64ISD::FMINNMV_PRED", SDT_AArch64Reduce>; +def AArch64lasta : SDNode<"AArch64ISD::LASTA", SDT_AArch64Reduce>; +def AArch64lastb : SDNode<"AArch64ISD::LASTB", SDT_AArch64Reduce>; +def AArch64orv_pred : SDNode<"AArch64ISD::ORV_PRED", SDT_AArch64Reduce>; +def AArch64saddv_pred : SDNode<"AArch64ISD::SADDV_PRED", SDT_AArch64Reduce>; +def AArch64smaxv_pred : SDNode<"AArch64ISD::SMAXV_PRED", SDT_AArch64Reduce>; +def AArch64sminv_pred : SDNode<"AArch64ISD::SMINV_PRED", SDT_AArch64Reduce>; +def AArch64uaddv_pred : SDNode<"AArch64ISD::UADDV_PRED", SDT_AArch64Reduce>; +def AArch64umaxv_pred : SDNode<"AArch64ISD::UMAXV_PRED", SDT_AArch64Reduce>; +def AArch64uminv_pred : SDNode<"AArch64ISD::UMINV_PRED", SDT_AArch64Reduce>; + +def SDT_AArch64FMinMax :SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<2>, + SDTCisSameAs<2, 3>]>; +def AArch64fmin_pred : SDNode<"AArch64ISD::FMIN_PRED", SDT_AArch64FMinMax>; +def AArch64fminnm_pred : SDNode<"AArch64ISD::FMINNM_PRED", SDT_AArch64FMinMax>; +def AArch64fmax_pred : SDNode<"AArch64ISD::FMAX_PRED", SDT_AArch64FMinMax>; +def AArch64fmaxnm_pred : SDNode<"AArch64ISD::FMAXNM_PRED", SDT_AArch64FMinMax>; + +def SDT_AArch64_LDFF1 : SDTypeProfile<1, 3, [ + SDTCisVec<0>, SDTCisVec<1>, SDTCisPtrTy<2>, + SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0,1> +]>; + +def SDT_AArch64_GLDFF1 : SDTypeProfile<1, 4, [ + SDTCisVec<0>, SDTCisVec<1>, SDTCisPtrTy<2>, SDTCisVec<3>, SDTCisVT<4, OtherVT>, + SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0,1> +]>; + +def AArch64ldff1 : SDNode<"AArch64ISD::LDFF1", SDT_AArch64_LDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldnf1 : SDNode<"AArch64ISD::LDNF1", SDT_AArch64_LDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldff1_gather : SDNode<"AArch64ISD::GLDFF1", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldff1_gather_scaled : SDNode<"AArch64ISD::GLDFF1_SCALED", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldff1_gather_sxtw : SDNode<"AArch64ISD::GLDFF1_SXTW", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldff1_gather_uxtw : SDNode<"AArch64ISD::GLDFF1_UXTW", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldff1_gather_sxtw_scaled : SDNode<"AArch64ISD::GLDFF1_SXTW_SCALED", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldff1_gather_uxtw_scaled : SDNode<"AArch64ISD::GLDFF1_UXTW_SCALED", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +// SVE2 non temporal gather loads +def AArch64ldnt1_gather : SDNode<"AArch64ISD::GLDNT1", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad]>; +def AArch64ldnt1_gather_uxtw : SDNode<"AArch64ISD::GLDNT1_UXTW", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad]>; + +def AArch64ldff1s : SDNode<"AArch64ISD::LDFF1S", SDT_AArch64_LDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldnf1s : SDNode<"AArch64ISD::LDNF1S", SDT_AArch64_LDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldff1s_gather : SDNode<"AArch64ISD::GLDFF1S", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldff1s_gather_scaled : SDNode<"AArch64ISD::GLDFF1S_SCALED", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldff1s_gather_sxtw : SDNode<"AArch64ISD::GLDFF1S_SXTW", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldff1s_gather_uxtw : SDNode<"AArch64ISD::GLDFF1S_UXTW", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldff1s_gather_sxtw_scaled : SDNode<"AArch64ISD::GLDFF1S_SXTW_SCALED", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +def AArch64ldff1s_gather_uxtw_scaled : SDNode<"AArch64ISD::GLDFF1S_UXTW_SCALED", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad, SDNPOptInGlue]>; +// SVE2 non temporal gather loads +def AArch64ldnt1s_gather : SDNode<"AArch64ISD::GLDNT1S", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad]>; +def AArch64ldnt1s_gather_uxtw : SDNode<"AArch64ISD::GLDNT1S_UXTW", SDT_AArch64_GLDFF1, [SDNPHasChain, SDNPMayLoad]>; + +def SDT_AArch64_LD1RQ : SDTypeProfile<1, 2, [ + SDTCisVec<0>, SDTCisVec<1>, SDTCisPtrTy<2>, + SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0,1> +]>; +def AArch64ld1rq : SDNode<"AArch64ISD::LD1RQ", SDT_AArch64_LD1RQ, [SDNPHasChain, SDNPMayLoad]>; + +def SDT_AArch64_LDNT1 : SDTypeProfile<1, 3, [ + SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameAs<0,3>, + SDTCVecEltisVT<2,i1>, SDTCisSameNumEltsAs<0,2> +]>; +def AArch64ldnt1 : SDNode<"AArch64ISD::LDNT1", SDT_AArch64_LDNT1, [SDNPHasChain, SDNPMayLoad]>; + +def SDT_AArch64_STNT1 : SDTypeProfile<0, 3, [ + SDTCisPtrTy<0>, SDTCisVec<1>, SDTCisVec<2>, + SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<1,2> +]>; +def AArch64stnt1 : SDNode<"AArch64ISD::STNT1", SDT_AArch64_STNT1, [SDNPHasChain, SDNPMayStore]>; + +// SVE2 non temporal scatter stores +def SDT_AArch64_SSTNT1 : SDTypeProfile<0, 5, [ + SDTCisVec<0>, SDTCisVec<1>, SDTCisPtrTy<2>, SDTCisVec<3>, SDTCisVT<4, OtherVT>, + SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0,1>, SDTCisSameNumEltsAs<0,3> + ]>; + +def AArch64stnt1_scatter : SDNode<"AArch64ISD::SSTNT1", SDT_AArch64_SSTNT1, [SDNPHasChain, SDNPMayStore]>; +def AArch64stnt1_scatter_uxtw : SDNode<"AArch64ISD::SSTNT1_UXTW", SDT_AArch64_SSTNT1, [SDNPHasChain, SDNPMayStore]>; + +def SDT_AArch64_GPRF : SDTypeProfile< 0, 5, [ + SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisInt<3>, + SDTCVecEltisVT<0,i1>, SDTCisSameNumEltsAs<0, 2> +]>; + +def AArch64prf_gather_s_imm : SDNode<"AArch64ISD::GPRF_S_IMM", SDT_AArch64_GPRF, [SDNPHasChain, SDNPMayLoad]>; +def AArch64prf_gather_d_imm : SDNode<"AArch64ISD::GPRF_D_IMM", SDT_AArch64_GPRF, [SDNPHasChain, SDNPMayLoad]>; +def AArch64prf_gather_d_scaled : SDNode<"AArch64ISD::GPRF_D_SCALED", SDT_AArch64_GPRF, [SDNPHasChain, SDNPMayLoad]>; +def AArch64prf_gather_s_sxtw_scaled : SDNode<"AArch64ISD::GPRF_S_SXTW_SCALED", SDT_AArch64_GPRF, [SDNPHasChain, SDNPMayLoad]>; +def AArch64prf_gather_s_uxtw_scaled : SDNode<"AArch64ISD::GPRF_S_UXTW_SCALED", SDT_AArch64_GPRF, [SDNPHasChain, SDNPMayLoad]>; +def AArch64prf_gather_d_sxtw_scaled : SDNode<"AArch64ISD::GPRF_D_SXTW_SCALED", SDT_AArch64_GPRF, [SDNPHasChain, SDNPMayLoad]>; +def AArch64prf_gather_d_uxtw_scaled : SDNode<"AArch64ISD::GPRF_D_UXTW_SCALED", SDT_AArch64_GPRF, [SDNPHasChain, SDNPMayLoad]>; + let Predicates = [HasSVE] in { - def RDFFR_PPz : sve_int_rdffr_pred<0b0, "rdffr">; + defm RDFFR_PPz : sve_int_rdffr_pred<0b0, "rdffr", AArch64rdffr_pred>; def RDFFRS_PPz : sve_int_rdffr_pred<0b1, "rdffrs">; - def RDFFR_P : sve_int_rdffr_unpred<"rdffr">; - def SETFFR : sve_int_setffr<"setffr">; - def WRFFR : sve_int_wrffr<"wrffr">; + defm RDFFR_P : sve_int_rdffr_unpred<"rdffr", AArch64rdffr>; + def SETFFR : sve_int_setffr<"setffr", AArch64setffr>; + def WRFFR : sve_int_wrffr<"wrffr", AArch64wrffr>; - defm ADD_ZZZ : sve_int_bin_cons_arit_0<0b000, "add">; - defm SUB_ZZZ : sve_int_bin_cons_arit_0<0b001, "sub">; - defm SQADD_ZZZ : sve_int_bin_cons_arit_0<0b100, "sqadd">; - defm UQADD_ZZZ : sve_int_bin_cons_arit_0<0b101, "uqadd">; - defm SQSUB_ZZZ : sve_int_bin_cons_arit_0<0b110, "sqsub">; - defm UQSUB_ZZZ : sve_int_bin_cons_arit_0<0b111, "uqsub">; + defm ADD_ZZZ : sve_int_bin_cons_arit_0<0b000, "add", add>; + defm SUB_ZZZ : sve_int_bin_cons_arit_0<0b001, "sub", sub>; + defm SQADD_ZZZ : sve_int_bin_cons_arit_0<0b100, "sqadd", int_aarch64_sve_sqadd_x>; + defm UQADD_ZZZ : sve_int_bin_cons_arit_0<0b101, "uqadd", int_aarch64_sve_uqadd_x>; + defm SQSUB_ZZZ : sve_int_bin_cons_arit_0<0b110, "sqsub", int_aarch64_sve_sqsub_x>; + defm UQSUB_ZZZ : sve_int_bin_cons_arit_0<0b111, "uqsub", int_aarch64_sve_uqsub_x>; - def AND_ZZZ : sve_int_bin_cons_log<0b00, "and">; - def ORR_ZZZ : sve_int_bin_cons_log<0b01, "orr">; - def EOR_ZZZ : sve_int_bin_cons_log<0b10, "eor">; - def BIC_ZZZ : sve_int_bin_cons_log<0b11, "bic">; + defm AND_ZZZ : sve_int_bin_cons_log<0b00, "and">; + defm ORR_ZZZ : sve_int_bin_cons_log<0b01, "orr">; + defm EOR_ZZZ : sve_int_bin_cons_log<0b10, "eor">; + defm BIC_ZZZ : sve_int_bin_cons_log<0b11, "bic">; - defm ADD_ZPmZ : sve_int_bin_pred_arit_0<0b000, "add">; - defm SUB_ZPmZ : sve_int_bin_pred_arit_0<0b001, "sub">; - defm SUBR_ZPmZ : sve_int_bin_pred_arit_0<0b011, "subr">; + defm ADD_ZPmZ : sve_int_bin_pred_arit_0<0b000, "add", "ADD_ZPZZ", int_aarch64_sve_add, DestructiveBinaryComm>; + defm SUB_ZPmZ : sve_int_bin_pred_arit_0<0b001, "sub", "SUB_ZPZZ", int_aarch64_sve_sub, DestructiveBinaryCommWithRev, "SUBR_ZPmZ", 1>; + defm SUBR_ZPmZ : sve_int_bin_pred_arit_0<0b011, "subr", "SUBR_ZPZZ", int_aarch64_sve_subr, DestructiveBinaryCommWithRev, "SUB_ZPmZ", 0>; - defm ORR_ZPmZ : sve_int_bin_pred_log<0b000, "orr">; - defm EOR_ZPmZ : sve_int_bin_pred_log<0b001, "eor">; - defm AND_ZPmZ : sve_int_bin_pred_log<0b010, "and">; - defm BIC_ZPmZ : sve_int_bin_pred_log<0b011, "bic">; + defm ADD_ZPZZ : sve_int_bin_pred_zx; + defm SUB_ZPZZ : sve_int_bin_pred_zx; + defm SUBR_ZPZZ : sve_int_bin_pred_zx; - defm ADD_ZI : sve_int_arith_imm0<0b000, "add">; - defm SUB_ZI : sve_int_arith_imm0<0b001, "sub">; - defm SUBR_ZI : sve_int_arith_imm0<0b011, "subr">; - defm SQADD_ZI : sve_int_arith_imm0<0b100, "sqadd">; - defm UQADD_ZI : sve_int_arith_imm0<0b101, "uqadd">; - defm SQSUB_ZI : sve_int_arith_imm0<0b110, "sqsub">; - defm UQSUB_ZI : sve_int_arith_imm0<0b111, "uqsub">; + defm ORR_ZPmZ : sve_int_bin_pred_log<0b000, "orr", "ORR_ZPZZ", int_aarch64_sve_orr, DestructiveBinaryComm>; + defm EOR_ZPmZ : sve_int_bin_pred_log<0b001, "eor", "EOR_ZPZZ", int_aarch64_sve_eor, DestructiveBinaryComm>; + defm AND_ZPmZ : sve_int_bin_pred_log<0b010, "and", "AND_ZPZZ", int_aarch64_sve_and, DestructiveBinaryComm>; + defm BIC_ZPmZ : sve_int_bin_pred_log<0b011, "bic", "BIC_ZPZZ", int_aarch64_sve_bic, DestructiveBinary>; - defm MAD_ZPmZZ : sve_int_mladdsub_vvv_pred<0b0, "mad">; - defm MSB_ZPmZZ : sve_int_mladdsub_vvv_pred<0b1, "msb">; - defm MLA_ZPmZZ : sve_int_mlas_vvv_pred<0b0, "mla">; - defm MLS_ZPmZZ : sve_int_mlas_vvv_pred<0b1, "mls">; + defm ORR_ZPZZ : sve_int_bin_pred_zx; + defm EOR_ZPZZ : sve_int_bin_pred_zx; + defm AND_ZPZZ : sve_int_bin_pred_zx; + defm BIC_ZPZZ : sve_int_bin_pred_noncomm_zx; + + defm ADD_ZI : sve_int_arith_imm0<0b000, "ADD_ZZI", "add">; + defm SUB_ZI : sve_int_arith_imm0<0b001, "SUB_ZZI", "sub">; + defm SUBR_ZI : sve_int_arith_imm0<0b011, "SUBR_ZZI", "subr">; + defm SQADD_ZI : sve_int_arith_imm0<0b100, "SQADD_ZZI", "sqadd">; + defm UQADD_ZI : sve_int_arith_imm0<0b101, "UQADD_ZZI", "uqadd">; + defm SQSUB_ZI : sve_int_arith_imm0<0b110, "SQSUB_ZZI", "sqsub">; + defm UQSUB_ZI : sve_int_arith_imm0<0b111, "UQSUB_ZZI", "uqsub">; + + defm ADD_ZZI : sve_int_arith_imm0_zzi; + defm SUB_ZZI : sve_int_arith_imm0_zzi; + defm SUBR_ZZI : sve_int_arith_imm0_zzi; + defm SQADD_ZZI : sve_int_arith_imm0_zzi; + defm UQADD_ZZI : sve_int_arith_imm0_zzi; + defm SQSUB_ZZI : sve_int_arith_imm0_zzi; + defm UQSUB_ZZI : sve_int_arith_imm0_zzi; + + defm MAD_ZPmZZ : sve_int_mladdsub_vvv_pred<0b0, "mad", int_aarch64_sve_mad, add, "MLA_ZPmZZ", 1>; + defm MSB_ZPmZZ : sve_int_mladdsub_vvv_pred<0b1, "msb", int_aarch64_sve_msb, sub, "MLS_ZPmZZ", 1>; + defm MLA_ZPmZZ : sve_int_mlas_vvv_pred<0b0, "mla", "MLA_ZPZZZ", int_aarch64_sve_mla, add, "MAD_ZPmZZ", 0>; + defm MLS_ZPmZZ : sve_int_mlas_vvv_pred<0b1, "mls", "MLS_ZPZZZ", int_aarch64_sve_mls, sub, "MSB_ZPmZZ", 0>; + + defm MLA_ZPZZZ : sve_int_ternary_pred_zx; + defm MLS_ZPZZZ : sve_int_ternary_pred_zx; // SVE predicated integer reductions. - defm SADDV_VPZ : sve_int_reduce_0_saddv<0b000, "saddv">; - defm UADDV_VPZ : sve_int_reduce_0_uaddv<0b001, "uaddv">; - defm SMAXV_VPZ : sve_int_reduce_1<0b000, "smaxv">; - defm UMAXV_VPZ : sve_int_reduce_1<0b001, "umaxv">; - defm SMINV_VPZ : sve_int_reduce_1<0b010, "sminv">; - defm UMINV_VPZ : sve_int_reduce_1<0b011, "uminv">; - defm ORV_VPZ : sve_int_reduce_2<0b000, "orv">; - defm EORV_VPZ : sve_int_reduce_2<0b001, "eorv">; - defm ANDV_VPZ : sve_int_reduce_2<0b010, "andv">; + defm SADDV_VPZ : sve_int_reduce_0_saddv<0b000, "saddv", AArch64saddv_pred>; + defm UADDV_VPZ : sve_int_reduce_0_uaddv<0b001, "uaddv", AArch64uaddv_pred>; + defm SMAXV_VPZ : sve_int_reduce_1<0b000, "smaxv", AArch64smaxv_pred>; + defm UMAXV_VPZ : sve_int_reduce_1<0b001, "umaxv", AArch64umaxv_pred>; + defm SMINV_VPZ : sve_int_reduce_1<0b010, "sminv", AArch64sminv_pred>; + defm UMINV_VPZ : sve_int_reduce_1<0b011, "uminv", AArch64uminv_pred>; + defm ORV_VPZ : sve_int_reduce_2<0b000, "orv", AArch64orv_pred>; + defm EORV_VPZ : sve_int_reduce_2<0b001, "eorv", AArch64eorv_pred>; + defm ANDV_VPZ : sve_int_reduce_2<0b010, "andv", AArch64andv_pred>; - defm ORR_ZI : sve_int_log_imm<0b00, "orr", "orn">; - defm EOR_ZI : sve_int_log_imm<0b01, "eor", "eon">; - defm AND_ZI : sve_int_log_imm<0b10, "and", "bic">; + defm ORR_ZI : sve_int_log_imm<0b00, "orr", or, "orn">; + defm EOR_ZI : sve_int_log_imm<0b01, "eor", xor, "eon">; + defm AND_ZI : sve_int_log_imm<0b10, "and", and, "bic">; defm SMAX_ZI : sve_int_arith_imm1<0b00, "smax", simm8>; defm SMIN_ZI : sve_int_arith_imm1<0b10, "smin", simm8>; @@ -74,108 +290,152 @@ defm UMIN_ZI : sve_int_arith_imm1<0b11, "umin", imm0_255>; defm MUL_ZI : sve_int_arith_imm2<"mul">; - defm MUL_ZPmZ : sve_int_bin_pred_arit_2<0b000, "mul">; - defm SMULH_ZPmZ : sve_int_bin_pred_arit_2<0b010, "smulh">; - defm UMULH_ZPmZ : sve_int_bin_pred_arit_2<0b011, "umulh">; + defm MUL_ZPmZ : sve_int_bin_pred_arit_2<0b000, "mul", "MUL_ZPZZ", int_aarch64_sve_mul>; + defm SMULH_ZPmZ : sve_int_bin_pred_arit_2<0b010, "smulh", "SMULH_ZPZZ", int_aarch64_sve_smulh>; + defm UMULH_ZPmZ : sve_int_bin_pred_arit_2<0b011, "umulh", "UMULH_ZPZZ", int_aarch64_sve_umulh>; - defm SDIV_ZPmZ : sve_int_bin_pred_arit_2_div<0b100, "sdiv">; - defm UDIV_ZPmZ : sve_int_bin_pred_arit_2_div<0b101, "udiv">; - defm SDIVR_ZPmZ : sve_int_bin_pred_arit_2_div<0b110, "sdivr">; - defm UDIVR_ZPmZ : sve_int_bin_pred_arit_2_div<0b111, "udivr">; + defm MUL_ZPZZ : sve_int_bin_pred_zx; + defm SMULH_ZPZZ : sve_int_bin_pred_zx; + defm UMULH_ZPZZ : sve_int_bin_pred_zx; - defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot">; - defm UDOT_ZZZ : sve_intx_dot<0b1, "udot">; + defm SDIV_ZPmZ : sve_int_bin_pred_arit_2_div<0b100, "sdiv", "SDIV_ZPZZ", int_aarch64_sve_sdiv, "SDIVR_ZPmZ", 1>; + defm UDIV_ZPmZ : sve_int_bin_pred_arit_2_div<0b101, "udiv", "UDIV_ZPZZ", int_aarch64_sve_udiv, "UDIVR_ZPmZ", 1>; + defm SDIVR_ZPmZ : sve_int_bin_pred_arit_2_div<0b110, "sdivr", "SDIVR_ZPZZ", int_aarch64_sve_sdivr, "SDIV_ZPmZ", 0>; + defm UDIVR_ZPmZ : sve_int_bin_pred_arit_2_div<0b111, "udivr", "UDIVR_ZPZZ", int_aarch64_sve_udivr, "UDIV_ZPmZ", 0>; - defm SDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b0, "sdot">; - defm UDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b1, "udot">; + defm SDIV_ZPZZ : sve_int_bin_pred_arit_2_div_zx; + defm UDIV_ZPZZ : sve_int_bin_pred_arit_2_div_zx; + defm SDIVR_ZPZZ : sve_int_bin_pred_arit_2_div_zx; + defm UDIVR_ZPZZ : sve_int_bin_pred_arit_2_div_zx; - defm SXTB_ZPmZ : sve_int_un_pred_arit_0_h<0b000, "sxtb">; - defm UXTB_ZPmZ : sve_int_un_pred_arit_0_h<0b001, "uxtb">; - defm SXTH_ZPmZ : sve_int_un_pred_arit_0_w<0b010, "sxth">; - defm UXTH_ZPmZ : sve_int_un_pred_arit_0_w<0b011, "uxth">; - defm SXTW_ZPmZ : sve_int_un_pred_arit_0_d<0b100, "sxtw">; - defm UXTW_ZPmZ : sve_int_un_pred_arit_0_d<0b101, "uxtw">; - defm ABS_ZPmZ : sve_int_un_pred_arit_0< 0b110, "abs">; - defm NEG_ZPmZ : sve_int_un_pred_arit_0< 0b111, "neg">; + defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", int_aarch64_sve_sdot>; + defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", int_aarch64_sve_udot>; - defm CLS_ZPmZ : sve_int_un_pred_arit_1< 0b000, "cls">; - defm CLZ_ZPmZ : sve_int_un_pred_arit_1< 0b001, "clz">; - defm CNT_ZPmZ : sve_int_un_pred_arit_1< 0b010, "cnt">; - defm CNOT_ZPmZ : sve_int_un_pred_arit_1< 0b011, "cnot">; - defm NOT_ZPmZ : sve_int_un_pred_arit_1< 0b110, "not">; - defm FABS_ZPmZ : sve_int_un_pred_arit_1_fp<0b100, "fabs">; - defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg">; + defm SDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b0, "sdot", int_aarch64_sve_sdot_lane>; + defm UDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b1, "udot", int_aarch64_sve_udot_lane>; - defm SMAX_ZPmZ : sve_int_bin_pred_arit_1<0b000, "smax">; - defm UMAX_ZPmZ : sve_int_bin_pred_arit_1<0b001, "umax">; - defm SMIN_ZPmZ : sve_int_bin_pred_arit_1<0b010, "smin">; - defm UMIN_ZPmZ : sve_int_bin_pred_arit_1<0b011, "umin">; - defm SABD_ZPmZ : sve_int_bin_pred_arit_1<0b100, "sabd">; - defm UABD_ZPmZ : sve_int_bin_pred_arit_1<0b101, "uabd">; + defm SXTB_ZPmZ : sve_int_un_pred_arit_0_h<0b000, "sxtb", int_aarch64_sve_sxtb>; + defm UXTB_ZPmZ : sve_int_un_pred_arit_0_h<0b001, "uxtb", int_aarch64_sve_uxtb>; + defm SXTH_ZPmZ : sve_int_un_pred_arit_0_w<0b010, "sxth", int_aarch64_sve_sxth>; + defm UXTH_ZPmZ : sve_int_un_pred_arit_0_w<0b011, "uxth", int_aarch64_sve_uxth>; + defm SXTW_ZPmZ : sve_int_un_pred_arit_0_d<0b100, "sxtw", int_aarch64_sve_sxtw>; + defm UXTW_ZPmZ : sve_int_un_pred_arit_0_d<0b101, "uxtw", int_aarch64_sve_uxtw>; + defm ABS_ZPmZ : sve_int_un_pred_arit_0< 0b110, "abs", int_aarch64_sve_abs>; + defm NEG_ZPmZ : sve_int_un_pred_arit_0< 0b111, "neg", int_aarch64_sve_neg>; - defm FRECPE_ZZ : sve_fp_2op_u_zd<0b110, "frecpe">; - defm FRSQRTE_ZZ : sve_fp_2op_u_zd<0b111, "frsqrte">; + defm CLS_ZPmZ : sve_int_un_pred_arit_1< 0b000, "cls", int_aarch64_sve_cls>; + defm CLZ_ZPmZ : sve_int_un_pred_arit_1< 0b001, "clz", int_aarch64_sve_clz>; + defm CNT_ZPmZ : sve_int_un_pred_arit_1< 0b010, "cnt", int_aarch64_sve_cnt>; + defm CNOT_ZPmZ : sve_int_un_pred_arit_1< 0b011, "cnot", int_aarch64_sve_cnot>; + defm NOT_ZPmZ : sve_int_un_pred_arit_1< 0b110, "not", int_aarch64_sve_not>; + defm FABS_ZPmZ : sve_int_un_pred_arit_1_fp<0b100, "fabs", int_aarch64_sve_fabs>; + defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg", int_aarch64_sve_fneg>; - defm FADD_ZPmI : sve_fp_2op_i_p_zds<0b000, "fadd", sve_fpimm_half_one>; - defm FSUB_ZPmI : sve_fp_2op_i_p_zds<0b001, "fsub", sve_fpimm_half_one>; - defm FMUL_ZPmI : sve_fp_2op_i_p_zds<0b010, "fmul", sve_fpimm_half_two>; - defm FSUBR_ZPmI : sve_fp_2op_i_p_zds<0b011, "fsubr", sve_fpimm_half_one>; - defm FMAXNM_ZPmI : sve_fp_2op_i_p_zds<0b100, "fmaxnm", sve_fpimm_zero_one>; - defm FMINNM_ZPmI : sve_fp_2op_i_p_zds<0b101, "fminnm", sve_fpimm_zero_one>; - defm FMAX_ZPmI : sve_fp_2op_i_p_zds<0b110, "fmax", sve_fpimm_zero_one>; - defm FMIN_ZPmI : sve_fp_2op_i_p_zds<0b111, "fmin", sve_fpimm_zero_one>; + defm SMAX_ZPmZ : sve_int_bin_pred_arit_1<0b000, "smax", "SMAX_ZPZZ", int_aarch64_sve_smax>; + defm UMAX_ZPmZ : sve_int_bin_pred_arit_1<0b001, "umax", "UMAX_ZPZZ", int_aarch64_sve_umax>; + defm SMIN_ZPmZ : sve_int_bin_pred_arit_1<0b010, "smin", "SMIN_ZPZZ", int_aarch64_sve_smin>; + defm UMIN_ZPmZ : sve_int_bin_pred_arit_1<0b011, "umin", "UMIN_ZPZZ", int_aarch64_sve_umin>; + defm SABD_ZPmZ : sve_int_bin_pred_arit_1<0b100, "sabd", "SABD_ZPZZ", int_aarch64_sve_sabd>; + defm UABD_ZPmZ : sve_int_bin_pred_arit_1<0b101, "uabd", "UABD_ZPZZ", int_aarch64_sve_uabd>; - defm FADD_ZPmZ : sve_fp_2op_p_zds<0b0000, "fadd">; - defm FSUB_ZPmZ : sve_fp_2op_p_zds<0b0001, "fsub">; - defm FMUL_ZPmZ : sve_fp_2op_p_zds<0b0010, "fmul">; - defm FSUBR_ZPmZ : sve_fp_2op_p_zds<0b0011, "fsubr">; - defm FMAXNM_ZPmZ : sve_fp_2op_p_zds<0b0100, "fmaxnm">; - defm FMINNM_ZPmZ : sve_fp_2op_p_zds<0b0101, "fminnm">; - defm FMAX_ZPmZ : sve_fp_2op_p_zds<0b0110, "fmax">; - defm FMIN_ZPmZ : sve_fp_2op_p_zds<0b0111, "fmin">; - defm FABD_ZPmZ : sve_fp_2op_p_zds<0b1000, "fabd">; - defm FSCALE_ZPmZ : sve_fp_2op_p_zds<0b1001, "fscale">; - defm FMULX_ZPmZ : sve_fp_2op_p_zds<0b1010, "fmulx">; - defm FDIVR_ZPmZ : sve_fp_2op_p_zds<0b1100, "fdivr">; - defm FDIV_ZPmZ : sve_fp_2op_p_zds<0b1101, "fdiv">; + defm SMAX_ZPZZ : sve_int_bin_pred_zx; + defm UMAX_ZPZZ : sve_int_bin_pred_zx; + defm SMIN_ZPZZ : sve_int_bin_pred_zx; + defm UMIN_ZPZZ : sve_int_bin_pred_zx; + defm SABD_ZPZZ : sve_int_bin_pred_zx; + defm UABD_ZPZZ : sve_int_bin_pred_zx; - defm FADD_ZZZ : sve_fp_3op_u_zd<0b000, "fadd">; - defm FSUB_ZZZ : sve_fp_3op_u_zd<0b001, "fsub">; - defm FMUL_ZZZ : sve_fp_3op_u_zd<0b010, "fmul">; - defm FTSMUL_ZZZ : sve_fp_3op_u_zd<0b011, "ftsmul">; - defm FRECPS_ZZZ : sve_fp_3op_u_zd<0b110, "frecps">; - defm FRSQRTS_ZZZ : sve_fp_3op_u_zd<0b111, "frsqrts">; + defm FRECPE_ZZ : sve_fp_2op_u_zd<0b110, "frecpe", int_aarch64_sve_frecpe_x>; + defm FRSQRTE_ZZ : sve_fp_2op_u_zd<0b111, "frsqrte", int_aarch64_sve_frsqrte_x>; - defm FTSSEL_ZZZ : sve_int_bin_cons_misc_0_b<"ftssel">; + defm FADD_ZPmI : sve_fp_2op_i_p_zds<0b000, "fadd", "FADD_ZPZI", sve_fpimm_half_one>; + defm FSUB_ZPmI : sve_fp_2op_i_p_zds<0b001, "fsub", "FSUB_ZPZI", sve_fpimm_half_one>; + defm FMUL_ZPmI : sve_fp_2op_i_p_zds<0b010, "fmul", "FMUL_ZPZI", sve_fpimm_half_two>; + defm FSUBR_ZPmI : sve_fp_2op_i_p_zds<0b011, "fsubr", "FSUBR_ZPZI", sve_fpimm_half_one>; + defm FMAXNM_ZPmI : sve_fp_2op_i_p_zds<0b100, "fmaxnm", "FMAXNM_ZPZI", sve_fpimm_zero_one>; + defm FMINNM_ZPmI : sve_fp_2op_i_p_zds<0b101, "fminnm", "FMINNM_ZPZI", sve_fpimm_zero_one>; + defm FMAX_ZPmI : sve_fp_2op_i_p_zds<0b110, "fmax", "FMAX_ZPZI", sve_fpimm_zero_one>; + defm FMIN_ZPmI : sve_fp_2op_i_p_zds<0b111, "fmin", "FMIN_ZPZI", sve_fpimm_zero_one>; - defm FCADD_ZPmZ : sve_fp_fcadd<"fcadd">; - defm FCMLA_ZPmZZ : sve_fp_fcmla<"fcmla">; + defm FADD_ZPZI : sve_fp_2op_i_p_zds_zx; + defm FSUB_ZPZI : sve_fp_2op_i_p_zds_zx; + defm FMUL_ZPZI : sve_fp_2op_i_p_zds_zx; + defm FSUBR_ZPZI : sve_fp_2op_i_p_zds_zx; + defm FMAXNM_ZPZI : sve_fp_2op_i_p_zds_zx; + defm FMINNM_ZPZI : sve_fp_2op_i_p_zds_zx; + defm FMAX_ZPZI : sve_fp_2op_i_p_zds_zx; + defm FMIN_ZPZI : sve_fp_2op_i_p_zds_zx; - defm FMLA_ZPmZZ : sve_fp_3op_p_zds_a<0b00, "fmla">; - defm FMLS_ZPmZZ : sve_fp_3op_p_zds_a<0b01, "fmls">; - defm FNMLA_ZPmZZ : sve_fp_3op_p_zds_a<0b10, "fnmla">; - defm FNMLS_ZPmZZ : sve_fp_3op_p_zds_a<0b11, "fnmls">; + defm FADD_ZPmZ : sve_fp_2op_p_zds<0b0000, "fadd", "FADD_ZPZZ", int_aarch64_sve_fadd, DestructiveBinaryComm>; + defm FSUB_ZPmZ : sve_fp_2op_p_zds<0b0001, "fsub", "FSUB_ZPZZ", int_aarch64_sve_fsub, DestructiveBinaryCommWithRev, "FSUBR_ZPmZ", 1>; + defm FMUL_ZPmZ : sve_fp_2op_p_zds<0b0010, "fmul", "FMUL_ZPZZ", int_aarch64_sve_fmul, DestructiveBinaryComm>; + defm FSUBR_ZPmZ : sve_fp_2op_p_zds<0b0011, "fsubr", "FSUBR_ZPZZ", int_aarch64_sve_fsubr, DestructiveBinaryCommWithRev, "FSUB_ZPmZ", 0>; + defm FMAXNM_ZPmZ : sve_fp_2op_p_zds<0b0100, "fmaxnm", "FMAXNM_ZPZZ", AArch64fmaxnm_pred, DestructiveBinaryComm>; + defm FMINNM_ZPmZ : sve_fp_2op_p_zds<0b0101, "fminnm", "FMINNM_ZPZZ", AArch64fminnm_pred, DestructiveBinaryComm>; + defm FMAX_ZPmZ : sve_fp_2op_p_zds<0b0110, "fmax", "FMAX_ZPZZ", AArch64fmax_pred, DestructiveBinaryComm>; + defm FMIN_ZPmZ : sve_fp_2op_p_zds<0b0111, "fmin", "FMIN_ZPZZ", AArch64fmin_pred, DestructiveBinaryComm>; + defm FABD_ZPmZ : sve_fp_2op_p_zds<0b1000, "fabd", "FABD_ZPZZ", int_aarch64_sve_fabd, DestructiveBinaryComm>; + defm FSCALE_ZPmZ : sve_fp_2op_p_zds_fscale<0b1001, "fscale", "FSCALE_ZPZZ", int_aarch64_sve_fscale>; + defm FMULX_ZPmZ : sve_fp_2op_p_zds<0b1010, "fmulx", "FMULX_ZPZZ", int_aarch64_sve_fmulx, DestructiveBinaryComm>; + defm FDIVR_ZPmZ : sve_fp_2op_p_zds<0b1100, "fdivr", "FDIVR_ZPZZ", int_aarch64_sve_fdivr, DestructiveBinaryCommWithRev, "FDIV_ZPmZ", 0>; + defm FDIV_ZPmZ : sve_fp_2op_p_zds<0b1101, "fdiv", "FDIV_ZPZZ", int_aarch64_sve_fdiv, DestructiveBinaryCommWithRev, "FDIVR_ZPmZ", 1>; - defm FMAD_ZPmZZ : sve_fp_3op_p_zds_b<0b00, "fmad">; - defm FMSB_ZPmZZ : sve_fp_3op_p_zds_b<0b01, "fmsb">; - defm FNMAD_ZPmZZ : sve_fp_3op_p_zds_b<0b10, "fnmad">; - defm FNMSB_ZPmZZ : sve_fp_3op_p_zds_b<0b11, "fnmsb">; + defm FADD_ZPZZ : sve_fp_2op_p_zds_zx; + defm FSUB_ZPZZ : sve_fp_2op_p_zds_zx; + defm FMUL_ZPZZ : sve_fp_2op_p_zds_zx; + defm FSUBR_ZPZZ : sve_fp_2op_p_zds_zx; + defm FMAXNM_ZPZZ : sve_fp_2op_p_zds_zx; + defm FMINNM_ZPZZ : sve_fp_2op_p_zds_zx; + defm FMAX_ZPZZ : sve_fp_2op_p_zds_zx; + defm FMIN_ZPZZ : sve_fp_2op_p_zds_zx; + defm FABD_ZPZZ : sve_fp_2op_p_zds_zx; + defm FSCALE_ZPZZ : sve_fp_2op_p_zds_fscale_zx; + defm FMULX_ZPZZ : sve_fp_2op_p_zds_zx; + defm FDIVR_ZPZZ : sve_fp_2op_p_zds_zx; + defm FDIV_ZPZZ : sve_fp_2op_p_zds_zx; - defm FTMAD_ZZI : sve_fp_ftmad<"ftmad">; + defm FADD_ZZZ : sve_fp_3op_u_zd<0b000, "fadd", fadd>; + defm FSUB_ZZZ : sve_fp_3op_u_zd<0b001, "fsub", fsub>; + defm FMUL_ZZZ : sve_fp_3op_u_zd<0b010, "fmul", fmul>; + defm FTSMUL_ZZZ : sve_fp_3op_u_zd_ftsmul<0b011, "ftsmul", int_aarch64_sve_ftsmul_x>; + defm FRECPS_ZZZ : sve_fp_3op_u_zd<0b110, "frecps", int_aarch64_sve_frecps_x>; + defm FRSQRTS_ZZZ : sve_fp_3op_u_zd<0b111, "frsqrts", int_aarch64_sve_frsqrts_x>; - defm FMLA_ZZZI : sve_fp_fma_by_indexed_elem<0b0, "fmla">; - defm FMLS_ZZZI : sve_fp_fma_by_indexed_elem<0b1, "fmls">; + defm FTSSEL_ZZZ : sve_int_bin_cons_misc_0_b<"ftssel", int_aarch64_sve_ftssel_x>; - defm FCMLA_ZZZI : sve_fp_fcmla_by_indexed_elem<"fcmla">; - defm FMUL_ZZZI : sve_fp_fmul_by_indexed_elem<"fmul">; + defm FCADD_ZPmZ : sve_fp_fcadd<"fcadd", int_aarch64_sve_fcadd>; + defm FCMLA_ZPmZZ : sve_fp_fcmla<"fcmla", int_aarch64_sve_fcmla>; + + defm FMLA_ZPmZZ : sve_fp_3op_p_zds_a<0b00, "fmla", "FMLA_ZPZZZ", int_aarch64_sve_fmla, "FMAD_ZPmZZ", 1>; + defm FMLS_ZPmZZ : sve_fp_3op_p_zds_a<0b01, "fmls", "FMLS_ZPZZZ", int_aarch64_sve_fmls, "FMSB_ZPmZZ", 1>; + defm FNMLA_ZPmZZ : sve_fp_3op_p_zds_a<0b10, "fnmla", "FNMLA_ZPZZZ", int_aarch64_sve_fnmla, "FNMAD_ZPmZZ", 1>; + defm FNMLS_ZPmZZ : sve_fp_3op_p_zds_a<0b11, "fnmls", "FNMLS_ZPZZZ", int_aarch64_sve_fnmls, "FNMSB_ZPmZZ", 1>; + + defm FMAD_ZPmZZ : sve_fp_3op_p_zds_b<0b00, "fmad", int_aarch64_sve_fmad, "FMLA_ZPmZZ", 0>; + defm FMSB_ZPmZZ : sve_fp_3op_p_zds_b<0b01, "fmsb", int_aarch64_sve_fmsb, "FMLS_ZPmZZ", 0>; + defm FNMAD_ZPmZZ : sve_fp_3op_p_zds_b<0b10, "fnmad", int_aarch64_sve_fnmad, "FNMLA_ZPmZZ", 0>; + defm FNMSB_ZPmZZ : sve_fp_3op_p_zds_b<0b11, "fnmsb", int_aarch64_sve_fnmsb, "FNMLS_ZPmZZ", 0>; + + defm FMLA_ZPZZZ : sve_fp_3op_p_zds_zx; + defm FMLS_ZPZZZ : sve_fp_3op_p_zds_zx; + defm FNMLA_ZPZZZ : sve_fp_3op_p_zds_zx; + defm FNMLS_ZPZZZ : sve_fp_3op_p_zds_zx; + + defm FTMAD_ZZI : sve_fp_ftmad<"ftmad", int_aarch64_sve_ftmad_x>; + + defm FMLA_ZZZI : sve_fp_fma_by_indexed_elem<0b0, "fmla", int_aarch64_sve_fmla_lane>; + defm FMLS_ZZZI : sve_fp_fma_by_indexed_elem<0b1, "fmls", int_aarch64_sve_fmls_lane>; + + defm FCMLA_ZZZI : sve_fp_fcmla_by_indexed_elem<"fcmla", int_aarch64_sve_fcmla_lane>; + defm FMUL_ZZZI : sve_fp_fmul_by_indexed_elem<"fmul", int_aarch64_sve_fmul_lane>; // SVE floating point reductions. - defm FADDA_VPZ : sve_fp_2op_p_vd<0b000, "fadda">; - defm FADDV_VPZ : sve_fp_fast_red<0b000, "faddv">; - defm FMAXNMV_VPZ : sve_fp_fast_red<0b100, "fmaxnmv">; - defm FMINNMV_VPZ : sve_fp_fast_red<0b101, "fminnmv">; - defm FMAXV_VPZ : sve_fp_fast_red<0b110, "fmaxv">; - defm FMINV_VPZ : sve_fp_fast_red<0b111, "fminv">; + defm FADDA_VPZ : sve_fp_2op_p_vd<0b000, "fadda", AArch64fadda_pred>; + defm FADDV_VPZ : sve_fp_fast_red<0b000, "faddv", AArch64faddv_pred>; + defm FMAXNMV_VPZ : sve_fp_fast_red<0b100, "fmaxnmv", AArch64fmaxnmv_pred>; + defm FMINNMV_VPZ : sve_fp_fast_red<0b101, "fminnmv", AArch64fminnmv_pred>; + defm FMAXV_VPZ : sve_fp_fast_red<0b110, "fmaxv", AArch64fmaxv_pred>; + defm FMINV_VPZ : sve_fp_fast_red<0b111, "fminv", AArch64fminv_pred>; // Splat immediate (unpredicated) defm DUP_ZI : sve_int_dup_imm<"dup">; @@ -192,33 +452,33 @@ defm DUP_ZZI : sve_int_perm_dup_i<"dup">; // Splat scalar register (predicated) - defm CPY_ZPmR : sve_int_perm_cpy_r<"cpy">; - defm CPY_ZPmV : sve_int_perm_cpy_v<"cpy">; + defm CPY_ZPmR : sve_int_perm_cpy_r<"cpy", AArch64dup_pred>; + defm CPY_ZPmV : sve_int_perm_cpy_v<"cpy", AArch64dup_pred>; // Select elements from either vector (predicated) - defm SEL_ZPZZ : sve_int_sel_vvv<"sel">; + defm SEL_ZPZZ : sve_int_sel_vvv<"sel", vselect>; - defm SPLICE_ZPZ : sve_int_perm_splice<"splice">; - defm COMPACT_ZPZ : sve_int_perm_compact<"compact">; - defm INSR_ZR : sve_int_perm_insrs<"insr">; - defm INSR_ZV : sve_int_perm_insrv<"insr">; - def EXT_ZZI : sve_int_perm_extract_i<"ext">; + defm SPLICE_ZPZ : sve_int_perm_splice<"splice", int_aarch64_sve_splice>; + defm COMPACT_ZPZ : sve_int_perm_compact<"compact", int_aarch64_sve_compact>; + defm INSR_ZR : sve_int_perm_insrs<"insr", AArch64insr>; + defm INSR_ZV : sve_int_perm_insrv<"insr", AArch64insr>; + def EXT_ZZI : sve_int_perm_extract_i<"ext", AArch64ext>; - defm RBIT_ZPmZ : sve_int_perm_rev_rbit<"rbit">; - defm REVB_ZPmZ : sve_int_perm_rev_revb<"revb">; - defm REVH_ZPmZ : sve_int_perm_rev_revh<"revh">; - defm REVW_ZPmZ : sve_int_perm_rev_revw<"revw">; + defm RBIT_ZPmZ : sve_int_perm_rev_rbit<"rbit", int_aarch64_sve_rbit>; + defm REVB_ZPmZ : sve_int_perm_rev_revb<"revb", int_aarch64_sve_revb, bswap>; + defm REVH_ZPmZ : sve_int_perm_rev_revh<"revh", int_aarch64_sve_revh>; + defm REVW_ZPmZ : sve_int_perm_rev_revw<"revw", int_aarch64_sve_revw>; - defm REV_PP : sve_int_perm_reverse_p<"rev">; - defm REV_ZZ : sve_int_perm_reverse_z<"rev">; + defm REV_PP : sve_int_perm_reverse_p<"rev", AArch64rev>; + defm REV_ZZ : sve_int_perm_reverse_z<"rev", AArch64rev>; - defm SUNPKLO_ZZ : sve_int_perm_unpk<0b00, "sunpklo">; - defm SUNPKHI_ZZ : sve_int_perm_unpk<0b01, "sunpkhi">; - defm UUNPKLO_ZZ : sve_int_perm_unpk<0b10, "uunpklo">; - defm UUNPKHI_ZZ : sve_int_perm_unpk<0b11, "uunpkhi">; + defm SUNPKLO_ZZ : sve_int_perm_unpk<0b00, "sunpklo", AArch64sunpklo>; + defm SUNPKHI_ZZ : sve_int_perm_unpk<0b01, "sunpkhi", AArch64sunpkhi>; + defm UUNPKLO_ZZ : sve_int_perm_unpk<0b10, "uunpklo", AArch64uunpklo>; + defm UUNPKHI_ZZ : sve_int_perm_unpk<0b11, "uunpkhi", AArch64uunpkhi>; - def PUNPKLO_PP : sve_int_perm_punpk<0b0, "punpklo">; - def PUNPKHI_PP : sve_int_perm_punpk<0b1, "punpkhi">; + defm PUNPKLO_PP : sve_int_perm_punpk<0b0, "punpklo", int_aarch64_sve_punpklo>; + defm PUNPKHI_PP : sve_int_perm_punpk<0b1, "punpkhi", int_aarch64_sve_punpkhi>; defm MOVPRFX_ZPzZ : sve_int_movprfx_pred_zero<0b000, "movprfx">; defm MOVPRFX_ZPmZ : sve_int_movprfx_pred_merge<0b001, "movprfx">; @@ -227,53 +487,53 @@ def FEXPA_ZZ_S : sve_int_bin_cons_misc_0_c<0b10000000, "fexpa", ZPR32>; def FEXPA_ZZ_D : sve_int_bin_cons_misc_0_c<0b11000000, "fexpa", ZPR64>; - def BRKPA_PPzPP : sve_int_brkp<0b00, "brkpa">; - def BRKPAS_PPzPP : sve_int_brkp<0b10, "brkpas">; - def BRKPB_PPzPP : sve_int_brkp<0b01, "brkpb">; - def BRKPBS_PPzPP : sve_int_brkp<0b11, "brkpbs">; + defm BRKPA_PPzPP : sve_int_brkp<0b00, "brkpa", int_aarch64_sve_brkpa_z>; + defm BRKPAS_PPzPP : sve_int_brkp<0b10, "brkpas", null_frag>; + defm BRKPB_PPzPP : sve_int_brkp<0b01, "brkpb", int_aarch64_sve_brkpb_z>; + defm BRKPBS_PPzPP : sve_int_brkp<0b11, "brkpbs", null_frag>; - def BRKN_PPzP : sve_int_brkn<0b0, "brkn">; - def BRKNS_PPzP : sve_int_brkn<0b1, "brkns">; + defm BRKN_PPzP : sve_int_brkn<0b0, "brkn", int_aarch64_sve_brkn_z>; + defm BRKNS_PPzP : sve_int_brkn<0b1, "brkns", null_frag>; - defm BRKA_PPzP : sve_int_break_z<0b000, "brka">; - defm BRKA_PPmP : sve_int_break_m<0b001, "brka">; - defm BRKAS_PPzP : sve_int_break_z<0b010, "brkas">; - defm BRKB_PPzP : sve_int_break_z<0b100, "brkb">; - defm BRKB_PPmP : sve_int_break_m<0b101, "brkb">; - defm BRKBS_PPzP : sve_int_break_z<0b110, "brkbs">; + defm BRKA_PPzP : sve_int_break_z<0b000, "brka", int_aarch64_sve_brka_z>; + defm BRKA_PPmP : sve_int_break_m<0b001, "brka", int_aarch64_sve_brka>; + defm BRKAS_PPzP : sve_int_break_z<0b010, "brkas", null_frag>; + defm BRKB_PPzP : sve_int_break_z<0b100, "brkb", int_aarch64_sve_brkb_z>; + defm BRKB_PPmP : sve_int_break_m<0b101, "brkb", int_aarch64_sve_brkb>; + defm BRKBS_PPzP : sve_int_break_z<0b110, "brkbs", null_frag>; def PTEST_PP : sve_int_ptest<0b010000, "ptest">; def PFALSE : sve_int_pfalse<0b000000, "pfalse">; - defm PFIRST : sve_int_pfirst<0b00000, "pfirst">; - defm PNEXT : sve_int_pnext<0b00110, "pnext">; + defm PFIRST : sve_int_pfirst<0b00000, "pfirst", int_aarch64_sve_pfirst>; + defm PNEXT : sve_int_pnext<0b00110, "pnext", int_aarch64_sve_pnext>; - def AND_PPzPP : sve_int_pred_log<0b0000, "and">; - def BIC_PPzPP : sve_int_pred_log<0b0001, "bic">; - def EOR_PPzPP : sve_int_pred_log<0b0010, "eor">; - def SEL_PPPP : sve_int_pred_log<0b0011, "sel">; - def ANDS_PPzPP : sve_int_pred_log<0b0100, "ands">; - def BICS_PPzPP : sve_int_pred_log<0b0101, "bics">; - def EORS_PPzPP : sve_int_pred_log<0b0110, "eors">; - def ORR_PPzPP : sve_int_pred_log<0b1000, "orr">; - def ORN_PPzPP : sve_int_pred_log<0b1001, "orn">; - def NOR_PPzPP : sve_int_pred_log<0b1010, "nor">; - def NAND_PPzPP : sve_int_pred_log<0b1011, "nand">; - def ORRS_PPzPP : sve_int_pred_log<0b1100, "orrs">; - def ORNS_PPzPP : sve_int_pred_log<0b1101, "orns">; - def NORS_PPzPP : sve_int_pred_log<0b1110, "nors">; - def NANDS_PPzPP : sve_int_pred_log<0b1111, "nands">; + defm AND_PPzPP : sve_int_pred_log<0b0000, "and", int_aarch64_sve_and_z>; + defm BIC_PPzPP : sve_int_pred_log<0b0001, "bic", int_aarch64_sve_bic_z>; + defm EOR_PPzPP : sve_int_pred_log<0b0010, "eor", int_aarch64_sve_eor_z>; + defm SEL_PPPP : sve_int_pred_log<0b0011, "sel", vselect>; + defm ANDS_PPzPP : sve_int_pred_log<0b0100, "ands">; + defm BICS_PPzPP : sve_int_pred_log<0b0101, "bics">; + defm EORS_PPzPP : sve_int_pred_log<0b0110, "eors">; + defm ORR_PPzPP : sve_int_pred_log<0b1000, "orr", int_aarch64_sve_orr_z>; + defm ORN_PPzPP : sve_int_pred_log<0b1001, "orn", int_aarch64_sve_orn_z>; + defm NOR_PPzPP : sve_int_pred_log<0b1010, "nor", int_aarch64_sve_nor_z>; + defm NAND_PPzPP : sve_int_pred_log<0b1011, "nand", int_aarch64_sve_nand_z>; + defm ORRS_PPzPP : sve_int_pred_log<0b1100, "orrs">; + defm ORNS_PPzPP : sve_int_pred_log<0b1101, "orns">; + defm NORS_PPzPP : sve_int_pred_log<0b1110, "nors">; + defm NANDS_PPzPP : sve_int_pred_log<0b1111, "nands">; - defm CLASTA_RPZ : sve_int_perm_clast_rz<0, "clasta">; - defm CLASTB_RPZ : sve_int_perm_clast_rz<1, "clastb">; - defm CLASTA_VPZ : sve_int_perm_clast_vz<0, "clasta">; - defm CLASTB_VPZ : sve_int_perm_clast_vz<1, "clastb">; - defm CLASTA_ZPZ : sve_int_perm_clast_zz<0, "clasta">; - defm CLASTB_ZPZ : sve_int_perm_clast_zz<1, "clastb">; + defm CLASTA_RPZ : sve_int_perm_clast_rz<0, "clasta", AArch64clasta_n>; + defm CLASTB_RPZ : sve_int_perm_clast_rz<1, "clastb", AArch64clastb_n>; + defm CLASTA_VPZ : sve_int_perm_clast_vz<0, "clasta", AArch64clasta_n>; + defm CLASTB_VPZ : sve_int_perm_clast_vz<1, "clastb", AArch64clastb_n>; + defm CLASTA_ZPZ : sve_int_perm_clast_zz<0, "clasta", int_aarch64_sve_clasta>; + defm CLASTB_ZPZ : sve_int_perm_clast_zz<1, "clastb", int_aarch64_sve_clastb>; - defm LASTA_RPZ : sve_int_perm_last_r<0, "lasta">; - defm LASTB_RPZ : sve_int_perm_last_r<1, "lastb">; - defm LASTA_VPZ : sve_int_perm_last_v<0, "lasta">; - defm LASTB_VPZ : sve_int_perm_last_v<1, "lastb">; + defm LASTA_RPZ : sve_int_perm_last_r<0, "lasta", AArch64lasta>; + defm LASTB_RPZ : sve_int_perm_last_r<1, "lastb", AArch64lastb>; + defm LASTA_VPZ : sve_int_perm_last_v<0, "lasta", AArch64lasta>; + defm LASTB_VPZ : sve_int_perm_last_v<1, "lastb", AArch64lastb>; // continuous load with reg+immediate defm LD1B_IMM : sve_mem_cld_si<0b0000, "ld1b", Z_b, ZPR8>; @@ -405,115 +665,115 @@ // Gathers using unscaled 32-bit offsets, e.g. // ld1h z0.s, p0/z, [x0, z0.s, uxtw] - defm GLD1SB_S : sve_mem_32b_gld_vs_32_unscaled<0b0000, "ld1sb", ZPR32ExtSXTW8Only, ZPR32ExtUXTW8Only>; - defm GLDFF1SB_S : sve_mem_32b_gld_vs_32_unscaled<0b0001, "ldff1sb", ZPR32ExtSXTW8Only, ZPR32ExtUXTW8Only>; - defm GLD1B_S : sve_mem_32b_gld_vs_32_unscaled<0b0010, "ld1b", ZPR32ExtSXTW8Only, ZPR32ExtUXTW8Only>; - defm GLDFF1B_S : sve_mem_32b_gld_vs_32_unscaled<0b0011, "ldff1b", ZPR32ExtSXTW8Only, ZPR32ExtUXTW8Only>; - defm GLD1SH_S : sve_mem_32b_gld_vs_32_unscaled<0b0100, "ld1sh", ZPR32ExtSXTW8, ZPR32ExtUXTW8>; - defm GLDFF1SH_S : sve_mem_32b_gld_vs_32_unscaled<0b0101, "ldff1sh", ZPR32ExtSXTW8, ZPR32ExtUXTW8>; - defm GLD1H_S : sve_mem_32b_gld_vs_32_unscaled<0b0110, "ld1h", ZPR32ExtSXTW8, ZPR32ExtUXTW8>; - defm GLDFF1H_S : sve_mem_32b_gld_vs_32_unscaled<0b0111, "ldff1h", ZPR32ExtSXTW8, ZPR32ExtUXTW8>; - defm GLD1W : sve_mem_32b_gld_vs_32_unscaled<0b1010, "ld1w", ZPR32ExtSXTW8, ZPR32ExtUXTW8>; - defm GLDFF1W : sve_mem_32b_gld_vs_32_unscaled<0b1011, "ldff1w", ZPR32ExtSXTW8, ZPR32ExtUXTW8>; + defm GLD1SB_S : sve_mem_32b_gld_vs_32_unscaled<0b0000, "ld1sb", null_frag, null_frag, ZPR32ExtSXTW8Only, ZPR32ExtUXTW8Only, nxv4i8>; + defm GLDFF1SB_S : sve_mem_32b_gld_vs_32_unscaled<0b0001, "ldff1sb", AArch64ldff1s_gather_sxtw, AArch64ldff1s_gather_uxtw, ZPR32ExtSXTW8Only, ZPR32ExtUXTW8Only, nxv4i8>; + defm GLD1B_S : sve_mem_32b_gld_vs_32_unscaled<0b0010, "ld1b", null_frag, null_frag, ZPR32ExtSXTW8Only, ZPR32ExtUXTW8Only, nxv4i8>; + defm GLDFF1B_S : sve_mem_32b_gld_vs_32_unscaled<0b0011, "ldff1b", AArch64ldff1_gather_sxtw, AArch64ldff1_gather_uxtw, ZPR32ExtSXTW8Only, ZPR32ExtUXTW8Only, nxv4i8>; + defm GLD1SH_S : sve_mem_32b_gld_vs_32_unscaled<0b0100, "ld1sh", null_frag, null_frag, ZPR32ExtSXTW8, ZPR32ExtUXTW8, nxv4i16>; + defm GLDFF1SH_S : sve_mem_32b_gld_vs_32_unscaled<0b0101, "ldff1sh", AArch64ldff1s_gather_sxtw, AArch64ldff1s_gather_uxtw, ZPR32ExtSXTW8, ZPR32ExtUXTW8, nxv4i16>; + defm GLD1H_S : sve_mem_32b_gld_vs_32_unscaled<0b0110, "ld1h", null_frag, null_frag, ZPR32ExtSXTW8, ZPR32ExtUXTW8, nxv4i16>; + defm GLDFF1H_S : sve_mem_32b_gld_vs_32_unscaled<0b0111, "ldff1h", AArch64ldff1_gather_sxtw, AArch64ldff1_gather_uxtw, ZPR32ExtSXTW8, ZPR32ExtUXTW8, nxv4i16>; + defm GLD1W : sve_mem_32b_gld_vs_32_unscaled<0b1010, "ld1w", null_frag, null_frag, ZPR32ExtSXTW8, ZPR32ExtUXTW8, nxv4i32>; + defm GLDFF1W : sve_mem_32b_gld_vs_32_unscaled<0b1011, "ldff1w", AArch64ldff1_gather_sxtw, AArch64ldff1_gather_uxtw, ZPR32ExtSXTW8, ZPR32ExtUXTW8, nxv4i32>; // Gathers using scaled 32-bit offsets, e.g. // ld1h z0.s, p0/z, [x0, z0.s, uxtw #1] - defm GLD1SH_S : sve_mem_32b_gld_sv_32_scaled<0b0100, "ld1sh", ZPR32ExtSXTW16, ZPR32ExtUXTW16>; - defm GLDFF1SH_S : sve_mem_32b_gld_sv_32_scaled<0b0101, "ldff1sh", ZPR32ExtSXTW16, ZPR32ExtUXTW16>; - defm GLD1H_S : sve_mem_32b_gld_sv_32_scaled<0b0110, "ld1h", ZPR32ExtSXTW16, ZPR32ExtUXTW16>; - defm GLDFF1H_S : sve_mem_32b_gld_sv_32_scaled<0b0111, "ldff1h", ZPR32ExtSXTW16, ZPR32ExtUXTW16>; - defm GLD1W : sve_mem_32b_gld_sv_32_scaled<0b1010, "ld1w", ZPR32ExtSXTW32, ZPR32ExtUXTW32>; - defm GLDFF1W : sve_mem_32b_gld_sv_32_scaled<0b1011, "ldff1w", ZPR32ExtSXTW32, ZPR32ExtUXTW32>; + defm GLD1SH_S : sve_mem_32b_gld_sv_32_scaled<0b0100, "ld1sh", null_frag, null_frag, ZPR32ExtSXTW16, ZPR32ExtUXTW16, nxv4i16>; + defm GLDFF1SH_S : sve_mem_32b_gld_sv_32_scaled<0b0101, "ldff1sh", AArch64ldff1s_gather_sxtw_scaled, AArch64ldff1s_gather_uxtw_scaled, ZPR32ExtSXTW16, ZPR32ExtUXTW16, nxv4i16>; + defm GLD1H_S : sve_mem_32b_gld_sv_32_scaled<0b0110, "ld1h", null_frag, null_frag, ZPR32ExtSXTW16, ZPR32ExtUXTW16, nxv4i16>; + defm GLDFF1H_S : sve_mem_32b_gld_sv_32_scaled<0b0111, "ldff1h", AArch64ldff1_gather_sxtw_scaled, AArch64ldff1_gather_uxtw_scaled, ZPR32ExtSXTW16, ZPR32ExtUXTW16, nxv4i16>; + defm GLD1W : sve_mem_32b_gld_sv_32_scaled<0b1010, "ld1w", null_frag, null_frag, ZPR32ExtSXTW32, ZPR32ExtUXTW32, nxv4i32>; + defm GLDFF1W : sve_mem_32b_gld_sv_32_scaled<0b1011, "ldff1w", AArch64ldff1_gather_sxtw_scaled, AArch64ldff1_gather_uxtw_scaled, ZPR32ExtSXTW32, ZPR32ExtUXTW32, nxv4i32>; // Gathers using scaled 32-bit pointers with offset, e.g. // ld1h z0.s, p0/z, [z0.s, #16] - defm GLD1SB_S : sve_mem_32b_gld_vi_32_ptrs<0b0000, "ld1sb", imm0_31>; - defm GLDFF1SB_S : sve_mem_32b_gld_vi_32_ptrs<0b0001, "ldff1sb", imm0_31>; - defm GLD1B_S : sve_mem_32b_gld_vi_32_ptrs<0b0010, "ld1b", imm0_31>; - defm GLDFF1B_S : sve_mem_32b_gld_vi_32_ptrs<0b0011, "ldff1b", imm0_31>; - defm GLD1SH_S : sve_mem_32b_gld_vi_32_ptrs<0b0100, "ld1sh", uimm5s2>; - defm GLDFF1SH_S : sve_mem_32b_gld_vi_32_ptrs<0b0101, "ldff1sh", uimm5s2>; - defm GLD1H_S : sve_mem_32b_gld_vi_32_ptrs<0b0110, "ld1h", uimm5s2>; - defm GLDFF1H_S : sve_mem_32b_gld_vi_32_ptrs<0b0111, "ldff1h", uimm5s2>; - defm GLD1W : sve_mem_32b_gld_vi_32_ptrs<0b1010, "ld1w", uimm5s4>; - defm GLDFF1W : sve_mem_32b_gld_vi_32_ptrs<0b1011, "ldff1w", uimm5s4>; + defm GLD1SB_S : sve_mem_32b_gld_vi_32_ptrs<0b0000, "ld1sb", imm0_31, null_frag, nxv4i8>; + defm GLDFF1SB_S : sve_mem_32b_gld_vi_32_ptrs<0b0001, "ldff1sb", imm0_31, AArch64ldff1s_gather_uxtw, nxv4i8>; + defm GLD1B_S : sve_mem_32b_gld_vi_32_ptrs<0b0010, "ld1b", imm0_31, null_frag, nxv4i8>; + defm GLDFF1B_S : sve_mem_32b_gld_vi_32_ptrs<0b0011, "ldff1b", imm0_31, AArch64ldff1_gather_uxtw, nxv4i8>; + defm GLD1SH_S : sve_mem_32b_gld_vi_32_ptrs<0b0100, "ld1sh", uimm5s2, null_frag, nxv4i16>; + defm GLDFF1SH_S : sve_mem_32b_gld_vi_32_ptrs<0b0101, "ldff1sh", uimm5s2, AArch64ldff1s_gather_uxtw, nxv4i16>; + defm GLD1H_S : sve_mem_32b_gld_vi_32_ptrs<0b0110, "ld1h", uimm5s2, null_frag, nxv4i16>; + defm GLDFF1H_S : sve_mem_32b_gld_vi_32_ptrs<0b0111, "ldff1h", uimm5s2, AArch64ldff1_gather_uxtw, nxv4i16>; + defm GLD1W : sve_mem_32b_gld_vi_32_ptrs<0b1010, "ld1w", uimm5s4, null_frag, nxv4i32>; + defm GLDFF1W : sve_mem_32b_gld_vi_32_ptrs<0b1011, "ldff1w", uimm5s4, AArch64ldff1_gather_uxtw, nxv4i32>; // Gathers using scaled 64-bit pointers with offset, e.g. // ld1h z0.d, p0/z, [z0.d, #16] - defm GLD1SB_D : sve_mem_64b_gld_vi_64_ptrs<0b0000, "ld1sb", imm0_31>; - defm GLDFF1SB_D : sve_mem_64b_gld_vi_64_ptrs<0b0001, "ldff1sb", imm0_31>; - defm GLD1B_D : sve_mem_64b_gld_vi_64_ptrs<0b0010, "ld1b", imm0_31>; - defm GLDFF1B_D : sve_mem_64b_gld_vi_64_ptrs<0b0011, "ldff1b", imm0_31>; - defm GLD1SH_D : sve_mem_64b_gld_vi_64_ptrs<0b0100, "ld1sh", uimm5s2>; - defm GLDFF1SH_D : sve_mem_64b_gld_vi_64_ptrs<0b0101, "ldff1sh", uimm5s2>; - defm GLD1H_D : sve_mem_64b_gld_vi_64_ptrs<0b0110, "ld1h", uimm5s2>; - defm GLDFF1H_D : sve_mem_64b_gld_vi_64_ptrs<0b0111, "ldff1h", uimm5s2>; - defm GLD1SW_D : sve_mem_64b_gld_vi_64_ptrs<0b1000, "ld1sw", uimm5s4>; - defm GLDFF1SW_D : sve_mem_64b_gld_vi_64_ptrs<0b1001, "ldff1sw", uimm5s4>; - defm GLD1W_D : sve_mem_64b_gld_vi_64_ptrs<0b1010, "ld1w", uimm5s4>; - defm GLDFF1W_D : sve_mem_64b_gld_vi_64_ptrs<0b1011, "ldff1w", uimm5s4>; - defm GLD1D : sve_mem_64b_gld_vi_64_ptrs<0b1110, "ld1d", uimm5s8>; - defm GLDFF1D : sve_mem_64b_gld_vi_64_ptrs<0b1111, "ldff1d", uimm5s8>; + defm GLD1SB_D : sve_mem_64b_gld_vi_64_ptrs<0b0000, "ld1sb", imm0_31, null_frag, nxv2i8>; + defm GLDFF1SB_D : sve_mem_64b_gld_vi_64_ptrs<0b0001, "ldff1sb", imm0_31, AArch64ldff1s_gather, nxv2i8>; + defm GLD1B_D : sve_mem_64b_gld_vi_64_ptrs<0b0010, "ld1b", imm0_31, null_frag, nxv2i8>; + defm GLDFF1B_D : sve_mem_64b_gld_vi_64_ptrs<0b0011, "ldff1b", imm0_31, AArch64ldff1_gather, nxv2i8>; + defm GLD1SH_D : sve_mem_64b_gld_vi_64_ptrs<0b0100, "ld1sh", uimm5s2, null_frag, nxv2i16>; + defm GLDFF1SH_D : sve_mem_64b_gld_vi_64_ptrs<0b0101, "ldff1sh", uimm5s2, AArch64ldff1s_gather, nxv2i16>; + defm GLD1H_D : sve_mem_64b_gld_vi_64_ptrs<0b0110, "ld1h", uimm5s2, null_frag, nxv2i16>; + defm GLDFF1H_D : sve_mem_64b_gld_vi_64_ptrs<0b0111, "ldff1h", uimm5s2, AArch64ldff1_gather, nxv2i16>; + defm GLD1SW_D : sve_mem_64b_gld_vi_64_ptrs<0b1000, "ld1sw", uimm5s4, null_frag, nxv2i32>; + defm GLDFF1SW_D : sve_mem_64b_gld_vi_64_ptrs<0b1001, "ldff1sw", uimm5s4, AArch64ldff1s_gather, nxv2i32>; + defm GLD1W_D : sve_mem_64b_gld_vi_64_ptrs<0b1010, "ld1w", uimm5s4, null_frag, nxv2i32>; + defm GLDFF1W_D : sve_mem_64b_gld_vi_64_ptrs<0b1011, "ldff1w", uimm5s4, AArch64ldff1_gather, nxv2i32>; + defm GLD1D : sve_mem_64b_gld_vi_64_ptrs<0b1110, "ld1d", uimm5s8, null_frag, nxv2i64>; + defm GLDFF1D : sve_mem_64b_gld_vi_64_ptrs<0b1111, "ldff1d", uimm5s8, AArch64ldff1_gather, nxv2i64>; // Gathers using unscaled 64-bit offsets, e.g. // ld1h z0.d, p0/z, [x0, z0.d] - defm GLD1SB_D : sve_mem_64b_gld_vs2_64_unscaled<0b0000, "ld1sb">; - defm GLDFF1SB_D : sve_mem_64b_gld_vs2_64_unscaled<0b0001, "ldff1sb">; - defm GLD1B_D : sve_mem_64b_gld_vs2_64_unscaled<0b0010, "ld1b">; - defm GLDFF1B_D : sve_mem_64b_gld_vs2_64_unscaled<0b0011, "ldff1b">; - defm GLD1SH_D : sve_mem_64b_gld_vs2_64_unscaled<0b0100, "ld1sh">; - defm GLDFF1SH_D : sve_mem_64b_gld_vs2_64_unscaled<0b0101, "ldff1sh">; - defm GLD1H_D : sve_mem_64b_gld_vs2_64_unscaled<0b0110, "ld1h">; - defm GLDFF1H_D : sve_mem_64b_gld_vs2_64_unscaled<0b0111, "ldff1h">; - defm GLD1SW_D : sve_mem_64b_gld_vs2_64_unscaled<0b1000, "ld1sw">; - defm GLDFF1SW_D : sve_mem_64b_gld_vs2_64_unscaled<0b1001, "ldff1sw">; - defm GLD1W_D : sve_mem_64b_gld_vs2_64_unscaled<0b1010, "ld1w">; - defm GLDFF1W_D : sve_mem_64b_gld_vs2_64_unscaled<0b1011, "ldff1w">; - defm GLD1D : sve_mem_64b_gld_vs2_64_unscaled<0b1110, "ld1d">; - defm GLDFF1D : sve_mem_64b_gld_vs2_64_unscaled<0b1111, "ldff1d">; + defm GLD1SB_D : sve_mem_64b_gld_vs2_64_unscaled<0b0000, "ld1sb", null_frag, nxv2i8>; + defm GLDFF1SB_D : sve_mem_64b_gld_vs2_64_unscaled<0b0001, "ldff1sb", AArch64ldff1s_gather, nxv2i8>; + defm GLD1B_D : sve_mem_64b_gld_vs2_64_unscaled<0b0010, "ld1b", null_frag, nxv2i8>; + defm GLDFF1B_D : sve_mem_64b_gld_vs2_64_unscaled<0b0011, "ldff1b", AArch64ldff1_gather, nxv2i8>; + defm GLD1SH_D : sve_mem_64b_gld_vs2_64_unscaled<0b0100, "ld1sh", null_frag, nxv2i16>; + defm GLDFF1SH_D : sve_mem_64b_gld_vs2_64_unscaled<0b0101, "ldff1sh", AArch64ldff1s_gather, nxv2i16>; + defm GLD1H_D : sve_mem_64b_gld_vs2_64_unscaled<0b0110, "ld1h", null_frag, nxv2i16>; + defm GLDFF1H_D : sve_mem_64b_gld_vs2_64_unscaled<0b0111, "ldff1h", AArch64ldff1_gather, nxv2i16>; + defm GLD1SW_D : sve_mem_64b_gld_vs2_64_unscaled<0b1000, "ld1sw", null_frag, nxv2i32>; + defm GLDFF1SW_D : sve_mem_64b_gld_vs2_64_unscaled<0b1001, "ldff1sw", AArch64ldff1s_gather, nxv2i32>; + defm GLD1W_D : sve_mem_64b_gld_vs2_64_unscaled<0b1010, "ld1w", null_frag, nxv2i32>; + defm GLDFF1W_D : sve_mem_64b_gld_vs2_64_unscaled<0b1011, "ldff1w", AArch64ldff1_gather, nxv2i32>; + defm GLD1D : sve_mem_64b_gld_vs2_64_unscaled<0b1110, "ld1d", null_frag, nxv2i64>; + defm GLDFF1D : sve_mem_64b_gld_vs2_64_unscaled<0b1111, "ldff1d", AArch64ldff1_gather, nxv2i64>; // Gathers using scaled 64-bit offsets, e.g. // ld1h z0.d, p0/z, [x0, z0.d, lsl #1] - defm GLD1SH_D : sve_mem_64b_gld_sv2_64_scaled<0b0100, "ld1sh", ZPR64ExtLSL16>; - defm GLDFF1SH_D : sve_mem_64b_gld_sv2_64_scaled<0b0101, "ldff1sh", ZPR64ExtLSL16>; - defm GLD1H_D : sve_mem_64b_gld_sv2_64_scaled<0b0110, "ld1h", ZPR64ExtLSL16>; - defm GLDFF1H_D : sve_mem_64b_gld_sv2_64_scaled<0b0111, "ldff1h", ZPR64ExtLSL16>; - defm GLD1SW_D : sve_mem_64b_gld_sv2_64_scaled<0b1000, "ld1sw", ZPR64ExtLSL32>; - defm GLDFF1SW_D : sve_mem_64b_gld_sv2_64_scaled<0b1001, "ldff1sw", ZPR64ExtLSL32>; - defm GLD1W_D : sve_mem_64b_gld_sv2_64_scaled<0b1010, "ld1w", ZPR64ExtLSL32>; - defm GLDFF1W_D : sve_mem_64b_gld_sv2_64_scaled<0b1011, "ldff1w", ZPR64ExtLSL32>; - defm GLD1D : sve_mem_64b_gld_sv2_64_scaled<0b1110, "ld1d", ZPR64ExtLSL64>; - defm GLDFF1D : sve_mem_64b_gld_sv2_64_scaled<0b1111, "ldff1d", ZPR64ExtLSL64>; + defm GLD1SH_D : sve_mem_64b_gld_sv2_64_scaled<0b0100, "ld1sh", null_frag, ZPR64ExtLSL16, nxv2i16>; + defm GLDFF1SH_D : sve_mem_64b_gld_sv2_64_scaled<0b0101, "ldff1sh", AArch64ldff1s_gather_scaled, ZPR64ExtLSL16, nxv2i16>; + defm GLD1H_D : sve_mem_64b_gld_sv2_64_scaled<0b0110, "ld1h", null_frag, ZPR64ExtLSL16, nxv2i16>; + defm GLDFF1H_D : sve_mem_64b_gld_sv2_64_scaled<0b0111, "ldff1h", AArch64ldff1_gather_scaled, ZPR64ExtLSL16, nxv2i16>; + defm GLD1SW_D : sve_mem_64b_gld_sv2_64_scaled<0b1000, "ld1sw", null_frag, ZPR64ExtLSL32, nxv2i32>; + defm GLDFF1SW_D : sve_mem_64b_gld_sv2_64_scaled<0b1001, "ldff1sw", AArch64ldff1s_gather_scaled, ZPR64ExtLSL32, nxv2i32>; + defm GLD1W_D : sve_mem_64b_gld_sv2_64_scaled<0b1010, "ld1w", null_frag, ZPR64ExtLSL32, nxv2i32>; + defm GLDFF1W_D : sve_mem_64b_gld_sv2_64_scaled<0b1011, "ldff1w", AArch64ldff1_gather_scaled, ZPR64ExtLSL32, nxv2i32>; + defm GLD1D : sve_mem_64b_gld_sv2_64_scaled<0b1110, "ld1d", null_frag, ZPR64ExtLSL64, nxv2i64>; + defm GLDFF1D : sve_mem_64b_gld_sv2_64_scaled<0b1111, "ldff1d", AArch64ldff1_gather_scaled, ZPR64ExtLSL64, nxv2i64>; // Gathers using unscaled 32-bit offsets unpacked in 64-bits elements, e.g. // ld1h z0.d, p0/z, [x0, z0.d, uxtw] - defm GLD1SB_D : sve_mem_64b_gld_vs_32_unscaled<0b0000, "ld1sb", ZPR64ExtSXTW8Only, ZPR64ExtUXTW8Only>; - defm GLDFF1SB_D : sve_mem_64b_gld_vs_32_unscaled<0b0001, "ldff1sb", ZPR64ExtSXTW8Only, ZPR64ExtUXTW8Only>; - defm GLD1B_D : sve_mem_64b_gld_vs_32_unscaled<0b0010, "ld1b", ZPR64ExtSXTW8Only, ZPR64ExtUXTW8Only>; - defm GLDFF1B_D : sve_mem_64b_gld_vs_32_unscaled<0b0011, "ldff1b", ZPR64ExtSXTW8Only, ZPR64ExtUXTW8Only>; - defm GLD1SH_D : sve_mem_64b_gld_vs_32_unscaled<0b0100, "ld1sh", ZPR64ExtSXTW8, ZPR64ExtUXTW8>; - defm GLDFF1SH_D : sve_mem_64b_gld_vs_32_unscaled<0b0101, "ldff1sh", ZPR64ExtSXTW8, ZPR64ExtUXTW8>; - defm GLD1H_D : sve_mem_64b_gld_vs_32_unscaled<0b0110, "ld1h", ZPR64ExtSXTW8, ZPR64ExtUXTW8>; - defm GLDFF1H_D : sve_mem_64b_gld_vs_32_unscaled<0b0111, "ldff1h", ZPR64ExtSXTW8, ZPR64ExtUXTW8>; - defm GLD1SW_D : sve_mem_64b_gld_vs_32_unscaled<0b1000, "ld1sw", ZPR64ExtSXTW8, ZPR64ExtUXTW8>; - defm GLDFF1SW_D : sve_mem_64b_gld_vs_32_unscaled<0b1001, "ldff1sw", ZPR64ExtSXTW8, ZPR64ExtUXTW8>; - defm GLD1W_D : sve_mem_64b_gld_vs_32_unscaled<0b1010, "ld1w", ZPR64ExtSXTW8, ZPR64ExtUXTW8>; - defm GLDFF1W_D : sve_mem_64b_gld_vs_32_unscaled<0b1011, "ldff1w", ZPR64ExtSXTW8, ZPR64ExtUXTW8>; - defm GLD1D : sve_mem_64b_gld_vs_32_unscaled<0b1110, "ld1d", ZPR64ExtSXTW8, ZPR64ExtUXTW8>; - defm GLDFF1D : sve_mem_64b_gld_vs_32_unscaled<0b1111, "ldff1d", ZPR64ExtSXTW8, ZPR64ExtUXTW8>; + defm GLD1SB_D : sve_mem_64b_gld_vs_32_unscaled<0b0000, "ld1sb", null_frag, ZPR64ExtSXTW8Only, ZPR64ExtUXTW8Only, nxv2i8>; + defm GLDFF1SB_D : sve_mem_64b_gld_vs_32_unscaled<0b0001, "ldff1sb", AArch64ldff1s_gather, ZPR64ExtSXTW8Only, ZPR64ExtUXTW8Only, nxv2i8>; + defm GLD1B_D : sve_mem_64b_gld_vs_32_unscaled<0b0010, "ld1b", null_frag, ZPR64ExtSXTW8Only, ZPR64ExtUXTW8Only, nxv2i8>; + defm GLDFF1B_D : sve_mem_64b_gld_vs_32_unscaled<0b0011, "ldff1b", AArch64ldff1_gather, ZPR64ExtSXTW8Only, ZPR64ExtUXTW8Only, nxv2i8>; + defm GLD1SH_D : sve_mem_64b_gld_vs_32_unscaled<0b0100, "ld1sh", null_frag, ZPR64ExtSXTW8, ZPR64ExtUXTW8, nxv2i16>; + defm GLDFF1SH_D : sve_mem_64b_gld_vs_32_unscaled<0b0101, "ldff1sh", AArch64ldff1s_gather, ZPR64ExtSXTW8, ZPR64ExtUXTW8, nxv2i16>; + defm GLD1H_D : sve_mem_64b_gld_vs_32_unscaled<0b0110, "ld1h", null_frag, ZPR64ExtSXTW8, ZPR64ExtUXTW8, nxv2i16>; + defm GLDFF1H_D : sve_mem_64b_gld_vs_32_unscaled<0b0111, "ldff1h", AArch64ldff1_gather, ZPR64ExtSXTW8, ZPR64ExtUXTW8, nxv2i16>; + defm GLD1SW_D : sve_mem_64b_gld_vs_32_unscaled<0b1000, "ld1sw", null_frag, ZPR64ExtSXTW8, ZPR64ExtUXTW8, nxv2i32>; + defm GLDFF1SW_D : sve_mem_64b_gld_vs_32_unscaled<0b1001, "ldff1sw", AArch64ldff1s_gather, ZPR64ExtSXTW8, ZPR64ExtUXTW8, nxv2i32>; + defm GLD1W_D : sve_mem_64b_gld_vs_32_unscaled<0b1010, "ld1w", null_frag, ZPR64ExtSXTW8, ZPR64ExtUXTW8, nxv2i32>; + defm GLDFF1W_D : sve_mem_64b_gld_vs_32_unscaled<0b1011, "ldff1w", AArch64ldff1_gather, ZPR64ExtSXTW8, ZPR64ExtUXTW8, nxv2i32>; + defm GLD1D : sve_mem_64b_gld_vs_32_unscaled<0b1110, "ld1d", null_frag, ZPR64ExtSXTW8, ZPR64ExtUXTW8, nxv2i64>; + defm GLDFF1D : sve_mem_64b_gld_vs_32_unscaled<0b1111, "ldff1d", AArch64ldff1_gather, ZPR64ExtSXTW8, ZPR64ExtUXTW8, nxv2i64>; // Gathers using scaled 32-bit offsets unpacked in 64-bits elements, e.g. // ld1h z0.d, p0/z, [x0, z0.d, uxtw #1] - defm GLD1SH_D : sve_mem_64b_gld_sv_32_scaled<0b0100, "ld1sh", ZPR64ExtSXTW16, ZPR64ExtUXTW16>; - defm GLDFF1SH_D : sve_mem_64b_gld_sv_32_scaled<0b0101, "ldff1sh",ZPR64ExtSXTW16, ZPR64ExtUXTW16>; - defm GLD1H_D : sve_mem_64b_gld_sv_32_scaled<0b0110, "ld1h", ZPR64ExtSXTW16, ZPR64ExtUXTW16>; - defm GLDFF1H_D : sve_mem_64b_gld_sv_32_scaled<0b0111, "ldff1h", ZPR64ExtSXTW16, ZPR64ExtUXTW16>; - defm GLD1SW_D : sve_mem_64b_gld_sv_32_scaled<0b1000, "ld1sw", ZPR64ExtSXTW32, ZPR64ExtUXTW32>; - defm GLDFF1SW_D : sve_mem_64b_gld_sv_32_scaled<0b1001, "ldff1sw",ZPR64ExtSXTW32, ZPR64ExtUXTW32>; - defm GLD1W_D : sve_mem_64b_gld_sv_32_scaled<0b1010, "ld1w", ZPR64ExtSXTW32, ZPR64ExtUXTW32>; - defm GLDFF1W_D : sve_mem_64b_gld_sv_32_scaled<0b1011, "ldff1w", ZPR64ExtSXTW32, ZPR64ExtUXTW32>; - defm GLD1D : sve_mem_64b_gld_sv_32_scaled<0b1110, "ld1d", ZPR64ExtSXTW64, ZPR64ExtUXTW64>; - defm GLDFF1D : sve_mem_64b_gld_sv_32_scaled<0b1111, "ldff1d", ZPR64ExtSXTW64, ZPR64ExtUXTW64>; + defm GLD1SH_D : sve_mem_64b_gld_sv_32_scaled<0b0100, "ld1sh", null_frag, ZPR64ExtSXTW16, ZPR64ExtUXTW16, nxv2i16>; + defm GLDFF1SH_D : sve_mem_64b_gld_sv_32_scaled<0b0101, "ldff1sh",AArch64ldff1s_gather_scaled, ZPR64ExtSXTW16, ZPR64ExtUXTW16, nxv2i16>; + defm GLD1H_D : sve_mem_64b_gld_sv_32_scaled<0b0110, "ld1h", null_frag, ZPR64ExtSXTW16, ZPR64ExtUXTW16, nxv2i16>; + defm GLDFF1H_D : sve_mem_64b_gld_sv_32_scaled<0b0111, "ldff1h", AArch64ldff1_gather_scaled, ZPR64ExtSXTW16, ZPR64ExtUXTW16, nxv2i16>; + defm GLD1SW_D : sve_mem_64b_gld_sv_32_scaled<0b1000, "ld1sw", null_frag, ZPR64ExtSXTW32, ZPR64ExtUXTW32, nxv2i32>; + defm GLDFF1SW_D : sve_mem_64b_gld_sv_32_scaled<0b1001, "ldff1sw",AArch64ldff1s_gather_scaled, ZPR64ExtSXTW32, ZPR64ExtUXTW32, nxv2i32>; + defm GLD1W_D : sve_mem_64b_gld_sv_32_scaled<0b1010, "ld1w", null_frag, ZPR64ExtSXTW32, ZPR64ExtUXTW32, nxv2i32>; + defm GLDFF1W_D : sve_mem_64b_gld_sv_32_scaled<0b1011, "ldff1w", AArch64ldff1_gather_scaled, ZPR64ExtSXTW32, ZPR64ExtUXTW32, nxv2i32>; + defm GLD1D : sve_mem_64b_gld_sv_32_scaled<0b1110, "ld1d", null_frag, ZPR64ExtSXTW64, ZPR64ExtUXTW64, nxv2i64>; + defm GLDFF1D : sve_mem_64b_gld_sv_32_scaled<0b1111, "ldff1d", AArch64ldff1_gather_scaled, ZPR64ExtSXTW64, ZPR64ExtUXTW64, nxv2i64>; // Non-temporal contiguous loads (register + immediate) defm LDNT1B_ZRI : sve_mem_cldnt_si<0b00, "ldnt1b", Z_b, ZPR8>; @@ -657,112 +917,112 @@ // Gather prefetch using scaled 32-bit offsets, e.g. // prfh pldl1keep, p0, [x0, z0.s, uxtw #1] - defm PRFB_S : sve_mem_32b_prfm_sv_scaled<0b00, "prfb", ZPR32ExtSXTW8Only, ZPR32ExtUXTW8Only>; - defm PRFH_S : sve_mem_32b_prfm_sv_scaled<0b01, "prfh", ZPR32ExtSXTW16, ZPR32ExtUXTW16>; - defm PRFW_S : sve_mem_32b_prfm_sv_scaled<0b10, "prfw", ZPR32ExtSXTW32, ZPR32ExtUXTW32>; - defm PRFD_S : sve_mem_32b_prfm_sv_scaled<0b11, "prfd", ZPR32ExtSXTW64, ZPR32ExtUXTW64>; + defm PRFB_S : sve_mem_32b_prfm_sv_scaled<0b00, "prfb", AArch64prf_gather_s_sxtw_scaled, AArch64prf_gather_s_uxtw_scaled, ZPR32ExtSXTW8Only, ZPR32ExtUXTW8Only , i8>; + defm PRFH_S : sve_mem_32b_prfm_sv_scaled<0b01, "prfh", AArch64prf_gather_s_sxtw_scaled, AArch64prf_gather_s_uxtw_scaled, ZPR32ExtSXTW16, ZPR32ExtUXTW16, i16>; + defm PRFW_S : sve_mem_32b_prfm_sv_scaled<0b10, "prfw", AArch64prf_gather_s_sxtw_scaled, AArch64prf_gather_s_uxtw_scaled, ZPR32ExtSXTW32, ZPR32ExtUXTW32, i32>; + defm PRFD_S : sve_mem_32b_prfm_sv_scaled<0b11, "prfd", AArch64prf_gather_s_sxtw_scaled, AArch64prf_gather_s_uxtw_scaled, ZPR32ExtSXTW64, ZPR32ExtUXTW64, i64>; // Gather prefetch using unpacked, scaled 32-bit offsets, e.g. // prfh pldl1keep, p0, [x0, z0.d, uxtw #1] - defm PRFB_D : sve_mem_64b_prfm_sv_ext_scaled<0b00, "prfb", ZPR64ExtSXTW8Only, ZPR64ExtUXTW8Only>; - defm PRFH_D : sve_mem_64b_prfm_sv_ext_scaled<0b01, "prfh", ZPR64ExtSXTW16, ZPR64ExtUXTW16>; - defm PRFW_D : sve_mem_64b_prfm_sv_ext_scaled<0b10, "prfw", ZPR64ExtSXTW32, ZPR64ExtUXTW32>; - defm PRFD_D : sve_mem_64b_prfm_sv_ext_scaled<0b11, "prfd", ZPR64ExtSXTW64, ZPR64ExtUXTW64>; + defm PRFB_D : sve_mem_64b_prfm_ext_scaled<0b00, "prfb", AArch64prf_gather_d_sxtw_scaled, AArch64prf_gather_d_uxtw_scaled, ZPR64ExtSXTW8Only, ZPR64ExtUXTW8Only, i8>; + defm PRFH_D : sve_mem_64b_prfm_ext_scaled<0b01, "prfh", AArch64prf_gather_d_sxtw_scaled, AArch64prf_gather_d_uxtw_scaled, ZPR64ExtSXTW16, ZPR64ExtUXTW16, i16>; + defm PRFW_D : sve_mem_64b_prfm_ext_scaled<0b10, "prfw", AArch64prf_gather_d_sxtw_scaled, AArch64prf_gather_d_uxtw_scaled, ZPR64ExtSXTW32, ZPR64ExtUXTW32, i32>; + defm PRFD_D : sve_mem_64b_prfm_ext_scaled<0b11, "prfd", AArch64prf_gather_d_sxtw_scaled, AArch64prf_gather_d_uxtw_scaled, ZPR64ExtSXTW64, ZPR64ExtUXTW64, i64>; // Gather prefetch using scaled 64-bit offsets, e.g. // prfh pldl1keep, p0, [x0, z0.d, lsl #1] - defm PRFB_D_SCALED : sve_mem_64b_prfm_sv_lsl_scaled<0b00, "prfb", ZPR64ExtLSL8>; - defm PRFH_D_SCALED : sve_mem_64b_prfm_sv_lsl_scaled<0b01, "prfh", ZPR64ExtLSL16>; - defm PRFW_D_SCALED : sve_mem_64b_prfm_sv_lsl_scaled<0b10, "prfw", ZPR64ExtLSL32>; - defm PRFD_D_SCALED : sve_mem_64b_prfm_sv_lsl_scaled<0b11, "prfd", ZPR64ExtLSL64>; + defm PRFB_D_SCALED : sve_mem_64b_prfm_sv_lsl_scaled<0b00, "prfb", AArch64prf_gather_d_scaled, ZPR64ExtLSL8, i8>; + defm PRFH_D_SCALED : sve_mem_64b_prfm_sv_lsl_scaled<0b01, "prfh", AArch64prf_gather_d_scaled, ZPR64ExtLSL16, i16>; + defm PRFW_D_SCALED : sve_mem_64b_prfm_sv_lsl_scaled<0b10, "prfw", AArch64prf_gather_d_scaled, ZPR64ExtLSL32, i32>; + defm PRFD_D_SCALED : sve_mem_64b_prfm_sv_lsl_scaled<0b11, "prfd", AArch64prf_gather_d_scaled, ZPR64ExtLSL64, i64>; // Gather prefetch using 32/64-bit pointers with offset, e.g. // prfh pldl1keep, p0, [z0.s, #16] // prfh pldl1keep, p0, [z0.d, #16] - defm PRFB_S_PZI : sve_mem_32b_prfm_vi<0b00, "prfb", imm0_31>; - defm PRFH_S_PZI : sve_mem_32b_prfm_vi<0b01, "prfh", uimm5s2>; - defm PRFW_S_PZI : sve_mem_32b_prfm_vi<0b10, "prfw", uimm5s4>; - defm PRFD_S_PZI : sve_mem_32b_prfm_vi<0b11, "prfd", uimm5s8>; + defm PRFB_S_PZI : sve_mem_32b_prfm_vi<0b00, "prfb", imm0_31, AArch64prf_gather_s_imm, i8>; + defm PRFH_S_PZI : sve_mem_32b_prfm_vi<0b01, "prfh", uimm5s2, AArch64prf_gather_s_imm, i16>; + defm PRFW_S_PZI : sve_mem_32b_prfm_vi<0b10, "prfw", uimm5s4, AArch64prf_gather_s_imm, i32>; + defm PRFD_S_PZI : sve_mem_32b_prfm_vi<0b11, "prfd", uimm5s8, AArch64prf_gather_s_imm, i64>; - defm PRFB_D_PZI : sve_mem_64b_prfm_vi<0b00, "prfb", imm0_31>; - defm PRFH_D_PZI : sve_mem_64b_prfm_vi<0b01, "prfh", uimm5s2>; - defm PRFW_D_PZI : sve_mem_64b_prfm_vi<0b10, "prfw", uimm5s4>; - defm PRFD_D_PZI : sve_mem_64b_prfm_vi<0b11, "prfd", uimm5s8>; + defm PRFB_D_PZI : sve_mem_64b_prfm_vi<0b00, "prfb", imm0_31, AArch64prf_gather_d_imm, i8>; + defm PRFH_D_PZI : sve_mem_64b_prfm_vi<0b01, "prfh", uimm5s2, AArch64prf_gather_d_imm, i16>; + defm PRFW_D_PZI : sve_mem_64b_prfm_vi<0b10, "prfw", uimm5s4, AArch64prf_gather_d_imm, i32>; + defm PRFD_D_PZI : sve_mem_64b_prfm_vi<0b11, "prfd", uimm5s8, AArch64prf_gather_d_imm, i64>; defm ADR_SXTW_ZZZ_D : sve_int_bin_cons_misc_0_a_sxtw<0b00, "adr">; defm ADR_UXTW_ZZZ_D : sve_int_bin_cons_misc_0_a_uxtw<0b01, "adr">; defm ADR_LSL_ZZZ_S : sve_int_bin_cons_misc_0_a_32_lsl<0b10, "adr">; defm ADR_LSL_ZZZ_D : sve_int_bin_cons_misc_0_a_64_lsl<0b11, "adr">; - defm TBL_ZZZ : sve_int_perm_tbl<"tbl">; + defm TBL_ZZZ : sve_int_perm_tbl<"tbl", AArch64TBL>; - defm ZIP1_ZZZ : sve_int_perm_bin_perm_zz<0b000, "zip1">; - defm ZIP2_ZZZ : sve_int_perm_bin_perm_zz<0b001, "zip2">; - defm UZP1_ZZZ : sve_int_perm_bin_perm_zz<0b010, "uzp1">; - defm UZP2_ZZZ : sve_int_perm_bin_perm_zz<0b011, "uzp2">; - defm TRN1_ZZZ : sve_int_perm_bin_perm_zz<0b100, "trn1">; - defm TRN2_ZZZ : sve_int_perm_bin_perm_zz<0b101, "trn2">; + defm ZIP1_ZZZ : sve_int_perm_bin_perm_zz<0b000, "zip1", AArch64zip1>; + defm ZIP2_ZZZ : sve_int_perm_bin_perm_zz<0b001, "zip2", AArch64zip2>; + defm UZP1_ZZZ : sve_int_perm_bin_perm_zz<0b010, "uzp1", AArch64uzp1>; + defm UZP2_ZZZ : sve_int_perm_bin_perm_zz<0b011, "uzp2", AArch64uzp2>; + defm TRN1_ZZZ : sve_int_perm_bin_perm_zz<0b100, "trn1", AArch64trn1>; + defm TRN2_ZZZ : sve_int_perm_bin_perm_zz<0b101, "trn2", AArch64trn2>; - defm ZIP1_PPP : sve_int_perm_bin_perm_pp<0b000, "zip1">; - defm ZIP2_PPP : sve_int_perm_bin_perm_pp<0b001, "zip2">; - defm UZP1_PPP : sve_int_perm_bin_perm_pp<0b010, "uzp1">; - defm UZP2_PPP : sve_int_perm_bin_perm_pp<0b011, "uzp2">; - defm TRN1_PPP : sve_int_perm_bin_perm_pp<0b100, "trn1">; - defm TRN2_PPP : sve_int_perm_bin_perm_pp<0b101, "trn2">; + defm ZIP1_PPP : sve_int_perm_bin_perm_pp<0b000, "zip1", AArch64zip1>; + defm ZIP2_PPP : sve_int_perm_bin_perm_pp<0b001, "zip2", AArch64zip2>; + defm UZP1_PPP : sve_int_perm_bin_perm_pp<0b010, "uzp1", AArch64uzp1>; + defm UZP2_PPP : sve_int_perm_bin_perm_pp<0b011, "uzp2", AArch64uzp2>; + defm TRN1_PPP : sve_int_perm_bin_perm_pp<0b100, "trn1", AArch64trn1>; + defm TRN2_PPP : sve_int_perm_bin_perm_pp<0b101, "trn2", AArch64trn2>; - defm CMPHS_PPzZZ : sve_int_cmp_0<0b000, "cmphs">; - defm CMPHI_PPzZZ : sve_int_cmp_0<0b001, "cmphi">; - defm CMPGE_PPzZZ : sve_int_cmp_0<0b100, "cmpge">; - defm CMPGT_PPzZZ : sve_int_cmp_0<0b101, "cmpgt">; - defm CMPEQ_PPzZZ : sve_int_cmp_0<0b110, "cmpeq">; - defm CMPNE_PPzZZ : sve_int_cmp_0<0b111, "cmpne">; + defm CMPHS_PPzZZ : sve_int_cmp_0<0b000, "cmphs", int_aarch64_sve_cmphs, SETUGE>; + defm CMPHI_PPzZZ : sve_int_cmp_0<0b001, "cmphi", int_aarch64_sve_cmphi, SETUGT>; + defm CMPGE_PPzZZ : sve_int_cmp_0<0b100, "cmpge", int_aarch64_sve_cmpge, SETGE>; + defm CMPGT_PPzZZ : sve_int_cmp_0<0b101, "cmpgt", int_aarch64_sve_cmpgt, SETGT>; + defm CMPEQ_PPzZZ : sve_int_cmp_0<0b110, "cmpeq", int_aarch64_sve_cmpeq, SETEQ>; + defm CMPNE_PPzZZ : sve_int_cmp_0<0b111, "cmpne", int_aarch64_sve_cmpne, SETNE>; - defm CMPEQ_WIDE_PPzZZ : sve_int_cmp_0_wide<0b010, "cmpeq">; - defm CMPNE_WIDE_PPzZZ : sve_int_cmp_0_wide<0b011, "cmpne">; - defm CMPGE_WIDE_PPzZZ : sve_int_cmp_1_wide<0b000, "cmpge">; - defm CMPGT_WIDE_PPzZZ : sve_int_cmp_1_wide<0b001, "cmpgt">; - defm CMPLT_WIDE_PPzZZ : sve_int_cmp_1_wide<0b010, "cmplt">; - defm CMPLE_WIDE_PPzZZ : sve_int_cmp_1_wide<0b011, "cmple">; - defm CMPHS_WIDE_PPzZZ : sve_int_cmp_1_wide<0b100, "cmphs">; - defm CMPHI_WIDE_PPzZZ : sve_int_cmp_1_wide<0b101, "cmphi">; - defm CMPLO_WIDE_PPzZZ : sve_int_cmp_1_wide<0b110, "cmplo">; - defm CMPLS_WIDE_PPzZZ : sve_int_cmp_1_wide<0b111, "cmpls">; + defm CMPEQ_WIDE_PPzZZ : sve_int_cmp_0_wide<0b010, "cmpeq", int_aarch64_sve_cmpeq_wide>; + defm CMPNE_WIDE_PPzZZ : sve_int_cmp_0_wide<0b011, "cmpne", int_aarch64_sve_cmpne_wide>; + defm CMPGE_WIDE_PPzZZ : sve_int_cmp_1_wide<0b000, "cmpge", int_aarch64_sve_cmpge_wide>; + defm CMPGT_WIDE_PPzZZ : sve_int_cmp_1_wide<0b001, "cmpgt", int_aarch64_sve_cmpgt_wide>; + defm CMPLT_WIDE_PPzZZ : sve_int_cmp_1_wide<0b010, "cmplt", int_aarch64_sve_cmplt_wide>; + defm CMPLE_WIDE_PPzZZ : sve_int_cmp_1_wide<0b011, "cmple", int_aarch64_sve_cmple_wide>; + defm CMPHS_WIDE_PPzZZ : sve_int_cmp_1_wide<0b100, "cmphs", int_aarch64_sve_cmphs_wide>; + defm CMPHI_WIDE_PPzZZ : sve_int_cmp_1_wide<0b101, "cmphi", int_aarch64_sve_cmphi_wide>; + defm CMPLO_WIDE_PPzZZ : sve_int_cmp_1_wide<0b110, "cmplo", int_aarch64_sve_cmplo_wide>; + defm CMPLS_WIDE_PPzZZ : sve_int_cmp_1_wide<0b111, "cmpls", int_aarch64_sve_cmpls_wide>; - defm CMPGE_PPzZI : sve_int_scmp_vi<0b000, "cmpge">; - defm CMPGT_PPzZI : sve_int_scmp_vi<0b001, "cmpgt">; - defm CMPLT_PPzZI : sve_int_scmp_vi<0b010, "cmplt">; - defm CMPLE_PPzZI : sve_int_scmp_vi<0b011, "cmple">; - defm CMPEQ_PPzZI : sve_int_scmp_vi<0b100, "cmpeq">; - defm CMPNE_PPzZI : sve_int_scmp_vi<0b101, "cmpne">; - defm CMPHS_PPzZI : sve_int_ucmp_vi<0b00, "cmphs">; - defm CMPHI_PPzZI : sve_int_ucmp_vi<0b01, "cmphi">; - defm CMPLO_PPzZI : sve_int_ucmp_vi<0b10, "cmplo">; - defm CMPLS_PPzZI : sve_int_ucmp_vi<0b11, "cmpls">; + defm CMPGE_PPzZI : sve_int_scmp_vi<0b000, "cmpge", SETGE, int_aarch64_sve_cmpge>; + defm CMPGT_PPzZI : sve_int_scmp_vi<0b001, "cmpgt", SETGT, int_aarch64_sve_cmpgt>; + defm CMPLT_PPzZI : sve_int_scmp_vi<0b010, "cmplt", SETLT, null_frag, int_aarch64_sve_cmpgt>; + defm CMPLE_PPzZI : sve_int_scmp_vi<0b011, "cmple", SETLE, null_frag, int_aarch64_sve_cmpge>; + defm CMPEQ_PPzZI : sve_int_scmp_vi<0b100, "cmpeq", SETEQ, int_aarch64_sve_cmpeq>; + defm CMPNE_PPzZI : sve_int_scmp_vi<0b101, "cmpne", SETNE, int_aarch64_sve_cmpne>; + defm CMPHS_PPzZI : sve_int_ucmp_vi<0b00, "cmphs", SETUGE, int_aarch64_sve_cmphs>; + defm CMPHI_PPzZI : sve_int_ucmp_vi<0b01, "cmphi", SETUGT, int_aarch64_sve_cmphi>; + defm CMPLO_PPzZI : sve_int_ucmp_vi<0b10, "cmplo", SETULT, null_frag, int_aarch64_sve_cmphi>; + defm CMPLS_PPzZI : sve_int_ucmp_vi<0b11, "cmpls", SETULE, null_frag, int_aarch64_sve_cmphs>; - defm FCMGE_PPzZZ : sve_fp_3op_p_pd<0b000, "fcmge">; - defm FCMGT_PPzZZ : sve_fp_3op_p_pd<0b001, "fcmgt">; - defm FCMEQ_PPzZZ : sve_fp_3op_p_pd<0b010, "fcmeq">; - defm FCMNE_PPzZZ : sve_fp_3op_p_pd<0b011, "fcmne">; - defm FCMUO_PPzZZ : sve_fp_3op_p_pd<0b100, "fcmuo">; - defm FACGE_PPzZZ : sve_fp_3op_p_pd<0b101, "facge">; - defm FACGT_PPzZZ : sve_fp_3op_p_pd<0b111, "facgt">; + defm FCMGE_PPzZZ : sve_fp_3op_p_pd<0b000, "fcmge", int_aarch64_sve_fcmpge, AArch64fcmge>; + defm FCMGT_PPzZZ : sve_fp_3op_p_pd<0b001, "fcmgt", int_aarch64_sve_fcmpgt, AArch64fcmgt>; + defm FCMEQ_PPzZZ : sve_fp_3op_p_pd<0b010, "fcmeq", int_aarch64_sve_fcmpeq, AArch64fcmeq>; + defm FCMNE_PPzZZ : sve_fp_3op_p_pd<0b011, "fcmne", int_aarch64_sve_fcmpne>; + defm FCMUO_PPzZZ : sve_fp_3op_p_pd<0b100, "fcmuo", int_aarch64_sve_fcmpuo>; + defm FACGE_PPzZZ : sve_fp_3op_p_pd<0b101, "facge", int_aarch64_sve_facge>; + defm FACGT_PPzZZ : sve_fp_3op_p_pd<0b111, "facgt", int_aarch64_sve_facgt>; - defm FCMGE_PPzZ0 : sve_fp_2op_p_pd<0b000, "fcmge">; - defm FCMGT_PPzZ0 : sve_fp_2op_p_pd<0b001, "fcmgt">; - defm FCMLT_PPzZ0 : sve_fp_2op_p_pd<0b010, "fcmlt">; - defm FCMLE_PPzZ0 : sve_fp_2op_p_pd<0b011, "fcmle">; - defm FCMEQ_PPzZ0 : sve_fp_2op_p_pd<0b100, "fcmeq">; - defm FCMNE_PPzZ0 : sve_fp_2op_p_pd<0b110, "fcmne">; + defm FCMGE_PPzZ0 : sve_fp_2op_p_pd<0b000, "fcmge", int_aarch64_sve_fcmpge, AArch64fcmge>; + defm FCMGT_PPzZ0 : sve_fp_2op_p_pd<0b001, "fcmgt", int_aarch64_sve_fcmpgt, AArch64fcmgt>; + defm FCMLT_PPzZ0 : sve_fp_2op_p_pd<0b010, "fcmlt", null_frag, null_frag, int_aarch64_sve_fcmpgt, AArch64fcmgt>; + defm FCMLE_PPzZ0 : sve_fp_2op_p_pd<0b011, "fcmle", null_frag, null_frag, int_aarch64_sve_fcmpge, AArch64fcmge>; + defm FCMEQ_PPzZ0 : sve_fp_2op_p_pd<0b100, "fcmeq", int_aarch64_sve_fcmpeq, AArch64fcmeq>; + defm FCMNE_PPzZ0 : sve_fp_2op_p_pd<0b110, "fcmne", int_aarch64_sve_fcmpne>; - defm WHILELT_PWW : sve_int_while4_rr<0b010, "whilelt">; - defm WHILELE_PWW : sve_int_while4_rr<0b011, "whilele">; - defm WHILELO_PWW : sve_int_while4_rr<0b110, "whilelo">; - defm WHILELS_PWW : sve_int_while4_rr<0b111, "whilels">; + defm WHILELT_PWW : sve_int_while4_rr<0b010, "whilelt", int_aarch64_sve_whilelt>; + defm WHILELE_PWW : sve_int_while4_rr<0b011, "whilele", int_aarch64_sve_whilele>; + defm WHILELO_PWW : sve_int_while4_rr<0b110, "whilelo", int_aarch64_sve_whilelo>; + defm WHILELS_PWW : sve_int_while4_rr<0b111, "whilels", int_aarch64_sve_whilels>; - defm WHILELT_PXX : sve_int_while8_rr<0b010, "whilelt">; - defm WHILELE_PXX : sve_int_while8_rr<0b011, "whilele">; - defm WHILELO_PXX : sve_int_while8_rr<0b110, "whilelo">; - defm WHILELS_PXX : sve_int_while8_rr<0b111, "whilels">; + defm WHILELT_PXX : sve_int_while8_rr<0b010, "whilelt", int_aarch64_sve_whilelt>; + defm WHILELE_PXX : sve_int_while8_rr<0b011, "whilele", int_aarch64_sve_whilele>; + defm WHILELO_PXX : sve_int_while8_rr<0b110, "whilelo", int_aarch64_sve_whilelo>; + defm WHILELS_PXX : sve_int_while8_rr<0b111, "whilels", int_aarch64_sve_whilels>; def CTERMEQ_WW : sve_int_cterm<0b0, 0b0, "ctermeq", GPR32>; def CTERMNE_WW : sve_int_cterm<0b0, 0b1, "ctermne", GPR32>; @@ -773,169 +1033,174 @@ def ADDVL_XXI : sve_int_arith_vl<0b0, "addvl">; def ADDPL_XXI : sve_int_arith_vl<0b1, "addpl">; - defm CNTB_XPiI : sve_int_count<0b000, "cntb">; - defm CNTH_XPiI : sve_int_count<0b010, "cnth">; - defm CNTW_XPiI : sve_int_count<0b100, "cntw">; - defm CNTD_XPiI : sve_int_count<0b110, "cntd">; - defm CNTP_XPP : sve_int_pcount_pred<0b0000, "cntp">; + defm CNTB_XPiI : sve_int_count<0b000, "cntb", int_aarch64_sve_cntb>; + defm CNTH_XPiI : sve_int_count<0b010, "cnth", int_aarch64_sve_cnth>; + defm CNTW_XPiI : sve_int_count<0b100, "cntw", int_aarch64_sve_cntw>; + defm CNTD_XPiI : sve_int_count<0b110, "cntd", int_aarch64_sve_cntd>; + defm CNTP_XPP : sve_int_pcount_pred<0b0000, "cntp", int_aarch64_sve_cntp, int_ctvpop>; - defm INCB_XPiI : sve_int_pred_pattern_a<0b000, "incb">; - defm DECB_XPiI : sve_int_pred_pattern_a<0b001, "decb">; - defm INCH_XPiI : sve_int_pred_pattern_a<0b010, "inch">; - defm DECH_XPiI : sve_int_pred_pattern_a<0b011, "dech">; - defm INCW_XPiI : sve_int_pred_pattern_a<0b100, "incw">; - defm DECW_XPiI : sve_int_pred_pattern_a<0b101, "decw">; - defm INCD_XPiI : sve_int_pred_pattern_a<0b110, "incd">; - defm DECD_XPiI : sve_int_pred_pattern_a<0b111, "decd">; + defm INCB_XPiI : sve_int_pred_pattern_a<0b000, "incb", add, int_aarch64_sve_cntb>; + defm DECB_XPiI : sve_int_pred_pattern_a<0b001, "decb", sub, int_aarch64_sve_cntb>; + defm INCH_XPiI : sve_int_pred_pattern_a<0b010, "inch", add, int_aarch64_sve_cnth>; + defm DECH_XPiI : sve_int_pred_pattern_a<0b011, "dech", sub, int_aarch64_sve_cnth>; + defm INCW_XPiI : sve_int_pred_pattern_a<0b100, "incw", add, int_aarch64_sve_cntw>; + defm DECW_XPiI : sve_int_pred_pattern_a<0b101, "decw", sub, int_aarch64_sve_cntw>; + defm INCD_XPiI : sve_int_pred_pattern_a<0b110, "incd", add, int_aarch64_sve_cntd>; + defm DECD_XPiI : sve_int_pred_pattern_a<0b111, "decd", sub, int_aarch64_sve_cntd>; - defm SQINCB_XPiWdI : sve_int_pred_pattern_b_s32<0b00000, "sqincb">; - defm UQINCB_WPiI : sve_int_pred_pattern_b_u32<0b00001, "uqincb">; - defm SQDECB_XPiWdI : sve_int_pred_pattern_b_s32<0b00010, "sqdecb">; - defm UQDECB_WPiI : sve_int_pred_pattern_b_u32<0b00011, "uqdecb">; - defm SQINCB_XPiI : sve_int_pred_pattern_b_x64<0b00100, "sqincb">; - defm UQINCB_XPiI : sve_int_pred_pattern_b_x64<0b00101, "uqincb">; - defm SQDECB_XPiI : sve_int_pred_pattern_b_x64<0b00110, "sqdecb">; - defm UQDECB_XPiI : sve_int_pred_pattern_b_x64<0b00111, "uqdecb">; + defm SQINCB_XPiWdI : sve_int_pred_pattern_b_s32<0b00000, "sqincb", int_aarch64_sve_sqincb_n32>; + defm UQINCB_WPiI : sve_int_pred_pattern_b_u32<0b00001, "uqincb", int_aarch64_sve_uqincb_n32>; + defm SQDECB_XPiWdI : sve_int_pred_pattern_b_s32<0b00010, "sqdecb", int_aarch64_sve_sqdecb_n32>; + defm UQDECB_WPiI : sve_int_pred_pattern_b_u32<0b00011, "uqdecb", int_aarch64_sve_uqdecb_n32>; + defm SQINCB_XPiI : sve_int_pred_pattern_b_x64<0b00100, "sqincb", int_aarch64_sve_sqincb_n64>; + defm UQINCB_XPiI : sve_int_pred_pattern_b_x64<0b00101, "uqincb", int_aarch64_sve_uqincb_n64>; + defm SQDECB_XPiI : sve_int_pred_pattern_b_x64<0b00110, "sqdecb", int_aarch64_sve_sqdecb_n64>; + defm UQDECB_XPiI : sve_int_pred_pattern_b_x64<0b00111, "uqdecb", int_aarch64_sve_uqdecb_n64>; - defm SQINCH_XPiWdI : sve_int_pred_pattern_b_s32<0b01000, "sqinch">; - defm UQINCH_WPiI : sve_int_pred_pattern_b_u32<0b01001, "uqinch">; - defm SQDECH_XPiWdI : sve_int_pred_pattern_b_s32<0b01010, "sqdech">; - defm UQDECH_WPiI : sve_int_pred_pattern_b_u32<0b01011, "uqdech">; - defm SQINCH_XPiI : sve_int_pred_pattern_b_x64<0b01100, "sqinch">; - defm UQINCH_XPiI : sve_int_pred_pattern_b_x64<0b01101, "uqinch">; - defm SQDECH_XPiI : sve_int_pred_pattern_b_x64<0b01110, "sqdech">; - defm UQDECH_XPiI : sve_int_pred_pattern_b_x64<0b01111, "uqdech">; + defm SQINCH_XPiWdI : sve_int_pred_pattern_b_s32<0b01000, "sqinch", int_aarch64_sve_sqinch_n32>; + defm UQINCH_WPiI : sve_int_pred_pattern_b_u32<0b01001, "uqinch", int_aarch64_sve_uqinch_n32>; + defm SQDECH_XPiWdI : sve_int_pred_pattern_b_s32<0b01010, "sqdech", int_aarch64_sve_sqdech_n32>; + defm UQDECH_WPiI : sve_int_pred_pattern_b_u32<0b01011, "uqdech", int_aarch64_sve_uqdech_n32>; + defm SQINCH_XPiI : sve_int_pred_pattern_b_x64<0b01100, "sqinch", int_aarch64_sve_sqinch_n64>; + defm UQINCH_XPiI : sve_int_pred_pattern_b_x64<0b01101, "uqinch", int_aarch64_sve_uqinch_n64>; + defm SQDECH_XPiI : sve_int_pred_pattern_b_x64<0b01110, "sqdech", int_aarch64_sve_sqdech_n64>; + defm UQDECH_XPiI : sve_int_pred_pattern_b_x64<0b01111, "uqdech", int_aarch64_sve_uqdech_n64>; - defm SQINCW_XPiWdI : sve_int_pred_pattern_b_s32<0b10000, "sqincw">; - defm UQINCW_WPiI : sve_int_pred_pattern_b_u32<0b10001, "uqincw">; - defm SQDECW_XPiWdI : sve_int_pred_pattern_b_s32<0b10010, "sqdecw">; - defm UQDECW_WPiI : sve_int_pred_pattern_b_u32<0b10011, "uqdecw">; - defm SQINCW_XPiI : sve_int_pred_pattern_b_x64<0b10100, "sqincw">; - defm UQINCW_XPiI : sve_int_pred_pattern_b_x64<0b10101, "uqincw">; - defm SQDECW_XPiI : sve_int_pred_pattern_b_x64<0b10110, "sqdecw">; - defm UQDECW_XPiI : sve_int_pred_pattern_b_x64<0b10111, "uqdecw">; + defm SQINCW_XPiWdI : sve_int_pred_pattern_b_s32<0b10000, "sqincw", int_aarch64_sve_sqincw_n32>; + defm UQINCW_WPiI : sve_int_pred_pattern_b_u32<0b10001, "uqincw", int_aarch64_sve_uqincw_n32>; + defm SQDECW_XPiWdI : sve_int_pred_pattern_b_s32<0b10010, "sqdecw", int_aarch64_sve_sqdecw_n32>; + defm UQDECW_WPiI : sve_int_pred_pattern_b_u32<0b10011, "uqdecw", int_aarch64_sve_uqdecw_n32>; + defm SQINCW_XPiI : sve_int_pred_pattern_b_x64<0b10100, "sqincw", int_aarch64_sve_sqincw_n64>; + defm UQINCW_XPiI : sve_int_pred_pattern_b_x64<0b10101, "uqincw", int_aarch64_sve_uqincw_n64>; + defm SQDECW_XPiI : sve_int_pred_pattern_b_x64<0b10110, "sqdecw", int_aarch64_sve_sqdecw_n64>; + defm UQDECW_XPiI : sve_int_pred_pattern_b_x64<0b10111, "uqdecw", int_aarch64_sve_uqdecw_n64>; - defm SQINCD_XPiWdI : sve_int_pred_pattern_b_s32<0b11000, "sqincd">; - defm UQINCD_WPiI : sve_int_pred_pattern_b_u32<0b11001, "uqincd">; - defm SQDECD_XPiWdI : sve_int_pred_pattern_b_s32<0b11010, "sqdecd">; - defm UQDECD_WPiI : sve_int_pred_pattern_b_u32<0b11011, "uqdecd">; - defm SQINCD_XPiI : sve_int_pred_pattern_b_x64<0b11100, "sqincd">; - defm UQINCD_XPiI : sve_int_pred_pattern_b_x64<0b11101, "uqincd">; - defm SQDECD_XPiI : sve_int_pred_pattern_b_x64<0b11110, "sqdecd">; - defm UQDECD_XPiI : sve_int_pred_pattern_b_x64<0b11111, "uqdecd">; + defm SQINCD_XPiWdI : sve_int_pred_pattern_b_s32<0b11000, "sqincd", int_aarch64_sve_sqincd_n32>; + defm UQINCD_WPiI : sve_int_pred_pattern_b_u32<0b11001, "uqincd", int_aarch64_sve_uqincd_n32>; + defm SQDECD_XPiWdI : sve_int_pred_pattern_b_s32<0b11010, "sqdecd", int_aarch64_sve_sqdecd_n32>; + defm UQDECD_WPiI : sve_int_pred_pattern_b_u32<0b11011, "uqdecd", int_aarch64_sve_uqdecd_n32>; + defm SQINCD_XPiI : sve_int_pred_pattern_b_x64<0b11100, "sqincd", int_aarch64_sve_sqincd_n64>; + defm UQINCD_XPiI : sve_int_pred_pattern_b_x64<0b11101, "uqincd", int_aarch64_sve_uqincd_n64>; + defm SQDECD_XPiI : sve_int_pred_pattern_b_x64<0b11110, "sqdecd", int_aarch64_sve_sqdecd_n64>; + defm UQDECD_XPiI : sve_int_pred_pattern_b_x64<0b11111, "uqdecd", int_aarch64_sve_uqdecd_n64>; - defm SQINCH_ZPiI : sve_int_countvlv<0b01000, "sqinch", ZPR16>; - defm UQINCH_ZPiI : sve_int_countvlv<0b01001, "uqinch", ZPR16>; - defm SQDECH_ZPiI : sve_int_countvlv<0b01010, "sqdech", ZPR16>; - defm UQDECH_ZPiI : sve_int_countvlv<0b01011, "uqdech", ZPR16>; + defm SQINCH_ZPiI : sve_int_countvlv<0b01000, "sqinch", ZPR16, int_aarch64_sve_sqinch, nxv8i16>; + defm UQINCH_ZPiI : sve_int_countvlv<0b01001, "uqinch", ZPR16, int_aarch64_sve_uqinch, nxv8i16>; + defm SQDECH_ZPiI : sve_int_countvlv<0b01010, "sqdech", ZPR16, int_aarch64_sve_sqdech, nxv8i16>; + defm UQDECH_ZPiI : sve_int_countvlv<0b01011, "uqdech", ZPR16, int_aarch64_sve_uqdech, nxv8i16>; defm INCH_ZPiI : sve_int_countvlv<0b01100, "inch", ZPR16>; defm DECH_ZPiI : sve_int_countvlv<0b01101, "dech", ZPR16>; - defm SQINCW_ZPiI : sve_int_countvlv<0b10000, "sqincw", ZPR32>; - defm UQINCW_ZPiI : sve_int_countvlv<0b10001, "uqincw", ZPR32>; - defm SQDECW_ZPiI : sve_int_countvlv<0b10010, "sqdecw", ZPR32>; - defm UQDECW_ZPiI : sve_int_countvlv<0b10011, "uqdecw", ZPR32>; + defm SQINCW_ZPiI : sve_int_countvlv<0b10000, "sqincw", ZPR32, int_aarch64_sve_sqincw, nxv4i32>; + defm UQINCW_ZPiI : sve_int_countvlv<0b10001, "uqincw", ZPR32, int_aarch64_sve_uqincw, nxv4i32>; + defm SQDECW_ZPiI : sve_int_countvlv<0b10010, "sqdecw", ZPR32, int_aarch64_sve_sqdecw, nxv4i32>; + defm UQDECW_ZPiI : sve_int_countvlv<0b10011, "uqdecw", ZPR32, int_aarch64_sve_uqdecw, nxv4i32>; defm INCW_ZPiI : sve_int_countvlv<0b10100, "incw", ZPR32>; defm DECW_ZPiI : sve_int_countvlv<0b10101, "decw", ZPR32>; - defm SQINCD_ZPiI : sve_int_countvlv<0b11000, "sqincd", ZPR64>; - defm UQINCD_ZPiI : sve_int_countvlv<0b11001, "uqincd", ZPR64>; - defm SQDECD_ZPiI : sve_int_countvlv<0b11010, "sqdecd", ZPR64>; - defm UQDECD_ZPiI : sve_int_countvlv<0b11011, "uqdecd", ZPR64>; + defm SQINCD_ZPiI : sve_int_countvlv<0b11000, "sqincd", ZPR64, int_aarch64_sve_sqincd, nxv2i64>; + defm UQINCD_ZPiI : sve_int_countvlv<0b11001, "uqincd", ZPR64, int_aarch64_sve_uqincd, nxv2i64>; + defm SQDECD_ZPiI : sve_int_countvlv<0b11010, "sqdecd", ZPR64, int_aarch64_sve_sqdecd, nxv2i64>; + defm UQDECD_ZPiI : sve_int_countvlv<0b11011, "uqdecd", ZPR64, int_aarch64_sve_uqdecd, nxv2i64>; defm INCD_ZPiI : sve_int_countvlv<0b11100, "incd", ZPR64>; defm DECD_ZPiI : sve_int_countvlv<0b11101, "decd", ZPR64>; - defm SQINCP_XPWd : sve_int_count_r_s32<0b00000, "sqincp">; - defm SQINCP_XP : sve_int_count_r_x64<0b00010, "sqincp">; - defm UQINCP_WP : sve_int_count_r_u32<0b00100, "uqincp">; - defm UQINCP_XP : sve_int_count_r_x64<0b00110, "uqincp">; - defm SQDECP_XPWd : sve_int_count_r_s32<0b01000, "sqdecp">; - defm SQDECP_XP : sve_int_count_r_x64<0b01010, "sqdecp">; - defm UQDECP_WP : sve_int_count_r_u32<0b01100, "uqdecp">; - defm UQDECP_XP : sve_int_count_r_x64<0b01110, "uqdecp">; + defm SQINCP_XPWd : sve_int_count_r_s32<0b00000, "sqincp", int_aarch64_sve_sqincp_n32>; + defm SQINCP_XP : sve_int_count_r_x64<0b00010, "sqincp", int_aarch64_sve_sqincp_n64>; + defm UQINCP_WP : sve_int_count_r_u32<0b00100, "uqincp", int_aarch64_sve_uqincp_n32>; + defm UQINCP_XP : sve_int_count_r_x64<0b00110, "uqincp", int_aarch64_sve_uqincp_n64>; + defm SQDECP_XPWd : sve_int_count_r_s32<0b01000, "sqdecp", int_aarch64_sve_sqdecp_n32>; + defm SQDECP_XP : sve_int_count_r_x64<0b01010, "sqdecp", int_aarch64_sve_sqdecp_n64>; + defm UQDECP_WP : sve_int_count_r_u32<0b01100, "uqdecp", int_aarch64_sve_uqdecp_n32>; + defm UQDECP_XP : sve_int_count_r_x64<0b01110, "uqdecp", int_aarch64_sve_uqdecp_n64>; defm INCP_XP : sve_int_count_r_x64<0b10000, "incp">; defm DECP_XP : sve_int_count_r_x64<0b10100, "decp">; - defm SQINCP_ZP : sve_int_count_v<0b00000, "sqincp">; - defm UQINCP_ZP : sve_int_count_v<0b00100, "uqincp">; - defm SQDECP_ZP : sve_int_count_v<0b01000, "sqdecp">; - defm UQDECP_ZP : sve_int_count_v<0b01100, "uqdecp">; + defm SQINCP_ZP : sve_int_count_v<0b00000, "sqincp", int_aarch64_sve_sqincp>; + defm UQINCP_ZP : sve_int_count_v<0b00100, "uqincp", int_aarch64_sve_uqincp>; + defm SQDECP_ZP : sve_int_count_v<0b01000, "sqdecp", int_aarch64_sve_sqdecp>; + defm UQDECP_ZP : sve_int_count_v<0b01100, "uqdecp", int_aarch64_sve_uqdecp>; defm INCP_ZP : sve_int_count_v<0b10000, "incp">; defm DECP_ZP : sve_int_count_v<0b10100, "decp">; - defm INDEX_RR : sve_int_index_rr<"index">; - defm INDEX_IR : sve_int_index_ir<"index">; - defm INDEX_RI : sve_int_index_ri<"index">; - defm INDEX_II : sve_int_index_ii<"index">; + defm INDEX_RR : sve_int_index_rr<"index", series_vector>; + defm INDEX_IR : sve_int_index_ir<"index", series_vector>; + defm INDEX_RI : sve_int_index_ri<"index", series_vector>; + defm INDEX_II : sve_int_index_ii<"index", series_vector>; // Unpredicated shifts - defm ASR_ZZI : sve_int_bin_cons_shift_imm_right<0b00, "asr">; - defm LSR_ZZI : sve_int_bin_cons_shift_imm_right<0b01, "lsr">; - defm LSL_ZZI : sve_int_bin_cons_shift_imm_left< 0b11, "lsl">; + defm ASR_ZZI : sve_int_bin_cons_shift_imm_right<0b00, "asr", sra>; + defm LSR_ZZI : sve_int_bin_cons_shift_imm_right<0b01, "lsr", srl>; + defm LSL_ZZI : sve_int_bin_cons_shift_imm_left< 0b11, "lsl", shl>; defm ASR_WIDE_ZZZ : sve_int_bin_cons_shift_wide<0b00, "asr">; defm LSR_WIDE_ZZZ : sve_int_bin_cons_shift_wide<0b01, "lsr">; defm LSL_WIDE_ZZZ : sve_int_bin_cons_shift_wide<0b11, "lsl">; // Predicated shifts - defm ASR_ZPmI : sve_int_bin_pred_shift_imm_right<0b000, "asr">; - defm LSR_ZPmI : sve_int_bin_pred_shift_imm_right<0b001, "lsr">; - defm LSL_ZPmI : sve_int_bin_pred_shift_imm_left< 0b011, "lsl">; - defm ASRD_ZPmI : sve_int_bin_pred_shift_imm_right<0b100, "asrd">; + defm ASR_ZPmI : sve_int_bin_pred_shift_imm_right<0b0000, "asr", "ASR_ZPZI">; + defm LSR_ZPmI : sve_int_bin_pred_shift_imm_right<0b0001, "lsr", "LSR_ZPZI">; + defm LSL_ZPmI : sve_int_bin_pred_shift_imm_left< 0b0011, "lsl">; + defm ASRD_ZPmI : sve_int_bin_pred_shift_imm_right<0b0100, "asrd", "ASRD_ZPZI", int_aarch64_sve_asrd>; - defm ASR_ZPmZ : sve_int_bin_pred_shift<0b000, "asr">; - defm LSR_ZPmZ : sve_int_bin_pred_shift<0b001, "lsr">; - defm LSL_ZPmZ : sve_int_bin_pred_shift<0b011, "lsl">; - defm ASRR_ZPmZ : sve_int_bin_pred_shift<0b100, "asrr">; - defm LSRR_ZPmZ : sve_int_bin_pred_shift<0b101, "lsrr">; - defm LSLR_ZPmZ : sve_int_bin_pred_shift<0b111, "lslr">; + defm ASR_ZPZZ : sve_int_bin_pred_zx; + defm LSR_ZPZZ : sve_int_bin_pred_zx; + defm LSL_ZPZZ : sve_int_bin_pred_zx; + defm ASRD_ZPZI : sve_int_bin_pred_shift_0_right_zx; - defm ASR_WIDE_ZPmZ : sve_int_bin_pred_shift_wide<0b000, "asr">; - defm LSR_WIDE_ZPmZ : sve_int_bin_pred_shift_wide<0b001, "lsr">; - defm LSL_WIDE_ZPmZ : sve_int_bin_pred_shift_wide<0b011, "lsl">; + defm ASR_ZPmZ : sve_int_bin_pred_shift<0b000, "asr", "ASR_ZPZZ", int_aarch64_sve_asr, "ASRR_ZPmZ", 1>; + defm LSR_ZPmZ : sve_int_bin_pred_shift<0b001, "lsr", "LSR_ZPZZ", int_aarch64_sve_lsr, "LSRR_ZPmZ", 1>; + defm LSL_ZPmZ : sve_int_bin_pred_shift<0b011, "lsl", "LSL_ZPZZ", int_aarch64_sve_lsl, "LSLR_ZPmZ", 1>; + defm ASRR_ZPmZ : sve_int_bin_pred_shift<0b100, "asrr", "ASRR_ZPZZ", null_frag, "ASR_ZPmZ", 0>; + defm LSRR_ZPmZ : sve_int_bin_pred_shift<0b101, "lsrr", "LSRR_ZPZZ", null_frag, "LSR_ZPmZ", 0>; + defm LSLR_ZPmZ : sve_int_bin_pred_shift<0b111, "lslr", "LSLR_ZPZZ", null_frag, "LSL_ZPmZ", 0>; - def FCVT_ZPmZ_StoH : sve_fp_2op_p_zd<0b1001000, "fcvt", ZPR32, ZPR16, ElementSizeS>; - def FCVT_ZPmZ_HtoS : sve_fp_2op_p_zd<0b1001001, "fcvt", ZPR16, ZPR32, ElementSizeS>; - def SCVTF_ZPmZ_HtoH : sve_fp_2op_p_zd<0b0110010, "scvtf", ZPR16, ZPR16, ElementSizeH>; - def SCVTF_ZPmZ_StoS : sve_fp_2op_p_zd<0b1010100, "scvtf", ZPR32, ZPR32, ElementSizeS>; - def UCVTF_ZPmZ_StoS : sve_fp_2op_p_zd<0b1010101, "ucvtf", ZPR32, ZPR32, ElementSizeS>; - def UCVTF_ZPmZ_HtoH : sve_fp_2op_p_zd<0b0110011, "ucvtf", ZPR16, ZPR16, ElementSizeH>; - def FCVTZS_ZPmZ_HtoH : sve_fp_2op_p_zd<0b0111010, "fcvtzs", ZPR16, ZPR16, ElementSizeH>; - def FCVTZS_ZPmZ_StoS : sve_fp_2op_p_zd<0b1011100, "fcvtzs", ZPR32, ZPR32, ElementSizeS>; - def FCVTZU_ZPmZ_HtoH : sve_fp_2op_p_zd<0b0111011, "fcvtzu", ZPR16, ZPR16, ElementSizeH>; - def FCVTZU_ZPmZ_StoS : sve_fp_2op_p_zd<0b1011101, "fcvtzu", ZPR32, ZPR32, ElementSizeS>; - def FCVT_ZPmZ_DtoH : sve_fp_2op_p_zd<0b1101000, "fcvt", ZPR64, ZPR16, ElementSizeD>; - def FCVT_ZPmZ_HtoD : sve_fp_2op_p_zd<0b1101001, "fcvt", ZPR16, ZPR64, ElementSizeD>; - def FCVT_ZPmZ_DtoS : sve_fp_2op_p_zd<0b1101010, "fcvt", ZPR64, ZPR32, ElementSizeD>; - def FCVT_ZPmZ_StoD : sve_fp_2op_p_zd<0b1101011, "fcvt", ZPR32, ZPR64, ElementSizeD>; - def SCVTF_ZPmZ_StoD : sve_fp_2op_p_zd<0b1110000, "scvtf", ZPR32, ZPR64, ElementSizeD>; - def UCVTF_ZPmZ_StoD : sve_fp_2op_p_zd<0b1110001, "ucvtf", ZPR32, ZPR64, ElementSizeD>; - def UCVTF_ZPmZ_StoH : sve_fp_2op_p_zd<0b0110101, "ucvtf", ZPR32, ZPR16, ElementSizeS>; - def SCVTF_ZPmZ_DtoS : sve_fp_2op_p_zd<0b1110100, "scvtf", ZPR64, ZPR32, ElementSizeD>; - def SCVTF_ZPmZ_StoH : sve_fp_2op_p_zd<0b0110100, "scvtf", ZPR32, ZPR16, ElementSizeS>; - def SCVTF_ZPmZ_DtoH : sve_fp_2op_p_zd<0b0110110, "scvtf", ZPR64, ZPR16, ElementSizeD>; - def UCVTF_ZPmZ_DtoS : sve_fp_2op_p_zd<0b1110101, "ucvtf", ZPR64, ZPR32, ElementSizeD>; - def UCVTF_ZPmZ_DtoH : sve_fp_2op_p_zd<0b0110111, "ucvtf", ZPR64, ZPR16, ElementSizeD>; - def SCVTF_ZPmZ_DtoD : sve_fp_2op_p_zd<0b1110110, "scvtf", ZPR64, ZPR64, ElementSizeD>; - def UCVTF_ZPmZ_DtoD : sve_fp_2op_p_zd<0b1110111, "ucvtf", ZPR64, ZPR64, ElementSizeD>; - def FCVTZS_ZPmZ_DtoS : sve_fp_2op_p_zd<0b1111000, "fcvtzs", ZPR64, ZPR32, ElementSizeD>; - def FCVTZU_ZPmZ_DtoS : sve_fp_2op_p_zd<0b1111001, "fcvtzu", ZPR64, ZPR32, ElementSizeD>; - def FCVTZS_ZPmZ_StoD : sve_fp_2op_p_zd<0b1111100, "fcvtzs", ZPR32, ZPR64, ElementSizeD>; - def FCVTZS_ZPmZ_HtoS : sve_fp_2op_p_zd<0b0111100, "fcvtzs", ZPR16, ZPR32, ElementSizeS>; - def FCVTZS_ZPmZ_HtoD : sve_fp_2op_p_zd<0b0111110, "fcvtzs", ZPR16, ZPR64, ElementSizeD>; - def FCVTZU_ZPmZ_HtoS : sve_fp_2op_p_zd<0b0111101, "fcvtzu", ZPR16, ZPR32, ElementSizeS>; - def FCVTZU_ZPmZ_HtoD : sve_fp_2op_p_zd<0b0111111, "fcvtzu", ZPR16, ZPR64, ElementSizeD>; - def FCVTZU_ZPmZ_StoD : sve_fp_2op_p_zd<0b1111101, "fcvtzu", ZPR32, ZPR64, ElementSizeD>; - def FCVTZS_ZPmZ_DtoD : sve_fp_2op_p_zd<0b1111110, "fcvtzs", ZPR64, ZPR64, ElementSizeD>; - def FCVTZU_ZPmZ_DtoD : sve_fp_2op_p_zd<0b1111111, "fcvtzu", ZPR64, ZPR64, ElementSizeD>; + defm ASR_WIDE_ZPmZ : sve_int_bin_pred_shift_wide<0b000, "asr", int_aarch64_sve_asr_wide>; + defm LSR_WIDE_ZPmZ : sve_int_bin_pred_shift_wide<0b001, "lsr", int_aarch64_sve_lsr_wide>; + defm LSL_WIDE_ZPmZ : sve_int_bin_pred_shift_wide<0b011, "lsl", int_aarch64_sve_lsl_wide>; - defm FRINTN_ZPmZ : sve_fp_2op_p_zd_HSD<0b00000, "frintn">; - defm FRINTP_ZPmZ : sve_fp_2op_p_zd_HSD<0b00001, "frintp">; - defm FRINTM_ZPmZ : sve_fp_2op_p_zd_HSD<0b00010, "frintm">; - defm FRINTZ_ZPmZ : sve_fp_2op_p_zd_HSD<0b00011, "frintz">; - defm FRINTA_ZPmZ : sve_fp_2op_p_zd_HSD<0b00100, "frinta">; - defm FRINTX_ZPmZ : sve_fp_2op_p_zd_HSD<0b00110, "frintx">; - defm FRINTI_ZPmZ : sve_fp_2op_p_zd_HSD<0b00111, "frinti">; - defm FRECPX_ZPmZ : sve_fp_2op_p_zd_HSD<0b01100, "frecpx">; - defm FSQRT_ZPmZ : sve_fp_2op_p_zd_HSD<0b01101, "fsqrt">; + defm FCVT_ZPmZ_StoH : sve_fp_2op_p_zd<0b1001000, "fcvt", ZPR32, ZPR16, int_aarch64_sve_fcvt_f16f32, nxv8f16, nxv16i1, nxv4f32, ElementSizeS>; + defm FCVT_ZPmZ_HtoS : sve_fp_2op_p_zd<0b1001001, "fcvt", ZPR16, ZPR32, int_aarch64_sve_fcvt_f32f16, nxv4f32, nxv16i1, nxv8f16, ElementSizeS>; + defm SCVTF_ZPmZ_HtoH : sve_fp_2op_p_zd<0b0110010, "scvtf", ZPR16, ZPR16, int_aarch64_sve_scvtf, nxv8f16, nxv8i1, nxv8i16, ElementSizeH>; + defm SCVTF_ZPmZ_StoS : sve_fp_2op_p_zd<0b1010100, "scvtf", ZPR32, ZPR32, int_aarch64_sve_scvtf, nxv4f32, nxv4i1, nxv4i32, ElementSizeS>; + defm UCVTF_ZPmZ_StoS : sve_fp_2op_p_zd<0b1010101, "ucvtf", ZPR32, ZPR32, int_aarch64_sve_ucvtf, nxv4f32, nxv4i1, nxv4i32, ElementSizeS>; + defm UCVTF_ZPmZ_HtoH : sve_fp_2op_p_zd<0b0110011, "ucvtf", ZPR16, ZPR16, int_aarch64_sve_ucvtf, nxv8f16, nxv8i1, nxv8i16, ElementSizeH>; + defm FCVTZS_ZPmZ_HtoH : sve_fp_2op_p_zd<0b0111010, "fcvtzs", ZPR16, ZPR16, int_aarch64_sve_fcvtzs, nxv8i16, nxv8i1, nxv8f16, ElementSizeH>; + defm FCVTZS_ZPmZ_StoS : sve_fp_2op_p_zd<0b1011100, "fcvtzs", ZPR32, ZPR32, int_aarch64_sve_fcvtzs, nxv4i32, nxv4i1, nxv4f32, ElementSizeS>; + defm FCVTZU_ZPmZ_HtoH : sve_fp_2op_p_zd<0b0111011, "fcvtzu", ZPR16, ZPR16, int_aarch64_sve_fcvtzu, nxv8i16, nxv8i1, nxv8f16, ElementSizeH>; + defm FCVTZU_ZPmZ_StoS : sve_fp_2op_p_zd<0b1011101, "fcvtzu", ZPR32, ZPR32, int_aarch64_sve_fcvtzu, nxv4i32, nxv4i1, nxv4f32, ElementSizeS>; + defm FCVT_ZPmZ_DtoH : sve_fp_2op_p_zd<0b1101000, "fcvt", ZPR64, ZPR16, int_aarch64_sve_fcvt_f16f64, nxv8f16, nxv16i1, nxv2f64, ElementSizeD>; + defm FCVT_ZPmZ_HtoD : sve_fp_2op_p_zd<0b1101001, "fcvt", ZPR16, ZPR64, int_aarch64_sve_fcvt_f64f16, nxv2f64, nxv16i1, nxv8f16, ElementSizeD>; + defm FCVT_ZPmZ_DtoS : sve_fp_2op_p_zd<0b1101010, "fcvt", ZPR64, ZPR32, int_aarch64_sve_fcvt_f32f64, nxv4f32, nxv16i1, nxv2f64, ElementSizeD>; + defm FCVT_ZPmZ_StoD : sve_fp_2op_p_zd<0b1101011, "fcvt", ZPR32, ZPR64, int_aarch64_sve_fcvt_f64f32, nxv2f64, nxv16i1, nxv4f32, ElementSizeD>; + defm SCVTF_ZPmZ_StoD : sve_fp_2op_p_zd<0b1110000, "scvtf", ZPR32, ZPR64, int_aarch64_sve_scvtf_f64i32, nxv2f64, nxv16i1, nxv4i32, ElementSizeD>; + defm UCVTF_ZPmZ_StoD : sve_fp_2op_p_zd<0b1110001, "ucvtf", ZPR32, ZPR64, int_aarch64_sve_ucvtf_f64i32, nxv2f64, nxv16i1, nxv4i32, ElementSizeD>; + defm UCVTF_ZPmZ_StoH : sve_fp_2op_p_zd<0b0110101, "ucvtf", ZPR32, ZPR16, int_aarch64_sve_ucvtf_f16i32, nxv8f16, nxv16i1, nxv4i32, ElementSizeS>; + defm SCVTF_ZPmZ_DtoS : sve_fp_2op_p_zd<0b1110100, "scvtf", ZPR64, ZPR32, int_aarch64_sve_scvtf_f32i64, nxv4f32, nxv16i1, nxv2i64, ElementSizeD>; + defm SCVTF_ZPmZ_StoH : sve_fp_2op_p_zd<0b0110100, "scvtf", ZPR32, ZPR16, int_aarch64_sve_scvtf_f16i32, nxv8f16, nxv16i1, nxv4i32, ElementSizeS>; + defm SCVTF_ZPmZ_DtoH : sve_fp_2op_p_zd<0b0110110, "scvtf", ZPR64, ZPR16, int_aarch64_sve_scvtf_f16i64, nxv8f16, nxv16i1, nxv2i64, ElementSizeD>; + defm UCVTF_ZPmZ_DtoS : sve_fp_2op_p_zd<0b1110101, "ucvtf", ZPR64, ZPR32, int_aarch64_sve_ucvtf_f32i64, nxv4f32, nxv16i1, nxv2i64, ElementSizeD>; + defm UCVTF_ZPmZ_DtoH : sve_fp_2op_p_zd<0b0110111, "ucvtf", ZPR64, ZPR16, int_aarch64_sve_ucvtf_f16i64, nxv8f16, nxv16i1, nxv2i64, ElementSizeD>; + defm SCVTF_ZPmZ_DtoD : sve_fp_2op_p_zd<0b1110110, "scvtf", ZPR64, ZPR64, int_aarch64_sve_scvtf, nxv2f64, nxv2i1, nxv2i64, ElementSizeD>; + defm UCVTF_ZPmZ_DtoD : sve_fp_2op_p_zd<0b1110111, "ucvtf", ZPR64, ZPR64, int_aarch64_sve_ucvtf, nxv2f64, nxv2i1, nxv2i64, ElementSizeD>; + defm FCVTZS_ZPmZ_DtoS : sve_fp_2op_p_zd<0b1111000, "fcvtzs", ZPR64, ZPR32, int_aarch64_sve_fcvtzs_i32f64, nxv4i32, nxv16i1, nxv2f64, ElementSizeD>; + defm FCVTZU_ZPmZ_DtoS : sve_fp_2op_p_zd<0b1111001, "fcvtzu", ZPR64, ZPR32, int_aarch64_sve_fcvtzu_i32f64, nxv4i32, nxv16i1, nxv2f64, ElementSizeD>; + defm FCVTZS_ZPmZ_StoD : sve_fp_2op_p_zd<0b1111100, "fcvtzs", ZPR32, ZPR64, int_aarch64_sve_fcvtzs_i64f32, nxv2i64, nxv16i1, nxv4f32, ElementSizeD>; + defm FCVTZS_ZPmZ_HtoS : sve_fp_2op_p_zd<0b0111100, "fcvtzs", ZPR16, ZPR32, int_aarch64_sve_fcvtzs_i32f16, nxv4i32, nxv16i1, nxv8f16, ElementSizeS>; + defm FCVTZS_ZPmZ_HtoD : sve_fp_2op_p_zd<0b0111110, "fcvtzs", ZPR16, ZPR64, int_aarch64_sve_fcvtzs_i64f16, nxv2i64, nxv16i1, nxv8f16, ElementSizeD>; + defm FCVTZU_ZPmZ_HtoS : sve_fp_2op_p_zd<0b0111101, "fcvtzu", ZPR16, ZPR32, int_aarch64_sve_fcvtzu_i32f16, nxv4i32, nxv16i1, nxv8f16, ElementSizeS>; + defm FCVTZU_ZPmZ_HtoD : sve_fp_2op_p_zd<0b0111111, "fcvtzu", ZPR16, ZPR64, int_aarch64_sve_fcvtzu_i64f16, nxv2i64, nxv16i1, nxv8f16, ElementSizeD>; + defm FCVTZU_ZPmZ_StoD : sve_fp_2op_p_zd<0b1111101, "fcvtzu", ZPR32, ZPR64, int_aarch64_sve_fcvtzu_i64f32, nxv2i64, nxv16i1, nxv4f32, ElementSizeD>; + defm FCVTZS_ZPmZ_DtoD : sve_fp_2op_p_zd<0b1111110, "fcvtzs", ZPR64, ZPR64, int_aarch64_sve_fcvtzs, nxv2i64, nxv2i1, nxv2f64, ElementSizeD>; + defm FCVTZU_ZPmZ_DtoD : sve_fp_2op_p_zd<0b1111111, "fcvtzu", ZPR64, ZPR64, int_aarch64_sve_fcvtzu, nxv2i64, nxv2i1, nxv2f64, ElementSizeD>; + + defm FRINTN_ZPmZ : sve_fp_2op_p_zd_HSD<0b00000, "frintn", int_aarch64_sve_frintn>; + defm FRINTP_ZPmZ : sve_fp_2op_p_zd_HSD<0b00001, "frintp", int_aarch64_sve_frintp>; + defm FRINTM_ZPmZ : sve_fp_2op_p_zd_HSD<0b00010, "frintm", int_aarch64_sve_frintm>; + defm FRINTZ_ZPmZ : sve_fp_2op_p_zd_HSD<0b00011, "frintz", int_aarch64_sve_frintz>; + defm FRINTA_ZPmZ : sve_fp_2op_p_zd_HSD<0b00100, "frinta", int_aarch64_sve_frinta>; + defm FRINTX_ZPmZ : sve_fp_2op_p_zd_HSD<0b00110, "frintx", int_aarch64_sve_frintx>; + defm FRINTI_ZPmZ : sve_fp_2op_p_zd_HSD<0b00111, "frinti", int_aarch64_sve_frinti>; + defm FRECPX_ZPmZ : sve_fp_2op_p_zd_HSD<0b01100, "frecpx", int_aarch64_sve_frecpx>; + defm FSQRT_ZPmZ : sve_fp_2op_p_zd_HSD<0b01101, "fsqrt", int_aarch64_sve_fsqrt>; // InstAliases def : InstAlias<"mov $Zd, $Zn", @@ -1021,4 +1286,2243 @@ (FCMGT_PPzZZ_S PPR32:$Zd, PPR3bAny:$Pg, ZPR32:$Zn, ZPR32:$Zm), 0>; def : InstAlias<"fcmlt $Zd, $Pg/z, $Zm, $Zn", (FCMGT_PPzZZ_D PPR64:$Zd, PPR3bAny:$Pg, ZPR64:$Zn, ZPR64:$Zm), 0>; + + // Pseudo instructions representing unpredicated LDR and STR for ZPR2,3,4. + // These later get expanded to PTRUE of ST2/LD2 etc + let mayLoad = 1, hasSideEffects = 0 in { + def LDR_ZZXI : Pseudo<(outs ZZ_b:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>; + def LDR_ZZZXI : Pseudo<(outs ZZZ_b:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>; + def LDR_ZZZZXI : Pseudo<(outs ZZZZ_b:$Zd), (ins GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>; + } + let mayStore = 1, hasSideEffects = 0 in { + def STR_ZZXI : Pseudo<(outs), (ins ZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>; + def STR_ZZZXI : Pseudo<(outs), (ins ZZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>; + def STR_ZZZZXI : Pseudo<(outs), (ins ZZZZ_b:$Zs, GPR64sp:$sp, simm4s1:$offset),[]>, Sched<[]>; + } + + // Create variants of DUP_ZZI (I=0) that can be used without INSERT_SUBREG. + def DUP_ZV_H : Pseudo<(outs ZPR16:$Zd), (ins FPR16:$Vn), []>, Sched<[]>; + def DUP_ZV_S : Pseudo<(outs ZPR32:$Zd), (ins FPR32:$Vn), []>, Sched<[]>; + def DUP_ZV_D : Pseudo<(outs ZPR64:$Zd), (ins FPR64:$Vn), []>, Sched<[]>; + + let hasNoSchedulingInfo = 1 in { + class PredSelZeroOpPseudo + : Pseudo<(outs zprty:$Zd), (ins PPR3bAny:$Pg, zprty:$Zs1, zprty:$Zs2), []>; + } + + multiclass selzero { + def _B : PredSelZeroOpPseudo; + def _H : PredSelZeroOpPseudo; + def _S : PredSelZeroOpPseudo; + def _D : PredSelZeroOpPseudo; + + def : Pat<(vselect nxv16i1:$Op1, nxv16i8:$Op2, (nxv16i8 (AArch64dup (i32 0)))), + (!cast(NAME # "_B") $Op1, $Op2, (DUP_ZI_B 0, 0))>; + def : Pat<(vselect nxv8i1:$Op1, nxv8i16:$Op2, (nxv8i16 (AArch64dup (i32 0)))), + (!cast(NAME # "_H") $Op1, $Op2, (DUP_ZI_H 0, 0))>; + def : Pat<(vselect nxv4i1:$Op1, nxv4i32:$Op2, (nxv4i32 (AArch64dup (i32 0)))), + (!cast(NAME # "_S") $Op1, $Op2, (DUP_ZI_S 0, 0))>; + def : Pat<(vselect nxv2i1:$Op1, nxv2i64:$Op2, (nxv2i64 (AArch64dup (i64 0)))), + (!cast(NAME # "_D") $Op1, $Op2, (DUP_ZI_D 0, 0))>; + def : Pat<(vselect nxv8i1:$Op1, nxv8f16:$Op2, (nxv8f16 (AArch64dup (f16 fpimm0)))), + (!cast(NAME # "_H") $Op1, $Op2, (DUP_ZI_H 0, 0))>; + def : Pat<(vselect nxv4i1:$Op1, nxv4f32:$Op2, (nxv4f32 (AArch64dup (f32 fpimm0)))), + (!cast(NAME # "_S") $Op1, $Op2, (DUP_ZI_S 0, 0))>; + def : Pat<(vselect nxv2i1:$Op1, nxv2f64:$Op2, (nxv2f64 (AArch64dup (f64 fpimm0)))), + (!cast(NAME # "_D") $Op1, $Op2, (DUP_ZI_D 0, 0))>; + } + defm SELZERO : selzero; + + // PATTERNS + multiclass sve_int_arith_immediates { + def : Pat<(nxv16i8 (op (nxv16i8 ZPR:$Zs1), + (nxv16i8 (AArch64dup (i32 (SVEUIntArithImm8 i32:$imm, i32:$shift)))))), + (!cast(I # "_B") ZPR:$Zs1, i32:$imm, i32:$shift)>; + def : Pat<(nxv8i16 (op (nxv8i16 ZPR:$Zs1), + (nxv8i16 (AArch64dup (i32 (SVEUIntArithImm16 i32:$imm, i32:$shift)))))), + (!cast(I # "_H") ZPR:$Zs1, i32:$imm, i32:$shift)>; + def : Pat<(nxv4i32 (op (nxv4i32 ZPR:$Zs1), + (nxv4i32 (AArch64dup (i32 (SVEUIntArithImm32 i32:$imm, i32:$shift)))))), + (!cast(I # "_S") ZPR:$Zs1, i32:$imm, i32:$shift)>; + def : Pat<(nxv2i64 (op (nxv2i64 ZPR:$Zs1), + (nxv2i64 (AArch64dup (i64 (SVEUIntArithImm64 i32:$imm, i32:$shift)))))), + (!cast(I # "_D") ZPR:$Zs1, i32:$imm, i32:$shift)>; + } + + defm : sve_int_arith_immediates<"ADD_ZZI", add>; + defm : sve_int_arith_immediates<"SUB_ZZI", sub>; + // Skipping subr? + defm : sve_int_arith_immediates<"SQADD_ZZI", int_aarch64_sve_sqadd_x>; + defm : sve_int_arith_immediates<"SQSUB_ZZI", int_aarch64_sve_sqsub_x>; + defm : sve_int_arith_immediates<"UQADD_ZZI", int_aarch64_sve_uqadd_x>; + defm : sve_int_arith_immediates<"UQSUB_ZZI", int_aarch64_sve_uqsub_x>; + + // TODO: Binopfrag? + def : Pat<(and (nxv16i8 ZPR:$Zs1), (xor (nxv16i8 ZPR:$Zs2), + (nxv16i8 (AArch64dup (i32 -1))))), + (BIC_ZZZ ZPR:$Zs1, ZPR:$Zs2)>; + + def : Pat<(and (nxv8i16 ZPR:$Zs1), (xor (nxv8i16 ZPR:$Zs2), + (nxv8i16 (AArch64dup (i32 -1))))), + (BIC_ZZZ ZPR:$Zs1, ZPR:$Zs2)>; + + def : Pat<(and (nxv4i32 ZPR:$Zs1), (xor (nxv4i32 ZPR:$Zs2), + (nxv4i32 (AArch64dup (i32 -1))))), + (BIC_ZZZ ZPR:$Zs1, ZPR:$Zs2)>; + + def : Pat<(and (nxv2i64 ZPR:$Zs1), (xor (nxv2i64 ZPR:$Zs2), + (nxv2i64 (AArch64dup (i64 -1))))), + (BIC_ZZZ ZPR:$Zs1, ZPR:$Zs2)>; + + + multiclass sve_int_logical_immediates { + def : Pat<(nxv16i8 (op (nxv16i8 ZPR:$Zs1), + (nxv16i8 (AArch64dup (i32 (SVELogicalImm8 i64:$imm)))))), + (Inst ZPR:$Zs1, i64:$imm)>; + def : Pat<(nxv8i16 (op (nxv8i16 ZPR:$Zs1), + (nxv8i16 (AArch64dup (i32 (SVELogicalImm16 i64:$imm)))))), + (Inst ZPR:$Zs1, i64:$imm)>; + def : Pat<(nxv4i32 (op (nxv4i32 ZPR:$Zs1), + (nxv4i32 (AArch64dup (i32 (SVELogicalImm32 i64:$imm)))))), + (Inst ZPR:$Zs1, i64:$imm)>; + def : Pat<(nxv2i64 (op (nxv2i64 ZPR:$Zs1), + (nxv2i64 (AArch64dup (i64 (SVELogicalImm64 i64:$imm)))))), + (Inst ZPR:$Zs1, i64:$imm)>; + } + + defm : sve_int_logical_immediates; + defm : sve_int_logical_immediates; + defm : sve_int_logical_immediates; + + def : Pat<(sext_inreg (nxv2i64 ZPR:$Zs), nxv2i32), (SXTW_ZPmZ_D (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(sext_inreg (nxv2i64 ZPR:$Zs), nxv2i16), (SXTH_ZPmZ_D (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(sext_inreg (nxv2i64 ZPR:$Zs), nxv2i8), (SXTB_ZPmZ_D (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(sext_inreg (nxv4i32 ZPR:$Zs), nxv4i16), (SXTH_ZPmZ_S (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(sext_inreg (nxv4i32 ZPR:$Zs), nxv4i8), (SXTB_ZPmZ_S (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(sext_inreg (nxv8i16 ZPR:$Zs), nxv8i8), (SXTB_ZPmZ_H (IMPLICIT_DEF), (PTRUE_H 31), ZPR:$Zs)>; + + // zext_inreg - 8,16,32 -> 64 // NOTE: covered by general AND of immediate + def : Pat<(and (nxv2i64 ZPR:$Zs), (nxv2i64 (AArch64dup (i64 logical_imm64:$imms1)))), + (AND_ZI ZPR:$Zs, logical_imm64:$imms1)>; + + // zext_inreg - 16 -> 32 + def : Pat<(and (nxv4i32 ZPR:$Zs), (nxv4i32 (AArch64dup (i32 0xFFFF)))), + (AND_ZI ZPR:$Zs, (logical_imm64_XFORM (i64 0x0000FFFF0000FFFF)))>; + + // zext_inreg - 8 -> 32 + def : Pat<(and (nxv4i32 ZPR:$Zs), (nxv4i32 (AArch64dup (i32 0xFF)))), + (AND_ZI ZPR:$Zs, (logical_imm64_XFORM (i64 0x000000FF000000FF)))>; + + // zext_inreg - 8 -> 16 + def : Pat<(and (nxv8i16 ZPR:$Zs), (nxv8i16 (AArch64dup (i32 0xFF)))), + (AND_ZI ZPR:$Zs, (logical_imm64_XFORM (i64 0x00FF00FF00FF00FF)))>; + + + class AddSubFoldPattern + : Pat<(op (Zty ZPR:$Za), + (vselect (PTy PPR:$gp), (Zty ZPR:$Zb), (SVEDup0))), + (!cast(I # "_ZPmZ_" # size) PPR:$gp, ZPR:$Za, ZPR:$Zb)>; + multiclass add_sub_folds { + def : AddSubFoldPattern; + def : AddSubFoldPattern; + def : AddSubFoldPattern; + def : AddSubFoldPattern; + } + defm : add_sub_folds; + defm : add_sub_folds; + + // min and max fold like this: + // select($gp AND ($a > $b), $a, $b) => max($gp, $b, $a) + class MinMaxFoldPattern1 + : Pat<(ZTy (vselect (and (PTy PPR:$gp), + (setcc (ZTy ZPR:$Zs1), (ZTy ZPR:$Zs2), Op)), + (ZTy ZPR:$Zs1), (ZTy ZPR:$Zs2))), + (!cast(I # "_ZPmZ_" # size) PPR:$gp, ZPR:$Zs2, ZPR:$Zs1)>; + // Same as above but for inverted operands and condition code. + class MinMaxFoldPattern2 + : Pat<(ZTy (vselect (and (PTy PPR:$gp), + (setcc (ZTy ZPR:$Zs1), (ZTy ZPR:$Zs2), Op)), + (ZTy ZPR:$Zs2), (ZTy ZPR:$Zs1))), + (!cast(I # "_ZPmZ_" # size) PPR:$gp, ZPR:$Zs1, ZPR:$Zs2)>; + multiclass min_max_folds { + def : MinMaxFoldPattern1; + def : MinMaxFoldPattern1; + def : MinMaxFoldPattern1; + def : MinMaxFoldPattern1; + def : MinMaxFoldPattern2; + def : MinMaxFoldPattern2; + def : MinMaxFoldPattern2; + def : MinMaxFoldPattern2; + } + defm : min_max_folds; + defm : min_max_folds; + defm : min_max_folds; + defm : min_max_folds; + + class MLASPattern + : Pat<(ZTy (op (ZTy ZPR:$Zs3), (mul (ZTy ZPR:$Zs1), (ZTy ZPR:$Zs2)))), + (!cast(I # size) + (!cast("PTRUE_" ## size) 31), ZPR:$Zs3, ZPR:$Zs1, ZPR:$Zs2)>; + multiclass mlas_folds { + def : MLASPattern; + def : MLASPattern; + def : MLASPattern; + def : MLASPattern; + } + defm : mlas_folds; + defm : mlas_folds; + + multiclass unpred_from_pred_two_op { + def : Pat<(nxv16i8 (N (nxv16i8 ZPR:$Zs1), (nxv16i8 ZPR:$Zs2))), + (!cast(I # "_B") (PTRUE_B 31), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv8i16 (N (nxv8i16 ZPR:$Zs1), (nxv8i16 ZPR:$Zs2))), + (!cast(I # "_H") (PTRUE_H 31), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv4i32 (N (nxv4i32 ZPR:$Zs1), (nxv4i32 ZPR:$Zs2))), + (!cast(I # "_S") (PTRUE_S 31), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2i64 (N (nxv2i64 ZPR:$Zs1), (nxv2i64 ZPR:$Zs2))), + (!cast(I # "_D") (PTRUE_D 31), ZPR:$Zs1, ZPR:$Zs2)>; + } + defm : unpred_from_pred_two_op; + defm : unpred_from_pred_two_op; + defm : unpred_from_pred_two_op; + defm : unpred_from_pred_two_op; + defm : unpred_from_pred_two_op; + defm : unpred_from_pred_two_op; + defm : unpred_from_pred_two_op; + defm : unpred_from_pred_two_op; + defm : unpred_from_pred_two_op; + + multiclass unpred_from_pred_two_op_fp { + def : Pat<(nxv8f16 (N (nxv8f16 ZPR:$Zs1), (nxv8f16 ZPR:$Zs2))), + (!cast(I # "_H") (PTRUE_H 31), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv4f32 (N (nxv4f32 ZPR:$Zs1), (nxv4f32 ZPR:$Zs2))), + (!cast(I # "_S") (PTRUE_S 31), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2f64 (N (nxv2f64 ZPR:$Zs1), (nxv2f64 ZPR:$Zs2))), + (!cast(I # "_D") (PTRUE_D 31), ZPR:$Zs1, ZPR:$Zs2)>; + } + defm : unpred_from_pred_two_op_fp; + defm : unpred_from_pred_two_op_fp; + defm : unpred_from_pred_two_op_fp; + defm : unpred_from_pred_two_op_fp; + + multiclass unpred_from_pred_two_op_div { + def : Pat<(nxv4i32 (N (nxv4i32 ZPR:$Zs1), (nxv4i32 ZPR:$Zs2))), + (!cast(I # "_S") (PTRUE_S 31), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2i64 (N (nxv2i64 ZPR:$Zs1), (nxv2i64 ZPR:$Zs2))), + (!cast(I # "_D") (PTRUE_D 31), ZPR:$Zs1, ZPR:$Zs2)>; + } + defm : unpred_from_pred_two_op_div; + defm : unpred_from_pred_two_op_div; + + foreach type = ["nxv16i8", "nxv8i16", "nxv4i32", "nxv2i64"] in { + def : Pat<(and (!cast(type) ZPR:$src1), + (!cast(type) ZPR:$src2)), + (AND_ZZZ ZPR:$src1, ZPR:$src2)>; + def : Pat<(or (!cast(type) ZPR:$src1), + (!cast(type) ZPR:$src2)), + (ORR_ZZZ ZPR:$src1, ZPR:$src2)>; + def : Pat<(xor (!cast(type) ZPR:$src1), + (!cast(type) ZPR:$src2)), + (EOR_ZZZ ZPR:$src1, ZPR:$src2)>; + } + + // LDR1 of 8-bit data + def : LD1RPat; + def : LD1RPat; + def : LD1RPat; + def : LD1RPat; + def : LD1RPat; + def : LD1RPat; + def : LD1RPat; + + // LDR1 of 16-bit data + def : LD1RPat; + def : LD1RPat; + def : LD1RPat; + def : LD1RPat; + def : LD1RPat; + + // LDR1 of 32-bit data + def : LD1RPat; + def : LD1RPat; + def : LD1RPat; + + // LDR1 of 64-bit data + def : LD1RPat; + + // LD1R of FP data + def : Pat<(nxv4f32 (AArch64dup (f32 (load (am_indexed32_6b GPR64:$Rn, uimm6s4:$offset))))), + (LD1RW_IMM (PTRUE_S 31), GPR64:$Rn, $offset)>; + def : Pat<(nxv2f32 (AArch64dup (f32 (load (am_indexed32_6b GPR64:$Rn, uimm6s4:$offset))))), + (LD1RW_D_IMM (PTRUE_D 31), GPR64:$Rn, $offset)>; + def : Pat<(nxv2f64 (AArch64dup (f64 (load (am_indexed64_6b GPR64:$Rn, uimm6s8:$offset))))), + (LD1RD_IMM (PTRUE_D 31), GPR64:$Rn, $offset)>; + + // LD1R of 128-bit masked data + def : Pat<(nxv16i8 (AArch64ld1rq PPR:$gp, GPR64:$base)), + (LD1RQ_B_IMM $gp, $base, (i64 0))>; + def : Pat<(nxv8i16 (AArch64ld1rq PPR:$gp, GPR64:$base)), + (LD1RQ_H_IMM $gp, $base, (i64 0))>; + def : Pat<(nxv4i32 (AArch64ld1rq PPR:$gp, GPR64:$base)), + (LD1RQ_W_IMM $gp, $base, (i64 0))>; + def : Pat<(nxv2i64 (AArch64ld1rq PPR:$gp, GPR64:$base)), + (LD1RQ_D_IMM $gp, $base, (i64 0))>; + + def : Pat<(nxv16i8 (AArch64ld1rq PPR:$gp, (add GPR64:$base, (i64 simm4s16:$imm)))), + (LD1RQ_B_IMM $gp, $base, simm4s16:$imm)>; + def : Pat<(nxv8i16 (AArch64ld1rq PPR:$gp, (add GPR64:$base, (i64 simm4s16:$imm)))), + (LD1RQ_H_IMM $gp, $base, simm4s16:$imm)>; + def : Pat<(nxv4i32 (AArch64ld1rq PPR:$gp, (add GPR64:$base, (i64 simm4s16:$imm)))), + (LD1RQ_W_IMM $gp, $base, simm4s16:$imm)>; + def : Pat<(nxv2i64 (AArch64ld1rq PPR:$gp, (add GPR64:$base, (i64 simm4s16:$imm)))), + (LD1RQ_D_IMM $gp, $base, simm4s16:$imm)>; + + defm : SVETruncStore; + defm : SVETruncStore; + defm : SVETruncStore; + defm : SVETruncStore; + defm : SVETruncStore; + defm : SVETruncStore; + + def : Pat<(nxv2i64 (AArch64dup (i64 logical_imm64:$imms1))), + (DUPM_ZI logical_imm64:$imms1)>; + + def : Pat<(i64 (add GPR64:$Xd, (int_ctvpop (nxv16i1 PPR:$Ps)))), + (INCP_XP_B $Ps, $Xd)>; + + // General case that we ideally never want to match. + def : Pat<(vscale GPR64:$scale), (MADDXrrr (UBFMXri (RDVLI_XI 1), 4, 63), $scale, XZR)>; + + let AddedComplexity = 5 in { + def : Pat<(vscale (sve_rdvl_imm i32:$imm)), (RDVLI_XI $imm)>; + def : Pat<(vscale (sve_cnth_imm i32:$imm)), (CNTH_XPiI 31, $imm)>; + def : Pat<(vscale (sve_cntw_imm i32:$imm)), (CNTW_XPiI 31, $imm)>; + def : Pat<(vscale (sve_cntd_imm i32:$imm)), (CNTD_XPiI 31, $imm)>; + + def : Pat<(vscale (sve_cnth_imm_neg i32:$imm)), (SUBXrs XZR, (CNTH_XPiI 31, $imm), 0)>; + def : Pat<(vscale (sve_cntw_imm_neg i32:$imm)), (SUBXrs XZR, (CNTW_XPiI 31, $imm), 0)>; + def : Pat<(vscale (sve_cntd_imm_neg i32:$imm)), (SUBXrs XZR, (CNTD_XPiI 31, $imm), 0)>; + + def : Pat<(add GPR64:$op, (vscale (sve_rdvl_imm i32:$imm))), + (ADDVL_XXI GPR64:$op, $imm)>; + def : Pat<(add GPR64:$op, (vscale (sve_cnth_imm i32:$imm))), + (INCH_XPiI GPR64:$op, 31, $imm)>; + def : Pat<(add GPR64:$op, (vscale (sve_cntw_imm i32:$imm))), + (INCW_XPiI GPR64:$op, 31, $imm)>; + def : Pat<(add GPR64:$op, (vscale (sve_cntd_imm i32:$imm))), + (INCD_XPiI GPR64:$op, 31, $imm)>; + + def : Pat<(add GPR64:$op, (vscale (sve_cnth_imm_neg i32:$imm))), + (DECH_XPiI GPR64:$op, 31, $imm)>; + def : Pat<(add GPR64:$op, (vscale (sve_cntw_imm_neg i32:$imm))), + (DECW_XPiI GPR64:$op, 31, $imm)>; + def : Pat<(add GPR64:$op, (vscale (sve_cntd_imm_neg i32:$imm))), + (DECD_XPiI GPR64:$op, 31, $imm)>; + + def : Pat<(add GPR32:$op, (i32 (trunc (vscale (sve_rdvl_imm i32:$imm))))), + (i32 (EXTRACT_SUBREG (ADDVL_XXI (INSERT_SUBREG (i64 (IMPLICIT_DEF)), + GPR32:$op, sub_32), $imm), + sub_32))>; + def : Pat<(add GPR32:$op, (i32 (trunc (vscale (sve_cnth_imm i32:$imm))))), + (i32 (EXTRACT_SUBREG (INCH_XPiI (INSERT_SUBREG (i64 (IMPLICIT_DEF)), + GPR32:$op, sub_32), 31, $imm), + sub_32))>; + def : Pat<(add GPR32:$op, (i32 (trunc (vscale (sve_cntw_imm i32:$imm))))), + (i32 (EXTRACT_SUBREG (INCW_XPiI (INSERT_SUBREG (i64 (IMPLICIT_DEF)), + GPR32:$op, sub_32), 31, $imm), + sub_32))>; + def : Pat<(add GPR32:$op, (i32 (trunc (vscale (sve_cntd_imm i32:$imm))))), + (i32 (EXTRACT_SUBREG (INCD_XPiI (INSERT_SUBREG (i64 (IMPLICIT_DEF)), + GPR32:$op, sub_32), 31, $imm), + sub_32))>; + + def : Pat<(nxv8i16 (add ZPR:$op, (nxv8i16 (AArch64dup (i32 (trunc (vscale (sve_cnth_imm i32:$imm)))))))), + (INCH_ZPiI ZPR:$op, 31, $imm)>; + def : Pat<(nxv4i32 (add ZPR:$op, (nxv4i32 (AArch64dup (i32 (trunc (vscale (sve_cntw_imm i32:$imm)))))))), + (INCW_ZPiI ZPR:$op, 31, $imm)>; + def : Pat<(nxv2i64 (add ZPR:$op, (nxv2i64 (AArch64dup (i64 (vscale (sve_cntd_imm i32:$imm))))))), + (INCD_ZPiI ZPR:$op, 31, $imm)>; + + def : Pat<(nxv8i16 (sub ZPR:$op, (nxv8i16 (AArch64dup (i32 (trunc (vscale (sve_cnth_imm i32:$imm)))))))), + (DECH_ZPiI ZPR:$op, 31, $imm)>; + def : Pat<(nxv4i32 (sub ZPR:$op, (nxv4i32 (AArch64dup (i32 (trunc (vscale (sve_cntw_imm i32:$imm)))))))), + (DECW_ZPiI ZPR:$op, 31, $imm)>; + def : Pat<(nxv2i64 (sub ZPR:$op, (nxv2i64 (AArch64dup (i64 (vscale (sve_cntd_imm i32:$imm))))))), + (DECD_ZPiI ZPR:$op, 31, $imm)>; + } + + def : Pat<(nxv8i1 (AArch64rdffr_pred (nxv8i1 PPR:$Pg))), (RDFFR_PPz PPR:$Pg)>; + def : Pat<(nxv4i1 (AArch64rdffr_pred (nxv4i1 PPR:$Pg))), (RDFFR_PPz PPR:$Pg)>; + def : Pat<(nxv2i1 (AArch64rdffr_pred (nxv2i1 PPR:$Pg))), (RDFFR_PPz PPR:$Pg)>; + + multiclass unpred_from_pred_one_op { + def : Pat<(nxv16i8 (N (nxv16i8 ZPR:$Zs))), + (!cast(I # "_B") (IMPLICIT_DEF), (PTRUE_B 31), ZPR:$Zs)>; + def : Pat<(nxv8i16 (N (nxv8i16 ZPR:$Zs))), + (!cast(I # "_H") (IMPLICIT_DEF), (PTRUE_H 31), ZPR:$Zs)>; + def : Pat<(nxv4i32 (N (nxv4i32 ZPR:$Zs))), + (!cast(I # "_S") (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(nxv2i64 (N (nxv2i64 ZPR:$Zs))), + (!cast(I # "_D") (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + } + + defm : unpred_from_pred_one_op; + defm : unpred_from_pred_one_op; + + def : Pat<(nxv16i8 (cttz (nxv16i8 ZPR:$Zs))), + (CLZ_ZPmZ_B (IMPLICIT_DEF), (PTRUE_B 31), (RBIT_ZPmZ_B (IMPLICIT_DEF), (PTRUE_B 31), ZPR:$Zs))>; + def : Pat<(nxv8i16 (cttz (nxv8i16 ZPR:$Zs))), + (CLZ_ZPmZ_H (IMPLICIT_DEF), (PTRUE_H 31), (RBIT_ZPmZ_H (IMPLICIT_DEF), (PTRUE_H 31), ZPR:$Zs))>; + def : Pat<(nxv4i32 (cttz (nxv4i32 ZPR:$Zs))), + (CLZ_ZPmZ_S (IMPLICIT_DEF), (PTRUE_S 31), (RBIT_ZPmZ_S (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs))>; + def : Pat<(nxv2i64 (cttz (nxv2i64 ZPR:$Zs))), + (CLZ_ZPmZ_D (IMPLICIT_DEF), (PTRUE_D 31), (RBIT_ZPmZ_D (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs))>; + + def : SVE_3_Op_Pat; + def : SVE_3_Op_Pat; + def : SVE_3_Op_Pat; + + def : Pat<(nxv4f32 (AArch64frecps (nxv4f32 ZPR:$Zs1), (nxv4f32 ZPR:$Zs2))), + (FRECPS_ZZZ_S ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2f64 (AArch64frecps (nxv2f64 ZPR:$Zs1), (nxv2f64 ZPR:$Zs2))), + (FRECPS_ZZZ_D ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv4f32 (AArch64frsqrts (nxv4f32 ZPR:$Zs1), (nxv4f32 ZPR:$Zs2))), + (FRSQRTS_ZZZ_S ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2f64 (AArch64frsqrts (nxv2f64 ZPR:$Zs1), (nxv2f64 ZPR:$Zs2))), + (FRSQRTS_ZZZ_D ZPR:$Zs1, ZPR:$Zs2)>; + + multiclass unpred_from_pred_2op_destructive_fp { + def : Pat<(nxv8f16 (N (nxv8f16 ZPR:$Zs1), (nxv8f16 ZPR:$Zs2))), + (!cast(I # "_H") (PTRUE_H 31), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2f32 (N (nxv2f32 ZPR:$Zs1), (nxv2f32 ZPR:$Zs2))), + (!cast(I # "_S") (PTRUE_D 31), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv4f32 (N (nxv4f32 ZPR:$Zs1), (nxv4f32 ZPR:$Zs2))), + (!cast(I # "_S") (PTRUE_S 31), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2f64 (N (nxv2f64 ZPR:$Zs1), (nxv2f64 ZPR:$Zs2))), + (!cast(I # "_D") (PTRUE_D 31), ZPR:$Zs1, ZPR:$Zs2)>; + } + + defm : unpred_from_pred_2op_destructive_fp; + + def : SVE_1_Op_Pat; + def : SVE_1_Op_Pat; + def : SVE_1_Op_Pat; + + def : Pat<(nxv4f32 (AArch64frecpe (nxv4f32 ZPR:$Zs))), + (FRECPE_ZZ_S ZPR:$Zs)>; + def : Pat<(nxv2f64 (AArch64frecpe (nxv2f64 ZPR:$Zs))), + (FRECPE_ZZ_D ZPR:$Zs)>; + def : Pat<(nxv4f32 (AArch64frsqrte (nxv4f32 ZPR:$Zs))), + (FRSQRTE_ZZ_S ZPR:$Zs)>; + def : Pat<(nxv2f64 (AArch64frsqrte (nxv2f64 ZPR:$Zs))), + (FRSQRTE_ZZ_D ZPR:$Zs)>; + + multiclass unpred_from_pred_one_op_fp { + def : Pat<(nxv8f16 (N (nxv8f16 ZPR:$Zs))), + (!cast(I # "_H") (IMPLICIT_DEF), (PTRUE_H 31), ZPR:$Zs)>; + def : Pat<(nxv2f32 (N (nxv2f32 ZPR:$Zs))), + (!cast(I # "_S") (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv4f32 (N (nxv4f32 ZPR:$Zs))), + (!cast(I # "_S") (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(nxv2f64 (N (nxv2f64 ZPR:$Zs))), + (!cast(I # "_D") (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + } + defm : unpred_from_pred_one_op_fp; + defm : unpred_from_pred_one_op_fp; + defm : unpred_from_pred_one_op_fp; + defm : unpred_from_pred_one_op_fp; + defm : unpred_from_pred_one_op_fp; + defm : unpred_from_pred_one_op_fp; + defm : unpred_from_pred_one_op_fp; + defm : unpred_from_pred_one_op_fp; + defm : unpred_from_pred_one_op_fp; + + // Use the standard packed operations on our unpacked types + // TODO: if/when we care about FP exceptions these must use predication + def : Pat<(nxv2f16 (fadd (nxv2f16 ZPR:$Zs1), (nxv2f16 ZPR:$Zs2))), + (FADD_ZZZ_H ZPR:$Zs1, ZPR:$Zs2)>; + + def : Pat<(nxv2f32 (fadd (nxv2f32 ZPR:$Zs1), (nxv2f32 ZPR:$Zs2))), + (FADD_ZZZ_S ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2f32 (fsub (nxv2f32 ZPR:$Zs1), (nxv2f32 ZPR:$Zs2))), + (FSUB_ZZZ_S ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2f32 (fmul (nxv2f32 ZPR:$Zs1), (nxv2f32 ZPR:$Zs2))), + (FMUL_ZZZ_S ZPR:$Zs1, ZPR:$Zs2)>; + + // Zd = Za + Zn * Zm + def : Pat<(nxv8f16 (fma ZPR:$Zn, ZPR:$Zm, ZPR:$Za)), + (FMLA_ZPZZZ_UNDEF_H (PTRUE_H 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv4f32 (fma ZPR:$Zn, ZPR:$Zm, ZPR:$Za)), + (FMLA_ZPZZZ_UNDEF_S (PTRUE_S 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f32 (fma ZPR:$Zn, ZPR:$Zm, ZPR:$Za)), + (FMLA_ZPZZZ_UNDEF_S (PTRUE_D 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f64 (fma ZPR:$Zn, ZPR:$Zm, ZPR:$Za)), + (FMLA_ZPZZZ_UNDEF_D (PTRUE_D 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + + // Zd = Za + -Zn * Zm + def : Pat<(nxv8f16 (fma (fneg ZPR:$Zn), ZPR:$Zm, ZPR:$Za)), + (FMLS_ZPZZZ_UNDEF_H (PTRUE_H 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv4f32 (fma (fneg ZPR:$Zn), ZPR:$Zm, ZPR:$Za)), + (FMLS_ZPZZZ_UNDEF_S (PTRUE_S 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f32 (fma (fneg ZPR:$Zn), ZPR:$Zm, ZPR:$Za)), + (FMLS_ZPZZZ_UNDEF_S (PTRUE_D 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f64 (fma (fneg ZPR:$Zn), ZPR:$Zm, ZPR:$Za)), + (FMLS_ZPZZZ_UNDEF_D (PTRUE_D 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + + // Zd = -Za + Zn * Zm + def : Pat<(nxv8f16 (fma ZPR:$Zn, ZPR:$Zm, (fneg ZPR:$Za))), + (FNMLS_ZPZZZ_UNDEF_H (PTRUE_H 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv4f32 (fma ZPR:$Zn, ZPR:$Zm, (fneg ZPR:$Za))), + (FNMLS_ZPZZZ_UNDEF_S (PTRUE_S 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f32 (fma ZPR:$Zn, ZPR:$Zm, (fneg ZPR:$Za))), + (FNMLS_ZPZZZ_UNDEF_S (PTRUE_D 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f64 (fma ZPR:$Zn, ZPR:$Zm, (fneg ZPR:$Za))), + (FNMLS_ZPZZZ_UNDEF_D (PTRUE_D 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + + // Zd = -Za + -Zn * Zm + def : Pat<(nxv8f16 (fma (fneg ZPR:$Zn), ZPR:$Zm, (fneg ZPR:$Za))), + (FNMLA_ZPZZZ_UNDEF_H (PTRUE_H 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv4f32 (fma (fneg ZPR:$Zn), ZPR:$Zm, (fneg ZPR:$Za))), + (FNMLA_ZPZZZ_UNDEF_S (PTRUE_S 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f32 (fma (fneg ZPR:$Zn), ZPR:$Zm, (fneg ZPR:$Za))), + (FNMLA_ZPZZZ_UNDEF_S (PTRUE_D 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(nxv2f64 (fma (fneg ZPR:$Zn), ZPR:$Zm, (fneg ZPR:$Za))), + (FNMLA_ZPZZZ_UNDEF_D (PTRUE_D 31), ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + + // Zda = Zda + Zn * Zm + def : Pat<(vselect (nxv8i1 PPR:$Pg), (nxv8f16 (fma ZPR:$Zn, ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLA_ZPmZZ_H PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(vselect (nxv4i1 PPR:$Pg), (nxv4f32 (fma ZPR:$Zn, ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLA_ZPmZZ_S PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(vselect (nxv2i1 PPR:$Pg), (nxv2f32 (fma ZPR:$Zn, ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLA_ZPmZZ_S PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(vselect (nxv2i1 PPR:$Pg), (nxv2f64 (fma ZPR:$Zn, ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLA_ZPmZZ_D PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + + // Zda = Zda + -Zn * Zm + def : Pat<(vselect (nxv8i1 PPR:$Pg), (nxv8f16 (fma (fneg ZPR:$Zn), ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLS_ZPmZZ_H PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(vselect (nxv4i1 PPR:$Pg), (nxv4f32 (fma (fneg ZPR:$Zn), ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLS_ZPmZZ_S PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(vselect (nxv2i1 PPR:$Pg), (nxv2f32 (fma (fneg ZPR:$Zn), ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLS_ZPmZZ_S PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + def : Pat<(vselect (nxv2i1 PPR:$Pg), (nxv2f64 (fma (fneg ZPR:$Zn), ZPR:$Zm, ZPR:$Za)), ZPR:$Za), + (FMLS_ZPmZZ_D PPR:$Pg, ZPR:$Za, ZPR:$Zn, ZPR:$Zm)>; + + // as above but with the resulting compare operands switched + multiclass SVEInvCmpPat { + def : Pat<(nxv16i8 (setcc (nxv16i8 ZPR:$Zs1), (nxv16i8 ZPR:$Zs2), IN_CMP)), + (CPY_ZPzI_B (!cast(!strconcat(OUT_CMP, "_B")) + (PTRUE_B 31), ZPR:$Zs2, ZPR:$Zs1), -1, 0)>; + + def : Pat<(nxv8i16 (setcc (nxv8i16 ZPR:$Zs1), (nxv8i16 ZPR:$Zs2), IN_CMP)), + (CPY_ZPzI_H (!cast(!strconcat(OUT_CMP, "_H")) + (PTRUE_H 31), ZPR:$Zs2, ZPR:$Zs1), -1, 0)>; + + def : Pat<(nxv4i32 (setcc (nxv4i32 ZPR:$Zs1), (nxv4i32 ZPR:$Zs2), IN_CMP)), + (CPY_ZPzI_S (!cast(!strconcat(OUT_CMP, "_S")) + (PTRUE_S 31), ZPR:$Zs2, ZPR:$Zs1), -1, 0)>; + + def : Pat<(nxv2i64 (setcc (nxv2i64 ZPR:$Zs1), (nxv2i64 ZPR:$Zs2), IN_CMP)), + (CPY_ZPzI_D (!cast(!strconcat(OUT_CMP, "_D")) + (PTRUE_D 31), ZPR:$Zs2, ZPR:$Zs1), -1, 0)>; + + def : Pat<(nxv16i1 (setcc (nxv16i8 ZPR:$Zs1), (nxv16i8 ZPR:$Zs2), IN_CMP)), + (!cast(!strconcat(OUT_CMP, "_B")) + (PTRUE_B 31), ZPR:$Zs2, ZPR:$Zs1)>; + + def : Pat<(nxv8i1 (setcc (nxv8i16 ZPR:$Zs1), (nxv8i16 ZPR:$Zs2), IN_CMP)), + (!cast(!strconcat(OUT_CMP, "_H")) + (PTRUE_H 31), ZPR:$Zs2, ZPR:$Zs1)>; + + def : Pat<(nxv4i1 (setcc (nxv4i32 ZPR:$Zs1), (nxv4i32 ZPR:$Zs2), IN_CMP)), + (!cast(!strconcat(OUT_CMP, "_S")) + (PTRUE_S 31), ZPR:$Zs2, ZPR:$Zs1)>; + + def : Pat<(nxv2i1 (setcc (nxv2i64 ZPR:$Zs1), (nxv2i64 ZPR:$Zs2), IN_CMP)), + (!cast(!strconcat(OUT_CMP, "_D")) + (PTRUE_D 31), ZPR:$Zs2, ZPR:$Zs1)>; + } + + // FUTURE: find out why this happens and stop it? + def : Pat<(nxv16i8 (vselect (nxv16i8 ZPR:$P), ZPR:$Zs1, ZPR:$Zs2)), + (SEL_ZPZZ_B (CMPNE_PPzZI_B (PTRUE_B 31), ZPR:$P, 0), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv8i16 (vselect (nxv8i16 ZPR:$P), ZPR:$Zs1, ZPR:$Zs2)), + (SEL_ZPZZ_H (CMPNE_PPzZI_H (PTRUE_H 31), ZPR:$P, 0), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv4i32 (vselect (nxv4i32 ZPR:$P), ZPR:$Zs1, ZPR:$Zs2)), + (SEL_ZPZZ_S (CMPNE_PPzZI_S (PTRUE_S 31), ZPR:$P, 0), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2i64 (vselect (nxv2i64 ZPR:$P), ZPR:$Zs1, ZPR:$Zs2)), + (SEL_ZPZZ_D (CMPNE_PPzZI_D (PTRUE_D 31), ZPR:$P, 0), ZPR:$Zs1, ZPR:$Zs2)>; + + // Whole vector selects. + def : Pat<(nxv16i8 (select GPR32:$cond, ZPR:$Zs1, ZPR:$Zs2)), + (SEL_ZPZZ_S (CMPNE_PPzZI_S (PTRUE_S 31), (DUP_ZR_S $cond), 0), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv8i16 (select GPR32:$cond, ZPR:$Zs1, ZPR:$Zs2)), + (SEL_ZPZZ_S (CMPNE_PPzZI_S (PTRUE_S 31), (DUP_ZR_S $cond), 0), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv4i32 (select GPR32:$cond, ZPR:$Zs1, ZPR:$Zs2)), + (SEL_ZPZZ_S (CMPNE_PPzZI_S (PTRUE_S 31), (DUP_ZR_S $cond), 0), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2i64 (select GPR32:$cond, ZPR:$Zs1, ZPR:$Zs2)), + (SEL_ZPZZ_S (CMPNE_PPzZI_S (PTRUE_S 31), (DUP_ZR_S $cond), 0), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv8f16 (select GPR32:$cond, ZPR:$Zs1, ZPR:$Zs2)), + (SEL_ZPZZ_S (CMPNE_PPzZI_S (PTRUE_S 31), (DUP_ZR_S $cond), 0), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2f32 (select GPR32:$cond, ZPR:$Zs1, ZPR:$Zs2)), + (SEL_ZPZZ_S (CMPNE_PPzZI_S (PTRUE_S 31), (DUP_ZR_S $cond), 0), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv4f32 (select GPR32:$cond, ZPR:$Zs1, ZPR:$Zs2)), + (SEL_ZPZZ_S (CMPNE_PPzZI_S (PTRUE_S 31), (DUP_ZR_S $cond), 0), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2f64 (select GPR32:$cond, ZPR:$Zs1, ZPR:$Zs2)), + (SEL_ZPZZ_S (CMPNE_PPzZI_S (PTRUE_S 31), (DUP_ZR_S $cond), 0), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv16i1 (select GPR32:$cond, PPR:$Ps1, PPR:$Ps2)), + (SEL_PPPP (CMPNE_PPzZI_B (PTRUE_B 31), (DUP_ZR_B $cond), 0), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv8i1 (select GPR32:$cond, PPR:$Ps1, PPR:$Ps2)), + (SEL_PPPP (CMPNE_PPzZI_B (PTRUE_B 31), (DUP_ZR_B $cond), 0), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv4i1 (select GPR32:$cond, PPR:$Ps1, PPR:$Ps2)), + (SEL_PPPP (CMPNE_PPzZI_B (PTRUE_B 31), (DUP_ZR_B $cond), 0), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv2i1 (select GPR32:$cond, PPR:$Ps1, PPR:$Ps2)), + (SEL_PPPP (CMPNE_PPzZI_B (PTRUE_B 31), (DUP_ZR_B $cond), 0), PPR:$Ps1, PPR:$Ps2)>; + + // Extract element from vector with immediate index + def : Pat<(i32 (vector_extract (nxv16i8 ZPR:$vec), sve_elm_idx_extdup_b:$index)), + (EXTRACT_SUBREG (DUP_ZZI_B ZPR:$vec, sve_elm_idx_extdup_b:$index), ssub)>; + def : Pat<(i32 (vector_extract (nxv8i16 ZPR:$vec), sve_elm_idx_extdup_h:$index)), + (EXTRACT_SUBREG (DUP_ZZI_H ZPR:$vec, sve_elm_idx_extdup_h:$index), ssub)>; + def : Pat<(i32 (vector_extract (nxv4i32 ZPR:$vec), sve_elm_idx_extdup_s:$index)), + (EXTRACT_SUBREG (DUP_ZZI_S ZPR:$vec, sve_elm_idx_extdup_s:$index), ssub)>; + def : Pat<(i64 (vector_extract (nxv2i64 ZPR:$vec), sve_elm_idx_extdup_d:$index)), + (EXTRACT_SUBREG (DUP_ZZI_D ZPR:$vec, sve_elm_idx_extdup_d:$index), dsub)>; + def : Pat<(f16 (vector_extract (nxv8f16 ZPR:$vec), sve_elm_idx_extdup_h:$index)), + (EXTRACT_SUBREG (DUP_ZZI_H ZPR:$vec, sve_elm_idx_extdup_h:$index), hsub)>; + def : Pat<(f32 (vector_extract (nxv4f32 ZPR:$vec), sve_elm_idx_extdup_s:$index)), + (EXTRACT_SUBREG (DUP_ZZI_S ZPR:$vec, sve_elm_idx_extdup_s:$index), ssub)>; + def : Pat<(f64 (vector_extract (nxv2f64 ZPR:$vec), sve_elm_idx_extdup_d:$index)), + (EXTRACT_SUBREG (DUP_ZZI_D ZPR:$vec, sve_elm_idx_extdup_d:$index), dsub)>; + def : Pat<(f16 (vector_extract (nxv4f16 ZPR:$vec), sve_elm_idx_extdup_s:$index)), + (EXTRACT_SUBREG (DUP_ZZI_S ZPR:$vec, sve_elm_idx_extdup_s:$index), hsub)>; + def : Pat<(f16 (vector_extract (nxv2f16 ZPR:$vec), sve_elm_idx_extdup_d:$index)), + (EXTRACT_SUBREG (DUP_ZZI_D ZPR:$vec, sve_elm_idx_extdup_d:$index), hsub)>; + + // Extract element from vector with scalar index + def : Pat<(i32 (vector_extract (nxv16i8 ZPR:$vec), GPR64:$index)), + (LASTB_RPZ_B (CMPEQ_PPzZZ_B (PTRUE_B 31), + (INDEX_II_B 0, 1), + (DUP_ZR_B (i32 (EXTRACT_SUBREG GPR64:$index, sub_32)))), + ZPR:$vec)>; + def : Pat<(i32 (vector_extract (nxv8i16 ZPR:$vec), GPR64:$index)), + (LASTB_RPZ_H (CMPEQ_PPzZZ_H (PTRUE_H 31), + (INDEX_II_H 0, 1), + (DUP_ZR_H (i32 (EXTRACT_SUBREG GPR64:$index, sub_32)))), + ZPR:$vec)>; + def : Pat<(i32 (vector_extract (nxv4i32 ZPR:$vec), GPR64:$index)), + (LASTB_RPZ_S (CMPEQ_PPzZZ_S (PTRUE_S 31), + (INDEX_II_S 0, 1), + (DUP_ZR_S (i32 (EXTRACT_SUBREG GPR64:$index, sub_32)))), + ZPR:$vec)>; + def : Pat<(i64 (vector_extract (nxv2i64 ZPR:$vec), GPR64:$index)), + (LASTB_RPZ_D (CMPEQ_PPzZZ_D (PTRUE_D 31), + (INDEX_II_D 0, 1), + (DUP_ZR_D GPR64:$index)), + ZPR:$vec)>; + + def : Pat<(f16 (vector_extract (nxv8f16 ZPR:$vec), GPR64:$index)), + (LASTB_VPZ_H (CMPEQ_PPzZZ_H (PTRUE_H 31), + (INDEX_II_H 0, 1), + (DUP_ZR_H (i32 (EXTRACT_SUBREG GPR64:$index, sub_32)))), + ZPR:$vec)>; + def : Pat<(f32 (vector_extract (nxv4f32 ZPR:$vec), GPR64:$index)), + (LASTB_VPZ_S (CMPEQ_PPzZZ_S (PTRUE_S 31), + (INDEX_II_S 0, 1), + (DUP_ZR_S (i32 (EXTRACT_SUBREG GPR64:$index, sub_32)))), + ZPR:$vec)>; + def : Pat<(f64 (vector_extract (nxv2f64 ZPR:$vec), GPR64:$index)), + (LASTB_VPZ_D (CMPEQ_PPzZZ_D (PTRUE_D 31), + (INDEX_II_D 0, 1), + (DUP_ZR_D $index)), + ZPR:$vec)>; + def : Pat<(f32 (vector_extract (nxv2f32 ZPR:$vec), GPR64:$index)), + (LASTB_VPZ_S (CMPEQ_PPzZZ_D (PTRUE_D 31), + (INDEX_II_D 0, 1), + (DUP_ZR_D $index)), + ZPR:$vec)>; + + // Shift by immediate patterns. Allowed immediate range is different for right + // vs. left shifts, so the patterns have to be different. + def : Pat<(nxv16i8 (shl (nxv16i8 ZPR:$Zs1), + (nxv16i8 (AArch64dup (vecshiftL8:$imm))))), + (LSL_ZZI_B ZPR:$Zs1, vecshiftL8:$imm)>; + def : Pat<(nxv8i16 (shl (nxv8i16 ZPR:$Zs1), + (nxv8i16 (AArch64dup (vecshiftL16:$imm))))), + (LSL_ZZI_H ZPR:$Zs1, vecshiftL16:$imm)>; + def : Pat<(nxv4i32 (shl (nxv4i32 ZPR:$Zs1), + (nxv4i32 (AArch64dup (vecshiftL32:$imm))))), + (LSL_ZZI_S ZPR:$Zs1, vecshiftL32:$imm)>; + def : Pat<(nxv2i64 (shl (nxv2i64 ZPR:$Zs1), + (nxv2i64 (AArch64dup (i64 (SVELShiftImm64 i32:$imm)))))), + (LSL_ZZI_D ZPR:$Zs1, vecshiftL64:$imm)>; + + def : Pat<(nxv16i8 (int_aarch64_sve_lsl (nxv16i1 PPR_3b:$Pg), + (nxv16i8 ZPR:$Zs1), + (nxv16i8 (AArch64dup (vecshiftL8:$imm))))), + (LSL_ZPmI_B PPR_3b:$Pg, ZPR:$Zs1, vecshiftL8:$imm)>; + def : Pat<(nxv8i16 (int_aarch64_sve_lsl (nxv8i1 PPR_3b:$Pg), + (nxv8i16 ZPR:$Zs1), + (nxv8i16 (AArch64dup (vecshiftL16:$imm))))), + (LSL_ZPmI_H PPR_3b:$Pg, ZPR:$Zs1, vecshiftL16:$imm)>; + def : Pat<(nxv4i32 (int_aarch64_sve_lsl (nxv4i1 PPR_3b:$Pg), + (nxv4i32 ZPR:$Zs1), + (nxv4i32 (AArch64dup (vecshiftL32:$imm))))), + (LSL_ZPmI_S PPR_3b:$Pg, ZPR:$Zs1, vecshiftL32:$imm)>; + def : Pat<(nxv2i64 (int_aarch64_sve_lsl (nxv2i1 PPR_3b:$Pg), + (nxv2i64 ZPR:$Zs1), + (nxv2i64 (AArch64dup (i64 (SVELShiftImm64 i32:$imm)))))), + (LSL_ZPmI_D PPR_3b:$Pg, ZPR:$Zs1, vecshiftL64:$imm)>; + + // Wide shifts + def : Pat<(nxv16i8 (int_aarch64_sve_lsl_wide (nxv16i1 PPR_3b:$Pg), + (nxv16i8 ZPR:$Zs1), + (nxv2i64 (AArch64dup (i64 (SVEWideLShiftImm8 i32:$imm)))))), + (LSL_ZPmI_B PPR_3b:$Pg, ZPR:$Zs1, vecshiftL8:$imm)>; + def : Pat<(nxv8i16 (int_aarch64_sve_lsl_wide (nxv8i1 PPR_3b:$Pg), + (nxv8i16 ZPR:$Zs1), + (nxv2i64 (AArch64dup (i64 (SVEWideLShiftImm16 i32:$imm)))))), + (LSL_ZPmI_H PPR_3b:$Pg, ZPR:$Zs1, vecshiftL16:$imm)>; + def : Pat<(nxv4i32 (int_aarch64_sve_lsl_wide (nxv4i1 PPR_3b:$Pg), + (nxv4i32 ZPR:$Zs1), + (nxv2i64 (AArch64dup (i64 (SVEWideLShiftImm32 i32:$imm)))))), + (LSL_ZPmI_S PPR_3b:$Pg, ZPR:$Zs1, vecshiftL32:$imm)>; + + + multiclass sve_unpred_rshift_immediates { + def : Pat<(nxv16i8 (ir_op (nxv16i8 ZPR:$Zs1), + (nxv16i8 (AArch64dup (vecshiftR8:$imm))))), + (!cast(ir_inst # "_B") ZPR:$Zs1, vecshiftR8:$imm)>; + def : Pat<(nxv8i16 (ir_op (nxv8i16 ZPR:$Zs1), + (nxv8i16 (AArch64dup (vecshiftR16:$imm))))), + (!cast(ir_inst # "_H") ZPR:$Zs1, vecshiftR16:$imm)>; + def : Pat<(nxv4i32 (ir_op (nxv4i32 ZPR:$Zs1), + (nxv4i32 (AArch64dup (vecshiftR32:$imm))))), + (!cast(ir_inst # "_S") ZPR:$Zs1, vecshiftR32:$imm)>; + def : Pat<(nxv2i64 (ir_op (nxv2i64 ZPR:$Zs1), + (nxv2i64 (AArch64dup (i64 (SVERShiftImm64 i32:$imm)))))), + (!cast(ir_inst # "_D") ZPR:$Zs1, vecshiftR64:$imm)>; + + def : Pat<(nxv16i8 (int_op (nxv16i1 PPR_3b:$Pg), + (nxv16i8 ZPR:$Zs1), + (nxv16i8 (AArch64dup (vecshiftR8:$imm))))), + (!cast(int_inst # "_B") PPR_3b:$Pg, ZPR:$Zs1, vecshiftR8:$imm)>; + def : Pat<(nxv8i16 (int_op (nxv8i1 PPR_3b:$Pg), + (nxv8i16 ZPR:$Zs1), + (nxv8i16 (AArch64dup (vecshiftR16:$imm))))), + (!cast(int_inst # "_H") PPR_3b:$Pg, ZPR:$Zs1, vecshiftR16:$imm)>; + def : Pat<(nxv4i32 (int_op (nxv4i1 PPR_3b:$Pg), + (nxv4i32 ZPR:$Zs1), + (nxv4i32 (AArch64dup (vecshiftR32:$imm))))), + (!cast(int_inst # "_S") PPR_3b:$Pg, ZPR:$Zs1, vecshiftR32:$imm)>; + def : Pat<(nxv2i64 (int_op (nxv2i1 PPR_3b:$Pg), + (nxv2i64 ZPR:$Zs1), + (nxv2i64 (AArch64dup (i64 (SVERShiftImm64 i32:$imm)))))), + (!cast(int_inst # "_D") PPR_3b:$Pg, ZPR:$Zs1, vecshiftR64:$imm)>; + + // Wide shifts + def : Pat<(nxv16i8 (wide_op (nxv16i1 PPR_3b:$Pg), + (nxv16i8 ZPR:$Zs1), + (nxv2i64 (AArch64dup (i64 (SVEWideRShiftImm8 i32:$imm)))))), + (!cast(int_inst # "_B") PPR_3b:$Pg, ZPR:$Zs1, vecshiftR8:$imm)>; + def : Pat<(nxv8i16 (wide_op (nxv8i1 PPR_3b:$Pg), + (nxv8i16 ZPR:$Zs1), + (nxv2i64 (AArch64dup (i64 (SVEWideRShiftImm16 i32:$imm)))))), + (!cast(int_inst # "_H") PPR_3b:$Pg, ZPR:$Zs1, vecshiftR16:$imm)>; + def : Pat<(nxv4i32 (wide_op (nxv4i1 PPR_3b:$Pg), + (nxv4i32 ZPR:$Zs1), + (nxv2i64 (AArch64dup (i64 (SVEWideRShiftImm32 i32:$imm)))))), + (!cast(int_inst # "_S") PPR_3b:$Pg, ZPR:$Zs1, vecshiftR32:$imm)>; + } + + defm : sve_unpred_rshift_immediates<"ASR_ZZI", "ASR_ZPmI", sra, int_aarch64_sve_asr, int_aarch64_sve_asr_wide>; + defm : sve_unpred_rshift_immediates<"LSR_ZZI", "LSR_ZPmI", srl, int_aarch64_sve_lsr, int_aarch64_sve_lsr_wide>; + + def : Pat<(nxv16i1 (trunc (nxv16i8 ZPR:$Zs))), + (CMPNE_PPzZI_B (PTRUE_B 31), (LSL_ZZI_B ZPR:$Zs, 7), 0)>; + def : Pat<(nxv8i1 (trunc (nxv8i16 ZPR:$Zs))), + (CMPNE_PPzZI_H (PTRUE_H 31), (LSL_ZZI_H ZPR:$Zs, 15), 0)>; + def : Pat<(nxv4i1 (trunc (nxv4i32 ZPR:$Zs))), + (CMPNE_PPzZI_S (PTRUE_S 31), (LSL_ZZI_S ZPR:$Zs, 31), 0)>; + def : Pat<(nxv2i1 (trunc (nxv2i64 ZPR:$Zs))), + (CMPNE_PPzZI_D (PTRUE_D 31), (LSL_ZZI_D ZPR:$Zs, 63), 0)>; + + def : Pat<(nxv16i1 (and (trunc (nxv16i8 ZPR:$Zs1)), PPR:$Ps2)), + (CMPNE_PPzZI_B PPR:$Ps2, (LSL_ZZI_B ZPR:$Zs1, 7), 0)>; + def : Pat<(nxv8i1 (and (trunc (nxv8i16 ZPR:$Zs1)), PPR:$Ps2)), + (CMPNE_PPzZI_H PPR:$Ps2, (LSL_ZZI_H ZPR:$Zs1, 15), 0)>; + def : Pat<(nxv4i1 (and (trunc (nxv4i32 ZPR:$Zs1)), PPR:$Ps2)), + (CMPNE_PPzZI_S PPR:$Ps2, (LSL_ZZI_S ZPR:$Zs1, 31), 0)>; + def : Pat<(nxv2i1 (and (trunc (nxv2i64 ZPR:$Zs1)), PPR:$Ps2)), + (CMPNE_PPzZI_D PPR:$Ps2, (LSL_ZZI_D ZPR:$Zs1, 63), 0)>; + + defm : SVEInvCmpPat; + defm : SVEInvCmpPat; + defm : SVEInvCmpPat; + defm : SVEInvCmpPat; + + // per-element any extend + def : Pat<(nxv16i8 (anyext (nxv16i1 PPR:$Ps1))), + (CPY_ZPzI_B PPR:$Ps1, 0x1, 0)>; + def : Pat<(nxv8i16 (anyext (nxv8i1 PPR:$Ps1))), + (CPY_ZPzI_H PPR:$Ps1, 0x1, 0)>; + def : Pat<(nxv4i32 (anyext (nxv4i1 PPR:$Ps1))), + (CPY_ZPzI_S PPR:$Ps1, 0x1, 0)>; + def : Pat<(nxv2i64 (anyext (nxv2i1 PPR:$Ps1))), + (CPY_ZPzI_D PPR:$Ps1, 0x1, 0)>; + + // per-element sign extend + def : Pat<(nxv16i8 (sext (nxv16i1 PPR:$Ps1))), + (CPY_ZPzI_B PPR:$Ps1, -1, 0)>; + def : Pat<(nxv8i16 (sext (nxv8i1 PPR:$Ps1))), + (CPY_ZPzI_H PPR:$Ps1, -1, 0)>; + def : Pat<(nxv4i32 (sext (nxv4i1 PPR:$Ps1))), + (CPY_ZPzI_S PPR:$Ps1, -1, 0)>; + def : Pat<(nxv2i64 (sext (nxv2i1 PPR:$Ps1))), + (CPY_ZPzI_D PPR:$Ps1, -1, 0)>; + + // per-element zero extend + def : Pat<(nxv16i8 (zext (nxv16i1 PPR:$Ps1))), + (CPY_ZPzI_B PPR:$Ps1, 0x1, 0)>; + def : Pat<(nxv8i16 (zext (nxv8i1 PPR:$Ps1))), + (CPY_ZPzI_H PPR:$Ps1, 0x1, 0)>; + def : Pat<(nxv4i32 (zext (nxv4i1 PPR:$Ps1))), + (CPY_ZPzI_S PPR:$Ps1, 0x1, 0)>; + def : Pat<(nxv2i64 (zext (nxv2i1 PPR:$Ps1))), + (CPY_ZPzI_D PPR:$Ps1, 0x1, 0)>; + + // brka + def : Pat<(nxv16i1 (AArch64brka (nxv16i1 PPR:$Pg), (nxv16i1 PPR:$Src1))), + (BRKA_PPzP $Pg, $Src1)>; + def : Pat<(nxv8i1 (AArch64brka (nxv8i1 PPR:$Pg), (nxv8i1 PPR:$Src1))), + (BRKA_PPzP $Pg, $Src1)>; + def : Pat<(nxv4i1 (AArch64brka (nxv4i1 PPR:$Pg), (nxv4i1 PPR:$Src1))), + (BRKA_PPzP $Pg, $Src1)>; + def : Pat<(nxv2i1 (AArch64brka (nxv2i1 PPR:$Pg), (nxv2i1 PPR:$Src1))), + (BRKA_PPzP $Pg, $Src1)>; + + def : Pat<(nxv16i8 (scalar_to_vector (i32 FPR32:$src))), + (INSERT_SUBREG (nxv16i8 (IMPLICIT_DEF)), FPR32:$src, ssub)>; + + def : Pat<(nxv8i16 (scalar_to_vector (i32 FPR32:$src))), + (INSERT_SUBREG (nxv8i16 (IMPLICIT_DEF)), FPR32:$src, ssub)>; + + def : Pat<(nxv4i32 (scalar_to_vector (i32 FPR32:$src))), + (INSERT_SUBREG (nxv4i32 (IMPLICIT_DEF)), FPR32:$src, ssub)>; + + def : Pat<(nxv2i64 (scalar_to_vector (i64 FPR64:$src))), + (INSERT_SUBREG (nxv2i64 (IMPLICIT_DEF)), FPR64:$src, dsub)>; + + def : Pat<(nxv8f16 (scalar_to_vector (f16 FPR16:$src))), + (INSERT_SUBREG (nxv8f16 (IMPLICIT_DEF)), FPR16:$src, hsub)>; + + def : Pat<(nxv4f32 (scalar_to_vector (f32 FPR32:$src))), + (INSERT_SUBREG (nxv4f32 (IMPLICIT_DEF)), FPR32:$src, ssub)>; + + def : Pat<(nxv2f64 (scalar_to_vector (f64 FPR64:$src))), + (INSERT_SUBREG (nxv2f64 (IMPLICIT_DEF)), FPR64:$src, dsub)>; + + def : Pat<(nxv2f32 (scalar_to_vector (f32 FPR32:$src))), + (INSERT_SUBREG (nxv2f32 (IMPLICIT_DEF)), FPR32:$src, ssub)>; + + def : Pat<(nxv16i1 (scalar_to_vector GPR32:$src)), + (CMPNE_PPzZI_B (PTRUE_B 31), (LSL_ZZI_B (DUP_ZR_B $src), 7), 0)>; + + def : Pat<(nxv8i1 (scalar_to_vector GPR32:$src)), + (CMPNE_PPzZI_H (PTRUE_H 31), (LSL_ZZI_H (DUP_ZR_H $src), 15), 0)>; + + def : Pat<(nxv4i1 (scalar_to_vector GPR32:$src)), + (CMPNE_PPzZI_S (PTRUE_S 31), (LSL_ZZI_S (DUP_ZR_S $src), 31), 0)>; + + def : Pat<(nxv2i1 (scalar_to_vector GPR32:$src)), + (CMPNE_PPzZI_D (PTRUE_D 31), (LSL_ZZI_S (DUP_ZR_S $src), 31), 0)>; + + // Insert scalar into vector[0] + def : Pat<(nxv16i8 (vector_insert (nxv16i8 (undef)), (i32 FPR32:$src), 0)), + (INSERT_SUBREG (nxv16i8 (IMPLICIT_DEF)), FPR32:$src, ssub)>; + def : Pat<(nxv8i16 (vector_insert (nxv8i16 (undef)), (i32 FPR32:$src), 0)), + (INSERT_SUBREG (nxv8i16 (IMPLICIT_DEF)), FPR32:$src, ssub)>; + def : Pat<(nxv4i32 (vector_insert (nxv4i32 (undef)), (i32 FPR32:$src), 0)), + (INSERT_SUBREG (nxv4i32 (IMPLICIT_DEF)), FPR32:$src, ssub)>; + def : Pat<(nxv2i64 (vector_insert (nxv2i64 (undef)), (i64 FPR64:$src), 0)), + (INSERT_SUBREG (nxv2i64 (IMPLICIT_DEF)), FPR64:$src, dsub)>; + + def : Pat<(nxv16i8 (vector_insert (nxv16i8 ZPR:$vec), (i32 GPR32:$src), 0)), + (CPY_ZPmR_B ZPR:$vec, (PTRUE_B 1), GPR32:$src)>; + def : Pat<(nxv8i16 (vector_insert (nxv8i16 ZPR:$vec), (i32 GPR32:$src), 0)), + (CPY_ZPmR_H ZPR:$vec, (PTRUE_H 1), GPR32:$src)>; + def : Pat<(nxv4i32 (vector_insert (nxv4i32 ZPR:$vec), (i32 GPR32:$src), 0)), + (CPY_ZPmR_S ZPR:$vec, (PTRUE_S 1), GPR32:$src)>; + def : Pat<(nxv2i64 (vector_insert (nxv2i64 ZPR:$vec), (i64 GPR64:$src), 0)), + (CPY_ZPmR_D ZPR:$vec, (PTRUE_D 1), GPR64:$src)>; + + def : Pat<(nxv8f16 (vector_insert (nxv8f16 ZPR:$vec), (f16 FPR16:$src), 0)), + (SEL_ZPZZ_H (PTRUE_H 1), (INSERT_SUBREG (IMPLICIT_DEF), FPR16:$src, hsub), ZPR:$vec)>; + def : Pat<(nxv4f32 (vector_insert (nxv4f32 ZPR:$vec), (f32 FPR32:$src), 0)), + (SEL_ZPZZ_S (PTRUE_S 1), (INSERT_SUBREG (IMPLICIT_DEF), FPR32:$src, ssub), ZPR:$vec)>; + def : Pat<(nxv2f64 (vector_insert (nxv2f64 ZPR:$vec), (f64 FPR64:$src), 0)), + (SEL_ZPZZ_D (PTRUE_D 1), (INSERT_SUBREG (IMPLICIT_DEF), FPR64:$src, dsub), ZPR:$vec)>; + + // Insert scalar into vector with scalar index + def : Pat<(nxv16i8 (vector_insert (nxv16i8 ZPR:$vec), GPR32:$src, GPR64:$index)), + (CPY_ZPmR_B ZPR:$vec, + (CMPEQ_PPzZZ_B (PTRUE_B 31), + (INDEX_II_B 0, 1), + (DUP_ZR_B (i32 (EXTRACT_SUBREG GPR64:$index, sub_32)))), + GPR32:$src)>; + def : Pat<(nxv8i16 (vector_insert (nxv8i16 ZPR:$vec), GPR32:$src, GPR64:$index)), + (CPY_ZPmR_H ZPR:$vec, + (CMPEQ_PPzZZ_H (PTRUE_H 31), + (INDEX_II_H 0, 1), + (DUP_ZR_H (i32 (EXTRACT_SUBREG GPR64:$index, sub_32)))), + GPR32:$src)>; + def : Pat<(nxv4i32 (vector_insert (nxv4i32 ZPR:$vec), GPR32:$src, GPR64:$index)), + (CPY_ZPmR_S ZPR:$vec, + (CMPEQ_PPzZZ_S (PTRUE_S 31), + (INDEX_II_S 0, 1), + (DUP_ZR_S (i32 (EXTRACT_SUBREG GPR64:$index, sub_32)))), + GPR32:$src)>; + def : Pat<(nxv2i64 (vector_insert (nxv2i64 ZPR:$vec), GPR64:$src, GPR64:$index)), + (CPY_ZPmR_D ZPR:$vec, + (CMPEQ_PPzZZ_D (PTRUE_D 31), + (INDEX_II_D 0, 1), + (DUP_ZR_D GPR64:$index)), + GPR64:$src)>; + + // Insert FP scalar into vector with scalar index + def : Pat<(nxv8f16 (vector_insert (nxv8f16 ZPR:$vec), (f16 FPR16:$src), GPR64:$index)), + (CPY_ZPmV_H ZPR:$vec, + (CMPEQ_PPzZZ_H (PTRUE_H 31), + (INDEX_II_H 0, 1), + (DUP_ZR_H (i32 (EXTRACT_SUBREG GPR64:$index, sub_32)))), + $src)>; + def : Pat<(nxv4f32 (vector_insert (nxv4f32 ZPR:$vec), (f32 FPR32:$src), GPR64:$index)), + (CPY_ZPmV_S ZPR:$vec, + (CMPEQ_PPzZZ_S (PTRUE_S 31), + (INDEX_II_S 0, 1), + (DUP_ZR_S (i32 (EXTRACT_SUBREG GPR64:$index, sub_32)))), + $src)>; + def : Pat<(nxv2f64 (vector_insert (nxv2f64 ZPR:$vec), (f64 FPR64:$src), GPR64:$index)), + (CPY_ZPmV_D ZPR:$vec, + (CMPEQ_PPzZZ_D (PTRUE_D 31), + (INDEX_II_D 0, 1), + (DUP_ZR_D $index)), + $src)>; + def : Pat<(nxv2f32 (vector_insert (nxv2f32 ZPR:$vec), (f32 FPR32:$src), GPR64:$index)), + (CPY_ZPmV_S ZPR:$vec, + (CMPEQ_PPzZZ_D (PTRUE_D 31), + (INDEX_II_D 0, 1), + (DUP_ZR_D $index)), + $src)>; + + // Duplicate FP scalar into all vector elements + def : Pat<(nxv8f16 (AArch64dup (f16 FPR16:$src))), (DUP_ZV_H $src)>; + def : Pat<(nxv4f16 (AArch64dup (f16 FPR16:$src))), (DUP_ZV_H $src)>; + def : Pat<(nxv2f16 (AArch64dup (f16 FPR16:$src))), (DUP_ZV_H $src)>; + def : Pat<(nxv4f32 (AArch64dup (f32 FPR32:$src))), (DUP_ZV_S $src)>; + def : Pat<(nxv2f32 (AArch64dup (f32 FPR32:$src))), (DUP_ZV_S $src)>; + def : Pat<(nxv2f64 (AArch64dup (f64 FPR64:$src))), (DUP_ZV_D $src)>; + + // Duplicate +0.0 into all vector elements + def : Pat<(nxv8f16 (AArch64dup (f16 fpimm0))), (DUP_ZI_H 0, 0)>; + def : Pat<(nxv4f16 (AArch64dup (f16 fpimm0))), (DUP_ZI_H 0, 0)>; + def : Pat<(nxv2f16 (AArch64dup (f16 fpimm0))), (DUP_ZI_H 0, 0)>; + def : Pat<(nxv4f32 (AArch64dup (f32 fpimm0))), (DUP_ZI_S 0, 0)>; + def : Pat<(nxv2f32 (AArch64dup (f32 fpimm0))), (DUP_ZI_S 0, 0)>; + def : Pat<(nxv2f64 (AArch64dup (f64 fpimm0))), (DUP_ZI_D 0, 0)>; + + // Duplicate immediate in all vector elements + def : Pat<(nxv16i8 (AArch64dup (i32 (SVE8BitLslImm i32:$a, i32:$b)))), + (DUP_ZI_B $a, $b)>; + def : Pat<(nxv8i16 (AArch64dup (i32 (SVE8BitLslImm i32:$a, i32:$b)))), + (DUP_ZI_H $a, $b)>; + def : Pat<(nxv4i32 (AArch64dup (i32 (SVE8BitLslImm i32:$a, i32:$b)))), + (DUP_ZI_S $a, $b)>; + def : Pat<(nxv2i64 (AArch64dup (i64 (SVE8BitLslImm i64:$a, i64:$b)))), + (DUP_ZI_D $a, $b)>; + + // Duplicate GPR in all vector elements + def : Pat<(nxv16i8 (AArch64dup GPR32:$a)), (DUP_ZR_B $a)>; + def : Pat<(nxv8i16 (AArch64dup GPR32:$a)), (DUP_ZR_H $a)>; + def : Pat<(nxv4i32 (AArch64dup GPR32:$a)), (DUP_ZR_S $a)>; + def : Pat<(nxv2i64 (AArch64dup GPR64:$a)), (DUP_ZR_D $a)>; + + // Dup of fp immediate pattern. + let AddedComplexity = 2 in { + def : Pat<(nxv8f16 (AArch64dup fpimm16:$imm8)), + (FDUP_ZI_H fpimm16:$imm8)>; + def : Pat<(nxv4f32 (AArch64dup fpimm32:$imm8)), + (FDUP_ZI_S fpimm32:$imm8)>; + def : Pat<(nxv2f32 (AArch64dup fpimm32:$imm8)), + (FDUP_ZI_S fpimm32:$imm8)>; + def : Pat<(nxv2f64 (AArch64dup fpimm64:$imm8)), + (FDUP_ZI_D fpimm64:$imm8)>; + } + + + /// Compact single bit fp immediates + multiclass intrinsic_compact_fp_immediates { + def : Pat<(nxv8f16 (op (nxv8i1 PPR_3b:$Pg), + (nxv8f16 ZPR:$Zs1), + (nxv8f16 (AArch64dup (f16 A))))), + (!cast(I # "_H") PPR_3b:$Pg, ZPR:$Zs1, 0)>; + def : Pat<(nxv8f16 (op (nxv8i1 PPR_3b:$Pg), + (nxv8f16 ZPR:$Zs1), + (nxv8f16 (AArch64dup (f16 B))))), + (!cast(I # "_H") PPR_3b:$Pg, ZPR:$Zs1, 1)>; + def : Pat<(nxv4f32 (op (nxv4i1 PPR_3b:$Pg), + (nxv4f32 ZPR:$Zs1), + (nxv4f32 (AArch64dup (f32 A))))), + (!cast(I # "_S") PPR_3b:$Pg, ZPR:$Zs1, 0)>; + def : Pat<(nxv4f32 (op (nxv4i1 PPR_3b:$Pg), + (nxv4f32 ZPR:$Zs1), + (nxv4f32 (AArch64dup (f32 B))))), + (!cast(I # "_S") PPR_3b:$Pg, ZPR:$Zs1, 1)>; + def : Pat<(nxv2f64 (op (nxv2i1 PPR_3b:$Pg), + (nxv2f64 ZPR:$Zs1), + (nxv2f64 (AArch64dup (f64 A))))), + (!cast(I # "_D") PPR_3b:$Pg, ZPR:$Zs1, 0)>; + def : Pat<(nxv2f64 (op (nxv2i1 PPR_3b:$Pg), + (nxv2f64 ZPR:$Zs1), + (nxv2f64 (AArch64dup (f64 B))))), + (!cast(I # "_D") PPR_3b:$Pg, ZPR:$Zs1, 1)>; + + def : Pat<(nxv8f16 (ir_op (nxv8f16 ZPR:$Zs1), + (nxv8f16 (AArch64dup (f16 A))))), + (!cast(IX # "_H") (PTRUE_H 31), ZPR:$Zs1, 0)>; + def : Pat<(nxv8f16 (ir_op (nxv8f16 ZPR:$Zs1), + (nxv8f16 (AArch64dup (f16 B))))), + (!cast(IX # "_H") (PTRUE_H 31), ZPR:$Zs1, 1)>; + def : Pat<(nxv4f32 (ir_op (nxv4f32 ZPR:$Zs1), + (nxv4f32 (AArch64dup (f32 A))))), + (!cast(IX # "_S") (PTRUE_S 31), ZPR:$Zs1, 0)>; + def : Pat<(nxv4f32 (ir_op (nxv4f32 ZPR:$Zs1), + (nxv4f32 (AArch64dup (f32 B))))), + (!cast(IX # "_S") (PTRUE_S 31), ZPR:$Zs1, 1)>; + def : Pat<(nxv2f64 (ir_op (nxv2f64 ZPR:$Zs1), + (nxv2f64 (AArch64dup (f64 A))))), + (!cast(IX # "_D") (PTRUE_D 31), ZPR:$Zs1, 0)>; + def : Pat<(nxv2f64 (ir_op (nxv2f64 ZPR:$Zs1), + (nxv2f64 (AArch64dup (f64 B))))), + (!cast(IX # "_D") (PTRUE_D 31), ZPR:$Zs1, 1)>; + + let AddedComplexity = 2 in { + // When Intrinsic combined with SELECT + def : Pat<(nxv8f16 (op nxv8i1:$Pg, + (vselect nxv8i1:$Pg, nxv8f16:$Zs1, (SVEDup0)), + (nxv8f16 (AArch64dup (f16 A))))), + (!cast(IZ # "_H") $Pg, $Zs1, 0)>; + def : Pat<(nxv8f16 (op nxv8i1:$Pg, + (vselect nxv8i1:$Pg, nxv8f16:$Zs1, (SVEDup0)), + (nxv8f16 (AArch64dup (f16 B))))), + (!cast(IZ # "_H") $Pg, $Zs1, 1)>; + def : Pat<(nxv4f32 (op nxv4i1:$Pg, + (vselect nxv4i1:$Pg, nxv4f32:$Zs1, (SVEDup0)), + (nxv4f32 (AArch64dup (f32 A))))), + (!cast(IZ # "_S") $Pg, $Zs1, 0)>; + def : Pat<(nxv4f32 (op nxv4i1:$Pg, + (vselect nxv4i1:$Pg, nxv4f32:$Zs1, (SVEDup0)), + (nxv4f32 (AArch64dup (f32 B))))), + (!cast(IZ # "_S") $Pg, $Zs1, 1)>; + def : Pat<(nxv2f64 (op nxv2i1:$Pg, + (vselect nxv2i1:$Pg, nxv2f64:$Zs1, (SVEDup0)), + (nxv2f64 (AArch64dup (f64 A))))), + (!cast(IZ # "_D") $Pg, $Zs1, 0)>; + def : Pat<(nxv2f64 (op nxv2i1:$Pg, + (vselect nxv2i1:$Pg, nxv2f64:$Zs1, (SVEDup0)), + (nxv2f64 (AArch64dup (f64 B))))), + (!cast(IZ # "_D") $Pg, $Zs1, 1)>; + } + } + + defm : intrinsic_compact_fp_immediates<"FADD_ZPmI", "FADD_ZPZI_ZERO", "FADD_ZPZI_UNDEF", fpimm_half, fpimm_one, int_aarch64_sve_fadd, fadd>; + defm : intrinsic_compact_fp_immediates<"FSUB_ZPmI", "FSUB_ZPZI_ZERO", "FSUB_ZPZI_UNDEF", fpimm_half, fpimm_one, int_aarch64_sve_fsub, fsub>; + defm : intrinsic_compact_fp_immediates<"FSUBR_ZPmI", "FSUBR_ZPZI_ZERO", "FSUBR_ZPZI_UNDEF", fpimm_half, fpimm_one, int_aarch64_sve_fsubr>; + defm : intrinsic_compact_fp_immediates<"FMUL_ZPmI", "FMUL_ZPZI_ZERO", "FMUL_ZPZI_UNDEF", fpimm_half, fpimm_two, int_aarch64_sve_fmul, fmul>; + defm : intrinsic_compact_fp_immediates<"FMAX_ZPmI", "FMAX_ZPZI_ZERO", "FMAX_ZPZI_UNDEF", fpimm0, fpimm_one, AArch64fmax_pred>; + defm : intrinsic_compact_fp_immediates<"FMIN_ZPmI", "FMIN_ZPZI_ZERO", "FMIN_ZPZI_UNDEF", fpimm0, fpimm_one, AArch64fmin_pred>; + defm : intrinsic_compact_fp_immediates<"FMAXNM_ZPmI","FMAXNM_ZPZI_ZERO","FMAXNM_ZPZI_UNDEF", fpimm0, fpimm_one, AArch64fmaxnm_pred>; + defm : intrinsic_compact_fp_immediates<"FMINNM_ZPmI","FMINNM_ZPZI_ZERO","FMINNM_ZPZI_UNDEF", fpimm0, fpimm_one, AArch64fminnm_pred>; + + foreach type = ["nxv16i1", "nxv8i1", "nxv4i1", "nxv2i1"] in { + def : Pat< (!cast(type) + (load (am_sve_pred GPR64sp:$base, simm9:$offset))), + (LDR_PXI GPR64sp:$base, simm9:$offset)>; + def : Pat<(store (!cast(type) PPR:$val), + (am_sve_pred GPR64sp:$base, simm9:$offset)), + (STR_PXI PPR:$val, GPR64sp:$base, simm9:$offset)>; + } + + def : Pat<(nxv2f32 (fpextend (nxv2f16 ZPR:$Zs))), + (FCVT_ZPmZ_HtoS (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv4f32 (fpextend (nxv4f16 ZPR:$Zs))), + (FCVT_ZPmZ_HtoS (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(nxv2f64 (fpextend (nxv2f16 ZPR:$Zs))), + (FCVT_ZPmZ_HtoD (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2f64 (fpextend (nxv2f32 ZPR:$Zs))), + (FCVT_ZPmZ_StoD (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + + def : Pat<(nxv2f16 (fpround (nxv2f32 ZPR:$Zs))), + (FCVT_ZPmZ_StoH (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv4f16 (fpround (nxv4f32 ZPR:$Zs))), + (FCVT_ZPmZ_StoH (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(nxv2f16 (fpround (nxv2f64 ZPR:$Zs))), + (FCVT_ZPmZ_DtoH (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2f32 (fpround (nxv2f64 ZPR:$Zs))), + (FCVT_ZPmZ_DtoS (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + + // floating-point -> s16 + def : Pat<(nxv8i16 (fp_to_sint (nxv8f16 ZPR:$Zs))), + (FCVTZS_ZPmZ_HtoH (IMPLICIT_DEF), (PTRUE_H 31), ZPR:$Zs)>; + + // floating-point -> s32 + def : Pat<(nxv4i32 (fp_to_sint (nxv4f16 ZPR:$Zs))), + (FCVTZS_ZPmZ_HtoS (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(nxv4i32 (fp_to_sint (nxv4f32 ZPR:$Zs))), + (FCVTZS_ZPmZ_StoS (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + + // floating-point -> s64 + def : Pat<(nxv2i64 (fp_to_sint (nxv2f16 ZPR:$Zs))), + (FCVTZS_ZPmZ_HtoD (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2i64 (fp_to_sint (nxv2f32 ZPR:$Zs))), + (FCVTZS_ZPmZ_StoD (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2i64 (fp_to_sint (nxv2f64 ZPR:$Zs))), + (FCVTZS_ZPmZ_DtoD (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + + // floating-point -> u16 + def : Pat<(nxv8i16 (fp_to_uint (nxv8f16 ZPR:$Zs))), + (FCVTZU_ZPmZ_HtoH (IMPLICIT_DEF), (PTRUE_H 31), ZPR:$Zs)>; + + // floating-point -> u32 + def : Pat<(nxv4i32 (fp_to_uint (nxv4f16 ZPR:$Zs))), + (FCVTZU_ZPmZ_HtoS (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(nxv4i32 (fp_to_uint (nxv4f32 ZPR:$Zs))), + (FCVTZU_ZPmZ_StoS (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + + // floating-point -> u64 + def : Pat<(nxv2i64 (fp_to_uint (nxv2f16 ZPR:$Zs))), + (FCVTZU_ZPmZ_HtoD (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2i64 (fp_to_uint (nxv2f32 ZPR:$Zs))), + (FCVTZU_ZPmZ_StoD (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2i64 (fp_to_uint (nxv2f64 ZPR:$Zs))), + (FCVTZU_ZPmZ_DtoD (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + + // signed int -> fp16 + def : Pat<(nxv2f16 (sint_to_fp (sext_inreg (nxv2i64 ZPR:$Zs), nxv2i16))), + (SCVTF_ZPmZ_HtoH (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2f16 (sint_to_fp (sext_inreg (nxv2i64 ZPR:$Zs), nxv2i32))), + (SCVTF_ZPmZ_StoH (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2f16 (sint_to_fp (nxv2i64 ZPR:$Zs))), + (SCVTF_ZPmZ_DtoH (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv4f16 (sint_to_fp (sext_inreg (nxv4i32 ZPR:$Zs), nxv4i16))), + (SCVTF_ZPmZ_HtoH (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(nxv4f16 (sint_to_fp (nxv4i32 ZPR:$Zs))), + (SCVTF_ZPmZ_StoH (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(nxv8f16 (sint_to_fp (nxv8i16 ZPR:$Zs))), + (SCVTF_ZPmZ_HtoH (IMPLICIT_DEF), (PTRUE_H 31), ZPR:$Zs)>; + + // signed int -> fp32 + def : Pat<(nxv2f32 (sint_to_fp (sext_inreg (nxv2i64 ZPR:$Zs), nxv2i32))), + (SCVTF_ZPmZ_StoS (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2f32 (sint_to_fp (nxv2i64 ZPR:$Zs))), + (SCVTF_ZPmZ_DtoS (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv4f32 (sint_to_fp (nxv4i32 ZPR:$Zs))), + (SCVTF_ZPmZ_StoS (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + + // signed int -> fp64 + def : Pat<(nxv2f64 (sint_to_fp (sext_inreg (nxv2i64 ZPR:$Zs), nxv2i32))), + (SCVTF_ZPmZ_StoD (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2f64 (sint_to_fp (nxv2i64 ZPR:$Zs))), + (SCVTF_ZPmZ_DtoD (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + + // unsigned int -> fp16 + def : Pat<(nxv2f16 (uint_to_fp (and (nxv2i64 ZPR:$Zs), (nxv2i64 (AArch64dup (i64 0xFFFF)))))), + (UCVTF_ZPmZ_HtoH (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2f16 (uint_to_fp (and (nxv2i64 ZPR:$Zs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))))), + (UCVTF_ZPmZ_StoH (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2f16 (uint_to_fp (nxv2i64 ZPR:$Zs))), + (UCVTF_ZPmZ_DtoH (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv4f16 (uint_to_fp (and (nxv4i32 ZPR:$Zs), (nxv4i32 (AArch64dup (i32 0xFFFF)))))), + (UCVTF_ZPmZ_HtoH (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(nxv4f16 (uint_to_fp (nxv4i32 ZPR:$Zs))), + (UCVTF_ZPmZ_StoH (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(nxv8f16 (uint_to_fp (nxv8i16 ZPR:$Zs))), + (UCVTF_ZPmZ_HtoH (IMPLICIT_DEF), (PTRUE_H 31), ZPR:$Zs)>; + + // unsigned int -> fp32 + def : Pat<(nxv2f32 (uint_to_fp (and (nxv2i64 ZPR:$Zs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))))), + (UCVTF_ZPmZ_StoS (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2f32 (uint_to_fp (nxv2i64 ZPR:$Zs))), + (UCVTF_ZPmZ_DtoS (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv4f32 (uint_to_fp (nxv4i32 ZPR:$Zs))), + (UCVTF_ZPmZ_StoS (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; + + // unsigned int -> fp64 + def : Pat<(nxv2f64 (uint_to_fp (and (nxv2i64 ZPR:$Zs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))))), + (UCVTF_ZPmZ_StoD (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(nxv2f64 (uint_to_fp (nxv2i64 ZPR:$Zs))), + (UCVTF_ZPmZ_DtoD (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; + + def : Pat<(AArch64ptest (nxv16i1 PPR:$pg), (nxv16i1 PPR:$src)), + (PTEST_PP PPR:$pg, PPR:$src)>; + def : Pat<(AArch64ptest (nxv8i1 PPR:$pg), (nxv8i1 PPR:$src)), + (PTEST_PP PPR:$pg, PPR:$src)>; + def : Pat<(AArch64ptest (nxv4i1 PPR:$pg), (nxv4i1 PPR:$src)), + (PTEST_PP PPR:$pg, PPR:$src)>; + def : Pat<(AArch64ptest (nxv2i1 PPR:$pg), (nxv2i1 PPR:$src)), + (PTEST_PP PPR:$pg, PPR:$src)>; + + def : Pat<(AArch64not (nxv16i1 PPR:$src)), + (NOR_PPzPP (PTRUE_B 31), PPR:$src, PPR:$src)>; + def : Pat<(AArch64not (nxv8i1 PPR:$src)), + (NOR_PPzPP (PTRUE_H 31), PPR:$src, PPR:$src)>; + def : Pat<(AArch64not (nxv4i1 PPR:$src)), + (NOR_PPzPP (PTRUE_S 31), PPR:$src, PPR:$src)>; + def : Pat<(AArch64not (nxv2i1 PPR:$src)), + (NOR_PPzPP (PTRUE_D 31), PPR:$src, PPR:$src)>; + + def : Pat<(nxv16i8 (bitconvert (nxv8i16 ZPR:$src))), (nxv16i8 ZPR:$src)>; + def : Pat<(nxv16i8 (bitconvert (nxv4i32 ZPR:$src))), (nxv16i8 ZPR:$src)>; + def : Pat<(nxv16i8 (bitconvert (nxv2i64 ZPR:$src))), (nxv16i8 ZPR:$src)>; + def : Pat<(nxv16i8 (bitconvert (nxv8f16 ZPR:$src))), (nxv16i8 ZPR:$src)>; + def : Pat<(nxv16i8 (bitconvert (nxv4f32 ZPR:$src))), (nxv16i8 ZPR:$src)>; + def : Pat<(nxv16i8 (bitconvert (nxv2f64 ZPR:$src))), (nxv16i8 ZPR:$src)>; + + def : Pat<(nxv8i16 (bitconvert (nxv16i8 ZPR:$src))), (nxv8i16 ZPR:$src)>; + def : Pat<(nxv8i16 (bitconvert (nxv4i32 ZPR:$src))), (nxv8i16 ZPR:$src)>; + def : Pat<(nxv8i16 (bitconvert (nxv2i64 ZPR:$src))), (nxv8i16 ZPR:$src)>; + def : Pat<(nxv8i16 (bitconvert (nxv8f16 ZPR:$src))), (nxv8i16 ZPR:$src)>; + def : Pat<(nxv8i16 (bitconvert (nxv4f32 ZPR:$src))), (nxv8i16 ZPR:$src)>; + def : Pat<(nxv8i16 (bitconvert (nxv2f64 ZPR:$src))), (nxv8i16 ZPR:$src)>; + + def : Pat<(nxv4i32 (bitconvert (nxv16i8 ZPR:$src))), (nxv4i32 ZPR:$src)>; + def : Pat<(nxv4i32 (bitconvert (nxv8i16 ZPR:$src))), (nxv4i32 ZPR:$src)>; + def : Pat<(nxv4i32 (bitconvert (nxv2i64 ZPR:$src))), (nxv4i32 ZPR:$src)>; + def : Pat<(nxv4i32 (bitconvert (nxv8f16 ZPR:$src))), (nxv4i32 ZPR:$src)>; + def : Pat<(nxv4i32 (bitconvert (nxv4f32 ZPR:$src))), (nxv4i32 ZPR:$src)>; + def : Pat<(nxv4i32 (bitconvert (nxv2f64 ZPR:$src))), (nxv4i32 ZPR:$src)>; + + def : Pat<(nxv2i64 (bitconvert (nxv16i8 ZPR:$src))), (nxv2i64 ZPR:$src)>; + def : Pat<(nxv2i64 (bitconvert (nxv8i16 ZPR:$src))), (nxv2i64 ZPR:$src)>; + def : Pat<(nxv2i64 (bitconvert (nxv4i32 ZPR:$src))), (nxv2i64 ZPR:$src)>; + def : Pat<(nxv2i64 (bitconvert (nxv8f16 ZPR:$src))), (nxv2i64 ZPR:$src)>; + def : Pat<(nxv2i64 (bitconvert (nxv4f32 ZPR:$src))), (nxv2i64 ZPR:$src)>; + def : Pat<(nxv2i64 (bitconvert (nxv2f64 ZPR:$src))), (nxv2i64 ZPR:$src)>; + + def : Pat<(nxv8f16 (bitconvert (nxv16i8 ZPR:$src))), (nxv8f16 ZPR:$src)>; + def : Pat<(nxv8f16 (bitconvert (nxv8i16 ZPR:$src))), (nxv8f16 ZPR:$src)>; + def : Pat<(nxv8f16 (bitconvert (nxv4i32 ZPR:$src))), (nxv8f16 ZPR:$src)>; + def : Pat<(nxv8f16 (bitconvert (nxv2i64 ZPR:$src))), (nxv8f16 ZPR:$src)>; + def : Pat<(nxv8f16 (bitconvert (nxv4f32 ZPR:$src))), (nxv8f16 ZPR:$src)>; + def : Pat<(nxv8f16 (bitconvert (nxv2f64 ZPR:$src))), (nxv8f16 ZPR:$src)>; + + def : Pat<(nxv4f32 (bitconvert (nxv16i8 ZPR:$src))), (nxv4f32 ZPR:$src)>; + def : Pat<(nxv4f32 (bitconvert (nxv8i16 ZPR:$src))), (nxv4f32 ZPR:$src)>; + def : Pat<(nxv4f32 (bitconvert (nxv4i32 ZPR:$src))), (nxv4f32 ZPR:$src)>; + def : Pat<(nxv4f32 (bitconvert (nxv2i64 ZPR:$src))), (nxv4f32 ZPR:$src)>; + def : Pat<(nxv4f32 (bitconvert (nxv8f16 ZPR:$src))), (nxv4f32 ZPR:$src)>; + def : Pat<(nxv4f32 (bitconvert (nxv2f64 ZPR:$src))), (nxv4f32 ZPR:$src)>; + + def : Pat<(nxv2f64 (bitconvert (nxv16i8 ZPR:$src))), (nxv2f64 ZPR:$src)>; + def : Pat<(nxv2f64 (bitconvert (nxv8i16 ZPR:$src))), (nxv2f64 ZPR:$src)>; + def : Pat<(nxv2f64 (bitconvert (nxv4i32 ZPR:$src))), (nxv2f64 ZPR:$src)>; + def : Pat<(nxv2f64 (bitconvert (nxv2i64 ZPR:$src))), (nxv2f64 ZPR:$src)>; + def : Pat<(nxv2f64 (bitconvert (nxv8f16 ZPR:$src))), (nxv2f64 ZPR:$src)>; + def : Pat<(nxv2f64 (bitconvert (nxv4f32 ZPR:$src))), (nxv2f64 ZPR:$src)>; + + // These are effectively bitconvert for predicates but due to the differently + // sized input/ouput ValueTypes we cannot simply return $src. + def : Pat<(nxv16i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv16i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv16i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv8i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv8i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv8i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv4i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv4i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv4i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv2i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv2i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv2i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + + // These allow bitcasts between unpacked_fp and packed_int datatypes. + def : Pat<(nxv2f16 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv2i64 (reinterpret_cast (nxv2f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv4f16 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv4i32 (reinterpret_cast (nxv4f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv2f32 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv2i64 (reinterpret_cast (nxv2f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + + // These allow bitcasts between unpacked_fp datatypes. + def : Pat<(nxv2f16 (reinterpret_cast (nxv4f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv4f16 (reinterpret_cast (nxv2f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv4f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv8f16 (reinterpret_cast (nxv4f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv2f32 (reinterpret_cast (nxv4f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv4f32 (reinterpret_cast (nxv2f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + + multiclass unpred_load { + // reg + reg + let AddedComplexity = 1 in { + def _reg_reg : Pat<(Ty (Load (AddrCP GPR64sp:$base, GPR64:$offset))), + (RegRegInst (PTrue 31), GPR64sp:$base, GPR64:$offset)>; + } + // reg + imm + let AddedComplexity = 2 in { + def _reg_imm : Pat<(Ty (Load (am_sve_indexed_s4 GPR64sp:$base, simm4s1:$offset))), + (RegImmInst (PTrue 31), GPR64sp:$base, simm4s1:$offset)>; + } + // default fallback + def _default : Pat<(Ty (Load GPR64sp:$base)), + (RegImmInst (PTrue 31), GPR64sp:$base, (i64 0))>; + } + + defm : unpred_load; + defm : unpred_load; + defm : unpred_load; + defm : unpred_load; + defm : unpred_load; + defm : unpred_load; + defm : unpred_load; + + defm : unpred_load; + defm : unpred_load; + defm : unpred_load; + defm : unpred_load; + defm : unpred_load; + defm : unpred_load; + + defm : unpred_load; + defm : unpred_load; + defm : unpred_load; + defm : unpred_load; + defm : unpred_load; + defm : unpred_load; + + multiclass pred_load { + // reg + reg + let AddedComplexity = 1 in { + def _reg_reg_z : Pat<(Ty (Load (AddrCP GPR64:$base, GPR64:$offset), (PredTy PPR:$gp), (SVEDup0Undef))), + (RegRegInst PPR:$gp, GPR64:$base, GPR64:$offset)>; + } + // reg + imm + let AddedComplexity = 2 in { + def _reg_imm_z : Pat<(Ty (Load (am_sve_indexed_s4 GPR64sp:$base, simm4s1:$offset), (PredTy PPR:$gp), (SVEDup0Undef))), + (RegImmInst PPR:$gp, GPR64:$base, simm4s1:$offset)>; + } + // default fallback + def _default_z : Pat<(Ty (Load GPR64:$base, (PredTy PPR:$gp), (SVEDup0Undef))), + (RegImmInst PPR:$gp, GPR64:$base, (i64 0))>; + } + + // 2-element contiguous loads + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + + // 4-element contiguous loads + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + + // 8-element contiguous loads + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + + // 16-element contiguous loads + defm : pred_load; + + multiclass ldff1 { + // base + index + def : Pat<(Ty (Load (PredTy PPR:$gp), (AddrCP GPR64sp:$base, GPR64:$offset), MemVT)), + (I PPR:$gp, GPR64sp:$base, GPR64:$offset)>; + // base + def : Pat<(Ty (Load (PredTy PPR:$gp), GPR64:$base, MemVT)), + (I PPR:$gp, GPR64sp:$base, XZR)>; + } + + // 2-element contiguous first faulting loads + defm : ldff1; + defm : ldff1; + defm : ldff1; + defm : ldff1; + defm : ldff1; + defm : ldff1; + defm : ldff1; + defm : ldff1; + defm : ldff1; + + // 4-element contiguous first faulting loads + defm : ldff1; + defm : ldff1; + defm : ldff1; + defm : ldff1; + defm : ldff1; + defm : ldff1; + + // 8-element contiguous first faulting loads + defm : ldff1; + defm : ldff1; + defm : ldff1; + defm : ldff1; + + // 16-element contiguous first faulting loads + defm : ldff1; + + // 2-element gather first faulting loads + // TODO: Handle floating point loads within DAGCombine + def : Pat<(nxv2f64 (AArch64ldff1_gather (nxv2i1 PPR:$gp), GPR64sp:$base, (nxv2i64 ZPR:$offsets), nxv2f64)), + (GLDFF1D PPR:$gp, GPR64sp:$base, ZPR:$offsets)>; + + def : Pat<(nxv2f64 (AArch64ldff1_gather (nxv2i1 PPR:$gp), uimm5s8:$index, (nxv2i64 ZPR:$ptrs), nxv2f64)), + (GLDFF1D_IMM PPR:$gp, ZPR:$ptrs, uimm5s8:$index)>; + + def : Pat<(nxv2f64 (AArch64ldff1_gather_scaled (nxv2i1 PPR:$gp), GPR64sp:$base, (nxv2i64 ZPR:$indices), nxv2f64)), + (GLDFF1D_SCALED PPR:$gp, GPR64sp:$base, ZPR:$indices)>; + + def : Pat<(nxv2f64 (AArch64ldff1_gather (nxv2i1 PPR:$gp), GPR64sp:$base, (sext_inreg (nxv2i64 ZPR:$offsets), nxv2i32), nxv2f64)), + (GLDFF1D_SXTW PPR:$gp, GPR64sp:$base, ZPR:$offsets)>; + + def : Pat<(nxv2f64 (AArch64ldff1_gather_scaled (nxv2i1 PPR:$gp), GPR64sp:$base, (sext_inreg (nxv2i64 ZPR:$indices), nxv2i32), nxv2f64)), + (GLDFF1D_SXTW_SCALED PPR:$gp, GPR64sp:$base, ZPR:$indices)>; + + def : Pat<(nxv2f64 (AArch64ldff1_gather (nxv2i1 PPR:$gp), GPR64sp:$base, (and (nxv2i64 ZPR:$offsets), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))), nxv2f64)), + (GLDFF1D_UXTW PPR:$gp, GPR64sp:$base, ZPR:$offsets)>; + + def : Pat<(nxv2f64 (AArch64ldff1_gather_scaled (nxv2i1 PPR:$gp), GPR64sp:$base, (and (nxv2i64 ZPR:$indices), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))), nxv2f64)), + (GLDFF1D_UXTW_SCALED PPR:$gp, GPR64sp:$base, ZPR:$indices)>; + + // 4-element gather first faulting loads + // TODO: Handle floating point loads within DAGCombine + def : Pat<(nxv4f32 (AArch64ldff1_gather_uxtw (nxv4i1 PPR:$gp), uimm5s4:$index, (nxv4i32 ZPR:$ptrs), nxv4f32)), + (GLDFF1W_IMM PPR:$gp, ZPR:$ptrs, uimm5s4:$index)>; + + def : Pat<(nxv4f32 (AArch64ldff1_gather_sxtw (nxv4i1 PPR:$gp), GPR64sp:$base, (nxv4i32 ZPR:$offsets), nxv4f32)), + (GLDFF1W_SXTW PPR:$gp, GPR64sp:$base, ZPR:$offsets)>; + + def : Pat<(nxv4f32 (AArch64ldff1_gather_sxtw_scaled (nxv4i1 PPR:$gp), GPR64sp:$base, (nxv4i32 ZPR:$indices), nxv4f32)), + (GLDFF1W_SXTW_SCALED PPR:$gp, GPR64sp:$base, ZPR:$indices)>; + + def : Pat<(nxv4f32 (AArch64ldff1_gather_uxtw (nxv4i1 PPR:$gp), GPR64sp:$base, (nxv4i32 ZPR:$offsets), nxv4f32)), + (GLDFF1W_UXTW PPR:$gp, GPR64sp:$base, ZPR:$offsets)>; + + def : Pat<(nxv4f32 (AArch64ldff1_gather_uxtw_scaled (nxv4i1 PPR:$gp), GPR64sp:$base, (nxv4i32 ZPR:$indices), nxv4f32)), + (GLDFF1W_UXTW_SCALED PPR:$gp, GPR64sp:$base, ZPR:$indices)>; + + multiclass sve_masked_gather_x2 { + // vector of pointers + def : Pat<(vt (load undef, (nxv2i1 PPR:$gp), 0, (nxv2i64 ZPR:$ptrs))), + (!cast(I # "_IMM") PPR:$gp, ZPR:$ptrs, 0)>; + // vector of pointers + immediate offset + def : Pat<(vt (load undef, (nxv2i1 PPR:$gp), 0, (add (nxv2i64 ZPR:$ptrs), (nxv2i64 (AArch64dup (i64 imm_ty:$imm)))))), + (!cast(I # "_IMM") PPR:$gp, ZPR:$ptrs, imm_ty:$imm)>; + // vector of pointers + scalar offset + def : Pat<(vt (load undef, (nxv2i1 PPR:$gp), 0, (add (nxv2i64 ZPR:$ptrs), (nxv2i64 (AArch64dup GPR64:$offset))))), + (!cast(I) PPR:$gp, GPR64:$offset, ZPR:$ptrs)>; + // base + vector of scaled offsets + def : Pat<(vt (load undef, (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs))), + (!cast(I # SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of signed 32bit scaled offsets + def : Pat<(vt (load undef, (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32))), + (!cast(I # "_SXTW" # SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of unsigned 32bit scaled offsets + def : Pat<(vt (load undef, (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))))), + (!cast(I # "_UXTW" # SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of signed 32bit offsets + def : Pat<(vt (load undef, (nxv2i1 PPR:$gp), 0, (add (nxv2i64 (AArch64dup GPR64:$base)), (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32)))), + (!cast(I # "_SXTW") PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of unsigned 32bit offsets + def : Pat<(vt (load undef, (nxv2i1 PPR:$gp), 0, (add (nxv2i64 (AArch64dup GPR64:$base)), (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF))))))), + (!cast(I # "_UXTW") PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + multiclass sve_masked_gather_x4 { + def : Pat<(vt (load undef, (nxv4i1 PPR:$gp), GPR64:$base, (nxv4i32 ZPR:$offs))), + (I PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + defm : sve_masked_gather_x2; + defm : sve_masked_gather_x2; + defm : sve_masked_gather_x2; + defm : sve_masked_gather_x2; + defm : sve_masked_gather_x2; + defm : sve_masked_gather_x2; + defm : sve_masked_gather_x2; + defm : sve_masked_gather_x2; + defm : sve_masked_gather_x2; + defm : sve_masked_gather_x2; + + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + + multiclass ldnf1 { + // base + offset_mul_vl + def : Pat<(Ty (Load (PredTy PPR:$gp), (am_sve_indexed_s4 GPR64sp:$base, simm4s1:$offset), MemVT)), + (I PPR:$gp, GPR64sp:$base, simm4s1:$offset)>; + // base + def : Pat<(Ty (Load (PredTy PPR:$gp), GPR64:$base, MemVT)), + (I PPR:$gp, GPR64sp:$base, (i64 0))>; + } + + // 2-element contiguous non-faulting loads + defm : ldnf1; + defm : ldnf1; + defm : ldnf1; + defm : ldnf1; + defm : ldnf1; + defm : ldnf1; + defm : ldnf1; + defm : ldnf1; + + // 4-element contiguous non-faulting loads + defm : ldnf1; + defm : ldnf1; + defm : ldnf1; + defm : ldnf1; + defm : ldnf1; + defm : ldnf1; + + // 8-element contiguous non-faulting loads + defm : ldnf1; + defm : ldnf1; + defm : ldnf1; + defm : ldnf1; + + // 16-element contiguous non-faulting loads + defm : ldnf1; + + multiclass unpred_store { + // reg + reg + let AddedComplexity = 1 in { + def _reg_reg : Pat<(store (Ty ZPR:$val), (AddrCP GPR64sp:$base, GPR64:$offset)), + (RegRegInst ZPR:$val, (PTrue 31), GPR64sp:$base, GPR64:$offset)>; + // reg + imm + } + let AddedComplexity = 2 in { + def _reg_imm : Pat<(store (Ty ZPR:$val), (am_sve_indexed_s4 GPR64sp:$base, simm4s1:$offset)), + (RegImmInst ZPR:$val, (PTrue 31), GPR64sp:$base, simm4s1:$offset)>; + } + // default fallback + def _default : Pat<(store (Ty ZPR:$val), GPR64sp:$base), + (RegImmInst ZPR:$val, (PTrue 31), GPR64sp:$base, (i64 0))>; + } + + defm Pat_ST1B : unpred_store; + defm Pat_ST1H : unpred_store; + defm Pat_ST1W : unpred_store; + defm Pat_ST1D : unpred_store; + defm Pat_ST1H_float16: unpred_store; + defm Pat_ST1W_float : unpred_store; + defm Pat_ST1D_double : unpred_store; + + + multiclass pred_store { + // reg + reg + let AddedComplexity = 1 in { + def _reg_reg : Pat<(Store (AddrCP GPR64:$base, GPR64:$offset), (PredTy PPR:$gp), (Ty ZPR:$vec)), + (RegRegInst ZPR:$vec, PPR:$gp, GPR64:$base, GPR64:$offset)>; + } + // reg + imm + let AddedComplexity = 2 in { + // TODO: write LL test for this pattern + def _reg_imm : Pat<(Store (am_sve_indexed_s4 GPR64sp:$base, simm4s1:$offset), (PredTy PPR:$gp), (Ty ZPR:$vec)), + (RegImmInst ZPR:$vec, PPR:$gp, GPR64:$base, simm4s1:$offset)>; + } + // default fallback + def _default : Pat<(Store GPR64:$base, (PredTy PPR:$gp), (Ty ZPR:$vec)), + (RegImmInst ZPR:$vec, PPR:$gp, GPR64:$base, (i64 0))>; + } + + // 2-element contiguous stores + defm : pred_store; + defm : pred_store; + defm : pred_store; + defm : pred_store; + defm : pred_store; + defm : pred_store; + defm : pred_store; + + // 4-element contiguous stores + defm : pred_store; + defm : pred_store; + defm : pred_store; + defm : pred_store; + defm : pred_store; + + // 8-element contiguous stores + defm : pred_store; + defm : pred_store; + defm : pred_store; + + // 16-element contiguous stores + defm : pred_store; + + multiclass sve_masked_scatter_x2 { + // vector of pointers + def : Pat<(store (vt ZPR:$vec), (nxv2i1 PPR:$gp), 0, (nxv2i64 ZPR:$ptrs)), + (!cast(I # "_IMM") ZPR:$vec, PPR:$gp, ZPR:$ptrs, 0)>; + // vector of pointers + immediate offset + def : Pat<(store (vt ZPR:$vec), (nxv2i1 PPR:$gp), 0, (add (nxv2i64 ZPR:$ptrs), (nxv2i64 (AArch64dup (i64 imm_ty:$imm))))), + (!cast(I # "_IMM") ZPR:$vec, PPR:$gp, ZPR:$ptrs, imm_ty:$imm)>; + // vector of pointers + scalar offset + def : Pat<(store (vt ZPR:$vec), (nxv2i1 PPR:$gp), 0, (add (nxv2i64 ZPR:$ptrs), (nxv2i64 (AArch64dup GPR64:$offset)))), + (!cast(I) ZPR:$vec, PPR:$gp, GPR64:$offset, ZPR:$ptrs)>; + // base + vector of scaled offsets + def : Pat<(store (vt ZPR:$vec), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs)), + (!cast(I # SCALED) ZPR:$vec, PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of signed 32bit scaled offsets + def : Pat<(store (vt ZPR:$vec), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32)), + (!cast(I # "_SXTW" # SCALED) ZPR:$vec, PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of unsigned 32bit scaled offsets + def : Pat<(store (vt ZPR:$vec), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF))))), + (!cast(I # "_UXTW" # SCALED) ZPR:$vec, PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of signed 32bit offsets + def : Pat<(store (vt ZPR:$vec), (nxv2i1 PPR:$gp), 0, (add (nxv2i64 (AArch64dup GPR64:$base)), (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32))), + (!cast(I # "_SXTW") ZPR:$vec, PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of unsigned 32bit offsets + def : Pat<(store (vt ZPR:$vec), (nxv2i1 PPR:$gp), 0, (add (nxv2i64 (AArch64dup GPR64:$base)), (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))))), + (!cast(I # "_UXTW") ZPR:$vec, PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + multiclass sve_masked_scatter_x4 { + def : Pat<(store (vt ZPR:$vec), (nxv4i1 PPR:$gp), GPR64:$base, (nxv4i32 ZPR:$offs)), + (I ZPR:$vec, PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + defm : sve_masked_scatter_x2; + defm : sve_masked_scatter_x2; + defm : sve_masked_scatter_x2; + defm : sve_masked_scatter_x2; + defm : sve_masked_scatter_x2; + defm : sve_masked_scatter_x2; + defm : sve_masked_scatter_x2; + + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + + // Unpacked floating point loads + def : Pat<(nxv2f32 (load (i64 GPR64:$base))), + (LD1W_D_IMM (PTRUE_D 31), GPR64:$base, (i64 0))>; + + // Unpacked floating point stores + def : Pat<(store (nxv2f32 ZPR:$Zs), (i64 GPR64:$base)), + (ST1W_D_IMM ZPR:$Zs, (PTRUE_D 31), GPR64:$base, (i64 0))>; + + + def : Pat<(nxv16i1 (and PPR:$Ps1, PPR:$Ps2)), + (AND_PPzPP (PTRUE_B 31), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv8i1 (and PPR:$Ps1, PPR:$Ps2)), + (AND_PPzPP (PTRUE_H 31), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv4i1 (and PPR:$Ps1, PPR:$Ps2)), + (AND_PPzPP (PTRUE_S 31), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv2i1 (and PPR:$Ps1, PPR:$Ps2)), + (AND_PPzPP (PTRUE_D 31), PPR:$Ps1, PPR:$Ps2)>; + + def : Pat<(nxv2i1 (and (nxv2i1 (and PPR:$Ps1, PPR:$Ps2)), PPR:$Ps3)), + (AND_PPzPP PPR:$Ps1, PPR:$Ps2, PPR:$Ps3)>; + def : Pat<(nxv2i1 (and PPR:$Ps1, (nxv2i1 (and PPR:$Ps2, PPR:$Ps3)))), + (AND_PPzPP PPR:$Ps1, PPR:$Ps2, PPR:$Ps3)>; + + def : Pat<(nxv16i1 (or PPR:$Ps1, PPR:$Ps2)), + (ORR_PPzPP (PTRUE_B 31), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv8i1 (or PPR:$Ps1, PPR:$Ps2)), + (ORR_PPzPP (PTRUE_H 31), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv4i1 (or PPR:$Ps1, PPR:$Ps2)), + (ORR_PPzPP (PTRUE_S 31), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv2i1 (or PPR:$Ps1, PPR:$Ps2)), + (ORR_PPzPP (PTRUE_D 31), PPR:$Ps1, PPR:$Ps2)>; + + def : Pat<(nxv16i1 (xor PPR:$Ps1, PPR:$Ps2)), + (EOR_PPzPP (PTRUE_B 31), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv8i1 (xor PPR:$Ps1, PPR:$Ps2)), + (EOR_PPzPP (PTRUE_H 31), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv4i1 (xor PPR:$Ps1, PPR:$Ps2)), + (EOR_PPzPP (PTRUE_S 31), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv2i1 (xor PPR:$Ps1, PPR:$Ps2)), + (EOR_PPzPP (PTRUE_D 31), PPR:$Ps1, PPR:$Ps2)>; + + // UNE compares + def : Pat<(nxv8i1 (AArch64not (AArch64fcmeq (nxv8f16 ZPR:$Zs1), (nxv8f16 ZPR:$Zs2)))), + (FCMNE_PPzZZ_H (PTRUE_H 31), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv4i1 (AArch64not (AArch64fcmeq (nxv4f32 ZPR:$Zs1), (nxv4f32 ZPR:$Zs2)))), + (FCMNE_PPzZZ_S (PTRUE_S 31), ZPR:$Zs1, ZPR:$Zs2)>; + def : Pat<(nxv2i1 (AArch64not (AArch64fcmeq (nxv2f64 ZPR:$Zs1), (nxv2f64 ZPR:$Zs2)))), + (FCMNE_PPzZZ_D (PTRUE_D 31), ZPR:$Zs1, ZPR:$Zs2)>; + + def : Pat<(v16i8 (extract_subvector ZPR:$Zs, (i64 0))), + (EXTRACT_SUBREG ZPR:$Zs, zsub)>; + def : Pat<(v8i16 (extract_subvector ZPR:$Zs, (i64 0))), + (EXTRACT_SUBREG ZPR:$Zs, zsub)>; + def : Pat<(v4i32 (extract_subvector ZPR:$Zs, (i64 0))), + (EXTRACT_SUBREG ZPR:$Zs, zsub)>; + def : Pat<(v2i64 (extract_subvector ZPR:$Zs, (i64 0))), + (EXTRACT_SUBREG ZPR:$Zs, zsub)>; + def : Pat<(v4i16 (extract_subvector ZPR:$Zs, (i64 0))), + (EXTRACT_SUBREG ZPR:$Zs, dsub)>; + def : Pat<(v2i32 (extract_subvector ZPR:$Zs, (i64 0))), + (EXTRACT_SUBREG ZPR:$Zs, dsub)>; + + // Extract lo/hi halves of legal predicate types. + def : Pat<(nxv2i1 (extract_subvector (nxv4i1 PPR:$Ps), (i64 0))), + (ZIP1_PPP_S PPR:$Ps, (PFALSE))>; + def : Pat<(nxv2i1 (extract_subvector (nxv4i1 PPR:$Ps), (i64 2))), + (ZIP2_PPP_S PPR:$Ps, (PFALSE))>; + def : Pat<(nxv4i1 (extract_subvector (nxv8i1 PPR:$Ps), (i64 0))), + (ZIP1_PPP_H PPR:$Ps, (PFALSE))>; + def : Pat<(nxv4i1 (extract_subvector (nxv8i1 PPR:$Ps), (i64 4))), + (ZIP2_PPP_H PPR:$Ps, (PFALSE))>; + def : Pat<(nxv8i1 (extract_subvector (nxv16i1 PPR:$Ps), (i64 0))), + (ZIP1_PPP_B PPR:$Ps, (PFALSE))>; + def : Pat<(nxv8i1 (extract_subvector (nxv16i1 PPR:$Ps), (i64 8))), + (ZIP2_PPP_B PPR:$Ps, (PFALSE))>; + + def : Pat<(nxv4f16 (extract_subvector (nxv8f16 ZPR:$Zs), (i64 0))), + (UUNPKLO_ZZ_S ZPR:$Zs)>; + def : Pat<(nxv4f16 (extract_subvector (nxv8f16 ZPR:$Zs), (i64 4))), + (UUNPKHI_ZZ_S ZPR:$Zs)>; + + def : Pat<(nxv2f16 (extract_subvector (nxv4f16 ZPR:$Zs), (i64 0))), + (UUNPKLO_ZZ_D ZPR:$Zs)>; + def : Pat<(nxv2f16 (extract_subvector (nxv4f16 ZPR:$Zs), (i64 2))), + (UUNPKHI_ZZ_D ZPR:$Zs)>; + def : Pat<(nxv2f32 (extract_subvector (nxv4f32 ZPR:$Zs), (i64 0))), + (UUNPKLO_ZZ_D ZPR:$Zs)>; + def : Pat<(nxv2f32 (extract_subvector (nxv4f32 ZPR:$Zs), (i64 2))), + (UUNPKHI_ZZ_D ZPR:$Zs)>; + + def : Pat<(nxv2i64 (add (nxv2i64 ZPR:$Zn), (nxv2i64LslBy3 (nxv2i64 ZPR:$Zm)))), + (ADR_LSL_ZZZ_D ZPR:$Zn, ZPR:$Zm, 3)>; + def : Pat<(nxv2i64 (add (nxv2i64LslBy3 (nxv2i64 ZPR:$Zm)), (nxv2i64 ZPR:$Zn))), + (ADR_LSL_ZZZ_D ZPR:$Zn, ZPR:$Zm, 3)>; + + // Contiguous SVE prefetches + multiclass sve_prefetch { + // reg + imm + let AddedComplexity = 2 in { + def _reg_imm : Pat<(prefetch (PredTy PPR_3b:$gp), (am_sve_indexed_s6 GPR64sp:$base, simm6s1:$offset), (i32 sve_prfop:$prfop)), + (RegImmInst sve_prfop:$prfop, PPR_3b:$gp, GPR64:$base, simm6s1:$offset)>; + } + + let AddedComplexity = 1 in { + def _reg_reg : Pat<(prefetch (PredTy PPR_3b:$gp), (AddrCP GPR64sp:$base, GPR64:$index), (i32 sve_prfop:$prfop)), + (RegRegInst sve_prfop:$prfop, PPR_3b:$gp, GPR64:$base, GPR64:$index)>; + } + + // default fallback + def _default : Pat<(prefetch (PredTy PPR_3b:$gp), GPR64:$base, (i32 sve_prfop:$prfop)), + (RegImmInst sve_prfop:$prfop, PPR_3b:$gp, GPR64:$base, (i64 0))>; + } + + defm : sve_prefetch; + defm : sve_prefetch; + defm : sve_prefetch; + defm : sve_prefetch; + + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + + defm : pred_store; + defm : pred_store; + defm : pred_store; + defm : pred_store; + +} // Predicates = [HasSVE] + +let Predicates = [HasSVE2] in { + // SVE2 integer multiply-add (indexed) + defm MLA_ZZZI : sve2_int_mla_by_indexed_elem<0b01, 0b0, "mla", int_aarch64_sve_mla_lane>; + defm MLS_ZZZI : sve2_int_mla_by_indexed_elem<0b01, 0b1, "mls", int_aarch64_sve_mls_lane>; + + // SVE2 saturating multiply-add high (indexed) + defm SQRDMLAH_ZZZI : sve2_int_mla_by_indexed_elem<0b10, 0b0, "sqrdmlah", int_aarch64_sve_sqrdmlah_lane>; + defm SQRDMLSH_ZZZI : sve2_int_mla_by_indexed_elem<0b10, 0b1, "sqrdmlsh", int_aarch64_sve_sqrdmlsh_lane>; + + // SVE2 saturating multiply-add high (vectors, unpredicated) + defm SQRDMLAH_ZZZ : sve2_int_mla<0b0, "sqrdmlah", int_aarch64_sve_sqrdmlah>; + defm SQRDMLSH_ZZZ : sve2_int_mla<0b1, "sqrdmlsh", int_aarch64_sve_sqrdmlsh>; + + // SVE2 integer multiply (indexed) + defm MUL_ZZZI : sve2_int_mul_by_indexed_elem<0b1110, "mul", int_aarch64_sve_mul_lane>; + + // SVE2 saturating multiply high (indexed) + defm SQDMULH_ZZZI : sve2_int_mul_by_indexed_elem<0b1100, "sqdmulh", int_aarch64_sve_sqdmulh_lane>; + defm SQRDMULH_ZZZI : sve2_int_mul_by_indexed_elem<0b1101, "sqrdmulh", int_aarch64_sve_sqrdmulh_lane>; + + // SVE2 signed saturating doubling multiply high (unpredicated) + defm SQDMULH_ZZZ : sve2_int_mul<0b100, "sqdmulh", int_aarch64_sve_sqdmulh>; + defm SQRDMULH_ZZZ : sve2_int_mul<0b101, "sqrdmulh", int_aarch64_sve_sqrdmulh>; + + // SVE2 integer multiply vectors (unpredicated) + let AddedComplexity = 1 in { + defm MUL_ZZZ : sve2_int_mul<0b000, "mul", mul>; + } + + defm SMULH_ZZZ : sve2_int_mul<0b010, "smulh">; + defm UMULH_ZZZ : sve2_int_mul<0b011, "umulh">; + defm PMUL_ZZZ_B : sve2_int_mul_single<0b00, 0b001, "pmul", int_aarch64_sve_pmul, ZPR8, nxv16i8>; + + // SVE2 complex integer dot product (indexed) + defm CDOT_ZZZI : sve2_cintx_dot_by_indexed_elem<"cdot", int_aarch64_sve_cdot_lane>; + + // SVE2 complex integer dot product + defm CDOT_ZZZ : sve2_cintx_dot<"cdot", int_aarch64_sve_cdot>; + + // SVE2 complex integer multiply-add (indexed) + defm CMLA_ZZZI : sve2_cmla_by_indexed_elem<0b0, "cmla", int_aarch64_sve_cmla_lane_x>; + // SVE2 complex saturating multiply-add (indexed) + defm SQRDCMLAH_ZZZI : sve2_cmla_by_indexed_elem<0b1, "sqrdcmlah", int_aarch64_sve_sqrdcmlah_lane_x>; + + // SVE2 complex integer multiply-add + defm CMLA_ZZZ : sve2_int_cmla<0b0, "cmla", int_aarch64_sve_cmla_x>; + defm SQRDCMLAH_ZZZ : sve2_int_cmla<0b1, "sqrdcmlah", int_aarch64_sve_sqrdcmlah_x>; + + // SVE2 integer multiply long (indexed) + defm SMULLB_ZZZI : sve2_int_mul_long_by_indexed_elem<0b000, "smullb", int_aarch64_sve_smullb_lane>; + defm SMULLT_ZZZI : sve2_int_mul_long_by_indexed_elem<0b001, "smullt", int_aarch64_sve_smullt_lane>; + defm UMULLB_ZZZI : sve2_int_mul_long_by_indexed_elem<0b010, "umullb", int_aarch64_sve_umullb_lane>; + defm UMULLT_ZZZI : sve2_int_mul_long_by_indexed_elem<0b011, "umullt", int_aarch64_sve_umullt_lane>; + + // SVE2 saturating multiply (indexed) + defm SQDMULLB_ZZZI : sve2_int_mul_long_by_indexed_elem<0b100, "sqdmullb", int_aarch64_sve_sqdmullb_lane>; + defm SQDMULLT_ZZZI : sve2_int_mul_long_by_indexed_elem<0b101, "sqdmullt", int_aarch64_sve_sqdmullt_lane>; + + // SVE2 integer multiply-add long (indexed) + defm SMLALB_ZZZI : sve2_int_mla_long_by_indexed_elem<0b1000, "smlalb", int_aarch64_sve_smlalb_lane>; + defm SMLALT_ZZZI : sve2_int_mla_long_by_indexed_elem<0b1001, "smlalt", int_aarch64_sve_smlalt_lane>; + defm UMLALB_ZZZI : sve2_int_mla_long_by_indexed_elem<0b1010, "umlalb", int_aarch64_sve_umlalb_lane>; + defm UMLALT_ZZZI : sve2_int_mla_long_by_indexed_elem<0b1011, "umlalt", int_aarch64_sve_umlalt_lane>; + defm SMLSLB_ZZZI : sve2_int_mla_long_by_indexed_elem<0b1100, "smlslb", int_aarch64_sve_smlslb_lane>; + defm SMLSLT_ZZZI : sve2_int_mla_long_by_indexed_elem<0b1101, "smlslt", int_aarch64_sve_smlslt_lane>; + defm UMLSLB_ZZZI : sve2_int_mla_long_by_indexed_elem<0b1110, "umlslb", int_aarch64_sve_umlslb_lane>; + defm UMLSLT_ZZZI : sve2_int_mla_long_by_indexed_elem<0b1111, "umlslt", int_aarch64_sve_umlslt_lane>; + + // SVE2 integer multiply-add long (vectors, unpredicated) + defm SMLALB_ZZZ : sve2_int_mla_long<0b10000, "smlalb", int_aarch64_sve_smlalb>; + defm SMLALT_ZZZ : sve2_int_mla_long<0b10001, "smlalt", int_aarch64_sve_smlalt>; + defm UMLALB_ZZZ : sve2_int_mla_long<0b10010, "umlalb", int_aarch64_sve_umlalb>; + defm UMLALT_ZZZ : sve2_int_mla_long<0b10011, "umlalt", int_aarch64_sve_umlalt>; + defm SMLSLB_ZZZ : sve2_int_mla_long<0b10100, "smlslb", int_aarch64_sve_smlslb>; + defm SMLSLT_ZZZ : sve2_int_mla_long<0b10101, "smlslt", int_aarch64_sve_smlslt>; + defm UMLSLB_ZZZ : sve2_int_mla_long<0b10110, "umlslb", int_aarch64_sve_umlslb>; + defm UMLSLT_ZZZ : sve2_int_mla_long<0b10111, "umlslt", int_aarch64_sve_umlslt>; + + // SVE2 saturating multiply-add long (indexed) + defm SQDMLALB_ZZZI : sve2_int_mla_long_by_indexed_elem<0b0100, "sqdmlalb", int_aarch64_sve_sqdmlalb_lane>; + defm SQDMLALT_ZZZI : sve2_int_mla_long_by_indexed_elem<0b0101, "sqdmlalt", int_aarch64_sve_sqdmlalt_lane>; + defm SQDMLSLB_ZZZI : sve2_int_mla_long_by_indexed_elem<0b0110, "sqdmlslb", int_aarch64_sve_sqdmlslb_lane>; + defm SQDMLSLT_ZZZI : sve2_int_mla_long_by_indexed_elem<0b0111, "sqdmlslt", int_aarch64_sve_sqdmlslt_lane>; + + // SVE2 saturating multiply-add long (vectors, unpredicated) + defm SQDMLALB_ZZZ : sve2_int_mla_long<0b11000, "sqdmlalb", int_aarch64_sve_sqdmlalb>; + defm SQDMLALT_ZZZ : sve2_int_mla_long<0b11001, "sqdmlalt", int_aarch64_sve_sqdmlalt>; + defm SQDMLSLB_ZZZ : sve2_int_mla_long<0b11010, "sqdmlslb", int_aarch64_sve_sqdmlslb>; + defm SQDMLSLT_ZZZ : sve2_int_mla_long<0b11011, "sqdmlslt", int_aarch64_sve_sqdmlslt>; + + // SVE2 saturating multiply-add interleaved long + defm SQDMLALBT_ZZZ : sve2_int_mla_long<0b00010, "sqdmlalbt", int_aarch64_sve_sqdmlalbt>; + defm SQDMLSLBT_ZZZ : sve2_int_mla_long<0b00011, "sqdmlslbt", int_aarch64_sve_sqdmlslbt>; + + // SVE2 integer halving add/subtract (predicated) + defm SHADD_ZPmZ : sve2_int_arith_pred<0b100000, "shadd", "SHADD_ZPZZ", int_aarch64_sve_shadd>; + defm UHADD_ZPmZ : sve2_int_arith_pred<0b100010, "uhadd", "UHADD_ZPZZ", int_aarch64_sve_uhadd>; + defm SHSUB_ZPmZ : sve2_int_arith_pred<0b100100, "shsub", "SHSUB_ZPZZ", int_aarch64_sve_shsub>; + defm UHSUB_ZPmZ : sve2_int_arith_pred<0b100110, "uhsub", "UHSUB_ZPZZ", int_aarch64_sve_uhsub>; + defm SRHADD_ZPmZ : sve2_int_arith_pred<0b101000, "srhadd", "SRHADD_ZPZZ", int_aarch64_sve_srhadd>; + defm URHADD_ZPmZ : sve2_int_arith_pred<0b101010, "urhadd", "URHADD_ZPZZ", int_aarch64_sve_urhadd>; + defm SHSUBR_ZPmZ : sve2_int_arith_pred<0b101100, "shsubr", "SHSUBR_ZPZZ", int_aarch64_sve_shsubr>; + defm UHSUBR_ZPmZ : sve2_int_arith_pred<0b101110, "uhsubr", "UHSUBR_ZPZZ", int_aarch64_sve_uhsubr>; + + // SVE2 integer pairwise add and accumulate long + defm SADALP_ZPmZ : sve2_int_sadd_long_accum_pairwise<0, "sadalp", "SADALP_ZPZZ", int_aarch64_sve_sadalp>; + defm UADALP_ZPmZ : sve2_int_sadd_long_accum_pairwise<1, "uadalp", "UADALP_ZPZZ", int_aarch64_sve_uadalp>; + + // SVE2 integer pairwise arithmetic + defm ADDP_ZPmZ : sve2_int_arith_pred<0b100011, "addp", "ADDP_ZPZZ", int_aarch64_sve_addp>; + defm SMAXP_ZPmZ : sve2_int_arith_pred<0b101001, "smaxp", "SMAXP_ZPZZ", int_aarch64_sve_smaxp>; + defm UMAXP_ZPmZ : sve2_int_arith_pred<0b101011, "umaxp", "UMAXP_ZPZZ", int_aarch64_sve_umaxp>; + defm SMINP_ZPmZ : sve2_int_arith_pred<0b101101, "sminp", "SMINP_ZPZZ", int_aarch64_sve_sminp>; + defm UMINP_ZPmZ : sve2_int_arith_pred<0b101111, "uminp", "UMINP_ZPZZ", int_aarch64_sve_uminp>; + + // SVE2 integer unary operations (predicated) + defm URECPE_ZPmZ : sve2_int_un_pred_arit_s<0b000, "urecpe", int_aarch64_sve_urecpe>; + defm URSQRTE_ZPmZ : sve2_int_un_pred_arit_s<0b001, "ursqrte", int_aarch64_sve_ursqrte>; + defm SQABS_ZPmZ : sve2_int_un_pred_arit<0b100, "sqabs", int_aarch64_sve_sqabs>; + defm SQNEG_ZPmZ : sve2_int_un_pred_arit<0b101, "sqneg", int_aarch64_sve_sqneg>; + + // SVE2 saturating add/subtract + defm SQADD_ZPmZ : sve2_int_arith_pred<0b110000, "sqadd", "SQADD_ZPZZ", int_aarch64_sve_sqadd>; + defm UQADD_ZPmZ : sve2_int_arith_pred<0b110010, "uqadd", "UQADD_ZPZZ", int_aarch64_sve_uqadd>; + defm SQSUB_ZPmZ : sve2_int_arith_pred<0b110100, "sqsub", "SQSUB_ZPZZ", int_aarch64_sve_sqsub>; + defm UQSUB_ZPmZ : sve2_int_arith_pred<0b110110, "uqsub", "UQSUB_ZPZZ", int_aarch64_sve_uqsub>; + defm SUQADD_ZPmZ : sve2_int_arith_pred<0b111000, "suqadd", "SUQADD_ZPZZ", int_aarch64_sve_suqadd>; + defm USQADD_ZPmZ : sve2_int_arith_pred<0b111010, "usqadd", "USQADD_ZPZZ", int_aarch64_sve_usqadd>; + defm SQSUBR_ZPmZ : sve2_int_arith_pred<0b111100, "sqsubr", "SQSUBR_ZPZZ", int_aarch64_sve_sqsubr>; + defm UQSUBR_ZPmZ : sve2_int_arith_pred<0b111110, "uqsubr", "UQSUBR_ZPZZ", int_aarch64_sve_uqsubr>; + + // SVE2 saturating/rounding bitwise shift left (predicated) + defm SRSHL_ZPmZ : sve2_int_arith_pred<0b000100, "srshl", "SRSHL_ZPZZ", int_aarch64_sve_srshl>; + defm URSHL_ZPmZ : sve2_int_arith_pred<0b000110, "urshl", "URSHL_ZPZZ", int_aarch64_sve_urshl>; + defm SRSHLR_ZPmZ : sve2_int_arith_pred<0b001100, "srshlr", "SRSHLR_ZPZZ">; + defm URSHLR_ZPmZ : sve2_int_arith_pred<0b001110, "urshlr", "URSHLR_ZPZZ">; + defm SQSHL_ZPmZ : sve2_int_arith_pred<0b010000, "sqshl", "SQSHL_ZPZZ", int_aarch64_sve_sqshl>; + defm UQSHL_ZPmZ : sve2_int_arith_pred<0b010010, "uqshl", "UQSHL_ZPZZ", int_aarch64_sve_uqshl>; + defm SQRSHL_ZPmZ : sve2_int_arith_pred<0b010100, "sqrshl", "SQRSHL_ZPZZ", int_aarch64_sve_sqrshl>; + defm UQRSHL_ZPmZ : sve2_int_arith_pred<0b010110, "uqrshl", "UQRSHL_ZPZZ", int_aarch64_sve_uqrshl>; + defm SQSHLR_ZPmZ : sve2_int_arith_pred<0b011000, "sqshlr", "SQSHLR_ZPZZ">; + defm UQSHLR_ZPmZ : sve2_int_arith_pred<0b011010, "uqshlr", "UQSHLR_ZPZZ">; + defm SQRSHLR_ZPmZ : sve2_int_arith_pred<0b011100, "sqrshlr", "SQRSHLR_ZPZZ">; + defm UQRSHLR_ZPmZ : sve2_int_arith_pred<0b011110, "uqrshlr", "UQRSHLR_ZPZZ">; + + // SVE2 predicated shifts + defm SQSHL_ZPmI : sve_int_bin_pred_shift_imm_left< 0b0110, "sqshl">; + defm UQSHL_ZPmI : sve_int_bin_pred_shift_imm_left< 0b0111, "uqshl">; + defm SQSHLU_ZPmI : sve2_int_bin_pred_shift_imm_left< 0b1111, "sqshlu", "SQSHLU_ZPZI", int_aarch64_sve_sqshlu>; + defm SRSHR_ZPmI : sve_int_bin_pred_shift_imm_right<0b1100, "srshr", "SRSHR_ZPZI", int_aarch64_sve_srshr>; + defm URSHR_ZPmI : sve_int_bin_pred_shift_imm_right<0b1101, "urshr", "URSHR_ZPZI", int_aarch64_sve_urshr>; + defm SRSHR_ZPZI : sve_int_bin_pred_shift_0_right_zx; + defm URSHR_ZPZI : sve_int_bin_pred_shift_0_right_zx; + + // SVE2 integer add/subtract long + defm SADDLB_ZZZ : sve2_wide_int_arith_long<0b00000, "saddlb", int_aarch64_sve_saddlb>; + defm SADDLT_ZZZ : sve2_wide_int_arith_long<0b00001, "saddlt", int_aarch64_sve_saddlt>; + defm UADDLB_ZZZ : sve2_wide_int_arith_long<0b00010, "uaddlb", int_aarch64_sve_uaddlb>; + defm UADDLT_ZZZ : sve2_wide_int_arith_long<0b00011, "uaddlt", int_aarch64_sve_uaddlt>; + defm SSUBLB_ZZZ : sve2_wide_int_arith_long<0b00100, "ssublb", int_aarch64_sve_ssublb>; + defm SSUBLT_ZZZ : sve2_wide_int_arith_long<0b00101, "ssublt", int_aarch64_sve_ssublt>; + defm USUBLB_ZZZ : sve2_wide_int_arith_long<0b00110, "usublb", int_aarch64_sve_usublb>; + defm USUBLT_ZZZ : sve2_wide_int_arith_long<0b00111, "usublt", int_aarch64_sve_usublt>; + defm SABDLB_ZZZ : sve2_wide_int_arith_long<0b01100, "sabdlb", int_aarch64_sve_sabdlb>; + defm SABDLT_ZZZ : sve2_wide_int_arith_long<0b01101, "sabdlt", int_aarch64_sve_sabdlt>; + defm UABDLB_ZZZ : sve2_wide_int_arith_long<0b01110, "uabdlb", int_aarch64_sve_uabdlb>; + defm UABDLT_ZZZ : sve2_wide_int_arith_long<0b01111, "uabdlt", int_aarch64_sve_uabdlt>; + + // SVE2 integer add/subtract wide + defm SADDWB_ZZZ : sve2_wide_int_arith_wide<0b000, "saddwb", int_aarch64_sve_saddwb>; + defm SADDWT_ZZZ : sve2_wide_int_arith_wide<0b001, "saddwt", int_aarch64_sve_saddwt>; + defm UADDWB_ZZZ : sve2_wide_int_arith_wide<0b010, "uaddwb", int_aarch64_sve_uaddwb>; + defm UADDWT_ZZZ : sve2_wide_int_arith_wide<0b011, "uaddwt", int_aarch64_sve_uaddwt>; + defm SSUBWB_ZZZ : sve2_wide_int_arith_wide<0b100, "ssubwb", int_aarch64_sve_ssubwb>; + defm SSUBWT_ZZZ : sve2_wide_int_arith_wide<0b101, "ssubwt", int_aarch64_sve_ssubwt>; + defm USUBWB_ZZZ : sve2_wide_int_arith_wide<0b110, "usubwb", int_aarch64_sve_usubwb>; + defm USUBWT_ZZZ : sve2_wide_int_arith_wide<0b111, "usubwt", int_aarch64_sve_usubwt>; + + // SVE2 integer multiply long + defm SQDMULLB_ZZZ : sve2_wide_int_arith_long<0b11000, "sqdmullb", int_aarch64_sve_sqdmullb>; + defm SQDMULLT_ZZZ : sve2_wide_int_arith_long<0b11001, "sqdmullt", int_aarch64_sve_sqdmullt>; + defm SMULLB_ZZZ : sve2_wide_int_arith_long<0b11100, "smullb", int_aarch64_sve_smullb>; + defm SMULLT_ZZZ : sve2_wide_int_arith_long<0b11101, "smullt", int_aarch64_sve_smullt>; + defm UMULLB_ZZZ : sve2_wide_int_arith_long<0b11110, "umullb", int_aarch64_sve_umullb>; + defm UMULLT_ZZZ : sve2_wide_int_arith_long<0b11111, "umullt", int_aarch64_sve_umullt>; + defm PMULLB_ZZZ : sve2_pmul_long<0b0, "pmullb", int_aarch64_sve_pmullb_pair>; + defm PMULLT_ZZZ : sve2_pmul_long<0b1, "pmullt", int_aarch64_sve_pmullt_pair>; + + // SVE2 bitwise shift and insert + defm SRI_ZZI : sve2_int_bin_shift_imm_right<0b0, "sri", int_aarch64_sve_sri>; + defm SLI_ZZI : sve2_int_bin_shift_imm_left< 0b1, "sli", int_aarch64_sve_sli>; + + // SVE2 bitwise shift right and accumulate + defm SSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b00, "ssra", int_aarch64_sve_ssra>; + defm USRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b01, "usra", int_aarch64_sve_usra>; + defm SRSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b10, "srsra", int_aarch64_sve_srsra>; + defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra>; + + // SVE2 complex integer add + defm CADD_ZZI : sve2_int_cadd<0b0, "cadd", int_aarch64_sve_cadd_x>; + defm SQCADD_ZZI : sve2_int_cadd<0b1, "sqcadd", int_aarch64_sve_sqcadd_x>; + + // SVE2 integer absolute difference and accumulate + defm SABA_ZZZ : sve2_int_absdiff_accum<0b0, "saba", int_aarch64_sve_saba>; + defm UABA_ZZZ : sve2_int_absdiff_accum<0b1, "uaba", int_aarch64_sve_uaba >; + + // SVE2 integer absolute difference and accumulate long + defm SABALB_ZZZ : sve2_int_absdiff_accum_long<0b00, "sabalb", int_aarch64_sve_sabalb>; + defm SABALT_ZZZ : sve2_int_absdiff_accum_long<0b01, "sabalt", int_aarch64_sve_sabalt>; + defm UABALB_ZZZ : sve2_int_absdiff_accum_long<0b10, "uabalb", int_aarch64_sve_uabalb>; + defm UABALT_ZZZ : sve2_int_absdiff_accum_long<0b11, "uabalt", int_aarch64_sve_uabalt>; + + // SVE2 integer add/subtract long with carry + defm ADCLB_ZZZ : sve2_int_addsub_long_carry<0b00, "adclb", int_aarch64_sve_adclb>; + defm ADCLT_ZZZ : sve2_int_addsub_long_carry<0b01, "adclt", int_aarch64_sve_adclt>; + defm SBCLB_ZZZ : sve2_int_addsub_long_carry<0b10, "sbclb", int_aarch64_sve_sbclb>; + defm SBCLT_ZZZ : sve2_int_addsub_long_carry<0b11, "sbclt", int_aarch64_sve_sbclt>; + + // SVE2 bitwise shift right narrow (bottom) + defm SQSHRUNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b000, "sqshrunb", int_aarch64_sve_sqshrunb>; + defm SQRSHRUNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b001, "sqrshrunb", int_aarch64_sve_sqrshrunb>; + defm SHRNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b010, "shrnb", int_aarch64_sve_shrnb>; + defm RSHRNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b011, "rshrnb", int_aarch64_sve_rshrnb>; + defm SQSHRNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b100, "sqshrnb", int_aarch64_sve_sqshrnb>; + defm SQRSHRNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b101, "sqrshrnb", int_aarch64_sve_sqrshrnb>; + defm UQSHRNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b110, "uqshrnb", int_aarch64_sve_uqshrnb>; + defm UQRSHRNB_ZZI : sve2_int_bin_shift_imm_right_narrow_bottom<0b111, "uqrshrnb", int_aarch64_sve_uqrshrnb>; + + // SVE2 bitwise shift right narrow (top) + defm SQSHRUNT_ZZI : sve2_int_bin_shift_imm_right_narrow_top<0b000, "sqshrunt", int_aarch64_sve_sqshrunt>; + defm SQRSHRUNT_ZZI : sve2_int_bin_shift_imm_right_narrow_top<0b001, "sqrshrunt", int_aarch64_sve_sqrshrunt>; + defm SHRNT_ZZI : sve2_int_bin_shift_imm_right_narrow_top<0b010, "shrnt", int_aarch64_sve_shrnt>; + defm RSHRNT_ZZI : sve2_int_bin_shift_imm_right_narrow_top<0b011, "rshrnt", int_aarch64_sve_rshrnt>; + defm SQSHRNT_ZZI : sve2_int_bin_shift_imm_right_narrow_top<0b100, "sqshrnt", int_aarch64_sve_sqshrnt>; + defm SQRSHRNT_ZZI : sve2_int_bin_shift_imm_right_narrow_top<0b101, "sqrshrnt", int_aarch64_sve_sqrshrnt>; + defm UQSHRNT_ZZI : sve2_int_bin_shift_imm_right_narrow_top<0b110, "uqshrnt", int_aarch64_sve_uqshrnt>; + defm UQRSHRNT_ZZI : sve2_int_bin_shift_imm_right_narrow_top<0b111, "uqrshrnt", int_aarch64_sve_uqrshrnt>; + + // SVE2 integer add/subtract narrow high part (bottom) + defm ADDHNB_ZZZ : sve2_int_addsub_narrow_high_bottom<0b00, "addhnb", int_aarch64_sve_addhnb>; + defm RADDHNB_ZZZ : sve2_int_addsub_narrow_high_bottom<0b01, "raddhnb", int_aarch64_sve_raddhnb>; + defm SUBHNB_ZZZ : sve2_int_addsub_narrow_high_bottom<0b10, "subhnb", int_aarch64_sve_subhnb>; + defm RSUBHNB_ZZZ : sve2_int_addsub_narrow_high_bottom<0b11, "rsubhnb", int_aarch64_sve_rsubhnb>; + + // SVE2 integer add/subtract narrow high part (top) + defm ADDHNT_ZZZ : sve2_int_addsub_narrow_high_top<0b00, "addhnt", int_aarch64_sve_addhnt>; + defm RADDHNT_ZZZ : sve2_int_addsub_narrow_high_top<0b01, "raddhnt", int_aarch64_sve_raddhnt>; + defm SUBHNT_ZZZ : sve2_int_addsub_narrow_high_top<0b10, "subhnt", int_aarch64_sve_subhnt>; + defm RSUBHNT_ZZZ : sve2_int_addsub_narrow_high_top<0b11, "rsubhnt", int_aarch64_sve_rsubhnt>; + + // SVE2 saturating extract narrow (bottom) + defm SQXTNB_ZZ : sve2_int_sat_extract_narrow_bottom<0b00, "sqxtnb", int_aarch64_sve_sqxtnb>; + defm UQXTNB_ZZ : sve2_int_sat_extract_narrow_bottom<0b01, "uqxtnb", int_aarch64_sve_uqxtnb>; + defm SQXTUNB_ZZ : sve2_int_sat_extract_narrow_bottom<0b10, "sqxtunb", int_aarch64_sve_sqxtunb>; + + // SVE2 saturating extract narrow (top) + defm SQXTNT_ZZ : sve2_int_sat_extract_narrow_top<0b00, "sqxtnt", int_aarch64_sve_sqxtnt>; + defm UQXTNT_ZZ : sve2_int_sat_extract_narrow_top<0b01, "uqxtnt", int_aarch64_sve_uqxtnt>; + defm SQXTUNT_ZZ : sve2_int_sat_extract_narrow_top<0b10, "sqxtunt", int_aarch64_sve_sqxtunt>; + + // SVE2 character match + defm MATCH_PPzZZ : sve2_char_match<0b0, "match", int_aarch64_sve_match>; + defm NMATCH_PPzZZ : sve2_char_match<0b1, "nmatch", int_aarch64_sve_nmatch>; + + // SVE2 bitwise exclusive-or interleaved + defm EORBT_ZZZ : sve2_bitwise_xor_interleaved<0b0, "eorbt", int_aarch64_sve_eorbt>; + defm EORTB_ZZZ : sve2_bitwise_xor_interleaved<0b1, "eortb", int_aarch64_sve_eortb>; + + // SVE2 bitwise shift left long + defm SSHLLB_ZZI : sve2_bitwise_shift_left_long<0b00, "sshllb", int_aarch64_sve_sshllb>; + defm SSHLLT_ZZI : sve2_bitwise_shift_left_long<0b01, "sshllt", int_aarch64_sve_sshllt>; + defm USHLLB_ZZI : sve2_bitwise_shift_left_long<0b10, "ushllb", int_aarch64_sve_ushllb>; + defm USHLLT_ZZI : sve2_bitwise_shift_left_long<0b11, "ushllt", int_aarch64_sve_ushllt>; + + // SVE2 integer add/subtract interleaved long + defm SADDLBT_ZZZ : sve2_misc_int_addsub_long_interleaved<0b00, "saddlbt", int_aarch64_sve_saddlbt>; + defm SSUBLBT_ZZZ : sve2_misc_int_addsub_long_interleaved<0b10, "ssublbt", int_aarch64_sve_ssublbt>; + defm SSUBLTB_ZZZ : sve2_misc_int_addsub_long_interleaved<0b11, "ssubltb", int_aarch64_sve_ssubltb>; + + // SVE2 histogram generation (segment) + def HISTSEG_ZZZ : sve2_hist_gen_segment<"histseg", int_aarch64_sve_histseg>; + + // SVE2 histogram generation (vector) + defm HISTCNT_ZPzZZ : sve2_hist_gen_vector<"histcnt", int_aarch64_sve_histcnt>; + + // SVE2 floating-point base 2 logarithm as integer + defm FLOGB_ZPmZ : sve2_fp_flogb<"flogb", int_aarch64_sve_flogb>; + + // SVE2 floating-point convert precision + defm FCVTXNT_ZPmZ : sve2_fp_convert_down_odd_rounding_top<"fcvtxnt", "int_aarch64_sve_fcvtxnt">; + defm FCVTX_ZPmZ : sve2_fp_convert_down_odd_rounding<"fcvtx", "int_aarch64_sve_fcvtx">; + defm FCVTNT_ZPmZ : sve2_fp_convert_down_narrow<"fcvtnt", "int_aarch64_sve_fcvtnt">; + defm FCVTLT_ZPmZ : sve2_fp_convert_up_long<"fcvtlt", "int_aarch64_sve_fcvtlt">; + + // SVE2 floating-point pairwise operations + defm FADDP_ZPmZZ : sve2_fp_pairwise_pred<0b000, "faddp", "FADDP_ZPZZZ", int_aarch64_sve_faddp>; + defm FMAXNMP_ZPmZZ : sve2_fp_pairwise_pred<0b100, "fmaxnmp", "FMAXNMP_ZPZZ", int_aarch64_sve_fmaxnmp>; + defm FMINNMP_ZPmZZ : sve2_fp_pairwise_pred<0b101, "fminnmp", "FMINNMP_ZPZZZ", int_aarch64_sve_fminnmp>; + defm FMAXP_ZPmZZ : sve2_fp_pairwise_pred<0b110, "fmaxp", "FMAXP_ZPZZZ", int_aarch64_sve_fmaxp>; + defm FMINP_ZPmZZ : sve2_fp_pairwise_pred<0b111, "fminp", "FMINP_ZPZZZ", int_aarch64_sve_fminp>; + + // SVE2 floating-point multiply-add long (indexed) + defm FMLALB_ZZZI_SHH : sve2_fp_mla_long_by_indexed_elem<0b00, "fmlalb", int_aarch64_sve_fmlalb_lane>; + defm FMLALT_ZZZI_SHH : sve2_fp_mla_long_by_indexed_elem<0b01, "fmlalt", int_aarch64_sve_fmlalt_lane>; + defm FMLSLB_ZZZI_SHH : sve2_fp_mla_long_by_indexed_elem<0b10, "fmlslb", int_aarch64_sve_fmlslb_lane>; + defm FMLSLT_ZZZI_SHH : sve2_fp_mla_long_by_indexed_elem<0b11, "fmlslt", int_aarch64_sve_fmlslt_lane>; + + // SVE2 floating-point multiply-add long + defm FMLALB_ZZZ_SHH : sve2_fp_mla_long<0b00, "fmlalb", int_aarch64_sve_fmlalb>; + defm FMLALT_ZZZ_SHH : sve2_fp_mla_long<0b01, "fmlalt", int_aarch64_sve_fmlalt>; + defm FMLSLB_ZZZ_SHH : sve2_fp_mla_long<0b10, "fmlslb", int_aarch64_sve_fmlslb>; + defm FMLSLT_ZZZ_SHH : sve2_fp_mla_long<0b11, "fmlslt", int_aarch64_sve_fmlslt>; + + // SVE2 bitwise ternary operations + defm EOR3_ZZZZ_D : sve2_int_bitwise_ternary_op<0b000, "eor3">; + defm BCAX_ZZZZ_D : sve2_int_bitwise_ternary_op<0b010, "bcax">; + def BSL_ZZZZ_D : sve2_int_bitwise_ternary_op_d<0b001, "bsl">; + def BSL1N_ZZZZ_D : sve2_int_bitwise_ternary_op_d<0b011, "bsl1n">; + def BSL2N_ZZZZ_D : sve2_int_bitwise_ternary_op_d<0b101, "bsl2n">; + def NBSL_ZZZZ_D : sve2_int_bitwise_ternary_op_d<0b111, "nbsl">; + + // SVE2 bitwise xor and rotate right by immediate + defm XAR_ZZZI : sve2_int_rotate_right_imm<"xar", int_aarch64_sve_xar>; + + // SVE2 extract vector (immediate offset, constructive) + def EXT_ZZI_B : sve2_int_perm_extract_i_cons<"ext">; + + // SVE2 non-temporal gather loads + defm LDNT1SB_ZZR_S : sve2_mem_gldnt_vs<0b00000, "ldnt1sb", Z_s, ZPR32, AArch64ldnt1s_gather_uxtw, nxv4i32, nxv4i1, nxv4i8>; + defm LDNT1B_ZZR_S : sve2_mem_gldnt_vs<0b00001, "ldnt1b", Z_s, ZPR32, AArch64ldnt1_gather_uxtw, nxv4i32, nxv4i1, nxv4i8>; + defm LDNT1SH_ZZR_S : sve2_mem_gldnt_vs<0b00100, "ldnt1sh", Z_s, ZPR32, AArch64ldnt1s_gather_uxtw, nxv4i32, nxv4i1, nxv4i16>; + defm LDNT1H_ZZR_S : sve2_mem_gldnt_vs<0b00101, "ldnt1h", Z_s, ZPR32, AArch64ldnt1_gather_uxtw, nxv4i32, nxv4i1, nxv4i16>; + defm LDNT1W_ZZR_S : sve2_mem_gldnt_vs<0b01001, "ldnt1w", Z_s, ZPR32, AArch64ldnt1_gather_uxtw, nxv4i32, nxv4i1, nxv4i32>; + + defm LDNT1SB_ZZR_D : sve2_mem_gldnt_vs<0b10000, "ldnt1sb", Z_d, ZPR64, AArch64ldnt1s_gather, nxv2i64, nxv2i1, nxv2i8>; + defm LDNT1B_ZZR_D : sve2_mem_gldnt_vs<0b10010, "ldnt1b", Z_d, ZPR64, AArch64ldnt1_gather, nxv2i64, nxv2i1, nxv2i8>; + defm LDNT1SH_ZZR_D : sve2_mem_gldnt_vs<0b10100, "ldnt1sh", Z_d, ZPR64, AArch64ldnt1s_gather, nxv2i64, nxv2i1, nxv2i16>; + defm LDNT1H_ZZR_D : sve2_mem_gldnt_vs<0b10110, "ldnt1h", Z_d, ZPR64, AArch64ldnt1_gather, nxv2i64, nxv2i1, nxv2i16>; + defm LDNT1SW_ZZR_D : sve2_mem_gldnt_vs<0b11000, "ldnt1sw", Z_d, ZPR64, AArch64ldnt1s_gather, nxv2i64, nxv2i1, nxv2i32>; + defm LDNT1W_ZZR_D : sve2_mem_gldnt_vs<0b11010, "ldnt1w", Z_d, ZPR64, AArch64ldnt1_gather, nxv2i64, nxv2i1, nxv2i32>; + defm LDNT1D_ZZR_D : sve2_mem_gldnt_vs<0b11110, "ldnt1d", Z_d, ZPR64, AArch64ldnt1_gather, nxv2i64, nxv2i1, nxv2i64>; + + def : Pat <(nxv4f16 (AArch64ldnt1_gather_uxtw (nxv4i1 PPR_3b:$Pg), (i64 GPR64:$Rm), (nxv4i32 ZPR32:$Zd), nxv4f16)), + (!cast(LDNT1H_ZZR_S_REAL) PPR:$Pg, ZPR:$Zd, GPR64:$Rm)>; + def : Pat <(nxv4f32 (AArch64ldnt1_gather_uxtw (nxv4i1 PPR_3b:$Pg), (i64 GPR64:$Rm), (nxv4i32 ZPR32:$Zd), nxv4f32)), + (!cast(LDNT1W_ZZR_S_REAL) PPR:$Pg, ZPR:$Zd, GPR64:$Rm)>; + def : Pat <(nxv2f16 (AArch64ldnt1_gather (nxv2i1 PPR_3b:$Pg), (i64 GPR64:$Rm), (nxv2i64 ZPR64:$Zd), nxv2f16)), + (!cast(LDNT1H_ZZR_D_REAL) PPR:$Pg, ZPR:$Zd, GPR64:$Rm)>; + def : Pat <(nxv2f32 (AArch64ldnt1_gather (nxv2i1 PPR_3b:$Pg), (i64 GPR64:$Rm), (nxv2i64 ZPR64:$Zd), nxv2f32)), + (!cast(LDNT1W_ZZR_D_REAL) PPR:$Pg, ZPR:$Zd, GPR64:$Rm)>; + def : Pat <(nxv2f64 (AArch64ldnt1_gather (nxv2i1 PPR_3b:$Pg), (i64 GPR64:$Rm), (nxv2i64 ZPR64:$Zd), nxv2f64)), + (!cast(LDNT1D_ZZR_D_REAL) PPR:$Pg, ZPR:$Zd, GPR64:$Rm)>; + + // SVE2 vector splice (constructive) + defm SPLICE_ZPZZ : sve2_int_perm_splice_cons<"splice">; + + // SVE2 non-temporal scatter stores + defm STNT1B_ZZR_S : sve2_mem_sstnt_vs<0b001, "stnt1b", Z_s, ZPR32, AArch64stnt1_scatter_uxtw, nxv4i32, nxv4i1, nxv4i8>; + defm STNT1H_ZZR_S : sve2_mem_sstnt_vs<0b011, "stnt1h", Z_s, ZPR32, AArch64stnt1_scatter_uxtw, nxv4i32, nxv4i1, nxv4i16>; + defm STNT1W_ZZR_S : sve2_mem_sstnt_vs<0b101, "stnt1w", Z_s, ZPR32, AArch64stnt1_scatter_uxtw, nxv4i32, nxv4i1, nxv4i32>; + + defm STNT1B_ZZR_D : sve2_mem_sstnt_vs<0b000, "stnt1b", Z_d, ZPR64, AArch64stnt1_scatter, nxv2i64, nxv2i1, nxv2i8>; + defm STNT1H_ZZR_D : sve2_mem_sstnt_vs<0b010, "stnt1h", Z_d, ZPR64, AArch64stnt1_scatter, nxv2i64, nxv2i1, nxv2i16>; + defm STNT1W_ZZR_D : sve2_mem_sstnt_vs<0b100, "stnt1w", Z_d, ZPR64, AArch64stnt1_scatter, nxv2i64, nxv2i1, nxv2i32>; + defm STNT1D_ZZR_D : sve2_mem_sstnt_vs<0b110, "stnt1d", Z_d, ZPR64, AArch64stnt1_scatter, nxv2i64, nxv2i1, nxv2i64>; + + // SVE2 table lookup (three sources) + defm TBL_ZZZZ : sve2_int_perm_tbl<"tbl">; + defm TBX_ZZZ : sve2_int_perm_tbx<"tbx", int_aarch64_sve_tbx>; + + // SVE2 integer compare scalar count and limit + defm WHILEGE_PWW : sve_int_while4_rr<0b000, "whilege", int_aarch64_sve_whilege>; + defm WHILEGT_PWW : sve_int_while4_rr<0b001, "whilegt", int_aarch64_sve_whilegt>; + defm WHILEHS_PWW : sve_int_while4_rr<0b100, "whilehs", int_aarch64_sve_whilehs>; + defm WHILEHI_PWW : sve_int_while4_rr<0b101, "whilehi", int_aarch64_sve_whilehi>; + + defm WHILEGE_PXX : sve_int_while8_rr<0b000, "whilege", int_aarch64_sve_whilege>; + defm WHILEGT_PXX : sve_int_while8_rr<0b001, "whilegt", int_aarch64_sve_whilegt>; + defm WHILEHS_PXX : sve_int_while8_rr<0b100, "whilehs", int_aarch64_sve_whilehs>; + defm WHILEHI_PXX : sve_int_while8_rr<0b101, "whilehi", int_aarch64_sve_whilehi>; + + // SVE2 pointer conflict compare + defm WHILEWR_PXX : sve2_int_while_rr<0b0, "whilewr", "int_aarch64_sve_whilewr">; + defm WHILERW_PXX : sve2_int_while_rr<0b1, "whilerw", "int_aarch64_sve_whilerw">; +} + +let Predicates = [HasSVE2AES] in { + // SVE2 crypto destructive binary operations + defm AESE_ZZZ_B : sve2_crypto_des_bin_op<0b00, "aese", ZPR8, int_aarch64_sve_aese, nxv16i8>; + defm AESD_ZZZ_B : sve2_crypto_des_bin_op<0b01, "aesd", ZPR8, int_aarch64_sve_aesd, nxv16i8>; + + // SVE2 crypto unary operations + defm AESMC_ZZ_B : sve2_crypto_unary_op<0b0, "aesmc", int_aarch64_sve_aesmc>; + defm AESIMC_ZZ_B : sve2_crypto_unary_op<0b1, "aesimc", int_aarch64_sve_aesimc>; + + // PMULLB and PMULLT instructions which operate with 64-bit source and + // 128-bit destination elements are enabled with crypto extensions, similar + // to NEON PMULL2 instruction. + defm PMULLB_ZZZ_Q : sve2_wide_int_arith_pmul<0b00, 0b11010, "pmullb", int_aarch64_sve_pmullb_pair>; + defm PMULLT_ZZZ_Q : sve2_wide_int_arith_pmul<0b00, 0b11011, "pmullt", int_aarch64_sve_pmullt_pair>; +} + +let Predicates = [HasSVE2SM4] in { + // SVE2 crypto constructive binary operations + defm SM4EKEY_ZZZ_S : sve2_crypto_cons_bin_op<0b0, "sm4ekey", ZPR32, int_aarch64_sve_sm4ekey, nxv4i32>; + // SVE2 crypto destructive binary operations + defm SM4E_ZZZ_S : sve2_crypto_des_bin_op<0b10, "sm4e", ZPR32, int_aarch64_sve_sm4e, nxv4i32>; +} + +let Predicates = [HasSVE2SHA3] in { + // SVE2 crypto constructive binary operations + defm RAX1_ZZZ_D : sve2_crypto_cons_bin_op<0b1, "rax1", ZPR64, int_aarch64_sve_rax1, nxv2i64>; +} + +let Predicates = [HasSVE2BitPerm] in { + // SVE2 bitwise permute + defm BEXT_ZZZ : sve2_misc_bitwise<0b1100, "bext", int_aarch64_sve_bext_x>; + defm BDEP_ZZZ : sve2_misc_bitwise<0b1101, "bdep", int_aarch64_sve_bdep_x>; + defm BGRP_ZZZ : sve2_misc_bitwise<0b1110, "bgrp", int_aarch64_sve_bgrp_x>; } Index: lib/Target/AArch64/AArch64SchedA53.td =================================================================== --- lib/Target/AArch64/AArch64SchedA53.td +++ lib/Target/AArch64/AArch64SchedA53.td @@ -27,7 +27,7 @@ // v 1.0 Spreadsheet let CompleteModel = 1; - list UnsupportedFeatures = [HasSVE]; + list UnsupportedFeatures = SVEUnsupported.F; } Index: lib/Target/AArch64/AArch64SchedA57.td =================================================================== --- lib/Target/AArch64/AArch64SchedA57.td +++ lib/Target/AArch64/AArch64SchedA57.td @@ -32,7 +32,7 @@ let LoopMicroOpBufferSize = 16; let CompleteModel = 1; - list UnsupportedFeatures = [HasSVE]; + list UnsupportedFeatures = SVEUnsupported.F; } //===----------------------------------------------------------------------===// Index: lib/Target/AArch64/AArch64SchedCyclone.td =================================================================== --- lib/Target/AArch64/AArch64SchedCyclone.td +++ lib/Target/AArch64/AArch64SchedCyclone.td @@ -19,7 +19,7 @@ let MispredictPenalty = 16; // 14-19 cycles are typical. let CompleteModel = 1; - list UnsupportedFeatures = [HasSVE]; + list UnsupportedFeatures = SVEUnsupported.F; } //===----------------------------------------------------------------------===// Index: lib/Target/AArch64/AArch64SchedExynosM1.td =================================================================== --- lib/Target/AArch64/AArch64SchedExynosM1.td +++ lib/Target/AArch64/AArch64SchedExynosM1.td @@ -25,7 +25,7 @@ let MispredictPenalty = 14; // Minimum branch misprediction penalty. let CompleteModel = 1; // Use the default model otherwise. - list UnsupportedFeatures = [HasSVE]; + list UnsupportedFeatures = SVEUnsupported.F; } //===----------------------------------------------------------------------===// Index: lib/Target/AArch64/AArch64SchedExynosM3.td =================================================================== --- lib/Target/AArch64/AArch64SchedExynosM3.td +++ lib/Target/AArch64/AArch64SchedExynosM3.td @@ -25,7 +25,7 @@ let MispredictPenalty = 16; // Minimum branch misprediction penalty. let CompleteModel = 1; // Use the default model otherwise. - list UnsupportedFeatures = [HasSVE]; + list UnsupportedFeatures = SVEUnsupported.F; // FIXME: Remove when all errors have been fixed. let FullInstRWOverlapCheck = 0; Index: lib/Target/AArch64/AArch64SchedFalkor.td =================================================================== --- lib/Target/AArch64/AArch64SchedFalkor.td +++ lib/Target/AArch64/AArch64SchedFalkor.td @@ -24,7 +24,7 @@ let MispredictPenalty = 11; // Minimum branch misprediction penalty. let CompleteModel = 1; - list UnsupportedFeatures = [HasSVE]; + list UnsupportedFeatures = SVEUnsupported.F; // FIXME: Remove when all errors have been fixed. let FullInstRWOverlapCheck = 0; Index: lib/Target/AArch64/AArch64SchedKryo.td =================================================================== --- lib/Target/AArch64/AArch64SchedKryo.td +++ lib/Target/AArch64/AArch64SchedKryo.td @@ -28,7 +28,7 @@ let LoopMicroOpBufferSize = 16; let CompleteModel = 1; - list UnsupportedFeatures = [HasSVE]; + list UnsupportedFeatures = SVEUnsupported.F; // FIXME: Remove when all errors have been fixed. let FullInstRWOverlapCheck = 0; Index: lib/Target/AArch64/AArch64SchedThunderX.td =================================================================== --- lib/Target/AArch64/AArch64SchedThunderX.td +++ lib/Target/AArch64/AArch64SchedThunderX.td @@ -26,7 +26,7 @@ let PostRAScheduler = 1; // Use PostRA scheduler. let CompleteModel = 1; - list UnsupportedFeatures = [HasSVE]; + list UnsupportedFeatures = SVEUnsupported.F; // FIXME: Remove when all errors have been fixed. let FullInstRWOverlapCheck = 0; Index: lib/Target/AArch64/AArch64SchedThunderX2T99.td =================================================================== --- lib/Target/AArch64/AArch64SchedThunderX2T99.td +++ lib/Target/AArch64/AArch64SchedThunderX2T99.td @@ -26,7 +26,7 @@ let PostRAScheduler = 1; // Using PostRA sched. let CompleteModel = 1; - list UnsupportedFeatures = [HasSVE]; + list UnsupportedFeatures = SVEUnsupported.F; // FIXME: Remove when all errors have been fixed. let FullInstRWOverlapCheck = 0; Index: lib/Target/AArch64/AArch64SelectionDAGInfo.h =================================================================== --- lib/Target/AArch64/AArch64SelectionDAGInfo.h +++ lib/Target/AArch64/AArch64SelectionDAGInfo.h @@ -24,7 +24,8 @@ SDValue Chain, SDValue Dst, SDValue Src, SDValue Size, unsigned Align, bool isVolatile, MachinePointerInfo DstPtrInfo) const override; - bool generateFMAsInMachineCombiner(CodeGenOpt::Level OptLevel) const override; + bool generateFMAsInMachineCombiner(SelectionDAG &DAG, + CodeGenOpt::Level OptLevel) const override; }; } Index: lib/Target/AArch64/AArch64SelectionDAGInfo.cpp =================================================================== --- lib/Target/AArch64/AArch64SelectionDAGInfo.cpp +++ lib/Target/AArch64/AArch64SelectionDAGInfo.cpp @@ -54,6 +54,7 @@ return SDValue(); } bool AArch64SelectionDAGInfo::generateFMAsInMachineCombiner( - CodeGenOpt::Level OptLevel) const { - return OptLevel >= CodeGenOpt::Aggressive; + SelectionDAG &DAG, CodeGenOpt::Level OptLevel) const { + const auto &STI = DAG.getMachineFunction().getSubtarget(); + return (OptLevel >= CodeGenOpt::Aggressive) && !STI.hasSVE(); } Index: lib/Target/AArch64/AArch64StackOffset.h =================================================================== --- /dev/null +++ lib/Target/AArch64/AArch64StackOffset.h @@ -0,0 +1,137 @@ +//==--AArch64StackOffset.h ---------------------------------------*- C++ -*-==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// The StackOffset class is a wrapper around scalable and non-scalable +// offsets and is used in several functions such as 'isAArch64FrameOffsetLegal' +// and 'emitFrameOffset()'. Objects can only be created based on their MVTs, +// e.g. +// +// StackOffset(1, MVT::nxv16i8) +// would describe an offset as being the size of a single SVE vector. +// +// The class also implements simple arithmetic (addition/subtraction) on these +// offsets, e.g. +// +// StackOffset(1, MVT::nxv16i8) + StackOffset(1, MVT::i64) +// describes an offset that spans the combined storage required for an SVE +// vector and a 64bit GPR. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_AARCH64_AARCH64STACKOFFSET_H +#define LLVM_LIB_TARGET_AARCH64_AARCH64STACKOFFSET_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/MachineValueType.h" + +namespace llvm { + +class StackOffset { + int64_t Bytes; + int64_t ScalableBytes; + +public: + using Part = std::pair; + + StackOffset() : Bytes(0), ScalableBytes(0) {} + + StackOffset(int64_t Offset, MVT::SimpleValueType T) : StackOffset() { + *this += Part(Offset, T); + } + + StackOffset(const StackOffset &Other) + : Bytes(Other.Bytes), + ScalableBytes(Other.ScalableBytes) {} + + StackOffset &operator+=(const StackOffset::Part &Other) { + assert(Other.second.getSizeInBits() % 8 == 0 && + "Offset type is not a multiple of bytes"); + int64_t OffsetInBytes = Other.first * (Other.second.getSizeInBits() / 8); + if (Other.second.isScalableVector()) + ScalableBytes += OffsetInBytes; + else + Bytes += OffsetInBytes; + return *this; + } + + StackOffset &operator=(const StackOffset &) = default; + + StackOffset &operator+=(const StackOffset &Other) { + Bytes += Other.Bytes; + ScalableBytes += Other.ScalableBytes; + return *this; + } + + StackOffset operator+(const StackOffset &Other) { + StackOffset Res(*this); + Res += Other; + return Res; + } + + StackOffset &operator-=(const StackOffset &Other) { + Bytes -= Other.Bytes; + ScalableBytes -= Other.ScalableBytes; + return *this; + } + + StackOffset operator-(const StackOffset &Other) { + StackOffset Res(*this); + Res -= Other; + return Res; + } + + /// Returns the scalable part of the offset in bytes. + int64_t getScalableBytes() const { return ScalableBytes; } + + /// Returns the non-scalable part of the offset in bytes. + int64_t getBytes() const { return Bytes; } + + void getForFrameOffset(int64_t &ByteSized, int64_t &PLSized, + int64_t &VLSized) const { + assert(isValid() && "Invalid frame offset"); + + ByteSized = Bytes; + VLSized = 0; + PLSized = ScalableBytes / 2; + // This method is used to get the offsets to adjust the frame offset. + // If the function requires ADDPL to be used and needs more than two ADDPL + // instructions, part of the offset is folded into VLSized so that it uses + // ADDVL for part of it, reducing the number of ADDPL instructions. + if (PLSized % 8 == 0 || PLSized < -64 || PLSized > 62) { + VLSized = PLSized / 8; + PLSized -= VLSized * 8; + } + } + + void getForDwarfOffset(int64_t &ByteSized, int64_t &VGSized) const { + assert(isValid() && "Invalid frame offset"); + // VGSized offsets are divided by '2', because the VG register is the + // the number of 64bit granules as opposed to 128bit vector chunks, + // which is how the 'n' in e.g. MVT::nxv1i8 is modelled. + // So, for a stack offset of 16 MVT::nxv1i8's, the size is n x 16 bytes. + // VG = n * 2 and the dwarf offset must be VG * 8 bytes. + ByteSized = Bytes; + VGSized = ScalableBytes / 2; + } + + /// Returns whether the offset is known zero. + bool isZero() const { return !Bytes && !ScalableBytes; } + + bool isValid() const { + // The smallest scalable element supported by scaled SVE addressing + // modes are predicates, which are 2 scalable bytes in size. So the scalable + // byte offset must always be a multiple of 2. + return ScalableBytes % 2 == 0; + } +}; + +} // end namespace llvm + +#endif Index: lib/Target/AArch64/AArch64Subtarget.h =================================================================== --- lib/Target/AArch64/AArch64Subtarget.h +++ lib/Target/AArch64/AArch64Subtarget.h @@ -89,9 +89,16 @@ bool HasLSLFast = false; bool HasSVE = false; + bool HasSVE2 = false; bool HasRCPC = false; bool HasAggressiveFMA = false; + // Arm SVE2 extensions + bool HasSVE2AES = false; + bool HasSVE2SM4 = false; + bool HasSVE2SHA3 = false; + bool HasSVE2BitPerm = false; + // HasZeroCycleRegMove - Has zero-cycle register mov instructions. bool HasZeroCycleRegMove = false; @@ -126,6 +133,8 @@ bool HasFuseLiterals = false; bool DisableLatencySchedHeuristic = false; bool UseRSqrt = false; + bool UseIterativeReciprocal = false; + bool HasFreeBasePlusRegAddrMode = true; uint8_t MaxInterleaveFactor = 2; uint8_t VectorInsertExtractBaseCost = 3; uint16_t CacheLineSize = 0; @@ -268,6 +277,8 @@ } bool useRSqrt() const { return UseRSqrt; } + bool useIterativeReciprocal() const { return UseIterativeReciprocal; } + bool hasFreeBasePlusRegAddrMode() const { return HasFreeBasePlusRegAddrMode; } unsigned getMaxInterleaveFactor() const { return MaxInterleaveFactor; } unsigned getVectorInsertExtractBaseCost() const { return VectorInsertExtractBaseCost; @@ -294,9 +305,16 @@ bool hasSPE() const { return HasSPE; } bool hasLSLFast() const { return HasLSLFast; } bool hasSVE() const { return HasSVE; } + bool hasSVE2() const { return HasSVE2; } bool hasRCPC() const { return HasRCPC; } bool hasAggressiveFMA() const { return HasAggressiveFMA; } + // Arm SVE2 extensions + bool hasSVE2AES() const { return HasSVE2AES; } + bool hasSVE2SM4() const { return HasSVE2SM4; } + bool hasSVE2SHA3() const { return HasSVE2SHA3; } + bool hasSVE2BitPerm() const { return HasSVE2BitPerm; } + bool isLittleEndian() const { return IsLittle; } bool isTargetDarwin() const { return TargetTriple.isOSDarwin(); } Index: lib/Target/AArch64/AArch64TargetMachine.h =================================================================== --- lib/Target/AArch64/AArch64TargetMachine.h +++ lib/Target/AArch64/AArch64TargetMachine.h @@ -50,6 +50,8 @@ return TLOF.get(); } + bool getO0WantsFastISel() override { return false; } + private: bool isLittle; }; Index: lib/Target/AArch64/AArch64TargetMachine.cpp =================================================================== --- lib/Target/AArch64/AArch64TargetMachine.cpp +++ lib/Target/AArch64/AArch64TargetMachine.cpp @@ -38,6 +38,7 @@ #include "llvm/Target/TargetLoweringObjectFile.h" #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" #include #include @@ -114,6 +115,11 @@ cl::desc("Work around Cortex-A53 erratum 835769"), cl::init(false)); +static cl::opt LowerGatherScatterToInterleaved( + "sve-lower-gather-scatter-to-interleaved", + cl::desc("Enable lowering gather/scatters to interleaved intrinsics"), + cl::init(true), cl::Hidden); + static cl::opt EnableGEPOpt("aarch64-enable-gep-opt", cl::Hidden, cl::desc("Enable optimizations on complex GEPs"), @@ -129,6 +135,17 @@ cl::desc("Enable the global merge pass")); static cl::opt + EnableSVEPostVec("aarch64-sve-postvec", cl::init(true), cl::Hidden, + cl::desc("Enable the SVE post vectorization pass.")); + +// Enable the clean up of unnecessary setffr instruction planted when lowering +// the load first-fault instructions +static cl::opt + EnableAArch64SetFFROptimize("aarch64-setffr-optimize", cl::Hidden, + cl::desc("Remove unnecessary 'setffr'."), + cl::init(false)); + +static cl::opt EnableLoopDataPrefetch("aarch64-enable-loop-data-prefetch", cl::Hidden, cl::desc("Enable the loop data prefetch pass"), cl::init(true)); @@ -136,11 +153,19 @@ static cl::opt EnableGlobalISelAtO( "aarch64-enable-global-isel-at-O", cl::Hidden, cl::desc("Enable GlobalISel at or below an opt level (-1 to disable)"), - cl::init(0)); + cl::init(-1)); + +static cl::opt +EnableSVEIntrinsicOpts("aarch64-sve-intrinsic-opts", cl::Hidden, + cl::desc("Enable SVE intrinsic opts"), + cl::init(true)); static cl::opt EnableFalkorHWPFFix("aarch64-enable-falkor-hwpf-fix", cl::init(true), cl::Hidden); +static cl::opt EnableContLoadStore("aarch64-enable-contiguous-load-store", + cl::init(true), cl::Hidden); + extern "C" void LLVMInitializeAArch64Target() { // Register the target. RegisterTargetMachine X(getTheAArch64leTarget()); @@ -164,6 +189,7 @@ initializeFalkorHWPFFixPass(*PR); initializeFalkorMarkStridedAccessesLegacyPass(*PR); initializeLDTLSCleanupPass(*PR); + initializeSVEIntrinsicOptsPass(*PR); } //===----------------------------------------------------------------------===// @@ -343,6 +369,7 @@ } void addIRPasses() override; + bool addPostCoalesce() override; bool addPreISel() override; bool addInstSelector() override; bool addIRTranslator() override; @@ -373,6 +400,19 @@ // ourselves. addPass(createAtomicExpandPass()); + // Make use of SVE intrinsics in place of common vector operations that span + // multiple basic blocks. + if (TM->getOptLevel() != CodeGenOpt::None && EnableSVEPostVec) + addPass(createSVEPostVectorizePass()); + + // Expand any SVE vector library calls that we can't code generate directly. + bool ExpandToOptimize = (TM->getOptLevel() != CodeGenOpt::None); + if (EnableSVEIntrinsicOpts && TM->getOptLevel() == CodeGenOpt::Aggressive) { + addPass(createSVEIntrinsicOptsPass()); + addPass(createInstructionCombiningPass()); + } + addPass(createSVEExpandLibCallPass(ExpandToOptimize)); + // Cmpxchg instructions are often used with a subsequent comparison to // determine whether it succeeded. We can exploit existing control-flow in // ldrex/strex loops to simplify this, but it needs tidying up. @@ -393,8 +433,25 @@ TargetPassConfig::addIRPasses(); // Match interleaved memory accesses to ldN/stN intrinsics. - if (TM->getOptLevel() != CodeGenOpt::None) + if (TM->getOptLevel() != CodeGenOpt::None) { + if (EnableContLoadStore) + addPass(createContiguousLoadStorePass()); addPass(createInterleavedAccessPass()); + } + + // Match interleaved gathers and scatters to ldN/stN intrinsics + if (TM->getOptLevel() == CodeGenOpt::Aggressive && + LowerGatherScatterToInterleaved) { + // Call EarlyCSE pass to ensure seriesvectors that looks the same are the + // same + addPass(createEarlyCSEPass()); + + addPass(createInterleavedGatherScatterStoreSinkPass()); + addPass(createInterleavedGatherScatterPass()); + + // Simplify the address calculation of any new interleaved accesses + addPass(createInstructionCombiningPass()); + } if (TM->getOptLevel() == CodeGenOpt::Aggressive && EnableGEPOpt) { // Call SeparateConstOffsetFromGEP pass to extract constants within indices @@ -411,6 +468,15 @@ } // Pass Pipeline Configuration + +bool AArch64PassConfig::addPostCoalesce() { + // Add a pass that transforms SVE MOVPRFXable Pseudo instructions + // to add an 'earlyclobber' under certain conditions + addPass(createSVEConditionalEarlyClobberPass()); + + return false; +} + bool AArch64PassConfig::addPreISel() { // Run promote constant before global merge, so that the promoted constants // get a chance to be merged @@ -427,6 +493,11 @@ addPass(createGlobalMergePass(TM, 4095, OnlyOptimizeForSize)); } + if (TM->getOptLevel() != CodeGenOpt::None && EnableSVEPostVec) { + addPass(createSVEAddressingModesPass()); + addPass(createDeadCodeEliminationPass()); + } + return false; } @@ -532,4 +603,7 @@ if (TM->getOptLevel() != CodeGenOpt::None && EnableCollectLOH && TM->getTargetTriple().isOSBinFormatMachO()) addPass(createAArch64CollectLOHPass()); + + // SVE bundles move prefixes with destructive operations. + addPass(createUnpackMachineBundles(nullptr)); } Index: lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- lib/Target/AArch64/AArch64TargetTransformInfo.h +++ lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -68,6 +68,13 @@ bool areInlineCompatible(const Function *Caller, const Function *Callee) const; + enum FixedWidthMode { + NEON, + SVE128, + SVE256, + SVE512 + }; + /// \name Scalar TTI Implementations /// @{ @@ -84,7 +91,7 @@ /// \name Vector TTI Implementations /// @{ - bool enableInterleavedAccessVectorization() { return true; } + bool enableInterleavedAccessVectorization() { return !ST->hasSVE(); } unsigned getNumberOfRegisters(bool Vector) { if (Vector) { @@ -104,6 +111,12 @@ return 64; } + unsigned getRegisterBitWidthUpperBound(bool Vector) { + if (Vector && ST->hasSVE()) + return AArch64::SVEMaxBitsPerVector; + return getRegisterBitWidth(Vector); + } + unsigned getMinVectorRegisterBitWidth() { return ST->getMinVectorRegisterBitWidth(); } @@ -131,6 +144,12 @@ int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, const Instruction *I = nullptr); + bool isNeonVector(Type *Ty, FixedWidthMode FW); + + unsigned getVectorModeBitWidth(FixedWidthMode Mode); + + Type *getTypeForVectorMode(Type *Ty, FixedWidthMode FW); + int getMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, unsigned AddressSpace, const Instruction *I = nullptr); @@ -144,10 +163,53 @@ bool getTgtMemIntrinsic(IntrinsicInst *Inst, MemIntrinsicInfo &Info); + bool isLegalMaskedLoadStore(Type *DataType) { + if (ST->hasSVE()) + return true; + + // When the IR vectorizer asks the question about whether a masked + // load/store is legal it passes a scalar type. We want to tell the + // vectorizer to use the masked load/store intrinsic instead of + // rewriting the loop in terms of if/else blocks. In the backend we + // ask the same question, but with a vector type and so we effectively + // let the backend do the scalarising - it does a much better job! + return DataType->isDoubleTy() || DataType->isIntegerTy(64); + } + + bool isLegalMaskedLoad(Type *DataType) { + return isLegalMaskedLoadStore(DataType); + } + bool isLegalMaskedStore(Type *DataType) { + return isLegalMaskedLoadStore(DataType); + } + bool isLegalMaskedGather(Type *DataType); + + bool isLegalMaskedScatter(Type *DataType); + + bool hasVectorMemoryOp(unsigned Opcode, Type *Ty, const MemAccessInfo &Info) { + if (ST->hasSVE()) + return true; + return BaseT::hasVectorMemoryOp(Opcode, Ty, Info); + } + + unsigned getVectorMemoryOpCost(unsigned Opcode, Type *Src, Value *Ptr, + unsigned Alignment, unsigned AddressSpace, + const MemAccessInfo &Info, Instruction *I); + int getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef Indices, unsigned Alignment, unsigned AddressSpace); + bool canVectorizeNonUnitStrides(bool forceFixedWidth = false) const { + if (forceFixedWidth) + return false; + return ST->hasSVE(); + } + + bool vectorizePreventedSLForwarding(void) const { + return ST->hasSVE(); + } + bool shouldConsiderAddressTypePromotion(const Instruction &I, bool &AllowPromotionWithoutCommonHeader); @@ -166,11 +228,20 @@ bool useReductionIntrinsic(unsigned Opcode, Type *Ty, TTI::ReductionFlags Flags) const; + bool canReduceInVector(unsigned Opcode, Type *ScalarTy, + TTI::ReductionFlags Flags) const; int getArithmeticReductionCost(unsigned Opcode, Type *Ty, bool IsPairwiseForm); int getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index, Type *SubTp); + + int getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, + ArrayRef Tys, FastMathFlags FMF, + unsigned ScalarizationCostPassed = UINT_MAX); + int getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, + ArrayRef Args, FastMathFlags FMF, + VectorType::ElementCount VF = VectorType::SingleElement()); /// @} }; Index: lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -24,6 +24,19 @@ static cl::opt EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix", cl::init(true), cl::Hidden); +cl::opt DisableSveGatherScatter( + "disable-sve-gather-scatter", cl::init(false), cl::Hidden, + cl::ZeroOrMore, + cl::desc("Disable use of sve gather/scatter instructions")); + +static cl::opt FixedWidthOpt( + "fixed-width-mode", cl::init(AArch64TTIImpl::NEON), + cl::desc("Specify Neon or SVE + width, e.g. sve256"), + cl::values(clEnumValN(AArch64TTIImpl::NEON, "neon", "Neon (128 bit)"), + clEnumValN(AArch64TTIImpl::SVE128, "sve128", "SVE (128 bit)"), + clEnumValN(AArch64TTIImpl::SVE256, "sve256", "SVE (256 bit)"), + clEnumValN(AArch64TTIImpl::SVE512, "sve512", "SVE (512 bit)"))); + bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, const Function *Callee) const { const TargetMachine &TM = getTLI()->getTargetMachine(); @@ -391,6 +404,72 @@ SrcTy.getSimpleVT())) return Entry->Cost; + static const TypeConversionCostTblEntry SVEConversionTbl[] = { + // Truncating two illegal vectors into one legal vector, + // can be done with one instruction on SVE e.g. + // UZP1([ .....v7 | .....v6 | .....v5 | .....v4 ], + // [ .....v3 | .....v2 | .....v1 | .....v0 ]) + // = [ v7 | v6 | v5 | v4 | v3 | v2 | v1 | v0 ] + // If type is smaller than half the size, ANDI + // can be used to fill zeros in high bits. + { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 2 * 1 }, + { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 1 }, + { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i64, 2 }, + { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 1 }, + { ISD::TRUNCATE, MVT::nxv8i8, MVT::nxv8i32, 2 }, + { ISD::TRUNCATE, MVT::nxv16i8, MVT::nxv16i16, 1 }, + { ISD::TRUNCATE, MVT::nxv16i8, MVT::nxv16i32, 3 }, // uzp1(uzp1(A),uzp1(B)) + + // Zero extend happens with unpack, possibly with 'AND' + // to zero the high bits. + { ISD::ZERO_EXTEND, MVT::nxv8i64, MVT::nxv8i32, 4 }, + { ISD::ZERO_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2 }, + { ISD::ZERO_EXTEND, MVT::nxv4i64, MVT::nxv4i16, 4 }, + { ISD::ZERO_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2 }, + { ISD::ZERO_EXTEND, MVT::nxv8i32, MVT::nxv8i8, 4 }, + { ISD::ZERO_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2 }, + { ISD::ZERO_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 4 }, + + // Zero extending requires zeroing high bits, + // can be done with a ANDI instruction that fills + // in zeros in high bits. Because it is so cheap, + // we override BaseT's implementation. + { ISD::ZERO_EXTEND, MVT::nxv2i64, MVT::nxv2i32, 1 }, + { ISD::ZERO_EXTEND, MVT::nxv2i64, MVT::nxv2i16, 1 }, + { ISD::ZERO_EXTEND, MVT::nxv4i32, MVT::nxv4i16, 1 }, + { ISD::ZERO_EXTEND, MVT::nxv2i64, MVT::nxv2i8, 1 }, + { ISD::ZERO_EXTEND, MVT::nxv4i32, MVT::nxv4i8, 1 }, + { ISD::ZERO_EXTEND, MVT::nxv8i16, MVT::nxv8i8, 1 }, + + // Truncating legal type to smaller type + // is free because the vector is unpacked. + { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 0 }, + { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i64, 0 }, + { ISD::TRUNCATE, MVT::nxv2i8, MVT::nxv2i64, 0 }, + { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 0 }, + { ISD::TRUNCATE, MVT::nxv4i8, MVT::nxv4i32, 0 }, + { ISD::TRUNCATE, MVT::nxv8i8, MVT::nxv8i16, 0 }, + + // Floating point extend / truncate + { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 }, + { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 2 }, + { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1 }, + { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2 }, + + { ISD::SINT_TO_FP, MVT::nxv4f32, MVT::nxv4i32, 1 }, + { ISD::SINT_TO_FP, MVT::nxv2f64, MVT::nxv2i64, 1 }, + { ISD::SINT_TO_FP, MVT::nxv2f32, MVT::nxv2i32, 1 }, + { ISD::UINT_TO_FP, MVT::nxv2f32, MVT::nxv2i32, 1 }, + { ISD::UINT_TO_FP, MVT::nxv4f32, MVT::nxv4i32, 1 }, + { ISD::UINT_TO_FP, MVT::nxv2f64, MVT::nxv2i64, 1 }, + }; + + if (getST()->hasSVE()) + if (const auto *Entry = ConvertCostTableLookup(SVEConversionTbl, ISD, + DstTy.getSimpleVT(), + SrcTy.getSimpleVT())) + return Entry->Cost; + return BaseT::getCastInstrCost(Opcode, Dst, Src); } @@ -593,6 +672,12 @@ // We don't lower some vector selects well that are wider than the register // width. if (ValTy->isVectorTy() && ISD == ISD::SELECT) { + if (ST->hasSVE()) { + EVT SelValTy = TLI->getValueType(DL, ValTy); + if (SelValTy.isScalableVector()) { + return SelValTy.getSizeInBits() / 128; + } + } // We would need this many instructions to hide the scalarization happening. const int AmortizationCost = 20; static const TypeConversionCostTblEntry @@ -617,9 +702,108 @@ return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, I); } +unsigned AArch64TTIImpl::getVectorMemoryOpCost(unsigned Opcode, Type *Src, + Value *Ptr, unsigned Alignment, + unsigned AddressSpace, + const MemAccessInfo &Info, + Instruction *I) { + if (!ST->hasSVE()) + return getMemoryOpCost(Opcode, Src, Alignment, AddressSpace); + + std::pair LT = TLI->getTypeLegalizationCost(DL, Src); + + if (Info.isUniform() || (Info.isStrided() && std::abs(Info.getStride())==1)) + return LT.first; + + unsigned NumElems = LT.second.getVectorNumElements(); + + // In the general case, strided loads of stride=1,2,3 are 'cheap' + // by using the LD(1|2|3) instructions, because they load a full + // vector in one operation. Strided stores on the other hand are + // expensive.. even though we have ST(2|3) instructions available, it + // also needs to store 1|2 different values to fill the gaps between + // the stride. Also add cost for the use of more result registers + // of LD2 (two result regs) and LD3 (three result regs). + if (Info.isStrided()) { + switch (std::abs(Info.getStride())) { + case 2: + case 3: + // Assume gather is used for double-word vectors. + if (Opcode == Instruction::Load && + Src->getVectorElementType()->getScalarSizeInBits() <= 32) + return LT.first + (std::abs(Info.getStride()) - 1); + default: + break; + } + } + + unsigned GatherWeight = 2; + if (Info.isNonStrided()) { + switch (Info.getIndexType()->getScalarSizeInBits()) { + case 8: + case 16: + case 32: + // Only handle worst case, because having + // index type < 64 bit is same as having + // strided not 1,2,3,4. + break; + default: + unsigned NumOps = std::max(1U, NumElems/2); + return GatherWeight * LT.first * NumOps; + } + } + + // With a gather offset < 64bits, we can load/store 4 elements at a time, + // so number of operations is NumElems divided by 4. + unsigned NumOps = std::max(1U,NumElems/4); + return GatherWeight * LT.first * NumOps; +} + +bool AArch64TTIImpl::isNeonVector(Type *Ty, FixedWidthMode FW) { + return !Ty->getVectorIsScalable() && FW == FixedWidthMode::NEON; +} + +unsigned AArch64TTIImpl::getVectorModeBitWidth(FixedWidthMode Mode) { + switch (Mode) { + default: + llvm_unreachable("Unhandled FixedWidthMode!"); + case FixedWidthMode::NEON: + case FixedWidthMode::SVE128: + return 128; + case FixedWidthMode::SVE256: + return 256; + case FixedWidthMode::SVE512: + return 512; + } +} + +// Returns Ty if Ty is scalable, or otherwise returns the +// appropriate VectorType that should be passed to the +// CodeGenerator, depending on the FixedWidthMode FW. +Type *AArch64TTIImpl::getTypeForVectorMode(Type *Ty, FixedWidthMode FW) { + if (Ty->getVectorIsScalable() || isNeonVector(Ty, FW)) + return Ty; + + unsigned Elts = 128 / Ty->getScalarSizeInBits(); + unsigned VectorWidth = getVectorModeBitWidth(FW); + + if (Ty->getPrimitiveSizeInBits() <= VectorWidth) { + Ty = VectorType::get(Ty->getVectorElementType(), Elts, true); + } else { + unsigned NumVectors = Ty->getPrimitiveSizeInBits() / VectorWidth; + Ty = VectorType::get(Ty->getVectorElementType(), Elts * NumVectors, true); + } + + return Ty; +} + int AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty, unsigned Alignment, unsigned AddressSpace, const Instruction *I) { + + if (Ty->isVectorTy()) + Ty = getTypeForVectorMode(Ty, FixedWidthOpt); + auto LT = TLI->getTypeLegalizationCost(DL, Ty); if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store && @@ -634,7 +818,8 @@ return LT.first * 2 * AmortizationCost; } - if (Ty->isVectorTy() && Ty->getVectorElementType()->isIntegerTy(8)) { + if (Ty->isVectorTy() && Ty->getVectorElementType()->isIntegerTy(8) + && isNeonVector(Ty, FixedWidthOpt)) { unsigned ProfitableNumElements; if (Opcode == Instruction::Store) // We use a custom trunc store lowering so v.4b should be profitable. @@ -893,6 +1078,11 @@ bool AArch64TTIImpl::useReductionIntrinsic(unsigned Opcode, Type *Ty, TTI::ReductionFlags Flags) const { assert(isa(Ty) && "Expected Ty to be a vector type"); + // For SVE, we must use intrinsics for representing reductions. + // Whether or not we can actually implement this reduction with SVE is handled + // by canReduceInVector(). + if (Ty->getVectorIsScalable()) + return true; unsigned ScalarBits = Ty->getScalarSizeInBits(); switch (Opcode) { case Instruction::FAdd: @@ -984,3 +1174,112 @@ return BaseT::getShuffleCost(Kind, Tp, Index, SubTp); } + +int AArch64TTIImpl::getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, + ArrayRef Tys, + FastMathFlags FMF, + unsigned ScalarizationCostPassed) { + // Until we have more accurate information let's just assume all vector + // operations take 4 cycles. + static const CostTblEntry SVECostTbl[] = { + { ISD::CTPOP, MVT::nxv2i64, 4 }, + { ISD::CTPOP, MVT::nxv4i32, 4 }, + { ISD::CTPOP, MVT::nxv8i16, 4 }, + { ISD::CTPOP, MVT::nxv16i8, 4 }, + + { ISD::FCOPYSIGN, MVT::nxv2f16, 4 }, + { ISD::FCOPYSIGN, MVT::nxv2f32, 4 }, + { ISD::FCOPYSIGN, MVT::nxv2f64, 4 }, + { ISD::FCOPYSIGN, MVT::nxv4f16, 4 }, + { ISD::FCOPYSIGN, MVT::nxv4f32, 4 }, + { ISD::FCOPYSIGN, MVT::nxv8f16, 4 }, + + { ISD::FMAXNUM, MVT::nxv2f16, 4 }, + { ISD::FMAXNUM, MVT::nxv2f32, 4 }, + { ISD::FMAXNUM, MVT::nxv2f64, 4 }, + { ISD::FMAXNUM, MVT::nxv4f16, 4 }, + { ISD::FMAXNUM, MVT::nxv4f32, 4 }, + { ISD::FMAXNUM, MVT::nxv8f16, 4 }, + + { ISD::FMINNUM, MVT::nxv2f16, 4 }, + { ISD::FMINNUM, MVT::nxv2f32, 4 }, + { ISD::FMINNUM, MVT::nxv2f64, 4 }, + { ISD::FMINNUM, MVT::nxv4f16, 4 }, + { ISD::FMINNUM, MVT::nxv4f32, 4 }, + { ISD::FMINNUM, MVT::nxv8f16, 4 } + }; + + unsigned ISD = ISD::DELETED_NODE; + switch (IID) { + default: + break; + case Intrinsic::copysign: + ISD = ISD::FCOPYSIGN; + break; + case Intrinsic::ctpop: + ISD = ISD::CTPOP; + break; + case Intrinsic::maxnum: + ISD = ISD::FMAXNUM; + break; + case Intrinsic::minnum: + ISD = ISD::FMINNUM; + break; + } + + // Legalize the type. + std::pair LT = TLI->getTypeLegalizationCost(DL, RetTy); + MVT MTy = LT.second; + + // Attempt to lookup cost. + if (ST->hasSVE()) + if (const auto *Entry = CostTableLookup(SVECostTbl, ISD, MTy)) + return LT.first * Entry->Cost; + + return BaseT::getIntrinsicInstrCost(IID, RetTy, Tys, FMF, + ScalarizationCostPassed); +} + +int AArch64TTIImpl::getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, + ArrayRef Args, + FastMathFlags FMF, + VectorType::ElementCount VF) { + return BaseT::getIntrinsicInstrCost(IID, RetTy, Args, FMF, VF); +} + +bool AArch64TTIImpl::canReduceInVector(unsigned Opcode, Type *ScalarTy, + TTI::ReductionFlags Flags) const { + // NEON can handle any kind of reduction by using shuffles, except ordered + // reductions. + if (!getST()->hasSVE()) + return !Flags.IsOrdered; + + if (ScalarTy->isFP128Ty()) + return false; + + switch (Opcode) { + default: + return false; + case Instruction::Add: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::FAdd: + case Instruction::ICmp: + case Instruction::FCmp: + return true; + } + return false; +} + +bool AArch64TTIImpl::isLegalMaskedGather(Type *DataType) { + if (DisableSveGatherScatter) + return false; + return ST->hasSVE(); +} + +bool AArch64TTIImpl::isLegalMaskedScatter(Type *DataType) { + if (DisableSveGatherScatter) + return false; + return ST->hasSVE(); +} Index: lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp =================================================================== --- lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp +++ lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp @@ -150,7 +150,8 @@ bool parseSysAlias(StringRef Name, SMLoc NameLoc, OperandVector &Operands); void createSysAlias(uint16_t Encoding, OperandVector &Operands, SMLoc S); - AArch64CC::CondCode parseCondCodeString(StringRef Cond); + AArch64CC::CondCode parseCondCodeString(StringRef Cond, + std::string &Suggestion); bool parseCondCode(OperandVector &Operands, bool invertCondCode); unsigned matchRegisterNameAlias(StringRef Name, RegKind Kind); bool parseRegister(OperandVector &Operands); @@ -164,6 +165,7 @@ OperandVector &Operands); bool parseDirectiveArch(SMLoc L); + bool parseDirectiveArchExtension(SMLoc L); bool parseDirectiveCPU(SMLoc L); bool parseDirectiveInst(SMLoc L); @@ -1053,7 +1055,7 @@ return DiagnosticPredicateTy::NoMatch; if (isSVEVectorReg() && - (ElementWidth == 0 || Reg.ElementWidth == ElementWidth)) + (Reg.ElementWidth == ElementWidth)) return DiagnosticPredicateTy::Match; return DiagnosticPredicateTy::NearMatch; @@ -1064,8 +1066,7 @@ if (Kind != k_Register || Reg.Kind != RegKind::SVEDataVector) return DiagnosticPredicateTy::NoMatch; - if (isSVEVectorReg() && - (ElementWidth == 0 || Reg.ElementWidth == ElementWidth)) + if (isSVEVectorReg() && (Reg.ElementWidth == ElementWidth)) return DiagnosticPredicateTy::Match; return DiagnosticPredicateTy::NearMatch; @@ -2583,8 +2584,10 @@ return MatchOperand_Success; } -/// parseCondCodeString - Parse a Condition Code string. -AArch64CC::CondCode AArch64AsmParser::parseCondCodeString(StringRef Cond) { +/// parseCondCodeString - Parse a Condition Code string, optionally returning a +/// suggestion to help common typos. +AArch64CC::CondCode +AArch64AsmParser::parseCondCodeString(StringRef Cond, std::string &Suggestion) { AArch64CC::CondCode CC = StringSwitch(Cond.lower()) .Case("eq", AArch64CC::EQ) .Case("ne", AArch64CC::NE) @@ -2607,7 +2610,7 @@ .Default(AArch64CC::Invalid); if (CC == AArch64CC::Invalid && - getSTI().getFeatureBits()[AArch64::FeatureSVE]) + getSTI().getFeatureBits()[AArch64::FeatureSVE]) { CC = StringSwitch(Cond.lower()) .Case("none", AArch64CC::EQ) .Case("any", AArch64CC::NE) @@ -2621,6 +2624,9 @@ .Case("tstop", AArch64CC::LT) .Default(AArch64CC::Invalid); + if (CC == AArch64CC::Invalid && Cond.lower() == "nfirst") + Suggestion = "nfrst"; + } return CC; } @@ -2633,9 +2639,14 @@ assert(Tok.is(AsmToken::Identifier) && "Token is not an Identifier"); StringRef Cond = Tok.getString(); - AArch64CC::CondCode CC = parseCondCodeString(Cond); - if (CC == AArch64CC::Invalid) - return TokError("invalid condition code"); + std::string Suggestion; + AArch64CC::CondCode CC = parseCondCodeString(Cond, Suggestion); + if (CC == AArch64CC::Invalid) { + std::string Msg = "invalid condition code"; + if (!Suggestion.empty()) + Msg += ", you probably meant " + Suggestion; + return TokError(Msg); + } Parser.Lex(); // Eat identifier token. if (invertCondCode) { @@ -3034,6 +3045,11 @@ return MatchOperand_NoMatch; unsigned ElementWidth = KindRes->second; + if (ElementWidth > 64) { + Error(S, "invalid element width"); + return MatchOperand_ParseFail; + } + Operands.push_back(AArch64Operand::CreateVectorReg( RegNum, RegKind::SVEPredicateVector, ElementWidth, S, getLoc(), getContext())); @@ -3646,9 +3662,14 @@ SMLoc SuffixLoc = SMLoc::getFromPointer(NameLoc.getPointer() + (Head.data() - Name.data())); - AArch64CC::CondCode CC = parseCondCodeString(Head); - if (CC == AArch64CC::Invalid) - return Error(SuffixLoc, "invalid condition code"); + std::string Suggestion; + AArch64CC::CondCode CC = parseCondCodeString(Head, Suggestion); + if (CC == AArch64CC::Invalid) { + std::string Msg = "invalid condition code"; + if (!Suggestion.empty()) + Msg += ", you probably meant " + Suggestion; + return Error(SuffixLoc, Msg); + } Operands.push_back( AArch64Operand::CreateToken(".", true, SuffixLoc, getContext())); Operands.push_back( @@ -4315,11 +4336,15 @@ case Match_InvalidSVEPredicateDReg: return Error(Loc, "invalid predicate register."); case Match_InvalidSVEPredicate3bAnyReg: + return Error(Loc, "invalid restricted predicate register, expected p0..p7 (without element suffix)"); case Match_InvalidSVEPredicate3bBReg: + return Error(Loc, "invalid restricted predicate register, expected p0.b..p7.b"); case Match_InvalidSVEPredicate3bHReg: + return Error(Loc, "invalid restricted predicate register, expected p0.h..p7.h"); case Match_InvalidSVEPredicate3bSReg: + return Error(Loc, "invalid restricted predicate register, expected p0.s..p7.s"); case Match_InvalidSVEPredicate3bDReg: - return Error(Loc, "restricted predicate has range [0, 7]."); + return Error(Loc, "invalid restricted predicate register, expected p0.d..p7.d"); case Match_InvalidSVEExactFPImmOperandHalfOne: return Error(Loc, "Invalid floating point constant, expected 0.5 or 1.0."); case Match_InvalidSVEExactFPImmOperandHalfTwo: @@ -4874,6 +4899,8 @@ parseDirectiveUnreq(Loc); else if (IDVal == ".inst") parseDirectiveInst(Loc); + else if (IDVal == ".arch_extension") + parseDirectiveArchExtension(Loc); else if (IsMachO) { if (IDVal == MCLOHDirectiveName()) parseDirectiveLOH(IDVal, Loc); @@ -4898,6 +4925,12 @@ { "simd", {AArch64::FeatureNEON} }, { "ras", {AArch64::FeatureRAS} }, { "lse", {AArch64::FeatureLSE} }, + { "sve", {AArch64::FeatureSVE} }, + { "sve2", {AArch64::FeatureSVE2} }, + { "sve2-aes", {AArch64::FeatureSVE2AES} }, + { "sve2-sm4", {AArch64::FeatureSVE2SM4} }, + { "sve2-sha3", {AArch64::FeatureSVE2SHA3} }, + { "sve2-bitperm", {AArch64::FeatureSVE2BitPerm} }, // FIXME: Unsupported extensions { "pan", {} }, @@ -5014,6 +5047,44 @@ return false; } +/// parseDirectiveArchExtension +/// ::= .arch_extension [no]feature +bool AArch64AsmParser::parseDirectiveArchExtension(SMLoc L) { + SMLoc ExtLoc = getLoc(); + + StringRef Name = getParser().parseStringToEndOfStatement().trim(); + + if (parseToken(AsmToken::EndOfStatement, + "unexpected token in '.arch_extension' directive")) + return true; + + bool EnableFeature = true; + if (Name.startswith_lower("no")) { + EnableFeature = false; + Name = Name.substr(2); + } + + MCSubtargetInfo &STI = copySTI(); + FeatureBitset Features = STI.getFeatureBits(); + for (const auto &Extension : ExtensionMap) { + if (Extension.Name != Name) + continue; + + if (Extension.Features.none()) + return Error(ExtLoc, "unsupported architectural extension: " + Name); + + FeatureBitset ToggleFeatures = EnableFeature + ? (~Features & Extension.Features) + : (Features & Extension.Features); + uint64_t Features = + ComputeAvailableFeatures(STI.ToggleFeature(ToggleFeatures)); + setAvailableFeatures(Features); + return false; + } + + return Error(ExtLoc, "unknown architectural extension: " + Name); +} + static SMLoc incrementLoc(SMLoc L, int Offset) { return SMLoc::getFromPointer(L.getPointer() + Offset); } Index: lib/Target/AArch64/CMakeLists.txt =================================================================== --- lib/Target/AArch64/CMakeLists.txt +++ lib/Target/AArch64/CMakeLists.txt @@ -53,6 +53,11 @@ AArch64TargetMachine.cpp AArch64TargetObjectFile.cpp AArch64TargetTransformInfo.cpp + SVEAddressingModes.cpp + SVEPostVectorize.cpp + SVEExpandLibCall.cpp + SVEConditionalEarlyClobberPass.cpp + SVEIntrinsicOpts.cpp AArch64SIMDInstrOpt.cpp DEPENDS Index: lib/Target/AArch64/LLVMBuild.txt =================================================================== --- lib/Target/AArch64/LLVMBuild.txt +++ lib/Target/AArch64/LLVMBuild.txt @@ -31,5 +31,5 @@ type = Library name = AArch64CodeGen parent = AArch64 -required_libraries = AArch64AsmPrinter AArch64Desc AArch64Info AArch64Utils Analysis AsmPrinter CodeGen Core MC Scalar SelectionDAG Support Target GlobalISel +required_libraries = AArch64AsmPrinter AArch64Desc AArch64Info AArch64Utils Analysis AsmPrinter CodeGen Core MC Scalar SelectionDAG Support Target TransformUtils InstCombine GlobalISel add_to_library_groups = AArch64 Index: lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp =================================================================== --- lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp +++ lib/Target/AArch64/MCTargetDesc/AArch64AsmBackend.cpp @@ -347,13 +347,24 @@ } bool AArch64AsmBackend::writeNopData(raw_ostream &OS, uint64_t Count) const { + bool Aligned = (Count % 4) == 0; + // If the count is not 4-byte aligned, we must be writing data into the text // section (otherwise we have unaligned instructions, and thus have far // bigger problems), so just write zeros instead. OS.write_zeros(Count % 4); - // We are properly aligned, so write NOPs as requested. + // We are properly aligned and can start talking in terms of instructions. Count /= 4; + + // Plant a branch rather than NOPs when the former is cheaper. The upper limit + // guards against Count being beyond the branch range (however unlikely). + if (Aligned && (Count > 1) && (Count <= 128)) { + support::endian::write(OS, 0x14000000 + Count, Endian); + --Count; + } + + // Fill the remain space with NOPs. for (uint64_t i = 0; i != Count; ++i) support::endian::write(OS, 0xd503201f, Endian); return true; Index: lib/Target/AArch64/SVEAddressingModes.cpp =================================================================== --- /dev/null +++ lib/Target/AArch64/SVEAddressingModes.cpp @@ -0,0 +1,656 @@ +//===- SVEAddressingModes - A SVE Optimizer ---------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass attempts to work around artefacts introduced by loop strength +// reduction. After LSR vectorised loops have a tendency to look as follows: +// +// vector.body: +// %idx = phi i64 [ %idx.n, %vector.body ], [ 0, %vector.ph ] +// %1 = bitcast double* %0 to i8* +// %uglygep580 = getelementptr i8, i8* %1, i64 %idx +// ...... +// %idx.n = add i64 %idx, mul (i64 elementcount ( undef), i64 8) +// br i1 %cond, label vector.body +// +// which requires a reg+reg addressing mode SVE does not possess. +// This pass transforms the above into: +// +// vector.body: +// %idx = phi i64 [ %idx.n, %vector.body ], [ 0, %vector.ph ] +// %idx.new = mul i64 %idx, 8 +// %1 = bitcast double* %0 to i8* +// %uglygep580 = getelementptr i8, i8* %1, i64 %idx.new +// ...... +// %idx.n = add i64 %idx, i64 elementcount ( undef) +// br i1 %cond, label vector.body +// +// that allows us to make use of SVE's reg+scaled_reg addressing mode. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "sve-addressing-modes" + +STATISTIC(NumVisited, "Number of loops visited."); +STATISTIC(NumOptimized, "Number of loops optimized."); + +namespace llvm { + void initializeSVEAddressingModesPass(PassRegistry &); +} + +namespace { +struct SVEAddressingModes : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + SVEAddressingModes() : FunctionPass(ID) { + initializeSVEAddressingModesPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + this->F = &F; + LI = &getAnalysis().getLoopInfo(); + + bool Changed = false; + for (auto I = LI->begin(), IE = LI->end(); I != IE; ++I) + for (auto L = df_begin(*I), LE = df_end(*I); L != LE; ++L) + Changed |= runOnLoop(*L); + + return Changed; + } + + bool runOnLoop(Loop *L); + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequiredID(LoopSimplifyID); + AU.addRequired(); + } + +private: + Function *F; + LoopInfo *LI; + + bool OptimizeBlock(BasicBlock *); + bool onlyUsedForAddrModes(PHINode*, Instruction*, ConstantInt*); + PHINode* findBetterPHI(PHINode*, BasicBlock*, BasicBlock*, ConstantInt*, + APInt&); +}; +} + +char SVEAddressingModes::ID = 0; +static const char *name = "SVE Addressing Modes"; +INITIALIZE_PASS_BEGIN(SVEAddressingModes, DEBUG_TYPE, name, false, false) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_END(SVEAddressingModes, DEBUG_TYPE, name, false, false) + +namespace llvm { +FunctionPass *createSVEAddressingModesPass() { + return new SVEAddressingModes(); +} +} + +bool SVEAddressingModes::runOnLoop(Loop *L) { + bool LoopOptimized = false; + + for (auto BB = L->block_begin(), BE = L->block_end(); BB != BE; ++BB) + LoopOptimized |= OptimizeBlock(*BB); + + ++NumVisited; + if (LoopOptimized) + ++NumOptimized; + + return LoopOptimized; +} + +/// Try to find the original PHI used before LSR introduced the new one. +/// +PHINode* SVEAddressingModes::findBetterPHI(PHINode *OrigPHI, + BasicBlock *EntryBB, + BasicBlock *BackEdgeBB, + ConstantInt *OrigScale, + APInt &ScaleDisp) { + assert(OrigPHI->getNumIncomingValues() == 2 && "Unexpected PHI"); + + for (Instruction &I : *OrigPHI->getParent()) { + auto PHI = dyn_cast(&I); + if (!PHI) + continue; + + if (PHI == OrigPHI) + continue; + + if (PHI->getNumIncomingValues() != 2) + continue; + + int Idx = PHI->getBasicBlockIndex(EntryBB); + if (Idx < 0) + continue; + if (!match(PHI->getIncomingValue(Idx), m_Zero())) + continue; + + Idx = PHI->getBasicBlockIndex(BackEdgeBB); + if (Idx < 0) + continue; + + ConstantInt *Scale; + if (!match(PHI->getIncomingValue(Idx), + m_Add(m_Specific(PHI), m_Mul(m_VScale(), m_ConstantInt(Scale))))) + continue; + + APInt ScaleVal = Scale->getValue(); + APInt OrigScaleVal = OrigScale->getValue(); + if (ScaleVal.getBitWidth() != OrigScaleVal.getBitWidth()) + continue; + if (!ScaleVal.isPowerOf2()) + continue; + if (!ScaleVal.slt(OrigScaleVal)) + continue; + ScaleDisp = OrigScaleVal.sdiv(ScaleVal); + + return PHI; + } + + return OrigPHI; +} + + +// Only fold into an addressing mode when the PHI is used for loads/stores +// where the conversion between load/store-type and EC type equals Scale. +// If this cannot be determined, than we may not safely be allowed to re-scale +// the start value of the PHI before the loop: we must know for sure that +// the start value is a multiple of Scale. +// +// For example: +// body: +// %lsr.iv = phi i64 [ %lsr.iv.next, %body ], [ %start, %preheader ] +// : +// %uglygep = getelementptr i8, i8* %43, i64 %lsr.iv +// ^^^^^^ +// Uses Scaled indvar (scale = 8) +// %uglygep2 = bitcast i8* %uglygep to * +// ^^ ^^^^^^^^^^^^^^ +// sizeof()/sizeof(i8) == n * Scale +// %wide.masked.load = +// call @llvm.masked.load(* %uglygep2, ..) +// : +// %lsr.iv.next = add i64 %lsr.iv, mul (i64 vscale, i64 16) +bool SVEAddressingModes::onlyUsedForAddrModes(PHINode *PHI, + Instruction *PHIUpdate, + ConstantInt *Scale) { + auto &DL = PHI->getModule()->getDataLayout(); + + SmallVector WorkList; + WorkList.append(PHI->user_begin(), PHI->user_end()); + + while (!WorkList.empty()) { + auto *U = WorkList.pop_back_val(); + + // Do not consider induction var update + if (&*U == PHIUpdate) + continue; + + // Make sure we load/store of EC type + if (isa(U) || isa(U)) + continue; + + if (const auto *II = dyn_cast(U)) { + switch (II->getIntrinsicID()) { + case Intrinsic::masked_load: + case Intrinsic::masked_store: + continue; + default: + break; + } + } + + // Scaling must match up + if (auto *BC = dyn_cast(U)) { + Type *DstETy = BC->getDestTy()->getPointerElementType(); + Type *SrcETy = BC->getSrcTy()->getPointerElementType(); + SrcETy = SrcETy->isArrayTy() ? SrcETy->getArrayElementType() : SrcETy; + int64_t BCScale = DL.getTypeStoreSize(DstETy) / + DL.getTypeStoreSize(SrcETy); + + if (BCScale == Scale->getSExtValue()) { + WorkList.append(BC->user_begin(), BC->user_end()); + continue; + } + } + + // Make sure this is a GEP that uses our PHI + if (auto *Gep = dyn_cast(U)) { + if (Gep->getOperand(Gep->getNumIndices()) == PHI) { + WorkList.append(Gep->user_begin(), Gep->user_end()); + continue; + } + } + + // One of the criteria failed + return false; + } + + // All of the criteria passed + return true; +} + +static bool isIndVarPHIUpdateValue(Instruction *I, PHINode *PHI, + BasicBlock *&EntryBB, + BasicBlock *&BackEdgeBB) { + if (I->getType() != PHI->getType()) + return false; + + if (PHI->getNumIncomingValues() != 2) + return false; + + // phi [ , %vector.ph ], [ %idx.next, %vector.body ] + if (PHI->getIncomingValue(1) == I) { + EntryBB = PHI->getIncomingBlock(0); + BackEdgeBB = PHI->getIncomingBlock(1); + } + // phi [ %idx.next, %vector.body ], [ , %vector.ph ] + else if (PHI->getIncomingValue(0) == I) { + EntryBB = PHI->getIncomingBlock(1); + BackEdgeBB = PHI->getIncomingBlock(0); + } + // invalid phi + else + return false; + + return true; +} + +static bool isScaledElementCountIVUpdate( + Instruction *I, SmallVectorImpl &PHINodes, + ConstantInt *&Scale, + BasicBlock *&EntryBB, + BasicBlock *&BackEdgeBB) { + PHINode *PHI; + + // e.g. + // %update = add %phi, mul (i64 vscale, i64 32) + if (match(I, m_Add(m_PHI(PHI), + m_Mul(m_VScale(), m_ConstantInt(Scale))))) { + if (!Scale->getValue().isPowerOf2()) + return false; + + if (!isIndVarPHIUpdateValue(I, PHI, EntryBB, BackEdgeBB)) + return false; + + PHINodes.push_back(PHI); + return true; + } + + // e.g. + // %phi = i32* phi [ %bc2, %vector.body ] , [..] + // : + // %bc = bitcast i32* %phi to i1* + // %update = getelementptr i1, i1* %bc, mul (i64 vscale, i64 32) + // %bc2 = bitcast i1* %update to i32* + if (auto *Gep = dyn_cast(I)) { + Value *Opnd = Gep->getOperand(Gep->getNumIndices()); + if (!match(Opnd, m_Mul(m_VScale(), m_ConstantInt(Scale)))) + return false; + + Value *Ptr = Gep->getPointerOperand(); + if (!match(Ptr, m_PHI(PHI)) && + !match(Ptr, m_BitCast(m_PHI(PHI)))) + return false; + + // If the pointer is a bitcast then only consider a single + // user so that we won't 'break' any other addr modes. + Instruction *Update = I; + if (auto *BC = dyn_cast(Ptr)) { + if (!BC->hasOneUse()) + return false; + } + + // An update may be used in multiple PHIs, go through + // all users of the update. + BasicBlock *NewEntryBB = nullptr; + BasicBlock *NewBackEdgeBB = nullptr; + for (auto *User : Gep->users()) { + auto *UserPN = dyn_cast(User); + + // If the user is a bitcast, look through it + if (auto *BC = dyn_cast(User)) { + if (!BC->hasOneUse()) + return false; + Update = BC; + UserPN = dyn_cast(*BC->user_begin()); + } + + // The update must be used in a PHI + if (!UserPN) + return false; + + // And it must be a proper PHI update + if (!isIndVarPHIUpdateValue(Update, UserPN, NewEntryBB, NewBackEdgeBB)) + return false; + + // The PHI must be in the same block as the other updates + if (EntryBB == nullptr) { + EntryBB = NewEntryBB; + BackEdgeBB = NewBackEdgeBB; + } else if (EntryBB != NewEntryBB || BackEdgeBB != NewBackEdgeBB) + return false; + + PHINodes.push_back(UserPN); + } + + auto &DL = PHI->getModule()->getDataLayout(); + int64_t ElemTypeSize = + DL.getTypeStoreSize(Gep->getType()->getPointerElementType()); + + // Scale must be in bytes + Scale = ConstantInt::get(Scale->getType(), + Scale->getSExtValue() * ElemTypeSize, true); + return true; + } + + return false; +} + +// Returns whether U is a GEP that can be trivially hoisted out of the loop +static bool IsHoistableGEPCandidate(User *U, BasicBlock *EntryBB, Value *PHI, + GetElementPtrInst *&Gep, + Value *&PointerOpnd, + BitCastInst *&BC) { + Gep = dyn_cast(U); + if (!Gep) + return false; + + // Function to determine if all operands are invariant or the IV + auto AllInvariantOrIV = [EntryBB,PHI](Value *V){ + if (V == PHI || isa(V) || isa(V)) + return true; + if (auto VI = dyn_cast(V)) + return VI->getParent() == EntryBB; + return false; + }; + + if (!std::all_of(Gep->idx_begin(), Gep->idx_end(), AllInvariantOrIV)) + return false; + + PointerOpnd = Gep->getPointerOperand(); + auto *PointerOpndI = dyn_cast(PointerOpnd); + if (PointerOpndI && PointerOpndI->getParent() != EntryBB) + if ((BC = dyn_cast(PointerOpndI))) + PointerOpnd = BC->getOperand(0); + PointerOpndI = dyn_cast(PointerOpnd); + + return !PointerOpndI || (PointerOpndI->getParent() == EntryBB); +} + +#define MAX_LOOPSTRENGTH_LOOPVARS 5 + +/// Attempt to restructure vectorised loops so that we can better utilise SVE +/// reg+scaled_reg addressing modes. +/// +bool SVEAddressingModes::OptimizeBlock(BasicBlock *BB) { + bool LoopChanged = false; + + auto &DL = BB->getModule()->getDataLayout(); + IRBuilder<> Builder(BB); + + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e;) { + Instruction *I = &*it++; + + // Find the scaled elementcount update + SmallVector PHINodes; + ConstantInt *OrigScale; + BasicBlock *EntryBB = nullptr; + BasicBlock *BackEdgeBB = nullptr; + APInt ScaleDisp; + + if (!isScaledElementCountIVUpdate(I, PHINodes, OrigScale, EntryBB, BackEdgeBB)) + continue; + + for (auto *PHI : PHINodes) { + // Try to find an unscaled PHI with same elementcount. + auto *OtherPHI = findBetterPHI(PHI, EntryBB, BackEdgeBB, OrigScale, + ScaleDisp); + + // If it cannot find a suitable one, continue + if (OtherPHI == PHI) + continue; + + // Transform patterns like: + // body: + // %lsr.idx = phi(..) + // %ptr = getelementpr %addr, %lsr.idx + // %ptr2 = getelementptr %ptr, i64 4 + // + // Into: + // ph: + // %addr2 = getelementptr %addr, i64 4 + // body: + // %lsr.idx = phi(..) + // %ptr = getelementpr %addr2, %lsr.idx + for (auto *U : PHI->users()) { + // Limit number of new pointers created + if (U->hasNUsesOrMore(3)) + continue; + + GetElementPtrInst *Gep; + Value *PointerOpnd; + BitCastInst *BC = nullptr; + // Check if we can split the base pointer from + // the offsets and make it partly loop invariant. + if (!IsHoistableGEPCandidate(U, EntryBB, PHI, + /* out args: */ Gep, PointerOpnd, BC)) + continue; + + SmallVector ToRemove; + for (auto *UU : U->users()) { + // User must be a GEP + auto *OffsetGep = dyn_cast(UU); + if (!OffsetGep) + continue; + + // With only constant offsets + if (!OffsetGep->hasAllConstantIndices()) + continue; + + // It must be an offset to the same type + if (OffsetGep->getType() != OffsetGep->getPointerOperand()->getType()) + continue; + + bool AnyStructTy = false; + for (auto GTI = gep_type_begin(Gep), E = gep_type_end(Gep); GTI != E; + ++GTI) + if (GTI.isStruct()) + AnyStructTy = true; + + if (AnyStructTy) + continue; + + // Cast the pointer to the type we're expecting (in preheader) + Builder.SetInsertPoint(EntryBB->getTerminator()); + auto *NewPointerOpnd = + Builder.CreateBitCast(PointerOpnd, Gep->getType()); + + // And create a new base pointer with offset, + // and cast the result to the right pointer type + auto *NewBase = Builder.Insert(OffsetGep->clone()); + NewBase->replaceUsesOfWith(U, NewPointerOpnd); + NewBase = cast( + Builder.CreateBitCast(NewBase, + Gep->getPointerOperand()->getType())); + + // Copy the original GEP and use the new base + Builder.SetInsertPoint(Gep); + auto *NewGep = cast(Builder.Insert(Gep->clone())); + NewGep->replaceUsesOfWith(Gep->getPointerOperand(), NewBase); + + // Replace all uses and cleanup + OffsetGep->replaceAllUsesWith(NewGep); + ToRemove.push_back(OffsetGep); + } + + // Remove original nodes from parent + for (auto *OffsetGep : ToRemove) + OffsetGep->eraseFromParent(); + } + + // Start value of the LSR variable + Value *Start = PHI->getIncomingValueForBlock(EntryBB); + + // First handle scalar types + if (!Start->getType()->isPointerTy()) { + auto CI = dyn_cast(Start); + + // Simple case, scaled indvar starting at '0' + if (CI && CI->isNullValue()) { + Builder.SetInsertPoint(PHI->getParent()->getFirstNonPHI()); + auto *Scale = ConstantInt::get(PHI->getType(), ScaleDisp); + auto *Mul = Builder.CreateMul(OtherPHI, Scale); + PHI->replaceAllUsesWith(Mul); + } else { + // Check it is only used in suitable addressing modes + if (!onlyUsedForAddrModes(PHI, I, OrigScale)) + continue; + + // Re-scale the start value + Builder.SetInsertPoint(EntryBB->getTerminator()); + unsigned LogBase = ScaleDisp.logBase2(); + auto *ShiftAmount = + ConstantInt::get(Start->getType(), LogBase, false); + auto *StartI = Builder.CreateAShr(Start, ShiftAmount); + + // Add the unscaled PHI + Builder.SetInsertPoint(PHI->getParent()->getFirstNonPHI()); + auto *NewIV = Builder.CreateAdd(StartI, OtherPHI); + NewIV = Builder.CreateShl(NewIV, ShiftAmount); + + // If the start value is constant, create a new pointer in + // preheader for each user (GEP) + if (CI) { + + // Create a new scaled induction var without constant offset + Builder.SetInsertPoint(PHI->getParent()->getFirstNonPHI()); + auto *NewIVWithoutOffset = Builder.CreateShl(OtherPHI, ShiftAmount); + unsigned Counter = 0; + SmallVector PhiUsers(PHI->users()); + for (auto *U : PhiUsers) { + BitCastInst *BC = nullptr; + GetElementPtrInst *Gep; + Value *PointerOpnd; + + // Test if the GEP + offset can be hoisted + if (Counter < MAX_LOOPSTRENGTH_LOOPVARS && + IsHoistableGEPCandidate(U, EntryBB, PHI, + Gep, PointerOpnd, BC)) { + Counter++; + + // Clone the GEP in the preheader + Builder.SetInsertPoint(EntryBB->getTerminator()); + auto *Clone = Builder.Insert(Gep->clone()); + + // If the pointer is a bitcast, reconstruct it in the preheader + if (BC) { + Builder.SetInsertPoint(Clone); + auto *NewPointerOpnd = + Builder.CreateBitCast(PointerOpnd, BC->getType()); + Clone->replaceUsesOfWith(BC, NewPointerOpnd); + } + + for (unsigned J = 1; J < Gep->getNumOperands(); ++J) { + Value *Opnd = Clone->getOperand(J); + Value *NewOpnd = ConstantInt::get(Opnd->getType(), + Opnd == PHI ? CI->getSExtValue() : 0); + Clone->setOperand(J, NewOpnd); + } + + // Finally, bitcast the new pointer to the original pointer type + Builder.SetInsertPoint(EntryBB->getTerminator()); + Clone = cast(Builder.CreateBitCast( + Clone, Gep->getPointerOperand()->getType())); + + // Create a new GEP in the vector body and use new IV + Gep->replaceUsesOfWith(Gep->getPointerOperand(), Clone); + Gep->replaceUsesOfWith(PHI, NewIVWithoutOffset); + } + } + } + + // Insert the new indvar and scaled start value + PHI->replaceAllUsesWith(NewIV); + } + } + + // Handle pointer inductions + if (Start->getType()->isPointerTy()) { + auto *GepI = cast(I); + Builder.SetInsertPoint(PHI->getParent()->getFirstNonPHI()); + + // Find scalar element type + Type *BaseTy = Start->getType()->getPointerElementType(); + uint64_t ElemTypeSize = DL.getTypeStoreSize(BaseTy); + int64_t ScaleVal = ScaleDisp.getSExtValue(); + if (!BaseTy->isSingleValueType() || (ScaleVal % ElemTypeSize) != 0) { + + // Do everything in largest possible type that fits the Scale + LLVMContext &Context = BB->getContext(); + Type *NewBaseType = Type::getInt8Ty(Context); + if (ScaleVal % 8 == 0) + NewBaseType = Type::getInt64Ty(Context); + else if (ScaleVal % 4 == 0) + NewBaseType = Type::getInt32Ty(Context); + else if (ScaleVal % 2 == 0) + NewBaseType = Type::getInt16Ty(Context); + + ElemTypeSize = DL.getTypeStoreSize(NewBaseType); + Start = Builder.CreateBitCast(Start, NewBaseType->getPointerTo()); + } + + // Rescale + int64_t NewScaleVal = ScaleVal / ElemTypeSize; + auto *Scale = ConstantInt::get(OtherPHI->getType(), NewScaleVal, true); + + // Multiply with scalar indvar + auto *Idx = Builder.CreateMul(OtherPHI, Scale); + + // Calculate address + auto *NewGep = GepI->isInBounds() ? + Builder.CreateInBoundsGEP(Start, Idx) : + Builder.CreateGEP(Start, Idx); + + // Bitcast to original type + auto *BitCast = Builder.CreateBitCast(NewGep, PHI->getType()); + + // Replace + PHI->replaceAllUsesWith(BitCast); + } + + LoopChanged = true; + } + } + + return LoopChanged; +} Index: lib/Target/AArch64/SVEConditionalEarlyClobberPass.cpp =================================================================== --- /dev/null +++ lib/Target/AArch64/SVEConditionalEarlyClobberPass.cpp @@ -0,0 +1,183 @@ +//==-- SVEConditionalEarlyClobberPass.cpp - Conditionally add early clobber ==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass solves an issue with MOVPRFXable instructions that +// have the restriction that the destination register of a MOVPRFX +// cannot be used in any operand of the next instruction, except for +// the destructive operand. +// +// We chose to create Pseudo instructions to implement false-lane zeroing, +// where we specifically tried not to use the '$Zd = $Zs1' restriction +// so that the register allocator doesn't insert normal +// MOV instructions. The downside of doing that, is that the register +// allocation of: +// vreg1 = OP_ZEROING vreg0, vreg0 +// may result in: +// Z8 = OP_ZEROING Z8, Z8 +// +// At expand time, the OP_ZEROING will either need a scratch register to +// implement an actual 'MOV(DUP(0))', or will need to use a MOVPRFX Pg/z +// with a dummy ('nop'-like) MOVPRFXable instruction, like LSL #0. +// +// This is better handled by the register allocator creating an allocation +// that takes the above restriction into account, e.g. +// Z3 = OP_ZEROING Z8, Z8 +// which can be correctly expanded into: +// Z3 = MOVPRFX Pg/z, Z8 +// Z3 = OP Z3, Z8 +// +// After Coalescing of virtual registers, we know whether the input operands +// to the instruction will be in the same register or not. +// For our example: +// vreg1 = OP_ZEROING vreg0, vreg0 +// we know that vreg0 and vreg0 will be equal, but we don't know the +// register allocation of vreg1. We want to force that vreg1 will be different +// from vreg0, which can be done using an 'earlyclobber'. +// +// This pass adds the earlyclobber to the machine operand, and also updates +// the cache of live ranges so that subsequent passes don't need to +// recalculate those for the newly added earlyclobber. +// +//===----------------------------------------------------------------------===// + +#include "AArch64InstrInfo.h" +#include "llvm/CodeGen/LiveIntervals.h" +#include "llvm/CodeGen/LivePhysRegs.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/SlotIndexes.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +using namespace llvm; + +#define PASS_SHORT_NAME "SOME PASS NAME" + +namespace llvm { + void initializeSVEConditionalEarlyClobberPassPass(PassRegistry &); +} + +namespace { +class SVEConditionalEarlyClobberPass : public MachineFunctionPass { +public: + static char ID; + SVEConditionalEarlyClobberPass() : MachineFunctionPass(ID) { + initializeSVEConditionalEarlyClobberPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnMachineFunction(MachineFunction &Fn) override; + + StringRef getPassName() const override { return PASS_SHORT_NAME; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired(); + AU.addPreserved(); + AU.addRequired(); + AU.addPreserved(); + MachineFunctionPass::getAnalysisUsage(AU); + } +private: + const TargetInstrInfo *TII; + LiveIntervals *LIS; + + bool addConditionalEC(MachineInstr &MI); + bool hasConditionalClobber(const MachineInstr &MI); +}; +char SVEConditionalEarlyClobberPass::ID = 0; +} + +INITIALIZE_PASS(SVEConditionalEarlyClobberPass, + "aarch64-conditional-early-clobber", + PASS_SHORT_NAME, false, false) + +FunctionPass *llvm::createSVEConditionalEarlyClobberPass() { + return new SVEConditionalEarlyClobberPass(); +} + +// We could also choose to do this with a new instruction annotation +// like 'earlyclobberif($Zd=$Zs1)', but because this is so specific to SVE +// it should be fine to explicitly check the type of SVE operation where +// we know what the conditions are. +bool SVEConditionalEarlyClobberPass::hasConditionalClobber( + const MachineInstr &MI) { + int Instr = AArch64::getSVEPseudoMap(MI.getOpcode()); + if (Instr == -1) + return false; + + uint64_t DType = + TII->get(Instr).TSFlags & AArch64::DestructiveInstTypeMask; + auto mo_equals = [&](const MachineOperand &MO1, const MachineOperand &MO2) { + if (MO1.getReg() == MO2.getReg() && MO1.getSubReg() == MO2.getSubReg()) { + // This is needed to deal with cases where subreg assignment means that + // the earlyclobber isn't necessary. + return MI.getOperand(0).getSubReg() == MO1.getSubReg() || + ((MO1.getSubReg() == 0) ^ (MI.getOperand(0).getSubReg() == 0)); + } + return false; + }; + switch (DType) { + case AArch64::DestructiveBinary: + case AArch64::DestructiveBinaryComm: + case AArch64::DestructiveBinaryCommWithRev: + return mo_equals(MI.getOperand(2), MI.getOperand(3)); + case AArch64::DestructiveTernaryCommWithRev: + return mo_equals(MI.getOperand(2), MI.getOperand(3)) || + mo_equals(MI.getOperand(2), MI.getOperand(4)) || + mo_equals(MI.getOperand(3), MI.getOperand(4)); + case AArch64::NotDestructive: + case AArch64::DestructiveBinaryImm: + case AArch64::DestructiveBinaryShImmUnpred: + return false; + default: + break; + } + + llvm_unreachable("Not a known destructive operand type"); +} + +bool SVEConditionalEarlyClobberPass::addConditionalEC(MachineInstr &MI) { + // If the operand is already 'earlyclobber' or it doesn't require + // adding a conditional one (based on instruction), then don't bother. + if (!hasConditionalClobber(MI)) + return false; + + if (MI.getOperand(0).isEarlyClobber()) + return false; + + assert(MI.getOperand(0).isDef()); + + // Set the 'EarlyClobber' attribute for when the live ranges need + // to be recalculated. + MI.getOperand(0).setIsEarlyClobber(true); + + SlotIndex Index = LIS->getInstructionIndex(MI); + SlotIndex DefSlot = Index.getRegSlot(0); + + // Update the LiveRange cache by extending the liferange of the + // 'Def' register to be live earlier, so it overlaps with the + // live ranges of the input operands. + unsigned Reg = MI.getOperand(0).getReg(); + auto *Seg = LIS->getInterval(Reg).getSegmentContaining(DefSlot); + assert(Seg && "Expected Def operand to be live with instruction"); + Seg->start = Index.getRegSlot(true); + Seg->valno->def = Seg->start; + + return true; +} + +bool SVEConditionalEarlyClobberPass::runOnMachineFunction(MachineFunction &MF) { + LIS = &getAnalysis(); + TII = MF.getSubtarget().getInstrInfo(); + + bool Modified = false; + for (auto &MBB : MF) + for (auto &MI : MBB) + Modified |= addConditionalEC(MI); + + return Modified; +} Index: lib/Target/AArch64/SVEExpandLibCall.cpp =================================================================== --- /dev/null +++ lib/Target/AArch64/SVEExpandLibCall.cpp @@ -0,0 +1,890 @@ +//===----- SVEExpandLibCall - SVE Lib Call Expansion ----------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs two optimizations: +// +// 1. Scalarizes scalable vector function calls by using the SVE +// pnext predicate generating instruction to efficiently loop through the +// vector arguments for the call, before merging the scalar result into the +// destination vector. +// +// 2. Expands memset and memcpy intrinsics to SVE loops. +// +//===----------------------------------------------------------------------===// + +#include "Utils/AArch64BaseInfo.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/Local.h" +#include + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "sve-expandlibcall" + +STATISTIC(NumExpandedCalls, "Number of SVE libcalls expanded"); + +static cl::opt EnableMemCallInlining( + "sve-enable-mem-call-inlining", cl::init(true), cl::Hidden, + cl::desc("Replace calls to memsets/memcpy with an inline SVE loop")); + +static cl::opt ExpandMemCallThreshold( + "sve-expand-mem-call-threshold", cl::init(128), cl::Hidden, + cl::desc("Size threshold for expanding memset/memcpy calls to SVE loops")); + +static cl::opt EnableMemCallRuntimeCheck( + "sve-enable-mem-call-rtcheck", cl::init(true), cl::Hidden, + cl::desc("Enable runtime check for small memsets/memcpy calls")); + +namespace llvm { + void initializeSVEExpandLibCallPass(PassRegistry &); +} + +namespace { +struct SVEExpandLibCall : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + SVEExpandLibCall(bool Optimize = true) + : FunctionPass(ID), Optimize(Optimize) { + initializeSVEExpandLibCallPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; +private: + Function *F; + + /// If Optimize is true, it will also expand calls for optimization purposes + /// as opposed to lowering only (needed for calls to some vector intrinsics). + bool Optimize; + + bool ExpandCallToLoop(IntrinsicInst *II); + bool ExpandMemCallToLoop(MemIntrinsic *II); + bool ReplaceReduction(IntrinsicInst *II); + Instruction *CreatePNext(Type* Ty, Value* GP, Value* Pred); + Instruction *CreateWhile(Intrinsic::ID ID, Type* Ty, Value* Op1, Value* Op2); + Instruction *CreateLastB(Value* GP, Value* Vec); + Instruction *CreateMergeCpy(Value *Merge, Value *GP, Value *Scalar); + Instruction *CreateCall(Intrinsic::ID Id, Type *RetTy, + ArrayRef Ops); + bool optimizeExperimentalReduction(IntrinsicInst *I); +}; +} + +char SVEExpandLibCall::ID = 0; +static const char *name = "SVE Vector Lib Call Expansion"; +INITIALIZE_PASS_BEGIN(SVEExpandLibCall, DEBUG_TYPE, name, false, false) +INITIALIZE_PASS_END(SVEExpandLibCall, DEBUG_TYPE, name, false, false) + +namespace llvm { +FunctionPass *createSVEExpandLibCallPass(bool Optimize) { + return new SVEExpandLibCall(Optimize); +} +} + +// Check if the given instruction is a memory lib call that needs expanding. +static bool isMemLibCall(Instruction *I) { + IntrinsicInst *II = dyn_cast(I); + if (!II) + return false; + + switch (II->getIntrinsicID()) { + case Intrinsic::memset: + case Intrinsic::memcpy: + return true; + default: + break; + } + + return false; +} + +// Check if the given instruction is a vector lib call that needs expanding. +static bool isVectorLibCall(Instruction *I) { + IntrinsicInst *II; + if (!(II = dyn_cast(I))) + return false; + + switch (II->getIntrinsicID()) { + case Intrinsic::sin: + case Intrinsic::cos: + case Intrinsic::exp: + case Intrinsic::exp2: + case Intrinsic::log: + case Intrinsic::log2: + case Intrinsic::log10: + case Intrinsic::pow: + case Intrinsic::powi: + case Intrinsic::masked_sin: + case Intrinsic::masked_cos: + case Intrinsic::masked_copysign: + case Intrinsic::masked_exp: + case Intrinsic::masked_exp2: + case Intrinsic::masked_log: + case Intrinsic::masked_log2: + case Intrinsic::masked_log10: + case Intrinsic::masked_maxnum: + case Intrinsic::masked_minnum: + case Intrinsic::masked_pow: + case Intrinsic::masked_powi: + case Intrinsic::masked_rint: + case Intrinsic::masked_fmod: + break; + default: + return false; + } + + // Check the types of the intrinsic call. + auto RetTy = II->getFunctionType()->getReturnType(); + auto VTy = dyn_cast(RetTy); + return VTy && VTy->isScalable(); +} + +static bool isReduction(Instruction *I) { + auto *II = dyn_cast(I); + if (!II) + return false; + + unsigned VecArg; + switch (II->getIntrinsicID()) { + default: + return false; + case Intrinsic::experimental_vector_reduce_fadd: + VecArg = 1; + break; + case Intrinsic::experimental_vector_reduce_add: + case Intrinsic::experimental_vector_reduce_and: + case Intrinsic::experimental_vector_reduce_or: + case Intrinsic::experimental_vector_reduce_xor: + case Intrinsic::experimental_vector_reduce_smax: + case Intrinsic::experimental_vector_reduce_smin: + case Intrinsic::experimental_vector_reduce_umax: + case Intrinsic::experimental_vector_reduce_umin: + case Intrinsic::experimental_vector_reduce_fmax: + case Intrinsic::experimental_vector_reduce_fmin: + VecArg = 0; + break; + } + + // Check the types of the intrinsic call. + auto ParamTy = II->getFunctionType()->getParamType(VecArg); + auto VTy = dyn_cast(ParamTy); + return VTy && VTy->isScalable(); +} + +static Intrinsic::ID getUnmaskedIntrinsic(Intrinsic::ID ID) { + switch (ID) { + case Intrinsic::masked_sin: + return Intrinsic::sin; + case Intrinsic::masked_cos: + return Intrinsic::cos; + case Intrinsic::masked_copysign: + return Intrinsic::copysign; + case Intrinsic::masked_exp: + return Intrinsic::exp; + case Intrinsic::masked_exp2: + return Intrinsic::exp2; + case Intrinsic::masked_log: + return Intrinsic::log; + case Intrinsic::masked_log2: + return Intrinsic::log2; + case Intrinsic::masked_log10: + return Intrinsic::log10; + case Intrinsic::masked_maxnum: + return Intrinsic::maxnum; + case Intrinsic::masked_minnum: + return Intrinsic::minnum; + case Intrinsic::masked_pow: + return Intrinsic::pow; + case Intrinsic::masked_powi: + return Intrinsic::powi; + case Intrinsic::masked_rint: + return Intrinsic::rint; + default: + llvm_unreachable("Unexpected intrinsic ID"); + } +} + +Instruction *SVEExpandLibCall::CreatePNext(Type* Ty, Value* GP, Value* Pred) { + SmallVector Types = { Ty }; + SmallVector Args { GP, Pred }; + + Intrinsic::ID IntID = Intrinsic::aarch64_sve_pnext; + Function *Intr = Intrinsic::getDeclaration(F->getParent(), IntID, Types); + return CallInst::Create(Intr->getFunctionType(), Intr, Args); +} + +/// Create call to specified WHILE intrinsic. +/// +Instruction* SVEExpandLibCall::CreateWhile(Intrinsic::ID IntID, Type* Ty, + Value* Op1, Value* Op2) { + SmallVector Types = { Ty, Op1->getType() }; + SmallVector Args { Op1, Op2 }; + + Function *Intrinsic = Intrinsic::getDeclaration(F->getParent(), IntID, Types); + return CallInst::Create(Intrinsic->getFunctionType(), Intrinsic, Args); +} + + +Instruction *SVEExpandLibCall::CreateLastB(Value* GP, Value* Vec) { + SmallVector Args { GP, Vec }; + + Intrinsic::ID IntID = Intrinsic::aarch64_sve_lastb; + auto *Intr = Intrinsic::getDeclaration(F->getParent(), IntID, Vec->getType()); + return CallInst::Create(Intr->getFunctionType(), Intr, Args); +} + +Instruction *SVEExpandLibCall::CreateMergeCpy(Value *Merge, Value *GP, + Value *Scalar) { + auto VecTy = Merge->getType(); + SmallVector Args { Merge, GP, Scalar }; + + Intrinsic::ID IntID = Intrinsic::aarch64_sve_dup; + Function *Intr = Intrinsic::getDeclaration(F->getParent(), IntID, { VecTy }); + return CallInst::Create(Intr->getFunctionType(), Intr, Args); +} + +Instruction *SVEExpandLibCall::CreateCall(Intrinsic::ID Id, + Type *RetTy, ArrayRef Ops) { + auto Callee = Intrinsic::getDeclaration(F->getParent(), Id, + Ops[0]->getType()); + auto CI = CallInst::Create(Callee->getFunctionType(), Callee, Ops); + CI->setCallingConv(Callee->getCallingConv()); + return CI; +} + +// Expand a call to memcpy or memset into an optimized SVE loop +bool SVEExpandLibCall::ExpandMemCallToLoop(MemIntrinsic *II) { + if (II->isVolatile()) + return false; + + bool IndexMayOverflow = true; + + if (auto *CI = dyn_cast(II->getLength())) { + if (CI->getZExtValue() < ExpandMemCallThreshold) + return false; + + // Determine if constant is in range, so that index increment + // will never overflow. + if (CI->getZExtValue() < (UINT64_MAX - AArch64::SVEMaxBitsPerVector/8)) + IndexMayOverflow = false; + } + + LLVM_DEBUG(dbgs() << "SVEExpandLib: Expanding call to: " << + II->getCalledFunction()->getName() << "\n"); + + // Splitting basic block into expand (loop body) block and resume block. + auto *ParentBlock = II->getParent(); + auto *PHBlock = ParentBlock->splitBasicBlock(II, "mem.ph"); + auto *LoopBlock = PHBlock->splitBasicBlock(II, "mem.exploop"); + auto *MemIntrBlock = LoopBlock->splitBasicBlock(II, "mem.intrinsic"); + auto *ResumeBlock = MemIntrBlock->splitBasicBlock(II, "mem.resume"); + + // Fill in the preheader + BasicBlock::iterator InsertPt(ParentBlock->getTerminator()); + IRBuilder<> Builder(&*InsertPt); + + // Always use a 64bit iteration counter + Type *IdxTy = Builder.getInt64Ty(); + + // Possibly zero-extend the length + Value *Length = II->getLength(); + if (!Length->getType()->isIntegerTy(64)) { + Length = Builder.CreateZExt(Length, IdxTy); + IndexMayOverflow = false; + } + + // If it was already zero/sign extended, there is no wrap + if (isa(Length)) + IndexMayOverflow = false; + + // Create runtime check based on runtime vector length + if (EnableMemCallRuntimeCheck) { + auto VT = VectorType::get(Builder.getInt8Ty(), 16, true); + auto MinElts = Builder.getInt64(16); + auto NumElts = ConstantExpr::getRuntimeNumElements(MinElts->getType(), VT); + auto ScalarCompare = Builder.CreateICmpULE(NumElts, MinElts); + Builder.CreateCondBr(ScalarCompare, MemIntrBlock, PHBlock); + ParentBlock->getTerminator()->eraseFromParent(); + } + + // Create the splat (memset) + Builder.SetInsertPoint(PHBlock->getTerminator()); + auto *ValTy = VectorType::get(Builder.getInt8Ty(), 16, true); + Value *SetVal = nullptr; + if (auto *MS = dyn_cast(II)) + SetVal = + Builder.CreateVectorSplat(ValTy->getElementCount(), MS->getValue()); + + // Expand the compare (0 < N) + Value *Zero = ConstantInt::get(IdxTy, 0, false); + auto *PredTy = VectorType::get(Builder.getInt1Ty(), 16, true); + auto *PredPH = CreateWhile(Intrinsic::aarch64_sve_whilelo, + PredTy, Zero, Length); + Builder.Insert(PredPH); + // Fall through into LoopBlock + + // Set Insert point to loop body + Builder.SetInsertPoint(LoopBlock->getTerminator()); + + // Create PHI node for Induction and Predicate + auto *Pred = Builder.CreatePHI(PredTy, 2); + Pred->addIncoming(PredPH, PHBlock); + + auto *Ind = Builder.CreatePHI(IdxTy, 2); + Ind->addIncoming(Zero, PHBlock); + SmallVector Indices = { Ind }; + + // Create the load (in case of memcpy) + if (auto *MC = dyn_cast(II)) { + auto SrcAddr = Builder.CreateGEP(MC->getRawSource(), Indices); + SrcAddr = Builder.CreateBitCast(SrcAddr, ValTy->getPointerTo()); + SetVal = Builder.CreateMaskedLoad(SrcAddr, II->getDestAlignment(), Pred); + } + + assert(SetVal && "No Value to store"); + + // Create store + auto *Addr = Builder.CreateGEP(II->getRawDest(), Indices); + Addr = Builder.CreateBitCast(Addr, ValTy->getPointerTo()); + Builder.CreateMaskedStore(SetVal, Addr, II->getDestAlignment(), Pred); + + // Create next.index + Value *NextInd = nullptr; + if (IndexMayOverflow) { + Value *CntVPop = Builder.CreateCntVPop(Pred, "popcnt"); + CntVPop = Builder.CreateZExtOrTrunc(CntVPop, IdxTy); + NextInd = Builder.CreateNUWAdd(Ind, CntVPop); + } else { + auto NumElts = ConstantExpr::getRuntimeNumElements(IdxTy, ValTy); + NextInd = Builder.CreateNUWAdd(Ind, NumElts); + } + + Ind->addIncoming(NextInd, LoopBlock); + + // Create next.predicate + auto *NextPred = CreateWhile(Intrinsic::aarch64_sve_whilelo, + PredTy, NextInd, Length); + Builder.Insert(NextPred); + Pred->addIncoming(NextPred, LoopBlock); + + // Create test and conditional branch + Value *Continue = Builder.CreateExtractElement(NextPred, Builder.getInt64(0)); + Builder.CreateCondBr(Continue, LoopBlock, ResumeBlock); + LoopBlock->getTerminator()->eraseFromParent(); + + // Remove the original memset + II->moveBefore(MemIntrBlock->getTerminator()); + + NumExpandedCalls++; + return true; +} + + +// Replace generic reduction intrinsics with SVE specific ones. +// We could handle this at the codegen level, but we do it here because the +// VECREDUCE_* SDNodes don't take a predicate, like the IR intrinsics, so we +// have to recover the original predicate from the input vector if it's a +// select. As the original mask may be potentially hoisted out of the block +// it can be more optimal to translate to target specific intrinsics here. +bool SVEExpandLibCall::ReplaceReduction(IntrinsicInst *II) { + Intrinsic::ID NewID; + Value *SrcV = II->getArgOperand(0); + bool Ordered = false; + FastMathFlags FMF; + if (isa(II)) + FMF = II->getFastMathFlags(); + + switch (II->getIntrinsicID()) { + default: + llvm_unreachable("Unhandled intrinsic ID"); + case Intrinsic::experimental_vector_reduce_fadd: + // The common code currently only knows about unordered fp-add-reductions, + // hence the need to differentiate. For ordered reductions we need to use + // target custom code. + // TODO Revisit once ordered fp-reductions are added to the common code. + Ordered = !II->getFastMathFlags().allowReassoc(); + if (!Ordered) + return false; + + NewID = Intrinsic::aarch64_sve_fadda; + SrcV = II->getArgOperand(1); + break; + case Intrinsic::experimental_vector_reduce_add: + case Intrinsic::experimental_vector_reduce_smax: + case Intrinsic::experimental_vector_reduce_smin: + case Intrinsic::experimental_vector_reduce_umax: + case Intrinsic::experimental_vector_reduce_umin: + case Intrinsic::experimental_vector_reduce_fmax: + case Intrinsic::experimental_vector_reduce_fmin: + return false; + case Intrinsic::experimental_vector_reduce_and: + if (SrcV->getType()->getVectorElementType()->isIntegerTy(1)) { + // TODO Implement lowering for i1 + NewID = Intrinsic::aarch64_sve_andv; + break; + } + return false; + case Intrinsic::experimental_vector_reduce_xor: + if (SrcV->getType()->getVectorElementType()->isIntegerTy(1)) { + // TODO Implement lowering for i1 + NewID = Intrinsic::aarch64_sve_eorv; + break; + } + return false; + case Intrinsic::experimental_vector_reduce_or: + if (SrcV->getType()->getVectorElementType()->isIntegerTy(1)) { + // TODO Implement lowering for i1 + NewID = Intrinsic::aarch64_sve_orv; + break; + } + return false; + } + // Look for a select between a value and an identity element for a reduction. + Value *Predicate = nullptr; + Value *MaskedSrc; + if (match(SrcV, m_Select(m_Value(Predicate), m_Value(MaskedSrc), m_Zero()))) { + SrcV = MaskedSrc; + } else { + auto EC = cast(SrcV->getType())->getElementCount(); + Type *PredTy = + VectorType::get(Type::getInt1Ty(SrcV->getType()->getContext()), EC); + Predicate = ConstantInt::getTrue(PredTy); + } + + SmallVector Args { Predicate }; + SmallVector Tys { SrcV->getType() }; + if (Ordered) + Args.push_back(II->getArgOperand(0)); + Args.push_back(SrcV); + Function *Decl = Intrinsic::getDeclaration(F->getParent(), NewID, Tys); + Instruction *Rdx = CallInst::Create(Decl->getFunctionType(), Decl, Args); + Rdx->insertAfter(II); + if (II->getType() != Rdx->getType()) { + assert(Rdx->getType()->isIntegerTy(64) && "Unexpected type mismatch"); + auto *CI = CastInst::CreateTruncOrBitCast(Rdx, II->getType()); + CI->insertAfter(Rdx); + Rdx = CI; + } + II->replaceAllUsesWith(Rdx); + II->eraseFromParent(); + return true; +} + +// Expand the vector intrinsic call to an SVE loop with +bool SVEExpandLibCall::ExpandCallToLoop(IntrinsicInst *II) { + // We expand intrinsic call in the following code pattern for unpredicated + // intrinsics. If the intrinsic has a predicate then we need to extract it + // and use it as the gp for the pnext loop. + // bb: + // %vecparam = + // %vecparam2 = + // %retval = call @libfunc(%vecparam, %vecparam2) + // to: + // bb: + // %vecparam = + // %vecparam2 = + // %gp = sve_pnext( vecsplat(false)) + // %output_val = vecsplat(false) + // br exploop + // exploop: + // %gp' = phi [%gp, bb], [%gpnext, exploop] + // %output_val' = phi [%output_val, bb], [%mergeout, exploop] + // %argval = sve_lastb(gp', vecparam) + // %argval2 = sve_lastb(gp', vecparam2) + // %callret = call float @libfunc(%argval, %argval2) + // %mergeout = sve_cpy(output_val', gp', callret) + // %gpnext = sve_pnext(%gp') + // %test = test any true %gpnext + // br i1 %test exploop, resume + // resume: + // %newretval = phi [%mergeout, exploop] ; RAUW old %retval + auto PHBlock = II->getParent(); + auto LoopBlock = II->getParent()->splitBasicBlock(II, "exploop"); + auto ResumeBlock = LoopBlock->splitBasicBlock(&LoopBlock->front(), "resume"); + + const auto IsThereAMaskParam = isMaskedVectorIntrinsic(II->getIntrinsicID()); + const bool IsPredicated = IsThereAMaskParam.first; + const unsigned MaskPosition = IsThereAMaskParam.second; + + BasicBlock::iterator InsertPt(PHBlock->getTerminator()); + IRBuilder<> Builder(&*InsertPt); + + auto *VecTy = cast(II->getFunctionType()->getReturnType()); + auto EC = VecTy->getElementCount(); + auto ScalarTy = VecTy->getScalarType(); + auto PredTy = VectorType::get(Builder.getInt1Ty(), EC); + + auto PTrue = Builder.CreateVectorSplat(EC, Builder.getTrue()); + auto PFalse = Builder.CreateVectorSplat(EC, Builder.getFalse()); + Value *Predicate = IsPredicated ? II->getArgOperand(MaskPosition) : PTrue; + + auto InitGP = CreatePNext(PredTy, Predicate, PFalse); + Builder.Insert(InitGP); + auto InitOutVal = Builder.CreateVectorSplat(EC, + Builder.getIntN( + ScalarTy->getPrimitiveSizeInBits(), 0)); + InitOutVal = Builder.CreateBitCast(InitOutVal, VecTy); + + // Generate code for the expansion loop block. + Builder.SetInsertPoint(LoopBlock->getTerminator()); + auto GP = Builder.CreatePHI(PredTy, 2, "gp"); + GP->addIncoming(InitGP, PHBlock); + auto OutVal = Builder.CreatePHI(VecTy, 2, "outval"); + OutVal->addIncoming(InitOutVal, PHBlock); + + LLVM_DEBUG(dbgs() << "SVEExpandLib: Expanding call to: " << + II->getCalledFunction() << "\n"); + + SmallVector ScalarArgs; + + // Emit code to find the the scalar argument values from original vectors. + for (unsigned ArgIdx = 0; ArgIdx < II->getNumArgOperands(); ++ArgIdx) { + // Skip the mask parameter + if (IsPredicated && (MaskPosition == ArgIdx)) + continue; + + auto Param = II->getOperand(ArgIdx); + // Optimize cases where the vector is a splat of a scalar, in which case + // just use the original scalar without re-extracting it using lastb. + Value *SplatVal; + Value *LastArg; + if (!Param->getType()->isVectorTy()) { + // Scalar arg, pass through + LastArg = Param; + } else if (match(Param, m_SplatVector(m_Value(SplatVal)))) + LastArg = SplatVal; + else { + LastArg = CreateLastB(GP, Param); + Builder.Insert(cast(LastArg)); + } + ScalarArgs.push_back(LastArg); + } + + // We need to generate a call to the scalar version of the function. + Intrinsic::ID IID = II->getIntrinsicID(); + Instruction *NewCall; + if (IID == Intrinsic::masked_fmod) { + NewCall = BinaryOperator::CreateFRem(ScalarArgs[0], ScalarArgs[1]); + } else { + if (IsPredicated) + IID = getUnmaskedIntrinsic(IID); + + NewCall = CreateCall(IID, ScalarTy, ScalarArgs); + } + NewCall->setName("newcall"); + Builder.Insert(NewCall); + + // Merge in the new scalar call value into the result vector. + auto MergeCpy = CreateMergeCpy(OutVal, GP, NewCall); + OutVal->addIncoming(MergeCpy, LoopBlock); + Builder.Insert(MergeCpy); + + // Generate 'next' predicate for subsequent element. + auto GPNext = CreatePNext(PredTy, Predicate, GP); + GP->addIncoming(GPNext, LoopBlock); + Builder.Insert(GPNext); + + auto Test = getAnyTrueReduction(Builder, GPNext); + Builder.CreateCondBr(Test, LoopBlock, ResumeBlock); + + // Now remove old terminator br. + LoopBlock->getTerminator()->eraseFromParent(); + + // Replace all uses of the old vector call with the new merged vector. + Builder.SetInsertPoint(&ResumeBlock->front()); + auto ResumeVec = Builder.CreatePHI(VecTy, 1); + ResumeVec->addIncoming(MergeCpy, LoopBlock); + II->replaceAllUsesWith(ResumeVec); + II->eraseFromParent(); + NumExpandedCalls++; + return true; +} + +namespace { +std::pair +getReductionInfo(IntrinsicInst *II) { + switch (II->getIntrinsicID()) { + case Intrinsic::experimental_vector_reduce_fmul: + return {Instruction::BinaryOps::FMul, 1}; + case Intrinsic::experimental_vector_reduce_mul: + return {Instruction::BinaryOps::Mul, 0}; + default: + llvm_unreachable("Can not handle this reduction."); + } +} + +bool isExperimentalScalableReduction(Instruction *I) { + IntrinsicInst *II = dyn_cast(I); + if (!II) + return false; + + // We allow tree reductions only if the intrinstic has the fast math + // flag. + auto FPMathOp = dyn_cast(I); + if (FPMathOp && !(FPMathOp->isFast())) + return false; + + switch (II->getIntrinsicID()) { + case Intrinsic::experimental_vector_reduce_fmul: + case Intrinsic::experimental_vector_reduce_mul: { + // The candidate for the tranformation must be operating on + // scalable vectors. + const auto VecArgPos = getReductionInfo(II).second; + auto VecTy = dyn_cast(II->getArgOperand(VecArgPos)->getType()); + assert(VecTy && "A vector is expected."); + return VecTy->isScalable(); + } break; + default: + break; + } + + return false; +} +} // namespace + +bool SVEExpandLibCall::runOnFunction(Function &F) { + this->F = &F; + bool Changed = false; + SmallVector VectorWorkList; + SmallVector MemWorkList; + SmallVector ReductionWorkList; + SmallVector ExperimentalReductionWL; + + for (auto I = inst_begin(F), E = inst_end(F); I != E; ++I) { + if (isVectorLibCall(&*I)) + VectorWorkList.push_back(&*I); + else if (EnableMemCallInlining && isMemLibCall(&*I)) + MemWorkList.push_back(&*I); + else if (isReduction(&*I)) + ReductionWorkList.push_back(&*I); + else if (isExperimentalScalableReduction(&*I)) + ExperimentalReductionWL.push_back(&*I); + } + + for (auto *I : VectorWorkList) + Changed |= ExpandCallToLoop(cast(I)); + + for (auto *I : ReductionWorkList) + Changed |= ReplaceReduction(cast(I)); + + for (auto *I : ExperimentalReductionWL) + Changed |= optimizeExperimentalReduction(cast(I)); + + // If the target-feature for SVE is not set, we can't generate + // explicit SVE intrinsics to optimize memsets. + bool HasSVEAttribute = F.getAttributes() + .getFnAttributes() + .getAttribute("target-features") + .getValueAsString() + .contains("+sve"); + + if (Optimize && HasSVEAttribute) { + for (auto I : MemWorkList) + Changed |= ExpandMemCallToLoop(cast(I)); + } + + return Changed; +} + +// Expand scalable experimental reductions that do not have direct +// support in the SVE instruction set into a loop that accumulate +// lanes in fixed-width vectors. The loop iterate on the Scalable +// vector by shifting the lanes to the left with the EXT instruction +// of SVE. The result of the fixed-width accumulation is then +// accumulated into a scalar with a fixed number of scalar reductions. +// +// Note that the transformation is enabled for FP data only when the +// intrinsics are invoked with the "fast" attribute. As per the +// LangRef documentation fo the IR, this means that the scalar +// argument is always ignored. +// +// The algorithm loops on all sub-vectors of the scalable +// vector and accumulates them into a fixed-length vector +// of shape . Each sub-vector is produced by 1) +// shifting to the left m lanes of the scalable one using the +// llvm.aarch64.ext intrinsic, and 2) extracting the vector of the +// first m-lanes with a shuffle vector. +// +// The number of fixed-length sub-vector that needs to be processed, +// which is the n in , is a runtime value that is computed +// using the @llvm.aarch64.sve.cnt[bhwd] intrinsics: +// +// * @llvm.aarch64.sve.cntb for m = 16, +// * @llvm.aarch64.sve.cnth for m = 8, +// * @llvm.aarch64.sve.cntw for m = 4, +// * @llvm.aarch64.sve.cntd for m = 2. +// +// The first fixed-length extraction is performed outside the loop +// (without a shift of m lanes), as we can always assume a runtime +// value of n greater or equal to 1. The fixed-length accumulator +// vector is then reduced to a scalar with an +// llvm.experimental.reduce.* intrinsic. +// +// Original pseudo-code with IR types (assuming only one parameter in +// input, as the scalar argument is ignored because of the fast math +// requirements): +// +// %scalable_vec +// ty %acc = call @llvm.experimental.reduce.OP(%scalable_vec) +// +// The output produced by the transformation looks as follow: +// +// %scalable_vec +// %vec_acc = shuffle_vector %scalable_vec, +// %scalable_vec, +// <0, ..., m - 1> +// %i32 runtime_n = call @llvm.aarch64.sve.cnt[bhwd](i32 31) +// for (k = 2; k < n; ++k) +// { +// %scalable_vec = call @llvm.aarch64.sve.ext(%scalable_vec, +// %scalable_vec, m) +// %tmp = shuffle_vector %scalable_vec, +// %scalable_vec, +// <0, ..., m - 1> +// %vec_acc = %tmp OP %vec_acc +// } +// ty %acc = llvm.experimental.reduce.OP(%vec_acc) +bool SVEExpandLibCall::optimizeExperimentalReduction(IntrinsicInst *I) { + auto FPMathOp = dyn_cast(I); + assert(((FPMathOp && FPMathOp->isFast()) || !FPMathOp) && + "This transformation requires fast math when dealing with FP data."); + + const auto CallInfo = getReductionInfo(I); + const unsigned VectorArgPos = CallInfo.second; + const Instruction::BinaryOps BinOp = CallInfo.first; + + auto M = F->getParent(); + + auto PH = I->getParent(); + auto ReductionLoopBlock = PH->splitBasicBlock(I, "sve.reduction"); + auto NeonReduction = ReductionLoopBlock->splitBasicBlock(I, "neon.reduction"); + + auto OldPHTerminator = PH->getTerminator(); + BasicBlock::iterator InsertPt(OldPHTerminator); + IRBuilder<> Builder(&*InsertPt); + + // First extraction of a Neon vector in the pre-header PH. + auto ScalableData = I->getArgOperand(VectorArgPos); + auto ScalableVecTy = dyn_cast(ScalableData->getType()); + assert(ScalableVecTy && "Not a vector type."); + const unsigned Lanes = ScalableVecTy->getNumElements(); + auto ElTy = ScalableVecTy->getElementType(); + auto FixedWidthVecTy = VectorType::get(ElTy, Lanes); + + Type *PatternTy = Builder.getInt32Ty(); + Value *Ops[] = {ConstantInt::get(PatternTy, 31, false)}; + Intrinsic::ID GetCNT; + switch (Lanes) { + case 16: + GetCNT = Intrinsic::aarch64_sve_cntb; + break; + case 8: + GetCNT = Intrinsic::aarch64_sve_cnth; + break; + case 4: + GetCNT = Intrinsic::aarch64_sve_cntw; + break; + case 2: + GetCNT = Intrinsic::aarch64_sve_cntd; + break; + default: + llvm_unreachable("Invalid size."); + break; + } + auto Fn = Intrinsic::getDeclaration(M, GetCNT); + auto CI = Builder.CreateCall(Fn->getFunctionType(), Fn, Ops, "VL"); + auto IdxTy = CI->getType(); + auto NextVL = + Builder.CreateSub(CI, ConstantInt::get(IdxTy, Lanes, false), "NextVL"); + auto CmpPH = Builder.CreateICmpEQ(NextVL, ConstantInt::get(IdxTy, 0, false)); + + SmallVector Indexes; + for (unsigned i = 0; i < Lanes; ++i) + Indexes.push_back(i); + auto Shuffle = + Builder.CreateShuffleVector(ScalableData, ScalableData, Indexes, "Init"); + Builder.CreateCondBr(CmpPH, NeonReduction, ReductionLoopBlock); + OldPHTerminator->eraseFromParent(); + + // SVE reduction block + Builder.SetInsertPoint(ReductionLoopBlock->getFirstNonPHI()); + // Store the unconditional terminator that we will erase once the + // loop is created. + auto OldRLBTerminator = ReductionLoopBlock->getTerminator(); + auto PHI = Builder.CreatePHI(IdxTy, 2); + auto ScalablePHI = Builder.CreatePHI(ScalableVecTy, 2); + auto FixedWidthPHI = Builder.CreatePHI(FixedWidthVecTy, 2); + + Value *ExtOps[] = {ScalablePHI, ScalablePHI, + ConstantInt::get(PatternTy, Lanes, false)}; + auto ExtCI = CreateCall(Intrinsic::aarch64_sve_ext, ScalableVecTy, ExtOps); + ExtCI->setName("EXT"); + Builder.Insert(ExtCI); + auto ExtShuffle = Builder.CreateShuffleVector(ExtCI, ExtCI, Indexes); + auto Acc = Builder.CreateBinOp(BinOp, ExtShuffle, FixedWidthPHI, "Acc"); + ScalablePHI->addIncoming(ExtCI, ReductionLoopBlock); + ScalablePHI->addIncoming(ScalableData, PH); + + FixedWidthPHI->addIncoming(Acc, ReductionLoopBlock); + FixedWidthPHI->addIncoming(Shuffle, PH); + + auto Sub = Builder.CreateSub(PHI, ConstantInt::get(IdxTy, Lanes, false)); + auto Cmp = Builder.CreateICmpEQ(Sub, ConstantInt::get(IdxTy, 0, false)); + Builder.CreateCondBr(Cmp, NeonReduction, ReductionLoopBlock); + // Remove the unconditional terminator now that we have created the + // loop. + OldRLBTerminator->eraseFromParent(); + PHI->addIncoming(Sub, ReductionLoopBlock); + PHI->addIncoming(NextVL, PH); + + // NEON reduction block + Builder.SetInsertPoint(NeonReduction->getFirstNonPHI()); + auto FixedWidthReductionPHI = Builder.CreatePHI(FixedWidthVecTy, 2); + FixedWidthReductionPHI->addIncoming(Acc, ReductionLoopBlock); + FixedWidthReductionPHI->addIncoming(Shuffle, PH); + + // Create the fixed-width reduction on the accumulator. + CallInst *FixedRed = nullptr; + auto IID = I->getIntrinsicID(); + if (VectorArgPos == 0) { + auto FixedRedFn = + Intrinsic::getDeclaration(M, IID, {ElTy, FixedWidthVecTy}); + FixedRed = Builder.CreateCall(FixedRedFn->getFunctionType(), FixedRedFn, + {FixedWidthReductionPHI}, "FixedRed"); + } else if (VectorArgPos == 1) { + auto FixedRedFn = + Intrinsic::getDeclaration(M, IID, {ElTy, ElTy, FixedWidthVecTy}); + FixedRed = Builder.CreateCall( + FixedRedFn->getFunctionType(), FixedRedFn, + {UndefValue::get(ElTy), FixedWidthReductionPHI}, "FixedRed"); + + } else + llvm_unreachable("Invalid parameter position"); + assert(FixedRed && "Unable to generate a fixed-length reduction."); + + // Attach original fast math flags when needed. + if (FPMathOp) + FixedRed->setFastMathFlags(I->getFastMathFlags()); + + // Remove original instruction. + I->replaceAllUsesWith(FixedRed); + if (I->use_empty()) + I->eraseFromParent(); + + return true; +} Index: lib/Target/AArch64/SVEISelLowering.inc.h =================================================================== --- /dev/null +++ lib/Target/AArch64/SVEISelLowering.inc.h @@ -0,0 +1,3363 @@ +// SVE include file for AArch64ISelLowering.cpp +// These are new additional functions added to the lowering code separated out +// to reduce the impact of merge conflicts. + +static EVT getDoubleWidthHalfCountVectorVT(SelectionDAG &DAG, EVT VT) { + unsigned EltBits = VT.getVectorElementType().getSizeInBits(); + EVT EltVT = EVT::getIntegerVT(*DAG.getContext(), EltBits * 2); + auto HalfEC = VT.getVectorElementCount() / 2; + return EVT::getVectorVT(*DAG.getContext(), EltVT, HalfEC); +} + +static EVT getHalfWidthDoubleCountVectorVT(SelectionDAG &DAG, EVT VT) { + unsigned EltBits = VT.getVectorElementType().getSizeInBits(); + EVT EltVT = EVT::getIntegerVT(*DAG.getContext(), EltBits / 2); + auto DoubleEC = VT.getVectorElementCount() * 2; + return EVT::getVectorVT(*DAG.getContext(), EltVT, DoubleEC); +} + +static inline EVT getNaturalIntSVETypeWithMatchingElementCount(EVT VT) { + if (!VT.isScalableVector()) + return EVT(); + + switch (VT.getVectorNumElements()) { + default: return EVT(); + case 16: return VT.changeVectorElementType(MVT::i8); + case 8: return VT.changeVectorElementType(MVT::i16); + case 4: return VT.changeVectorElementType(MVT::i32); + case 2: return VT.changeVectorElementType(MVT::i64); + } +} + +static inline EVT getNaturalIntSVETypeWithMatchingElementType(EVT VT) { + if (!VT.isScalableVector()) + return EVT(); + + switch (VT.getVectorElementType().getSimpleVT().SimpleTy) { + default: return EVT(); + case MVT::i8: return MVT::nxv16i8; + case MVT::i16: return MVT::nxv8i16; + case MVT::i32: return MVT::nxv4i32; + case MVT::i64: return MVT::nxv2i64; + } +} + +static inline EVT getNaturalPredSVETypeWithMatchingElementType(EVT VT) { + if (!VT.isScalableVector()) + return EVT(); + + switch (VT.getVectorElementType().getSimpleVT().SimpleTy) { + default: return EVT(); + case MVT::i8: return MVT::nxv16i1; + case MVT::i16: return MVT::nxv8i1; + case MVT::i32: return MVT::nxv4i1; + case MVT::i64: return MVT::nxv2i1; + } +} + +static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT, + int Pattern) { + return DAG.getNode(AArch64ISD::PTRUE, DL, VT, + DAG.getConstant(Pattern, DL, MVT::i32)); +} + +static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, + AArch64CC::CondCode Cond) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + SDLoc DL(Op); + EVT OpVT = Op.getValueType(); + assert(OpVT.isScalableVector() && TLI.isTypeLegal(OpVT) && + "Expected legal scalable vector type!"); + + // Ensure target specific opcodes are using legal type. + EVT OutVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT); + SDValue TVal = DAG.getConstant(1, DL, OutVT); + SDValue FVal = DAG.getConstant(0, DL, OutVT); + + // Set condition code (CC) flags. + SDValue Test = DAG.getNode(AArch64ISD::PTEST, DL, MVT::Other, Pg, Op); + + // Convert CC to integer based on requested condition. + // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare. + SDValue CC = DAG.getConstant(getInvertedCondCode(Cond), DL, MVT::i32); + SDValue Res = DAG.getNode(AArch64ISD::CSEL, DL, OutVT, FVal, TVal, CC, Test); + return DAG.getZExtOrTrunc(Res, DL, VT); +} + +SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, + SelectionDAG &DAG) const { + SDLoc dl(Op); + SmallVector Results; + + ConstantSDNode *CN = cast(Op->getOperand(1)); + Intrinsic::ID IntID = static_cast(CN->getZExtValue()); + switch (IntID) { + default: + return SDValue(); + + case Intrinsic::masked_spec_load: + ReplaceMaskedSpecLoadResults(Op.getNode(), Results, DAG); + if (Results.size()) + return DAG.getMergeValues(Results, dl); + else + return SDValue(); + break; + } + + return SDValue(); +} + +static SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) { + assert(Op.getOpcode() == ISD::SDIV || Op.getOpcode() == ISD::UDIV); + + EVT VT = Op.getValueType(); + if ((VT != MVT::nxv16i8) && (VT != MVT::nxv8i16)) + return SDValue(); + + SDLoc DL(Op); + bool isSigned = Op.getOpcode() == ISD::SDIV; + unsigned ExtHiOpc = isSigned ? AArch64ISD::SUNPKHI : AArch64ISD::UUNPKHI; + unsigned ExtLoOpc = isSigned ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO; + + // Factors required to perform the operation using 32bit type (i.e. nxv4i32). + unsigned MaxFactor = VT.getVectorNumElements() / 4; + + SmallVector Res(MaxFactor); + SmallVector Op0(MaxFactor); + SmallVector Op1(MaxFactor); + + unsigned Factors = 1; + Op0[0] = Op.getOperand(0); + Op1[0] = Op.getOperand(1); + + // Extend the operands until suitable for the operation. + for (; Factors < MaxFactor; Factors *= 2) { + VT = getDoubleWidthHalfCountVectorVT(DAG, VT); + + for (unsigned i = Factors; i > 0; --i) { + Op0[2*i-1] = DAG.getNode(ExtHiOpc, DL, VT, Op0[i-1]); + Op0[2*i-2] = DAG.getNode(ExtLoOpc, DL, VT, Op0[i-1]); + Op1[2*i-1] = DAG.getNode(ExtHiOpc, DL, VT, Op1[i-1]); + Op1[2*i-2] = DAG.getNode(ExtLoOpc, DL, VT, Op1[i-1]); + } + } + + // Perform the operation using extended operands. + for (unsigned i = 0; i < Factors; ++i) + Res[i] = DAG.getNode(Op.getOpcode(), DL, VT, Op0[i], Op1[i]); + + // Truncate the result. + for (; Factors > 1; Factors /= 2) { + VT = getHalfWidthDoubleCountVectorVT(DAG, VT); + + for (unsigned i = 0; i < Factors; i += 2) + Res[i/2] = DAG.getNode(AArch64ISD::UZP1, DL, VT, Res[i], Res[i+1]); + } + + assert(Factors == 1); + return Res[0]; +} + +static SDValue LowerREM(SDValue Op, SelectionDAG &DAG) { + assert(Op.getOpcode() == ISD::SREM || Op.getOpcode() == ISD::UREM); + SDLoc DL(Op); + EVT VT = Op.getValueType(); + unsigned DivOp = Op.getOpcode() == ISD::SREM ? ISD::SDIV : ISD::UDIV; + + SDValue Div = DAG.getNode(DivOp, DL, VT, Op.getOperand(0), Op.getOperand(1)); + SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, Div, Op.getOperand(1)); + SDValue Rem = DAG.getNode(ISD::SUB, DL, VT, Op.getOperand(0), Mul); + return Rem; +} + +bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { + if (ExtVal.getValueType().isScalableVector()) + return true; + return false; +} + +// Lowering of LASTA and LASTB +// Recursively perform: +// { 1, 1, 1, 1, 1, 1, 0, 0 } +// {ab, cd, ef, gh }{ ij, kl, mn, op } +// uzp_lo => { b, d, f, h, j, l, n, p} +// uzp_hi => { a, c, e, g, i, k, m, o} +// lastb_lo => l +// lastb_hi => k +// result => (lastb_hi << 8) | lastb_lo => kl +SDValue AArch64TargetLowering::LowerLASTX(SDValue Op, SelectionDAG &DAG) const { + SDLoc DL(Op); + SDValue Pred = Op.getOperand(0); + SDValue InVec = Op.getOperand(1); + + if (isTypeLegal(InVec.getValueType())) + return Op; + + EVT VT = Op.getValueType(); + if (VT.isFloatingPoint()) { + EVT IntResVT = VT.changeTypeToInteger(); + EVT IntVecVT = InVec.getValueType().changeVectorElementTypeToInteger(); + SDValue IntInVec = DAG.getNode(ISD::BITCAST, DL, IntVecVT, InVec); + SDValue Res = DAG.getNode(Op.getOpcode(), DL, IntResVT, Pred, IntInVec); + return DAG.getNode(ISD::BITCAST, DL, VT, Res); + } + + assert(isTypeLegal(Pred.getValueType()) && "Need a legal predicate type"); + + // Type #elements stays the same + EVT EltVT = InVec.getValueType().getVectorElementType(); + EVT NewEltVT = EVT::getIntegerVT(*DAG.getContext(), EltVT.getSizeInBits()/2); + EVT SplitVT = InVec.getValueType().changeVectorElementType(NewEltVT); + + // Split sequence + SDValue InVecLo, InVecHi; + std::tie(InVecLo, InVecHi) = DAG.SplitVector(InVec, DL); + + // Bitcast to different type + InVecLo = DAG.getNode(ISD::BITCAST, DL, SplitVT, InVecLo); + InVecHi = DAG.getNode(ISD::BITCAST, DL, SplitVT, InVecHi); + + // Unzip (because #elements is same, it splits up elements in two parts) + SDValue Zero = DAG.getConstant(0, DL, MVT::i32); + SDValue One = DAG.getConstant(1, DL, MVT::i32); + SDValue Step = DAG.getConstant(2, DL, MVT::i32); + SDValue SVEven = DAG.getNode(ISD::SERIES_VECTOR, DL, SplitVT, Zero, Step); + SDValue SVOdd = DAG.getNode(ISD::SERIES_VECTOR, DL, SplitVT, One, Step); + + SDValue Even = DAG.getNode(ISD::VECTOR_SHUFFLE_VAR, DL, SplitVT, + InVecLo, InVecHi, SVEven); + SDValue Odd = DAG.getNode(ISD::VECTOR_SHUFFLE_VAR, DL, SplitVT, + InVecLo, InVecHi, SVOdd); + + // Do a LAST(A|B) for even and uneven + SDValue ResEven = DAG.getNode(Op.getOpcode(), DL, MVT::i32, Pred, Even); + SDValue ResOdd = DAG.getNode(Op.getOpcode(), DL, MVT::i32, Pred, Odd); + + // Possibly extend to i64 type to do the final combine + if (VT != MVT::i32) { + ResEven = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, ResEven); + ResOdd = DAG.getNode(ISD::ANY_EXTEND, DL, VT, ResOdd); + } + + SDValue NewEltBits = DAG.getConstant(NewEltVT.getSizeInBits(), DL, MVT::i32); + SDValue Res = DAG.getNode(ISD::SHL, DL, VT, ResOdd, NewEltBits); + Res = DAG.getNode(ISD::OR, DL, VT, Res, ResEven); + + return Res; +} + +// Use SVE to implement fixed-width masked loads. +static SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) { + auto MLN = cast(Op.getNode()); + SDValue Mask = MLN->getMask(); + SDValue Src0 = MLN->getSrc0(); + + EVT VT = Op.getValueType(); + if (!DAG.getTargetLoweringInfo().isTypeLegal(VT) || + VT.isScalableVector()) + return SDValue(); + + SDLoc DL(Op); + EVT DataVT, MaskVT; + + switch (VT.getVectorElementType().getSimpleVT().SimpleTy) { + default: return SDValue(); + case MVT::i8: DataVT = MVT::nxv16i8; MaskVT = MVT::nxv16i8; break; + case MVT::i16: DataVT = MVT::nxv8i16; MaskVT = MVT::nxv8i16; break; + case MVT::i32: DataVT = MVT::nxv4i32; MaskVT = MVT::nxv4i32; break; + case MVT::f32: DataVT = MVT::nxv4f32; MaskVT = MVT::nxv4i32; break; + case MVT::i64: DataVT = MVT::nxv2i64; MaskVT = MVT::nxv2i64; break; + case MVT::f64: DataVT = MVT::nxv2f64; MaskVT = MVT::nxv2i64; break; + } + + // TODO: Rather than mask InsertSubReg with a fixed length predicate we could + // just insert the original mask into a zero'd register. + int PgPattern; + switch (VT.getVectorNumElements()) { + default: return SDValue(); + case 16: PgPattern = AArch64SVEPredPattern::vl16; break; + case 8: PgPattern = AArch64SVEPredPattern::vl8; break; + case 4: PgPattern = AArch64SVEPredPattern::vl4; break; + case 2: PgPattern = AArch64SVEPredPattern::vl2; break; + } + + // Widen the NEON operands to SVE. + int Idx = VT.is64BitVector() ? AArch64::dsub : AArch64::zsub; + Mask = DAG.getTargetInsertSubreg(Idx, DL, MaskVT, DAG.getUNDEF(MaskVT), Mask); + Src0 = DAG.getTargetInsertSubreg(Idx, DL, DataVT, DAG.getUNDEF(DataVT), Src0); + + // Create a predicate restricted to the size of the NEON input. + // + // For the case of v2i32/v2f32, this works due to the fact the + // load is contiguous and doesn't need to expand to unpacked + // types like the gather lowering below. + EVT PredVT = MaskVT.changeVectorElementType(MVT::i1); + SDValue Pg = getPTrue(DAG, DL, PredVT, PgPattern); + Mask = DAG.getNode(ISD::TRUNCATE, DL, PredVT, Mask); + Mask = DAG.getNode(ISD::AND, DL, PredVT, Mask, Pg); + + SDValue MLoad = DAG.getMaskedLoad(DataVT, DL, MLN->getChain(), + MLN->getBasePtr(), Mask, Src0, + MLN->getMemoryVT(), MLN->getMemOperand(), + MLN->getExtensionType()); + + SDValue Res = DAG.getTargetExtractSubreg(Idx, DL, VT, MLoad.getValue(0)); + return DAG.getMergeValues({ Res, MLoad.getValue(1) }, DL); +} + +// Use SVE to implement fixed-width masked stores. +static SDValue LowerMSTORE(SDValue Op, SelectionDAG &DAG) { + auto MSN = cast(Op.getNode()); + SDValue Mask = MSN->getMask(); + SDValue Data = MSN->getValue(); + + EVT VT = Data.getValueType(); + if (!DAG.getTargetLoweringInfo().isTypeLegal(VT) || + VT.isScalableVector()) + return SDValue(); + + SDLoc DL(Op); + EVT DataVT, MaskVT; + + switch (VT.getVectorElementType().getSimpleVT().SimpleTy) { + default: return SDValue(); + case MVT::i8: DataVT = MVT::nxv16i8; MaskVT = MVT::nxv16i8; break; + case MVT::i16: DataVT = MVT::nxv8i16; MaskVT = MVT::nxv8i16; break; + case MVT::i32: DataVT = MVT::nxv4i32; MaskVT = MVT::nxv4i32; break; + case MVT::f32: DataVT = MVT::nxv4f32; MaskVT = MVT::nxv4i32; break; + case MVT::i64: DataVT = MVT::nxv2i64; MaskVT = MVT::nxv2i64; break; + case MVT::f64: DataVT = MVT::nxv2f64; MaskVT = MVT::nxv2i64; break; + } + + // TODO: Rather than mask InsertSubReg with a fixed length predicate we could + // just insert the original into a zero'd register. + int PgPattern; + switch (VT.getVectorNumElements()) { + default: return SDValue(); + case 16: PgPattern = AArch64SVEPredPattern::vl16; break; + case 8: PgPattern = AArch64SVEPredPattern::vl8; break; + case 4: PgPattern = AArch64SVEPredPattern::vl4; break; + case 2: PgPattern = AArch64SVEPredPattern::vl2; break; + } + + // Widen the NEON operands to SVE. + int Idx = VT.is64BitVector() ? AArch64::dsub : AArch64::zsub; + Mask = DAG.getTargetInsertSubreg(Idx, DL, MaskVT, DAG.getUNDEF(MaskVT), Mask); + Data = DAG.getTargetInsertSubreg(Idx, DL, DataVT, DAG.getUNDEF(DataVT), Data); + + // Create a predicate restricted to the size of the NEON input. + EVT PredVT = MaskVT.changeVectorElementType(MVT::i1); + SDValue Pg = getPTrue(DAG, DL, PredVT, PgPattern); + Mask = DAG.getNode(ISD::TRUNCATE, DL, PredVT, Mask); + Mask = DAG.getNode(ISD::AND, DL, PredVT, Mask, Pg); + + return DAG.getMaskedStore(MSN->getChain(), DL, Data, MSN->getBasePtr(), Mask, + MSN->getMemoryVT(), MSN->getMemOperand(), + MSN->isTruncatingStore()); +} + +// Use SVE to implement fixed-width masked gathers. +static SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + auto MGN = cast(Op.getNode()); + SDValue Mask = MGN->getMask(); + SDValue Src0 = MGN->getSrc0(); + SDValue Idxs = MGN->getIndex(); + + EVT VT = Op.getValueType(); + if (!TLI.isTypeLegal(VT) || VT.isScalableVector()) + return SDValue(); + + SDLoc DL(Op); + unsigned NumEls = VT.getVectorNumElements(); + EVT SveVT = (NumEls == 2) ? MVT::nxv2i64 : MVT::nxv4i32; + EVT NeonVT = (NumEls == 2) ? MVT::v2i64 : MVT::v4i32; + SDValue Undef = DAG.getUNDEF(SveVT); + + // We don't care about the actual data, just the number of bits. + if (VT.isFloatingPoint()) + Src0 = DAG.getNode(ISD::BITCAST, DL, VT.changeTypeToInteger(), Src0); + + // Promote all vectors to 128bit... + Src0 = DAG.getNode(ISD::ANY_EXTEND, DL, NeonVT, Src0); + Mask = DAG.getNode(ISD::ANY_EXTEND, DL, NeonVT, Mask); + Idxs = DAG.getNode(ISD::SIGN_EXTEND, DL, NeonVT, Idxs); + + // ...then make them scalable. + Src0 = DAG.getTargetInsertSubreg(AArch64::zsub, DL, SveVT, Undef, Src0); + Mask = DAG.getTargetInsertSubreg(AArch64::zsub, DL, SveVT, Undef, Mask); + Idxs = DAG.getTargetInsertSubreg(AArch64::zsub, DL, SveVT, Undef, Idxs); + + // Convert the mask to a real predicate. + EVT PredVT = SveVT.changeVectorElementType(MVT::i1); + SDValue Pg = getPTrue(DAG, DL, PredVT, (NumEls == 2) + ? AArch64SVEPredPattern::vl2 + : AArch64SVEPredPattern::vl4); + Mask = DAG.getNode(ISD::TRUNCATE, DL, PredVT, Mask); + Mask = DAG.getNode(ISD::AND, DL, PredVT, Mask, Pg); + + // Extension is introduced when we promote the data operand. + EVT MemVT = MGN->getMemoryVT().changeTypeToInteger(); + ISD::LoadExtType Ext = MGN->getExtensionType(); + if ((Ext == ISD::NON_EXTLOAD) && (MemVT != NeonVT)) + Ext = ISD::SEXTLOAD; + + auto Scale = DAG.getTargetConstant(1, DL, MVT::i64); + SDValue Ops[] = { MGN->getChain(), Src0, Mask, MGN->getBasePtr(), Idxs, Scale }; + SDValue MGather = DAG.getMaskedGather(DAG.getVTList(SveVT, MVT::Other), + MemVT, DL, Ops, MGN->getMemOperand(), + Ext, MGN->getIndexType()); + + // Convert load data back into a fixed-width type... + SDValue Data = MGather.getValue(0); + Data = DAG.getTargetExtractSubreg(AArch64::zsub, DL, NeonVT, Data); + Data = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Data); + + // ...of the correct denomination. + if (VT.isFloatingPoint()) + Data = DAG.getNode(ISD::BITCAST, DL, VT, Data); + + return DAG.getMergeValues({ Data, MGather.getValue(1) }, DL); +} + +// Use SVE to implement fixed-width masked scatters. +static SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + auto MSN = cast(Op.getNode()); + SDValue Mask = MSN->getMask(); + SDValue Data = MSN->getValue(); + SDValue Idxs = MSN->getIndex(); + + EVT VT = Data.getValueType(); + if (!TLI.isTypeLegal(VT) || VT.isScalableVector()) + return SDValue(); + + SDLoc DL(Op); + unsigned NumEls = VT.getVectorNumElements(); + EVT SveVT = (NumEls == 2) ? MVT::nxv2i64 : MVT::nxv4i32; + EVT NeonVT = (NumEls == 2) ? MVT::v2i64 : MVT::v4i32; + SDValue Undef = DAG.getUNDEF(SveVT); + + // We don't care about the actual data, just the number of bits. + if (VT.isFloatingPoint()) + Data = DAG.getNode(ISD::BITCAST, DL, VT.changeTypeToInteger(), Data); + + // Promote all vectors to 128bit... + Data = DAG.getNode(ISD::ANY_EXTEND, DL, NeonVT, Data); + Mask = DAG.getNode(ISD::ANY_EXTEND, DL, NeonVT, Mask); + Idxs = DAG.getNode(ISD::SIGN_EXTEND, DL, NeonVT, Idxs); + + // ...then make them scalable. + Data = DAG.getTargetInsertSubreg(AArch64::zsub, DL, SveVT, Undef, Data); + Mask = DAG.getTargetInsertSubreg(AArch64::zsub, DL, SveVT, Undef, Mask); + Idxs = DAG.getTargetInsertSubreg(AArch64::zsub, DL, SveVT, Undef, Idxs); + + // Convert the mask to a real predicate. + EVT PredVT = SveVT.changeVectorElementType(MVT::i1); + SDValue Pg = getPTrue(DAG, DL, PredVT, (NumEls == 2) + ? AArch64SVEPredPattern::vl2 + : AArch64SVEPredPattern::vl4); + Mask = DAG.getNode(ISD::TRUNCATE, DL, PredVT, Mask); + Mask = DAG.getNode(ISD::AND, DL, PredVT, Mask, Pg); + + // Truncation is introduced when we promote the data operand. + EVT MemVT = MSN->getMemoryVT().changeTypeToInteger(); + bool isTrunc = (MemVT != NeonVT); + + auto Scale = DAG.getTargetConstant(1, DL, MVT::i64); + SDValue Ops[] = { MSN->getChain(), Data, Mask, MSN->getBasePtr(), Idxs, Scale }; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL, Ops, + MSN->getMemOperand(), isTrunc, + MSN->getIndexType()); +} + +static SDValue getElementCountVector(SelectionDAG &DAG, SDLoc dl, EVT ResVT, + EVT VecVT) { + EVT EltType = ResVT.getVectorElementType(); + if (EltType.getSizeInBits() < 32) + EltType = MVT::i32; + SDValue EltCount = DAG.getVScale(dl, MVT::i64, VecVT.getVectorNumElements()); + EltCount = DAG.getZExtOrTrunc(EltCount, dl, EltType); + return DAG.getNode(AArch64ISD::DUP, dl, ResVT, EltCount); +} + +// Use Opcode to join consecutive vector operands in Ops together, +// then use CONCAT_VECTORS to join the results into a single vector. +static SDValue getSVEChainShuffle(unsigned Opcode, ArrayRef Ops, + unsigned Factor, SelectionDAG &DAG, + SDLoc dl, EVT VT, bool reverse = false) { + SmallVector Nodes(Factor); + if (!reverse) { + for (unsigned i = 0; i < Factor; ++i) { + SDValue Op0 = Ops[i * 2]; + SDValue Op1 = Ops[i * 2 + 1]; + Nodes[i] = DAG.getNode(Opcode, dl, Op0.getValueType(), Op0, Op1); + } + } else { + for (unsigned i = 0; i < Factor; ++i) { + SDValue Op0 = Ops[i]; + Nodes[Factor-1-i] = DAG.getNode(Opcode, dl, Op0.getValueType(), Op0); + } + } + return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Nodes); +} + +// Lower a VECTOR_SHUFFLE_VAR node that operates on vectors that are Factor +// times wider than a legal SVE vector by using smaller shuffles of type +// NewVT. +SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE_VAR( + SDValue Op, SelectionDAG &DAG, unsigned Factor, EVT NewVT) const { + SDLoc dl(Op); + EVT VT = Op.getValueType(); + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + SDValue Sel = Op.getOperand(2); + auto NewEltCnt = NewVT.getVectorElementCount(); + + // Divide concat(Op0, Op1) into NewVT pieces. + SmallVector Ops(Factor * 2); + if (Factor == 1) { + Ops[0] = Op0; + Ops[1] = Op1; + } else { + for (unsigned i = 0; i < Factor; ++i) { + SDValue Index = DAG.getConstant(i * NewEltCnt.Min, dl, MVT::i32); + Ops[i] = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NewVT, Op0, Index); + Ops[i + Factor] = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NewVT, + Op1, Index); + } + } + + if (Sel.getOpcode() == ISD::SPLAT_VECTOR) { + // Check for a splat of a vector insert. + auto CSplatElt = dyn_cast(Sel.getOperand(0)); + if (CSplatElt && CSplatElt->getZExtValue() == 0) { + if (Op0.getOpcode() == ISD::SCALAR_TO_VECTOR) + return DAG.getNode(ISD::SPLAT_VECTOR, dl, VT, Op0.getOperand(0)); + else if (Op0.getOpcode() == ISD::ZERO_EXTEND && + Op0.getOperand(0).getOpcode() == ISD::SCALAR_TO_VECTOR) { + SDValue Scalar = Op0.getOperand(0).getOperand(0); + EVT SplatVT = Op0.getOperand(0).getValueType(); + SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, dl, SplatVT, Scalar); + return DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Splat); + } + } + } + + // Find a potential ISD::SERIES_VECTOR, looking through truncations that are + // redundent given the restricted ranges of start/step we care about. + SDValue SeriesVector = Sel; + if (SeriesVector.getOpcode() == ISD::AND) { + SDValue Mask = SeriesVector.getOperand(1); + + if (Mask.getOpcode() == ISD::SPLAT_VECTOR) { + auto SplatVal = dyn_cast(Mask.getOperand(0)); + if (SplatVal && SplatVal->getZExtValue() == 0xfffffffful) + SeriesVector = SeriesVector->getOperand(0); + } + } + + if (SeriesVector.getOpcode() == ISD::SERIES_VECTOR) { + // Look for patterns that have SVE instructions associated with them. + SDValue Start = SeriesVector.getOperand(0); + SDValue Step = SeriesVector.getOperand(1); + if (isa(Start) && isa(Step)) { + uint64_t CStart = cast(Start)->getZExtValue(); + uint64_t CStep = cast(Step)->getZExtValue(); + // Remove shuffles that don't shuffle anything. + if (CStart == 0 && CStep == 1) + return Op0; + if (CStart == 0 && CStep == 2) + return getSVEChainShuffle(AArch64ISD::UZP1, Ops, Factor, DAG, dl, VT); + if (CStart == 1 && CStep == 2) + return getSVEChainShuffle(AArch64ISD::UZP2, Ops, Factor, DAG, dl, VT); + // Look for a splat. + if (CStart == 0 && CStep == 0 && + Op0.getOpcode() == ISD::SCALAR_TO_VECTOR) + return DAG.getNode(ISD::SPLAT_VECTOR, dl, VT, Op0.getOperand(0)); + // Look for a zero-extended splat. + if (CStart == 0 && CStep == 0 && + Op0.getOpcode() == ISD::ZERO_EXTEND && + Op0.getOperand(0).getOpcode() == ISD::SCALAR_TO_VECTOR) { + SDValue Scalar = Op0.getOperand(0).getOperand(0); + EVT SplatVT = Op0.getOperand(0).getValueType(); + SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, dl, SplatVT, Scalar); + return DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Splat); + } + } else { + // seriesvector((vscale * NumElts) - 1, -1) + if ((Start.getOpcode() == ISD::ADD) && + (Start.getOperand(0).getOpcode() == ISD::VSCALE) && + (isa(Start.getOperand(1))) && + (isa(Step))) { + SDValue VSImm = Start.getOperand(0).getOperand(0); + int64_t NumElts = SeriesVector.getValueType().getVectorNumElements(); + + if ((cast(VSImm)->getSExtValue() == NumElts) && + (cast(Start.getOperand(1))->getSExtValue() == -1) && + (cast(Step)->getSExtValue() == -1)) { + return getSVEChainShuffle(AArch64ISD::REV, Ops, Factor, DAG, dl, VT, + true); + } + } + } + } + + if (NewVT.getVectorElementType() == MVT::i1) { + EVT Elt; + switch (Op0.getValueType().getSimpleVT().SimpleTy) { + case MVT::nxv2i1: Elt = MVT::i64; break; + case MVT::nxv4i1: Elt = MVT::i32; break; + case MVT::nxv8i1: Elt = MVT::i16; break; + case MVT::nxv16i1: Elt = MVT::i8; break; + default: + llvm_unreachable("unexpected predicate type"); + } + + EVT VT1 = EVT::getVectorVT(*DAG.getContext(), Elt, + Op0.getValueType().getVectorElementCount()); + EVT VT2 = EVT::getVectorVT(*DAG.getContext(), Elt, NewEltCnt ); + + Op0 = DAG.getNode(ISD::ANY_EXTEND, dl, VT1, Op0); + Op1 = DAG.getNode(ISD::ANY_EXTEND, dl, VT1, Op1); + SDValue VS = DAG.getNode(ISD::VECTOR_SHUFFLE_VAR, dl, VT2, Op0, Op1, Sel); + return DAG.getNode(ISD::TRUNCATE, dl, NewVT, VS); + } + + EVT SelVT = Sel.getValueType(); + unsigned SelEltBits = SelVT.getVectorElementType().getSizeInBits(); + unsigned EltBits = VT.getVectorElementType().getSizeInBits(); + + bool FloatCast = NewVT.isFloatingPoint(); + EVT SavedType = Op.getValueType(); + EVT TruncType = SavedType; + if (FloatCast) { + assert(NewVT == VT && + "Float bitcasts in vecshufvar currently rely on types being equal"); + NewVT = VT = VT.changeVectorElementTypeToInteger(); + Op0 = DAG.getNode(ISD::BITCAST, dl, VT, Op0); + Op1 = DAG.getNode(ISD::BITCAST, dl, VT, Op1); + TruncType = VT; + } + + if (EltBits == 8 && SelEltBits > 8 && Op1.getOpcode() != ISD::UNDEF) { + // We're using a 16-bit mask to select 8-bit elements. Since the + // architecture limits the size of SVE vectors to 2048 bits, + // an i8 mask element is big enough to select any element of the + // first vector but might be too small to select an element of + // the second vector. We therefore need to do the selection + // on i16s and then truncate the result back to i8s. + EltBits *= 2; + EVT EltVT = EVT::getIntegerVT(*DAG.getContext(), EltBits); + VT = EVT::getVectorVT(*DAG.getContext(), EltVT, + VT.getVectorElementCount()); + // Widen each element of Ops by a factor of two and then split each + // widened value to get a pair of legal SVE vectors. + EVT DoubleVT = EVT::getVectorVT(*DAG.getContext(), EltVT, NewEltCnt); + NewEltCnt /= 2; + NewVT = EVT::getVectorVT(*DAG.getContext(), EltVT, NewEltCnt); + Factor *= 2; + Ops.resize(Factor * 2); + for (unsigned i = Factor; i-- > 0;) { + SDValue Ext = DAG.getNode(ISD::ANY_EXTEND, dl, DoubleVT, Ops[i]); + std::tie(Ops[i * 2], Ops[i * 2 + 1]) = DAG.SplitVector(Ext, dl, + NewVT, NewVT); + } + } + // Convert the mask to the element size. This is always safe if the + // result is i16 or wider, since i16 can hold any value in the range + // [0, 2*VL). The code above dealt with the cases where truncating + // to i8 would be unsafe. + if (EltBits != SelEltBits) { + Sel = DAG.getZExtOrTrunc(Sel, dl, VT); + } + + // Get the number of elements in each legally-typed piece. + SDValue Width = getElementCountVector(DAG, dl, NewVT, NewVT); + SmallVector Nodes(Factor); + SmallVector SubNodes(Factor * 2); + for (unsigned i = 0; i < Factor; ++i) { + // SubSel is the selector for piece i of the result. The part contributed + // by piece j of Ops is given by: + // + // TBL (Ops[j], SubSel - j * Width) + // + // in which the elements not taken from Ops[j] are zero. We can then + // OR up the results to get the final value. + SDValue Index = DAG.getConstant(i * NewEltCnt.Min, dl, MVT::i32); + SDValue SubSel = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NewVT, Sel, Index); + for (unsigned j = 0; j < Factor * 2; ++j) { + if (j != 0) + // Calculate SubSel - j * Width using a series of subtractions. + SubSel = DAG.getNode(ISD::SUB, dl, NewVT, SubSel, Width); + SubNodes[j] = DAG.getNode(AArch64ISD::TBL, dl, NewVT, Ops[j], SubSel); + } + // Create a tree of Factor*2-1 ORs. + for (unsigned j = Factor; j > 0; j /= 2) + for (unsigned k = 0; k < j; ++k) + SubNodes[k] = DAG.getNode(ISD::OR, dl, NewVT, SubNodes[k], + SubNodes[k + j]); + Nodes[i] = SubNodes[0]; + } + SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Nodes); + // If we had to widen the elements above, convert the result back + // to the original type. + SDValue Shuffled = DAG.getNode(ISD::TRUNCATE, dl, TruncType, Concat); + + if (FloatCast) + Shuffled = DAG.getNode(ISD::BITCAST, dl, SavedType, Shuffled); + + return Shuffled; +} + +// Lower SVE vector reductions for INT types +static SDValue LowerVECREDUCE_INT(SDValue Op, unsigned NewOpc, EVT RedResVT, + EVT LegalRedResVT, SelectionDAG &DAG) { + SDLoc DL(Op); + const SDValue &VecToReduce = Op.getOperand(0); + + // First, lower the reduction + SDValue Ptrue = getPTrue( + DAG, DL, VecToReduce.getValueType().changeVectorElementType(MVT::i1), + AArch64SVEPredPattern::all); + SDValue Zero = DAG.getConstant(0, DL, MVT::i32); + SDValue RdxVecVal = DAG.getNode(NewOpc, DL, RedResVT, {Ptrue, VecToReduce}); + + // Second, legalize the return value. In some cases bitcast or truncation + // (or both) will be folded away. + RdxVecVal = DAG.getNode(ISD::BITCAST, DL, LegalRedResVT, RdxVecVal); + SDValue ScalarRes = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, + LegalRedResVT.getVectorElementType(), RdxVecVal, Zero); + return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), ScalarRes); +} + +// Lower SVE vector reductions for FP types +static SDValue LowerVECREDUCE_FP(SDValue Op, unsigned NewOpc, + SelectionDAG &DAG) { + + SDLoc DL(Op); + SDValue VecToReduce = Op.getOperand(0); + EVT InVecVT = VecToReduce.getValueType(); + + // Lower the reduction. As per TableGen definitions, the return VT always + // maps 1-1 to an FP register. + SDValue Ptrue = getPTrue(DAG, DL, InVecVT.changeVectorElementType(MVT::i1), + AArch64SVEPredPattern::all); + return DAG.getNode(NewOpc, DL, InVecVT.getVectorElementType().getSimpleVT(), + {Ptrue, VecToReduce}); +} + +SDValue AArch64TargetLowering::LowerVECREDUCE_SVE(SDValue Op, + SelectionDAG &DAG) const { + + // As per TableGen definitions, some integer SVE reductions return vectors of + // illegal types (e.g. v8i16 or v16i8). Such vectors have to be legalized + // before the scalar result can be extracted. Get the original (RedResVT) and + // legal (LegalRedResVT) reduction result VT and let LowerVECREDUCE_INT do + // the work. For floating point reductions the return VT always maps 1-1 to + // an FP register, so RedResVT and LegalRedResVT are irrelevant. + const SDValue &VecToReduce = Op.getOperand(0); + EVT RedResVT = EVT::getVectorVT( + *DAG.getContext(), VecToReduce.getValueType().getVectorElementType(), + VecToReduce.getValueType().getVectorNumElements()); + + EVT ResEltVT = Op.getValueType(); + unsigned NumElts = 128 / ResEltVT.getScalarSizeInBits(); + EVT LegalRedResVT = EVT::getVectorVT(*DAG.getContext(), ResEltVT, NumElts); + + switch (Op.getOpcode()) { + case ISD::VECREDUCE_ADD: + return LowerVECREDUCE_INT(Op, AArch64ISD::UADDV_PRED, MVT::v2i64, + MVT::v2i64, DAG); + case ISD::VECREDUCE_AND: + return LowerVECREDUCE_INT(Op, AArch64ISD::ANDV_PRED, RedResVT, + LegalRedResVT, DAG); + case ISD::VECREDUCE_OR: + return LowerVECREDUCE_INT(Op, AArch64ISD::ORV_PRED, RedResVT, + LegalRedResVT, DAG); + case ISD::VECREDUCE_XOR: + return LowerVECREDUCE_INT(Op, AArch64ISD::EORV_PRED, RedResVT, + LegalRedResVT, DAG); + case ISD::VECREDUCE_SMAX: + return LowerVECREDUCE_INT(Op, AArch64ISD::SMAXV_PRED, RedResVT, + LegalRedResVT, DAG); + case ISD::VECREDUCE_SMIN: + return LowerVECREDUCE_INT(Op, AArch64ISD::SMINV_PRED, RedResVT, + LegalRedResVT, DAG); + case ISD::VECREDUCE_UMAX: + return LowerVECREDUCE_INT(Op, AArch64ISD::UMAXV_PRED, RedResVT, + LegalRedResVT, DAG); + case ISD::VECREDUCE_UMIN: + return LowerVECREDUCE_INT(Op, AArch64ISD::UMINV_PRED, RedResVT, + LegalRedResVT, DAG); + case ISD::VECREDUCE_FADD: + return LowerVECREDUCE_FP(Op, AArch64ISD::FADDV_PRED, DAG); + case ISD::VECREDUCE_FMAX: + if (Op->getFlags().hasNoNaNs()) + return LowerVECREDUCE_FP(Op, AArch64ISD::FMAXNMV_PRED, DAG); + else + return LowerVECREDUCE_FP(Op, AArch64ISD::FMAXV_PRED, DAG); + case ISD::VECREDUCE_FMIN: + if (Op->getFlags().hasNoNaNs()) + return LowerVECREDUCE_FP(Op, AArch64ISD::FMINNMV_PRED, DAG); + else + return LowerVECREDUCE_FP(Op, AArch64ISD::FMINV_PRED, DAG); + default: + llvm_unreachable("Unhandled reduction"); + } +} + +SDValue AArch64TargetLowering::LowerSERIES_VECTOR(SDValue Op, + SelectionDAG &DAG) const { + SDLoc dl(Op); + EVT VT = Op.getValueType(); + EVT ElemVT = VT.getScalarType(); + + SDValue Start = Op.getOperand(0); + SDValue Step = Op.getOperand(1); + + if (VT.getVectorElementType().isFloatingPoint()) + return SDValue(); + + if (ElemVT == MVT::i1) { + auto *CStart = dyn_cast(Op.getOperand(0)); + auto *CStep = dyn_cast(Op.getOperand(1)); + + if (CStart && CStep && CStep->isNullValue()) { + if (CStart->isNullValue()) + return SDValue(DAG.getMachineNode(AArch64::PFALSE, dl, VT), 0); + else + return DAG.getNode(AArch64ISD::PTRUE, dl, VT, + DAG.getConstant(31, dl, MVT::i32)); + } else { + auto EltCnt = VT.getVectorElementCount(); + int ExtBits = AArch64::SVEBitsPerBlock / EltCnt.Min; + EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ExtBits); + + EVT SplatVT = EVT::getVectorVT(*DAG.getContext(), ExtVT, EltCnt); + SDValue Splat = DAG.getNode(ISD::SERIES_VECTOR, dl, SplatVT, Start, Step); + return DAG.getNode(ISD::TRUNCATE, dl, VT, Splat); + } + } else if (ElemVT == MVT::i64) { + SDValue Op1 = DAG.getSExtOrTrunc(Op.getOperand(0), dl, ElemVT); + SDValue Op2 = DAG.getSExtOrTrunc(Op.getOperand(1), dl, ElemVT); + return DAG.getNode(ISD::SERIES_VECTOR, dl, VT, Op1, Op2); + } + + // use default expansion for everything else + return SDValue(); +} + +SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op, + SelectionDAG &DAG) const { + SDLoc dl(Op); + EVT VT = Op.getValueType(); + EVT ElemVT = VT.getScalarType(); + + SDValue SplatVal = Op.getOperand(0); + + switch (ElemVT.getSimpleVT().SimpleTy) { + case MVT::i1: + if (auto CSplatVal = dyn_cast(SplatVal)) { + if (CSplatVal->isNullValue()) + return SDValue(DAG.getMachineNode(AArch64::PFALSE, dl, VT), 0); + else + return DAG.getNode(AArch64ISD::PTRUE, dl, VT, + DAG.getConstant(31, dl, MVT::i32)); + } else { + auto EltCnt = VT.getVectorElementCount(); + int ExtBits = AArch64::SVEBitsPerBlock / EltCnt.Min; + EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ExtBits); + + switch (ExtVT.getSimpleVT().SimpleTy) { + case MVT::i64: + SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64); + break; + case MVT::i32: + case MVT::i16: + case MVT::i8: + SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i32); + break; + default: + llvm_unreachable("Unusable extended element type for i1 splat!"); + break; + } + + EVT SplatVT = EVT::getVectorVT(*DAG.getContext(), ExtVT, EltCnt); + SDValue Splat = DAG.getNode(AArch64ISD::DUP, dl, SplatVT, SplatVal); + return DAG.getNode(ISD::TRUNCATE, dl, VT, Splat); + } + break; + case MVT::i8: + case MVT::i16: + SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i32); + break; + case MVT::i64: + SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64); + break; + case MVT::i32: + case MVT::f16: + case MVT::f32: + case MVT::f64: + // Fine as is + break; + default: + llvm_unreachable("Unsupported SPLAT_VECTOR input operand type"); + break; + } + + return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); +} + +SDValue AArch64TargetLowering::LowerCONCAT_VECTORS(SDValue Op, + SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + assert(VT.isScalableVector() && "can only lower WA vectors"); + assert(isTypeLegal(VT) && "expected type legal result"); + + EVT OpVT = Op.getOperand(0).getValueType(); + if (OpVT != Op.getOperand(1).getValueType()) + return SDValue(); + + SDLoc dl(Op); + SDValue OpLHS = Op.getOperand(0); + SDValue OpRHS = Op.getOperand(1); + + assert(VT.getVectorNumElements() == (OpVT.getVectorNumElements()*2) && + "total source operand element count does not match result"); + + if (OpVT.getVectorElementType().isInteger() && + OpVT.getVectorElementType() != MVT::i1) { + EVT WideVT = OpVT.widenIntegerVectorElementType(*DAG.getContext()); + OpLHS = DAG.getNode(ISD::ANY_EXTEND, dl, WideVT, OpLHS); + OpRHS = DAG.getNode(ISD::ANY_EXTEND, dl, WideVT, OpRHS); + } else if (OpVT.getVectorElementType().isFloatingPoint()) { + OpLHS = DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, VT, OpLHS); + OpRHS = DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, VT, OpRHS); + } + + // Are we extracting the hi half of a double legal width vector? + if ((OpLHS->getOpcode() == ISD::SRL) && (OpRHS->getOpcode() == ISD::SRL) && + (OpLHS->getOperand(1) == OpRHS->getOperand(1))) { + APInt ShiftAmount; + if (DAG.isConstantIntSplat(OpLHS->getOperand(1), &ShiftAmount)) { + if (ShiftAmount == VT.getVectorElementType().getSizeInBits()) + return DAG.getNode(AArch64ISD::UZP2, dl, VT, OpLHS->getOperand(0), + OpRHS->getOperand(0)); + } + } + + return DAG.getNode(AArch64ISD::UZP1, dl, VT, OpLHS, OpRHS); +} + +SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + EVT Op0VT = Op.getOperand(0).getValueType(); + EVT Op1VT = Op.getOperand(1).getValueType(); + + if (!Op0VT.isScalableVector() || !Op0VT.isInteger()) + return SDValue(); + + unsigned NumElts = Op1VT.getVectorNumElements(); + + SDValue Vec0 = Op.getOperand(0); + SDValue Vec1 = Op.getOperand(1); + SDValue Idx = Op.getOperand(2); + + // Idx needs to be a constant, either 0 or 1 for WA vectors for now. + ConstantSDNode *CIdx = dyn_cast(Idx); + if (!CIdx) + return SDValue(); + + unsigned IdxVal = CIdx->getZExtValue(); + + // Ensure the subvector is half the size of the main vector. + if (Op0VT.getVectorNumElements() != (Op1VT.getVectorNumElements() * 2)) + return SDValue(); + + // Extend elements of smaller vector... + EVT WideVT = Op1VT.widenIntegerVectorElementType(*(DAG.getContext())); + SDValue ExtVec = DAG.getNode(ISD::ANY_EXTEND, DL, WideVT, Vec1); + + // Can only handle upper/lower half right now. + if (IdxVal == 0) { + SDValue HiVec0 = DAG.getNode(AArch64ISD::UUNPKHI, DL, WideVT, Vec0); + return DAG.getNode(AArch64ISD::UZP1, DL, VT, ExtVec, HiVec0); + } else if (IdxVal == NumElts) { + SDValue LoVec0 = DAG.getNode(AArch64ISD::UUNPKLO, DL, WideVT, Vec0); + return DAG.getNode(AArch64ISD::UZP1, DL, VT, LoVec0, ExtVec); + } + + return SDValue(); +} + +SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op, + SelectionDAG &DAG) const { + SDLoc dl(Op); + + EVT VT = Op.getValueType(); + if (!isTypeLegal(VT) || !VT.isScalableVector()) + return SDValue(); + + if (Op.getOperand(0).getOpcode() == ISD::CONCAT_VECTORS) { + SDValue CC = Op.getOperand(0); + SDValue Lo = CC.getOperand(0); + SDValue Hi = CC.getOperand(1); + + EVT CCVT = CC.getValueType(); + EVT LoVT = Lo.getValueType(); + + if (CCVT.getVectorNumElements() == (LoVT.getVectorNumElements() * 2)) + return DAG.getNode(AArch64ISD::UZP1, dl, VT, Lo, Hi); + } + + return SDValue(); +} + +SDValue AArch64TargetLowering::LowerDUPQLane(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + + EVT VT = Op.getValueType(); + if (!isTypeLegal(VT) || !VT.isScalableVector()) + return SDValue(); + + // Current lowering only supports the SVE-ACLE types. + if (VT.getSizeInBits() != AArch64::SVEBitsPerBlock) + return SDValue(); + + // The DUPQ operation is indepedent of element type so normalise to i64s. + auto V = DAG.getNode(ISD::BITCAST, DL, MVT::nxv2i64, Op.getOperand(1)); + auto Idx128 = Op.getOperand(2); + + // DUPQ can be used when idx is in range. + auto CIdx = dyn_cast(Idx128); + if (CIdx && (CIdx->getZExtValue() <= 3)) { + auto CI = DAG.getTargetConstant(CIdx->getZExtValue(), DL, MVT::i64); + auto DUPQ = DAG.getMachineNode(AArch64::DUP_ZZI_Q, DL, MVT::nxv2i64, V, CI); + return DAG.getNode(ISD::BITCAST, DL, VT, SDValue(DUPQ, 0)); + } + + // The ACLE says this must produce the same result as: + // svtbl(data, svadd_x(svptrue_b64(), + // svand_x(svptrue_b64(), svindex_u64(0, 1), 1), + // index * 2)) + auto Zero = DAG.getConstant(0, DL, MVT::i64); + auto One = DAG.getConstant(1, DL, MVT::i64); + auto SplatOne = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, One); + + // create the vector 0,1,0,1,... + auto SV = DAG.getNode(ISD::SERIES_VECTOR, DL, MVT::nxv2i64, Zero, One); + SV = DAG.getNode(ISD::AND, DL, MVT::nxv2i64, SV, SplatOne); + + // create the vector idx64,idx64+1,idx64,idx64+1,... + auto Idx64 = DAG.getNode(ISD::ADD, DL, MVT::i64, Idx128, Idx128); + auto SplatIdx64 = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, Idx64); + auto ShufffleMask = DAG.getNode(ISD::ADD, DL, MVT::nxv2i64, SV, SplatIdx64); + + // create the vector Val[idx64],Val[idx64+1],Val[idx64],Val[idx64+1],... + auto TBL = DAG.getNode(AArch64ISD::TBL, DL, MVT::nxv2i64, V, ShufffleMask); + return DAG.getNode(ISD::BITCAST, DL, VT, TBL); +} + +SDValue AArch64TargetLowering::LowerVectorBITCAST(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + auto Src = Op.getOperand(0); + + if (!VT.isScalableVector()) + return SDValue(); + + if (isTypeLegal(VT) && !isTypeLegal(Src.getValueType())) { + assert(VT.isFloatingPoint() && "Expected int->fp bitcast!"); + EVT ContainerVT = getNaturalIntSVETypeWithMatchingElementCount(VT); + auto Tmp = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Src); + return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Tmp); + } + + return SDValue(); +} + +SDValue AArch64TargetLowering::LowerVSCALE(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + + if (VT != MVT::i64) { + int64_t MulImm = cast(Op.getOperand(0))->getSExtValue(); + return DAG.getZExtOrTrunc(DAG.getVScale(DL, MVT::i64, MulImm), DL, VT); + } + + return SDValue(); +} + +// zip([a, b, c, d], [e, f, g, h]) = +// = {[a, e, b, f], [c, g, d, h]} +// ^^^^^^^^^^ ^^^^^^^^^^ +// zip lo zip hi +static std::tuple getZip(Value *A, Value *B, + IRBuilder<> &Builder) { + Intrinsic::ID zips[] = { Intrinsic::aarch64_sve_zip1, + Intrinsic::aarch64_sve_zip2 }; + Module *M = Builder.GetInsertPoint()->getModule(); + auto ZipLoIntr = Intrinsic::getDeclaration(M, zips[0], {A->getType()}); + auto ZipHiIntr = Intrinsic::getDeclaration(M, zips[1], {A->getType()}); + Value *Ops[] = {A, B}; + auto ZipLo = Builder.CreateCall(ZipLoIntr, Ops, "ZipLo"); + auto ZipHi = Builder.CreateCall(ZipHiIntr, Ops, "ZipHi"); + + return std::make_tuple(ZipLo, ZipHi); +} + +// Unzip([a, b, c, d], [e, f, g, h]) +// = {[a, c, e, g], [b, d, f, h]} +// ^^^^^^^^^^ ^^^^^^^^^^ +// even odd +static std::tuple getUnzip(Value *A, Value *B, + IRBuilder<> &Builder) { + auto Zero = Builder.getInt32(0); + auto One = Builder.getInt32(1); + auto Two = Builder.getInt32(2); + + auto EC = cast(A->getType())->getElementCount(); + + auto EvenIdx = Builder.CreateSeriesVector(EC, Zero, Two); + auto OddIdx = Builder.CreateSeriesVector(EC, One, Two); + + auto UnzipEven = Builder.CreateShuffleVector(A, B, EvenIdx); + auto UnzipOdd = Builder.CreateShuffleVector(A, B, OddIdx); + + return std::make_tuple(UnzipEven, UnzipOdd); +} + +bool AArch64TargetLowering::lowerGathersToInterleavedLoad( + ArrayRef Gathers, IntrinsicInst *FirstGather, + int OffsetFirstGather, unsigned Factor, TargetTransformInfo *TTI) const { + assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() && + "Invalid interleave factor"); + assert(!Gathers.empty() && "Empty gather input"); + + auto LowestGather = cast(Gathers[0]); + const DataLayout &DL = LowestGather->getModule()->getDataLayout(); + + VectorType *VecTy = dyn_cast(LowestGather->getType()); + if (!VecTy) + return false; + // Test if factor can be efficiently vectorised with LD(2|3|4) + // instructions. + switch (Factor){ + case 2: + case 3: + case 4: + case 6: + case 8: + break; + default: + return false; + } + // This code assumes scalable types + if (!Subtarget->hasSVE() || !VecTy->isScalable()) + return false; + // We only support legal types for now + if (!TTI->isTypeLegal(VecTy)) + return false; + // We don't handle . SC-1277 will remove this + // limitation + if ((VecTy->getNumElements() == 2) && VecTy->getElementType()->isFloatTy()) + return false; + + // A pointer vector can not be the return type of the ldN intrinsics. Need to + // load integer vectors first and then convert to pointer vectors. + Type *EltTy = VecTy->getElementType(); + if (EltTy->isPointerTy()) + VecTy = VectorType::get(DL.getIntPtrType(EltTy), VecTy->getElementCount()); + + Type *Tys = {VecTy}; + static const Intrinsic::ID LoadInts[3] = {Intrinsic::aarch64_sve_ld2_legacy, + Intrinsic::aarch64_sve_ld3_legacy, + Intrinsic::aarch64_sve_ld4_legacy}; + bool IsNativeStride = (Factor <= 4); + auto FactorIdx = IsNativeStride ? Factor : Factor/2; + + Function *LdNFunc = Intrinsic::getDeclaration(LowestGather->getModule(), + LoadInts[FactorIdx - 2], Tys); + + IRBuilder<> Builder(FirstGather); + auto FirstAddr = Builder.CreateExtractElement(FirstGather->getOperand(0), + Builder.getInt32(0)); + if (OffsetFirstGather) { + FirstAddr = Builder.CreateBitCast(FirstAddr, Builder.getInt8PtrTy()); + Value *Offset = + ConstantInt::getSigned(Builder.getInt32Ty(), OffsetFirstGather); + FirstAddr = Builder.CreateGEP(FirstAddr, Offset); + } + auto PtrType = VecTy->getPointerTo(); + auto Ptr = Builder.CreateBitCast(FirstAddr, PtrType); + + // Create loads + CallInst *LdN, *LdN2; + if (!IsNativeStride) { + Value *Pred = LowestGather->getOperand(2); + + // Create zip for the predicates, e.g. (1,1,1,0) + // -> (1,1,1,1),(1,1,0,0) for the two loads + Value *ZipPredLo, *ZipPredHi; + std::tie(ZipPredLo, ZipPredHi) = getZip(Pred, Pred, Builder); + + // Create first load + Value *Ops[] = {ZipPredLo, Ptr}; + LdN = Builder.CreateCall(LdNFunc, Ops, "ldN"); + + // Create second load + auto Ptr2 = Builder.CreateGEP(Ptr, Builder.getInt64(Factor/2)); + Value *Ops2[] = {ZipPredHi, Ptr2}; + LdN2 = Builder.CreateCall(LdNFunc, Ops2, "ldN"); + } else { + Value *Ops[] = {LowestGather->getOperand(2), Ptr}; + LdN = Builder.CreateCall(LdNFunc, Ops, "ldN"); + } + + // Replace uses of each shufflevector with the corresponding vector loaded + // by ldN. + for (unsigned i = 0; i < Gathers.size(); i++) { + if (!IsNativeStride && i >= Factor/2) + break; + + auto Gather1 = Gathers[i]; + auto Gather2 = IsNativeStride ? nullptr : Gathers[Factor/2+i]; + auto ActiveGather = Gather1 ? Gather1 : Gather2; + if (!ActiveGather) + continue; + + auto II = cast(ActiveGather); + assert((II->getOperand(2) == LowestGather->getOperand(2)) && + "Expected gathers to have the same predicate arguments"); + + Value *SubVec = Builder.CreateExtractValue(LdN, i); + if (EltTy->isPointerTy()) + SubVec = Builder.CreateIntToPtr(SubVec, II->getType()); + SubVec = Builder.CreateBitCast(SubVec,ActiveGather->getType()); + + if (IsNativeStride) + ActiveGather->replaceAllUsesWith(SubVec); + else { + auto SubVec2 = Builder.CreateExtractValue(LdN2, i); + if (EltTy->isPointerTy()) + SubVec2 = Builder.CreateIntToPtr(SubVec2, II->getType()); + SubVec2 = Builder.CreateBitCast(SubVec2,ActiveGather->getType()); + + // Do an unzip + Value *Even, *Odd; + std::tie(Even, Odd) = getUnzip(SubVec, SubVec2, Builder); + + // Replace uses + if (Gather1) + Gather1->replaceAllUsesWith(Even); + if (Gather2) + Gather2->replaceAllUsesWith(Odd); + } + } + + return true; +} + +bool AArch64TargetLowering::lowerScattersToInterleavedStore( + ArrayRef ValuesToStore, + Value *FirstScatterAddress, + IntrinsicInst *ReplaceNode, + unsigned Factor, + TargetTransformInfo *TTI) const { + assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() && + "Invalid interleave factor"); + + const DataLayout &DL = ReplaceNode->getModule()->getDataLayout(); + + VectorType *VecTy = dyn_cast(ValuesToStore[0]->getType()); + if (!VecTy) + return false; + // StN intrinsics don't support pointer vectors as arguments. + if (VecTy->isPtrOrPtrVectorTy()) + VecTy = cast(DL.getIntPtrType(VecTy)); + + // Test if factor can be efficiently vectorised with LD(2|3|4) + // instructions. + switch (Factor){ + case 2: + case 3: + case 4: + case 6: + case 8: + break; + default: + return false; + } + // This code assumes scalable types + if (!Subtarget->hasSVE() || !VecTy->isScalable()) + return false; + // We only support legal types for now + if (!TTI->isTypeLegal(VecTy)) + return false; + // We don't handle unpacked datatypes. SC-1277 will remove this limitation. + if ((VecTy->getNumElements() != 4) && VecTy->getElementType()->isFloatTy()) + return false; + if ((VecTy->getNumElements() != 8) && VecTy->getElementType()->isHalfTy()) + return false; + + bool IsNativeStride = (Factor <= 4); + + SmallVector, 6> Ops = { {}, {} }; + IRBuilder<> Builder(ReplaceNode); + // Push the values to be stored, casting to int if necessary + for (unsigned i = 0; i < Factor; ++i) { + if (!IsNativeStride && i >= Factor/2) + break; + + Value *StoreVal = ValuesToStore[i]; + // Casting is necessary when stored values are a mix of equally sized + // int, float and pointer types. + if (StoreVal->getType()->isPtrOrPtrVectorTy()) { + Type *IntTy = DL.getIntPtrType(StoreVal->getType()); + StoreVal = Builder.CreatePtrToInt(StoreVal, IntTy); + } + if (StoreVal->getType() != VecTy) + StoreVal = Builder.CreateBitCast(StoreVal, VecTy); + + if (!IsNativeStride) { + Value *StoreSecondVal = ValuesToStore[Factor/2+i]; + // Casting is necessary when stored values are a mix of equally sized + // int, float and pointer types. + if (StoreSecondVal->getType()->isPtrOrPtrVectorTy()) { + Type *IntTy = DL.getIntPtrType(StoreSecondVal->getType()); + StoreSecondVal = Builder.CreatePtrToInt(StoreSecondVal, IntTy); + } + if (StoreSecondVal->getType() != VecTy) + StoreSecondVal = Builder.CreateBitCast(StoreSecondVal, VecTy); + + // Create zips for vectors + Value *ZipLoVec, *ZipHiVec; + std::tie(ZipLoVec, ZipHiVec) = getZip(StoreVal, StoreSecondVal, Builder); + + Ops[0].push_back(ZipLoVec); + Ops[1].push_back(ZipHiVec); + } else + Ops[0].push_back(StoreVal); + } + + static const Intrinsic::ID StoreInts[3] = {Intrinsic::aarch64_sve_st2, + Intrinsic::aarch64_sve_st3, + Intrinsic::aarch64_sve_st4}; + + auto FactorIdx = IsNativeStride ? Factor : Factor/2; + Function *StNFunc = Intrinsic::getDeclaration(ReplaceNode->getModule(), + StoreInts[FactorIdx - 2], + {VecTy}); + + // Calculate the base address (this extract will be folded later) + auto FirstAddr = Builder.CreateExtractElement(FirstScatterAddress, + Builder.getInt32(0)); + auto PtrType = VecTy->getPointerTo(); + auto Ptr = Builder.CreateBitCast(FirstAddr, PtrType); + + // Push the predicate arg (must be the same for all scatters) + if (!IsNativeStride) { + Value *Pred = ReplaceNode->getOperand(3); + + // Create zip for the predicates, e.g. (1,1,1,0) + // -> (1,1,1,1),(1,1,0,0) for the two stores + Value *ZipPredLo, *ZipPredHi; + std::tie(ZipPredLo,ZipPredHi) = getZip(Pred, Pred, Builder); + Ops[0].push_back(ZipPredLo); + Ops[1].push_back(ZipPredHi); + + // Create first store + Ops[0].push_back(Ptr); + Builder.CreateCall(StNFunc, Ops[0]); + + // Create second store + auto Ptr2 = Builder.CreateGEP(Ptr, Builder.getInt64(Factor/2)); + Ops[1].push_back(Ptr2); + Builder.CreateCall(StNFunc, Ops[1]); + } else { + Ops[0].push_back(ReplaceNode->getOperand(3)); + Ops[0].push_back(Ptr); + Builder.CreateCall(StNFunc, Ops[0]); + } + + return true; +} + +static SDValue tryCombineVecOrNot(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + SelectionDAG &DAG = DCI.DAG; + SDLoc Dl(N); + EVT VT = N->getValueType(0); + + assert(N->getOpcode() == ISD::OR && "Unexpected root"); + + if (!VT.isVector() || VT.getVectorElementType() != MVT::i1) + return SDValue(); + // Try to combine or (not a), a => ptrue. + // Represented as: or (xor a, splat(true)), a + auto LHS = N->getOperand(0); + auto RHS = N->getOperand(1); + + if (LHS->getOpcode() != ISD::XOR) + std::swap(LHS, RHS); // Try swapping operands. + if (LHS->getOpcode() != ISD::XOR) + return SDValue(); + + auto Splat = LHS->getOperand(1); + if (Splat->getOpcode() != ISD::SPLAT_VECTOR) + return SDValue(); + if (LHS->getOperand(0) != RHS) + return SDValue(); + auto SplatVal = dyn_cast(Splat->getOperand(0)); + if (SplatVal && (SplatVal->getZExtValue() == 1)) // && + return getPTrue(DAG, Dl, VT, AArch64SVEPredPattern::all); + return SDValue(); +} + +/// Checks if a given node is a first faulting load. +static bool isLoadFF(const SDNode *N) { + switch (N->getOpcode()) { + case AArch64ISD::LDFF1: + case AArch64ISD::LDFF1S: + return true; + } + + return false; +} + +static bool isConstantSplatVectorMaskForType(SDNode *N, EVT MemVT) { + if (!MemVT.getVectorElementType().isSimple()) + return false; + + uint64_t MaskForTy = 0ull; + switch(MemVT.getVectorElementType().getSimpleVT().SimpleTy) { + case MVT::i8: + MaskForTy = 0xffull; + break; + case MVT::i16: + MaskForTy = 0xffffull; + break; + case MVT::i32: + MaskForTy = 0xffffffffull; + break; + default: + return false; + break; + } + + if (N->getOpcode() == AArch64ISD::DUP || + N->getOpcode() == ISD::SPLAT_VECTOR) + if (auto *Op0 = dyn_cast(N->getOperand(0))) + return Op0->getAPIntValue().getLimitedValue() == MaskForTy; + + return false; +} + +/// Checks whether the node \param S is the end node of a loadff chain that +/// starts with a SetFFR instruction. If so, it sets the value of \param SetFFR +/// with that setffr node. The parameter \param Pred is added to make sure that +/// changes in the lowering of llvm.load.ff that would allow the generation of +/// chains in the form (1) would not be recognised as valid and be converted +/// into teh chain in (2), which is clearly not equivalent to (1). +/// +/// (1) S->ld(p1)->ld(p2)->R->S->ld(p2)->R +/// (2) S->ld(p1)->ld(p2)->ld(p2)->R +static bool chainIsValid(SDValue S, SDValue Pred, SDValue &SetFFR) { + + if (isLoadFF(S.getNode()) && S.getOperand(1) == Pred) + return chainIsValid(S.getOperand(0), Pred, SetFFR); + + if (S.getOpcode() == AArch64ISD::SETFFR) { + SetFFR = S; + return true; + } + + return false; +} + + +static SDValue performAndCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const AArch64Subtarget *Subtarget) { + SDLoc DL(N); + SelectionDAG &DAG = DCI.DAG; + assert(N->getOpcode() == ISD::AND && "Unexpected root"); + + EVT VT = N->getValueType(0); + if (VT.isVector()) { + if (!DCI.isBeforeLegalizeOps() && + (N->getOperand(0)->getOpcode() == ISD::SERIES_VECTOR) && + (VT.getVectorElementType() == MVT::i64)) { + + // Try to remove pointless scalar extends of INDEX operands. + // + // zext_inreg_w(index_d(sext(A), sext(B)) + // ==> zext_inreg_w(index_d(aext(A), aext(B)) + + auto Data = N->getOperand(0); + auto Mask = N->getOperand(1); + + // NOTE: zext_inreg_w is modelled as AND A, 0xffffffff + bool MaskIsZeroExtend = false; + if (Mask->getOpcode() == ISD::SPLAT_VECTOR || + Mask->getOpcode() == AArch64ISD::DUP) { + auto SplatVal = dyn_cast(Mask->getOperand(0)); + if (SplatVal && SplatVal->getZExtValue() == 0xfffffffful) + MaskIsZeroExtend = true; + } + + auto DStart = Data->getOperand(0); + bool DStartExtended = (DStart.getOpcode() == ISD::SIGN_EXTEND) && + (DStart.getOperand(0).getValueType() == MVT::i32); + + auto DStep = Data->getOperand(1); + bool DStepExtended = (DStep.getOpcode() == ISD::SIGN_EXTEND) && + (DStep.getOperand(0).getValueType() == MVT::i32); + + if (MaskIsZeroExtend && (DStartExtended || DStepExtended)) { + if (DStartExtended) + DStart = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, + DStart.getOperand(0)); + if (DStepExtended) + DStep = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, + DStep.getOperand(0)); + + SDValue SV = DAG.getNode(ISD::SERIES_VECTOR, DL, VT, DStart, DStep); + return DAG.getNode(ISD::AND, DL, VT, SV, Mask); + } + } + + if (!DCI.isBeforeLegalizeOps()) { + SDValue Data = N->getOperand(0); + SDValue Mask = N->getOperand(1); + unsigned Opc = Data->getOpcode(); + + // This instruction performs an implicit zero-extend. + if ((Opc == AArch64ISD::LDFF1) || (Opc == AArch64ISD::LDNF1)) { + EVT MemVT = cast(Data->getOperand(3))->getVT(); + + if (isConstantSplatVectorMaskForType(Mask.getNode(), MemVT)) + return Data; + } + + // These instructions perform an implicit zero-extend. + if ((Opc == AArch64ISD::GLDNT1) || + (Opc == AArch64ISD::GLDNT1_UXTW) || + (Opc == AArch64ISD::GLDFF1) || + (Opc == AArch64ISD::GLDFF1_SCALED) || + (Opc == AArch64ISD::GLDFF1_SXTW) || + (Opc == AArch64ISD::GLDFF1_SXTW_SCALED) || + (Opc == AArch64ISD::GLDFF1_UXTW) || + (Opc == AArch64ISD::GLDFF1_UXTW_SCALED)) { + EVT MemVT = cast(Data->getOperand(4))->getVT(); + + if (isConstantSplatVectorMaskForType(Mask.getNode(), MemVT)) + return Data; + } + } + + // Check for AND of the FFR of two loadff that can be chained + if (VT.isScalableVector()) { + auto FFR_L = N->getOperand(0); + if (FFR_L.getOpcode() != AArch64ISD::RDFFR_PRED) + return SDValue(); + + auto FFR_R = N->getOperand(1); + if (FFR_R.getOpcode() != AArch64ISD::RDFFR_PRED) + return SDValue(); + + if (!FFR_L.hasOneUse() || !FFR_R.hasOneUse()) + return SDValue(); + + SDValue LoadVal_L = FFR_L.getOperand(0); + SDValue LoadVal_R = FFR_R.getOperand(0); + auto Load_L = LoadVal_L.getNode(); + auto Load_R = LoadVal_R.getNode(); + // We can accept only LoadFF nodes, + if (!isLoadFF(Load_L) || !isLoadFF(Load_R)) + return SDValue(); + + // Check if the chains are valid chains made of a setffr followed by some + // loadff (and nothing else). + SDValue SetFFR_L, SetFFR_R; + if (!chainIsValid(LoadVal_L.getOperand(0), + LoadVal_L.getOperand(1), SetFFR_L) || + !chainIsValid(LoadVal_R.getOperand(0), + LoadVal_R.getOperand(1), SetFFR_R)) + return SDValue(); + + // Now that each of the RHS and LHS chain are valid, we have to make sure + // that the SETFFR happen simultaneously, i.e. that they are parallel. + if (SetFFR_L.getOperand(0) != SetFFR_R.getOperand(0)) + return SDValue(); + + bool isLoad_L_SetFFR = + Load_L->getOperand(0).getOpcode() == AArch64ISD::SETFFR; + bool isLoad_R_SetFFR = + Load_R->getOperand(0).getOpcode() == AArch64ISD::SETFFR; + + // At least one of the two nodes must be preceded by a setffr, because we + // can chain only a single load and not a chain of loads. + if (!isLoad_L_SetFFR && !isLoad_R_SetFFR) + return SDValue(); + + // Swap the nodes if the one marked as the Left one is not the + // one with the SETFFR + if (!isLoad_L_SetFFR) { + std::swap(Load_L, Load_R); + std::swap(FFR_L, FFR_R); + } + + // Both loads must have the same predicate + if (Load_L->getOperand(1) != Load_R->getOperand(1)) + return SDValue(); + + //create the replacement of the left hand side load + SDValue Chain = SDValue(Load_R, 1); + SDValue LoadOps[] = { Chain, // Chain in after the RHS load + Load_L->getOperand(1), // GP + Load_L->getOperand(2), // Address + Load_L->getOperand(3) }; // MemVT + SDVTList LoadVTs = DAG.getVTList(Load_L->getValueType(0), MVT::Other); + SDValue Load = DAG.getNode(Load_L->getOpcode(), DL, LoadVTs, LoadOps); + SDValue LoadChain = SDValue(Load.getNode(), 1); + SDValue PredOps[] = { LoadChain, FFR_R->getOperand(1) }; // Chain in, GP + SDVTList PredVTs = DAG.getVTList(FFR_R->getValueType(0), MVT::Other); + SDValue FaultPred = DAG.getNode(AArch64ISD::RDFFR_PRED, DL, PredVTs, + PredOps); + + // Replace the use of the original values of Load_L and of FFR_L + // and FFR_R with the new ones created before, Load and + // FaultPred + DAG.ReplaceAllUsesOfValueWith(SDValue(Load_L, 0), + SDValue(Load.getNode(), 0)); + DAG.ReplaceAllUsesOfValueWith(SDValue(Load_L, 1), + SDValue(Load.getNode(), 1)); + DAG.ReplaceAllUsesOfValueWith(SDValue(FFR_L.getNode(), 1), + SDValue(FaultPred.getNode(), 1)); + DAG.ReplaceAllUsesOfValueWith(SDValue(FFR_R.getNode(), 1), + SDValue(FaultPred.getNode(), 1)); + + return FaultPred; + } + + return SDValue(); + } + + return SDValue(); +} + +static SDValue tryConvertSVEWideCompare(SDNode *N, unsigned ReplacementIID, + bool Invert, SelectionDAG &DAG) { + SDValue Comparator = N->getOperand(3); + if (Comparator.getOpcode() == AArch64ISD::DUP || + Comparator.getOpcode() == ISD::SPLAT_VECTOR) { + unsigned IID = getIntrinsicID(N); + EVT VT = N->getValueType(0); + EVT CmpVT = N->getOperand(2).getValueType(); + SDValue Pred = N->getOperand(1); + SDLoc DL(N); + + switch (IID) { + default: + llvm_unreachable("Called with wrong intrinsic!"); + break; + + // Signed comparisons + case Intrinsic::aarch64_sve_cmpeq_wide: + case Intrinsic::aarch64_sve_cmpne_wide: + case Intrinsic::aarch64_sve_cmpge_wide: + case Intrinsic::aarch64_sve_cmpgt_wide: + case Intrinsic::aarch64_sve_cmplt_wide: + case Intrinsic::aarch64_sve_cmple_wide: { + if (auto *CN = dyn_cast(Comparator.getOperand(0))) { + int64_t ImmVal = CN->getSExtValue(); + + if (ImmVal >= -16 && ImmVal <= 15) { + SDValue Imm = DAG.getConstant(ImmVal, DL, MVT::i32); + SDValue Splat = DAG.getNode(AArch64ISD::DUP, DL, CmpVT, Imm); + SDValue ID = DAG.getTargetConstant(ReplacementIID, DL, MVT::i64); + SDValue Op0, Op1; + if (Invert) { + Op0 = Splat; + Op1 = N->getOperand(2); + } else { + Op0 = N->getOperand(2); + Op1 = Splat; + } + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + ID, Pred, Op0, Op1); + } + } + break; + } + // Unsigned comparisons + case Intrinsic::aarch64_sve_cmphs_wide: + case Intrinsic::aarch64_sve_cmphi_wide: + case Intrinsic::aarch64_sve_cmplo_wide: + case Intrinsic::aarch64_sve_cmpls_wide: { + if (auto *CN = dyn_cast(Comparator.getOperand(0))) { + uint64_t ImmVal = CN->getZExtValue(); + + if (ImmVal <= 127) { + SDValue Imm = DAG.getConstant(ImmVal, DL, MVT::i32); + SDValue Splat = DAG.getNode(AArch64ISD::DUP, DL, CmpVT, Imm); + SDValue ID = DAG.getTargetConstant(ReplacementIID, DL, MVT::i64); + SDValue Op0, Op1; + if (Invert) { + Op0 = Splat; + Op1 = N->getOperand(2); + } else { + Op0 = N->getOperand(2); + Op1 = Splat; + } + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + ID, Pred, Op0, Op1); + } + } + break; + } + } + } + + return SDValue(); +} + +static SDValue LowerSVEIntrinsicDUP(SDNode *N, SelectionDAG &DAG) { + SDLoc dl(N); + SDValue Scalar = N->getOperand(3); + EVT ScalarTy = Scalar.getValueType(); + + if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16)) + Scalar = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Scalar); + + return DAG.getNode(AArch64ISD::DUP_PRED, dl, N->getValueType(0), + N->getOperand(1), N->getOperand(2), Scalar); +} + +static SDValue LowerSVEIntrinsicEXT(SDNode *N, SelectionDAG &DAG) { + SDLoc dl(N); + LLVMContext &Ctx = *DAG.getContext(); + EVT VT = N->getValueType(0); + + assert(VT.isScalableVector() && "Expected a scalable vector."); + + const unsigned M = VT.getVectorNumElements(); + const unsigned WideLaneBits = 128 / M; + assert(128 % M == 0 && "Invalid number of lanes."); + + SDValue Op0 = N->getOperand(1); + SDValue Op1 = N->getOperand(2); + SDValue Op2 = N->getOperand(3); + + // Convert any FP Type to integer. + const EVT MxIntVT = VT.changeVectorElementTypeToInteger(); + Op0 = DAG.getNode(ISD::BITCAST, dl, MxIntVT, Op0); + Op1 = DAG.getNode(ISD::BITCAST, dl, MxIntVT, Op1); + + // Widen to a packed vector if needed. + const EVT MxWideIntVT = + EVT::getVectorVT(Ctx, EVT::getIntegerVT(Ctx, WideLaneBits), {M, true}); + Op0 = DAG.getNode(ISD::ANY_EXTEND, dl, MxWideIntVT, Op0); + Op1 = DAG.getNode(ISD::ANY_EXTEND, dl, MxWideIntVT, Op1); + + // Generate EXT on packed 8-bit lane vector. + const EVT ExtVT = EVT::getVectorVT(Ctx, MVT::i8, {16, true}); + Op0 = DAG.getNode(ISD::BITCAST, dl, ExtVT, Op0); + Op1 = DAG.getNode(ISD::BITCAST, dl, ExtVT, Op1); + Op2 = DAG.getNode(ISD::MUL, dl, MVT::i32, Op2, + DAG.getConstant(WideLaneBits / 8, dl, MVT::i32)); + SDValue Ext = DAG.getNode(AArch64ISD::EXT, dl, ExtVT, Op0, Op1, Op2); + + // Narrow to the unpacked vector if needed. + Ext = DAG.getNode(ISD::BITCAST, dl, MxWideIntVT, Ext); + Ext = DAG.getNode(ISD::TRUNCATE, dl, MxIntVT, Ext); + + // Bitcast to the original type. + return DAG.getNode(ISD::BITCAST, dl, VT, Ext); +} + +static bool isSplatConstVector(SDValue Vec, unsigned Val) { + if (Vec.getOpcode() != ISD::SPLAT_VECTOR) + return false; + ConstantSDNode *COp; + if (!(COp = dyn_cast(Vec.getOperand(0)))) + return false; + if (COp->getZExtValue() != Val) + return false; + return true; +} + +static SDValue tryLowerPredTestReduction(SDNode *N, unsigned Opc, + SelectionDAG &DAG) { + SDValue Pg = N->getOperand(1); + SDValue Op = N->getOperand(2); + EVT OpVT = Op.getValueType(); + + if (OpVT.getVectorElementType() != MVT::i1) + return SDValue(); + + if (!isSplatConstVector(Pg, 1)) + return SDValue(); + + AArch64CC::CondCode Cond; + switch (Opc) { + default: + return SDValue(); + case AArch64ISD::ORV_PRED: + Cond = AArch64CC::ANY_ACTIVE; + break; + case AArch64ISD::ANDV_PRED: + // ANDV(X) = !ORV(!X) + Cond = AArch64CC::NONE_ACTIVE; + Op = DAG.getNode(ISD::XOR, SDLoc(N), OpVT, Op, Pg); + break; + } + + return getPTest(DAG, N->getValueType(0), Pg, Op, Cond); +} + +static SDValue LowerSVEIntReduction(SDNode *N, unsigned Opc, + SelectionDAG &DAG) { + SDLoc dl(N); + LLVMContext &Ctx = *DAG.getContext(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + EVT VT = N->getValueType(0); + SDValue IID = N->getOperand(0); + SDValue Pred = N->getOperand(1); + SDValue Data = N->getOperand(2); + + EVT DataVT = Data.getValueType(); + + // Bitwise OR/AND reductions of i1s can be expressed using a PTEST. + if (SDValue Res = tryLowerPredTestReduction(N, Opc, DAG)) + return Res; + + // If we still have failed i1 reduction->ptest, then promote the result + // and operands for legalization to handle later. + if (DataVT.getVectorElementType() == MVT::i1) { + EVT NewDataVT = + getNaturalIntSVETypeWithMatchingElementCount(DataVT); + auto Extend = DAG.getNode(ISD::SIGN_EXTEND, dl, NewDataVT, Data); + auto NewOp = + DAG.getNode(N->getOpcode(), dl, NewDataVT.getVectorElementType(), IID, + Pred, Extend); + return DAG.getNode(ISD::TRUNCATE, dl, VT, NewOp); + } + + if (DataVT.getSizeInBits() < AArch64::SVEBitsPerBlock) { + // The following does no real work but will allow instruction selection. + + // Promote the element type. + EVT VT1 = getNaturalIntSVETypeWithMatchingElementCount(DataVT); + Data = DAG.getNode(ISD::ANY_EXTEND, dl, VT1, Data); + + // Cast back to the original element type. + EVT VT2 = getNaturalIntSVETypeWithMatchingElementType(DataVT); + Data = DAG.getNode(ISD::BITCAST, dl, VT2, Data); + + // Cast predicate to match the original element type. + EVT VT3 = getNaturalPredSVETypeWithMatchingElementType(DataVT); + Pred = DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, VT3, Pred); + + DataVT = Data.getValueType(); + } + + if (!TLI.isTypeLegal(DataVT)) + return SDValue(); + + EVT ReduceVT = EVT::getVectorVT(Ctx, VT, 128 / VT.getSizeInBits()); + SDValue Reduce = DAG.getNode(Opc, dl, ReduceVT, Pred, Data); + + SDValue Zero = DAG.getConstant(0, dl, MVT::i64); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Reduce, Zero); +} + +static SDValue performExtendInRegCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + unsigned Opc = N0->getOpcode(); + + if ((Opc == AArch64ISD::LDFF1) || (Opc == AArch64ISD::LDNF1)) { + EVT SrcVT = cast(N1)->getVT(); + EVT MemVT = cast(N0->getOperand(3))->getVT(); + + if ((SrcVT == MemVT) && N0.hasOneUse()) { + SDVTList VTs = DAG.getVTList(VT, MVT::Other); + SDValue Ops[] = { N0->getOperand(0), + N0->getOperand(1), + N0->getOperand(2), + N0->getOperand(3) }; + + unsigned SignedOpc = (Opc == AArch64ISD::LDFF1) ? AArch64ISD::LDFF1S + : AArch64ISD::LDNF1S; + SDValue ExtLoad = DAG.getNode(SignedOpc, DL, VTs, Ops); + DCI.CombineTo(N, ExtLoad); + DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } + } + + if ((Opc == AArch64ISD::GLDNT1) || + (Opc == AArch64ISD::GLDNT1_UXTW) || + (Opc == AArch64ISD::GLDFF1) || + (Opc == AArch64ISD::GLDFF1_SCALED) || + (Opc == AArch64ISD::GLDFF1_SXTW) || + (Opc == AArch64ISD::GLDFF1_SXTW_SCALED) || + (Opc == AArch64ISD::GLDFF1_UXTW) || + (Opc == AArch64ISD::GLDFF1_UXTW_SCALED)) { + EVT SrcVT = cast(N1)->getVT(); + EVT MemVT = cast(N0->getOperand(4))->getVT(); + + if ((SrcVT == MemVT) && N0.hasOneUse()) { + SDVTList VTs = DAG.getVTList(VT, MVT::Other); + SDValue Ops[] = { N0->getOperand(0), + N0->getOperand(1), + N0->getOperand(2), + N0->getOperand(3), + N0->getOperand(4) }; + + unsigned Opc; + switch (N0->getOpcode()) { + case AArch64ISD::GLDNT1: + Opc = AArch64ISD::GLDNT1S; + break; + case AArch64ISD::GLDNT1_UXTW: + Opc = AArch64ISD::GLDNT1S_UXTW; + break; + case AArch64ISD::GLDFF1: + Opc = AArch64ISD::GLDFF1S; + break; + case AArch64ISD::GLDFF1_SCALED: + Opc = AArch64ISD::GLDFF1S_SCALED; + break; + case AArch64ISD::GLDFF1_SXTW: + Opc = AArch64ISD::GLDFF1S_SXTW; + break; + case AArch64ISD::GLDFF1_SXTW_SCALED: + Opc = AArch64ISD::GLDFF1S_SXTW_SCALED; + break; + case AArch64ISD::GLDFF1_UXTW: + Opc = AArch64ISD::GLDFF1S_UXTW; + break; + case AArch64ISD::GLDFF1_UXTW_SCALED: + Opc = AArch64ISD::GLDFF1S_UXTW_SCALED; + break; + } + + SDValue ExtLoad = DAG.getNode(Opc, DL, VTs, Ops); + DCI.CombineTo(N, ExtLoad); + DCI.CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } + } + + return SDValue(); +} + +static SDValue performTRUNCATECombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + + EVT VT = N->getValueType(0); + SDValue N0 = N->getOperand(0); + + // Vector shuffles of predicates are first expanded into data registers where + // TBL can be used to perform the shuffle. Sometimes a TBL is not required, + // with a dedicated instruction (REV for example) used instead. Here we check + // for a predicate equivalent thus removing the need to use data registers. + if (VT.isScalableVector() && (VT.getVectorElementType() == MVT::i1)) { + unsigned Opc = N0->getOpcode(); + + if ((Opc == AArch64ISD::REV) || + (Opc == AArch64ISD::UZP1) || + (Opc == AArch64ISD::UZP2)) { + SmallVector Ops; + + // Check all operands are predicate extensions. + for (const SDValue &Ext : N0->ops()) { + unsigned ExtOpc = Ext.getOpcode(); + if (((ExtOpc == ISD::ANY_EXTEND) || + (ExtOpc == ISD::SIGN_EXTEND) || + (ExtOpc == ISD::ZERO_EXTEND)) && + (Ext.getOperand(0).getValueType() == VT)) + Ops.push_back(Ext.getOperand(0)); + } + + if (Ops.size() == N0->getNumOperands()) + return DAG.getNode(Opc, SDLoc(N), VT, Ops); + } + } + + return SDValue(); +} + +static SDValue performLD1RQCombine(SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + + EVT LoadVT = VT; + if (VT.isFloatingPoint()) + LoadVT = VT.changeTypeToInteger(); + + SDValue Ops[] = { N->getOperand(0), N->getOperand(2), N->getOperand(3) }; + SDValue L = DAG.getNode(AArch64ISD::LD1RQ, DL, { LoadVT, MVT::Other }, Ops); + + if (VT.isFloatingPoint()) { + SDValue Ops[] = { DAG.getNode(ISD::BITCAST, DL, VT, L), L.getValue(1) }; + return DAG.getMergeValues(Ops, DL); + } + + return L; +} + +static SDValue performLDNT1Combine(SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + + EVT LoadVT = VT; + if (VT.isFloatingPoint()) + LoadVT = VT.changeTypeToInteger(); + + // NOTE: It's important we match ISD::MLOAD's operand order. + SDValue Ops[] = { N->getOperand(0), N->getOperand(3), N->getOperand(2), + DAG.getUNDEF(LoadVT) }; + SDValue L = DAG.getNode(AArch64ISD::LDNT1, DL, { LoadVT, MVT::Other }, Ops); + + if (VT.isFloatingPoint()) { + SDValue Ops[] = { DAG.getNode(ISD::BITCAST, DL, VT, L), L.getValue(1) }; + return DAG.getMergeValues(Ops, DL); + } + + return L; +} + +static SDValue performSTNT1Combine(SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + + SDValue Data = N->getOperand(2); + EVT DataVT = Data.getValueType(); + + if (DataVT.isFloatingPoint()) + Data = DAG.getNode(ISD::BITCAST, DL, DataVT.changeTypeToInteger(), Data); + + // NOTE: It's important we match ISD::MSTORE's operand order. + SDValue Ops[] = { N->getOperand(0), N->getOperand(4), N->getOperand(3), + Data }; + return DAG.getNode(AArch64ISD::STNT1, DL, N->getValueType(0), Ops); +} + +static SDValue performSTNT1ScatterCombine(SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + SDValue Data = N->getOperand(2); + EVT DataVT = Data.getValueType(); + if (DataVT.getSizeInBits() > AArch64::SVEBitsPerBlock) + return SDValue(); + const SDValue Offsets = N->getOperand(5); + SDValue Base = N->getOperand(4); + SDValue Vector = Offsets; // 64bit Offsets + unsigned Opcode = AArch64ISD::SSTNT1; + + // Is a better addressing mode available? + if (Offsets.getValueType() == MVT::nxv4i64) { + if (Offsets.getOpcode() == ISD::ZERO_EXTEND) { + Vector = Offsets.getOperand(0); // Unsigned 32bit Offsets + Opcode = AArch64ISD::SSTNT1_UXTW; + if (Vector.getValueType() != MVT::nxv4i32) + Vector = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::nxv4i32, Vector); + } + } + + if (!DAG.getTargetLoweringInfo().isTypeLegal(Base.getValueType()) || + !DAG.getTargetLoweringInfo().isTypeLegal(Vector.getValueType())) + return SDValue(); + + if (DataVT.isFloatingPoint()) { + Data = DAG.getNode(ISD::BITCAST, DL, DataVT.changeTypeToInteger(), Data); + DataVT = Data.getValueType(); + } + + EVT ContainerVT; + if (DataVT.isInteger()) { + switch (DataVT.getVectorNumElements()) { + default: return SDValue(); + case 4: ContainerVT = MVT::nxv4i32; break; + case 2: ContainerVT = MVT::nxv2i64; break; + } + } else + return SDValue(); + + if (DataVT != ContainerVT) + Data = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Data); + + //order is chain, data , predictae, scalar base, vec offset, data VT + SDValue Ops[] = { N->getOperand(0), Data, N->getOperand(3), Base, + Vector, DAG.getValueType(DataVT) }; + return DAG.getNode(Opcode, DL, N->getValueType(0), Ops); +} + +static SDValue performLDFF1Combine(SDNode *N, SelectionDAG &DAG, bool isNF) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + + if (VT.getSizeInBits() > AArch64::SVEBitsPerBlock) + return SDValue(); + + EVT ContainerVT = VT; + if (ContainerVT.isInteger()) { + switch (VT.getVectorNumElements()) { + default: return SDValue(); + case 16: ContainerVT = MVT::nxv16i8; break; + case 8: ContainerVT = MVT::nxv8i16; break; + case 4: ContainerVT = MVT::nxv4i32; break; + case 2: ContainerVT = MVT::nxv2i64; break; + } + } + + SDVTList VTs = DAG.getVTList(ContainerVT, MVT::Other); + SDValue Ops[] = { N->getOperand(0), // Chain + N->getOperand(2), // Pg + N->getOperand(3), // Base + DAG.getValueType(VT) }; + + unsigned Opc = isNF ? AArch64ISD::LDNF1 : AArch64ISD::LDFF1; + SDValue Load = DAG.getNode(Opc, DL, VTs, Ops); + SDValue LoadChain = SDValue(Load.getNode(), 1); + + if (ContainerVT.isInteger() && (VT != ContainerVT)) + Load = DAG.getNode(ISD::TRUNCATE, DL, VT, Load.getValue(0)); + + return DAG.getMergeValues({ Load, LoadChain }, DL); +} + +static SDValue performLDFF1GatherCombine(SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + if (VT.getSizeInBits() > AArch64::SVEBitsPerBlock) + return SDValue(); + + const SDValue Base = N->getOperand(3); + const SDValue Offsets = N->getOperand(4); + unsigned EltBits = VT.getVectorElementType().getSizeInBits(); + + // These are to be changed based on available addressing modes. + unsigned Opcode = AArch64ISD::GLDFF1; + SDValue Scalar = Base; + SDValue Vector = Offsets; // 64bit Offsets + APInt ShiftAmount; + + // Is a better addressing mode available? + if (Offsets.getValueType() == MVT::nxv4i64) { + if (Offsets.getOpcode() == ISD::SIGN_EXTEND) { + Opcode = AArch64ISD::GLDFF1_SXTW; + Vector = Offsets.getOperand(0); // Signed 32bit Offsets + } else if (Offsets.getOpcode() == ISD::ZERO_EXTEND) { + Opcode = AArch64ISD::GLDFF1_UXTW; + Vector = Offsets.getOperand(0); // Unsigned 32bit Offsets + } else if ((Offsets.getOpcode() == ISD::SHL) && + DAG.isConstantIntSplat(Offsets.getOperand(1), &ShiftAmount) && + ((8u << ShiftAmount.getZExtValue()) == EltBits)) { + if (Offsets.getOperand(0).getOpcode() == ISD::SIGN_EXTEND) { + Opcode = AArch64ISD::GLDFF1_SXTW_SCALED; + Vector = Offsets.getOperand(0).getOperand(0); // Signed 32bit Indices + } else if (Offsets.getOperand(0).getOpcode() == ISD::ZERO_EXTEND) { + Opcode = AArch64ISD::GLDFF1_UXTW_SCALED; + Vector = Offsets.getOperand(0).getOperand(0); // Unsigned 32bit Indices + } + } + + // We must reapply extensions for offsets that are now smaller than i32. + if (Vector.getValueType() != MVT::nxv4i32) { + switch (Opcode) { + default: + break; + case AArch64ISD::GLDFF1_SXTW: + case AArch64ISD::GLDFF1_SXTW_SCALED: + Vector = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::nxv4i32, Vector); + break; + case AArch64ISD::GLDFF1_UXTW: + case AArch64ISD::GLDFF1_UXTW_SCALED: + Vector = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::nxv4i32, Vector); + break; + } + } + } else if (Offsets.getValueType() == MVT::nxv2i64) { + if ((Offsets.getOpcode() == ISD::SHL) && + DAG.isConstantIntSplat(Offsets.getOperand(1), &ShiftAmount) && + ((8u << ShiftAmount.getZExtValue()) == EltBits)) { + Opcode = AArch64ISD::GLDFF1_SCALED; + Vector = Offsets.getOperand(0); // 64bit Indices + } + } + + if (!DAG.getTargetLoweringInfo().isTypeLegal(Scalar.getValueType()) || + !DAG.getTargetLoweringInfo().isTypeLegal(Vector.getValueType())) + return SDValue(); + + EVT ContainerVT = VT; + if (ContainerVT.isInteger()) { + switch (VT.getVectorNumElements()) { + default: return SDValue(); + case 16: ContainerVT = MVT::nxv16i8; break; + case 8: ContainerVT = MVT::nxv8i16; break; + case 4: ContainerVT = MVT::nxv4i32; break; + case 2: ContainerVT = MVT::nxv2i64; break; + } + } + + SDVTList VTs = DAG.getVTList(ContainerVT, MVT::Other); + SDValue Ops[] = { N->getOperand(0), // Chain + N->getOperand(2), // Pg + Scalar, + Vector, + DAG.getValueType(VT) }; + + SDValue Load = DAG.getNode(Opcode, DL, VTs, Ops); + SDValue LoadChain = SDValue(Load.getNode(), 1); + + if (ContainerVT.isInteger() && (VT != ContainerVT)) + Load = DAG.getNode(ISD::TRUNCATE, DL, VT, Load.getValue(0)); + + return DAG.getMergeValues({ Load, LoadChain }, DL); +} + +static SDValue performLDNTGatherCombine(SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + if (VT.getSizeInBits() > AArch64::SVEBitsPerBlock) + return SDValue(); + + const SDValue Base = N->getOperand(3); + const SDValue Offsets = N->getOperand(4); + + // These are to be changed based on available addressing modes. + unsigned Opcode = AArch64ISD::GLDNT1; + SDValue Scalar = Base; + SDValue Vector = Offsets; // 64bit Offsets + + // Is a better addressing mode available? + if (Offsets.getValueType() == MVT::nxv4i64) { + if (Offsets.getOpcode() == ISD::ZERO_EXTEND) { + Vector = Offsets.getOperand(0); // Unsigned 32bit Offsets + Opcode = AArch64ISD::GLDNT1_UXTW; + if (Vector.getValueType() != MVT::nxv4i32) + Vector = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::nxv4i32, Vector); + } + } + + if (!DAG.getTargetLoweringInfo().isTypeLegal(Scalar.getValueType()) || + !DAG.getTargetLoweringInfo().isTypeLegal(Vector.getValueType())) + return SDValue(); + + EVT ContainerVT = VT; + if (ContainerVT.isInteger()) { + switch (VT.getVectorNumElements()) { + default: return SDValue(); + case 16: ContainerVT = MVT::nxv16i8; break; + case 8: ContainerVT = MVT::nxv8i16; break; + case 4: ContainerVT = MVT::nxv4i32; break; + case 2: ContainerVT = MVT::nxv2i64; break; + } + } + + SDVTList VTs = DAG.getVTList(ContainerVT, MVT::Other); + SDValue Ops[] = { N->getOperand(0), // Chain + N->getOperand(2), // Pg + Scalar, + Vector, + DAG.getValueType(VT) }; + + SDValue Load = DAG.getNode(Opcode, DL, VTs, Ops); + SDValue LoadChain = SDValue(Load.getNode(), 1); + + if (ContainerVT.isInteger() && (VT != ContainerVT)) + Load = DAG.getNode(ISD::TRUNCATE, DL, VT, Load.getValue(0)); + + return DAG.getMergeValues({ Load, LoadChain }, DL); +} + +static SDValue performPrefetchGatherCombine(SDNode *N, SelectionDAG &DAG, + EVT VT) { + assert(VT.getSizeInBits() == AArch64::SVEBitsPerBlock && "Unexpected VT"); + SDLoc DL(N); + + const SDValue Base = N->getOperand(3); + const SDValue Offsets = N->getOperand(4); + const SDValue PrfOp = N->getOperand(5); + const EVT VectorTy = Offsets.getValueType(); + + // These are to be changed based on available addressing modes. + unsigned Opcode = 0; + SDValue Scalar = Base; + SDValue Vector = Offsets; + EVT EltTy = VT.getVectorElementType(); + const unsigned EltBits = EltTy.getSizeInBits(); + + // Is there a constant base that would fit in the ZI form? (0-31, scaled by + // number of bytes) + SDValue SmallConstantBase = SDValue(); + if (auto *CBaseSDNode = dyn_cast(Base)) { + int64_t CBaseOffset = CBaseSDNode->getSExtValue(); + int64_t EltBytes = EltBits / 8; + int64_t Index = CBaseOffset / EltBytes; + int64_t Rem = CBaseOffset % EltBytes; + if ((Rem == 0) && (Index >= 0) && (Index < 32)) + SmallConstantBase = Base; + } + + // Is this a scaled vector of indices? + SDValue ScaledVector = SDValue(); + APInt ShiftAmount; + if ((Offsets.getOpcode() == ISD::SHL) && + DAG.isConstantIntSplat(Offsets.getOperand(1), &ShiftAmount) && + ((8u << ShiftAmount.getZExtValue()) == EltBits)) + ScaledVector = Offsets.getOperand(0); + else if (EltBits == 8) + // Scaled and unscaled are identical for byte addressing + ScaledVector = Offsets; + + if (ScaledVector) { + if (VectorTy == MVT::nxv4i64) { + if (ScaledVector.getOpcode() == ISD::SIGN_EXTEND) { + // Scaled 32-bit signed offsets + Opcode = AArch64ISD::GPRF_S_SXTW_SCALED; + Vector = ScaledVector.getOperand(0); + } else if (ScaledVector.getOpcode() == ISD::ZERO_EXTEND) { + Vector = ScaledVector.getOperand(0); + if ((EltBits == 8) && SmallConstantBase) { + // We can use the unscaled ZI form for byte addresses + Opcode = AArch64ISD::GPRF_S_IMM; + Scalar = SmallConstantBase; + } else { + // Scaled 32-bit unsigned offsets + Opcode = AArch64ISD::GPRF_S_UXTW_SCALED; + } + } + } else { // VectorTy == MVT::nxv2i64 + if ((ScaledVector.getOpcode() == ISD::SIGN_EXTEND) && + (ScaledVector.getOperand(0).getOpcode() == ISD::TRUNCATE)) { + // scaled unpacked 32-bit signed offsets + Opcode = AArch64ISD::GPRF_D_SXTW_SCALED; + Vector = ScaledVector.getOperand(0).getOperand(0); + } else if ((ScaledVector.getOpcode() == ISD::ZERO_EXTEND) && + (ScaledVector.getOperand(0).getOpcode() == ISD::TRUNCATE)) { + // scaled unpacked 32-bit unsigned offsets + Opcode = AArch64ISD::GPRF_D_UXTW_SCALED; + Vector = ScaledVector.getOperand(0).getOperand(0); + } else if ((EltBits == 8) && SmallConstantBase) { + // We can use the unscaled ZI form for byte addresses + Opcode = AArch64ISD::GPRF_D_IMM; + Scalar = SmallConstantBase; + } else { + // scaled 64-bit offsets + Opcode = AArch64ISD::GPRF_D_SCALED; + Vector = ScaledVector; + } + } + } else { // Unscaled vector + if ((VectorTy == MVT::nxv4i64) && + (Offsets.getOpcode() == ISD::ZERO_EXTEND)) { + if (SmallConstantBase) { + // Unscaled unsigned 32-bit bases with in-range constant index + Opcode = AArch64ISD::GPRF_S_IMM; + Vector = Offsets.getOperand(0); + Scalar = SmallConstantBase; + } else { + // Unscaled 32-bit bases plus out of range (or non-constant) index. + // Only byte prefetches (prfb) supports this addressing mode. The SVE + // architecture team feel, if a user requests such a prefetch, it is + // better to use prfb, regardless of the requested element size, rather + // than legalising and generating two prefetches + Opcode = AArch64ISD::GPRF_S_UXTW_SCALED; + Vector = Offsets.getOperand(0); + EltTy = MVT::i8; + } + } else { // VectorTy == MVT::nxv2i64 + if (SmallConstantBase) { + // Unscaled 64-bit bases with in-range constant index + Opcode = AArch64ISD::GPRF_D_IMM; + Scalar = SmallConstantBase; + } else { + // splat the base onto the vector + auto Splat = DAG.getNode(ISD::SPLAT_VECTOR, DL, VectorTy, Base); + Vector = DAG.getNode(ISD::ADD, DL, VectorTy, {Vector, Splat}); + Scalar = DAG.getConstant(0, DL, MVT::i64); + Opcode = AArch64ISD::GPRF_D_IMM; + } + } + } + + if ((Opcode == 0) || + !DAG.getTargetLoweringInfo().isTypeLegal(Scalar.getValueType()) || + !DAG.getTargetLoweringInfo().isTypeLegal(Vector.getValueType())) + return SDValue(); + + SDVTList VTs = DAG.getVTList(MVT::Other); + SDValue Ops[] = { N->getOperand(0), // Chain + N->getOperand(2), // Pg + Scalar, + Vector, + PrfOp, + DAG.getValueType(EltTy) }; + + return DAG.getNode(Opcode, DL, VTs, Ops); +} + +// Analyse the specified address returning true if a change of IndexType will +// yield something that's more efficient to legalise. When returning true all +// parameters are updated to reflect their recommended values. +static bool FindMoreOptimalIndexType(SDValue &BasePtr, SDValue &Index, + ISD::MemIndexType &IndexType) { + if (IndexType != ISD::SIGNED_SCALED) + return false; + + // This is the only illegal type we can do much with. Smaller types can be + // easily promoted and bigger types will require splitting regardless. + if (Index.getValueType() != MVT::nxv4i64) + return false; + + // If we're starting with a vector_of_pointers. + if (isNullConstant(BasePtr) && Index.getOpcode() == ISD::ADD) { + SDValue Base = Index.getOperand(0); + SDValue Offset = Index.getOperand(1); + + if (Base.getOpcode() != ISD::SPLAT_VECTOR) + std::swap(Base, Offset); + + if (Base.getOpcode() != ISD::SPLAT_VECTOR) + return false; + + if (Offset.getOpcode() == ISD::SIGN_EXTEND) + IndexType = ISD::SIGNED_UNSCALED; + else if (Offset.getOpcode() == ISD::ZERO_EXTEND) + IndexType = ISD::UNSIGNED_UNSCALED; + + // Only recommend actual changes of addressing mode. + if (IndexType == ISD::SIGNED_SCALED) + return false; + + BasePtr = Base.getOperand(0); + Index = Offset.getOperand(0); + return true; + } + + // If we're starting with a base plus vector_of_unsigned_indices. + if (!isNullConstant(BasePtr) && + (Index.getOpcode() == ISD::ZERO_EXTEND) && + (Index.getOperand(0).getValueType() == MVT::nxv4i32)) { + IndexType = ISD::UNSIGNED_SCALED; + Index = Index.getOperand(0); + return true; + } + + return false; +} + +static SDValue performMGATHERCombine(MaskedGatherSDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + SDLoc DL(N); + SelectionDAG &DAG = DCI.DAG; + + SDValue Index = N->getIndex(); + SDValue Scale = N->getScale(); + SDValue Chain = N->getChain(); + SDValue PassThrough = N->getSrc0(); + SDValue Mask = N->getMask(); + SDValue BasePtr = N->getBasePtr(); + + EVT VT = N->getValueType(0); + EVT IVT = Index.getValueType(); + EVT PVT = Mask.getValueType(); + EVT BVT = BasePtr.getValueType(); + EVT IEltVT = IVT.getVectorElementType(); + EVT MemVT = N->getMemoryVT(); + ISD::MemIndexType IndexType = N->getIndexType(); + + if (DCI.isBeforeLegalize()) { + // SVE gather/scatter requires indices of i32/i64. Promote anything smaller + // prior to legalisation so the result can be split if required. + if ((IVT.getVectorElementType() == MVT::i8) || + (IVT.getVectorElementType() == MVT::i16)) { + EVT NewIVT = IVT.changeVectorElementType(MVT::i32); + if (N->isIndexSigned()) + Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIVT, Index); + else + Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIVT, Index); + + SDValue Ops[] = { Chain, PassThrough, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), + N->getMemoryVT(), DL, Ops, + N->getMemOperand(), N->getExtensionType(), + N->getIndexType()); + } + + // Some uses of MGATHER naturally lead to illegal types. For example, + // unsigned indicies are typically zero extended because MGATHER's + // default IndexType is SIGNED_SCALED. Unsigned offsets within nxv4i32 + // become signed offsets within nxv4i64. The latter is illegal and triggers + // splitvec style legalisation that's very difficult to undo. + // + // Here we catch such cases early and change MGATHER's IndexType to allow + // the use of an Index that's more legalisation friendly. + if (FindMoreOptimalIndexType(BasePtr, Index, IndexType)) { + SDValue Ops[] = { Chain, PassThrough, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), + N->getMemoryVT(), DL, Ops, N->getMemOperand(), + N->getExtensionType(), IndexType); + } + } else if (EnableSGToContiguousXForm) { + // After everything has been legalized, we want to recognize gathers with + // a constant stride of two and convert to two contiguous loads with + // shuffles and masking. This hurts a little bit on 128b, but is quite + // beneficial for greater register widths. + // + // We don't currently consider strides greater than two because that + // wouldn't fit nicely in 128b. A stride of four could work with 32b + // element types, but that should only be implemented if found to be + // common enough to warrant it. + // + // TODO: If we try this on a vector legalized from a wider type, things + // seem to go wrong; not sure why this is yet. For now, we bail out by + // checking whether the mask was extracted (SC-1425). + if (Index->getOpcode() == ISD::SERIES_VECTOR && + isa(Index->getOperand(1)) && + PassThrough.isUndef() && + Mask->getOpcode() != ISD::EXTRACT_SUBVECTOR && + IndexType == ISD::SIGNED_SCALED) { + auto Step = cast(Index->getOperand(1))->getSExtValue(); + if (Step == 2) { + unsigned ShiftAmt; + + switch(MemVT.getSimpleVT().SimpleTy) { + case MVT::nxv2i64: + case MVT::nxv2f64: + ShiftAmt = 3; + break; + case MVT::nxv4i32: + case MVT::nxv4f32: + ShiftAmt = 2; + break; + default: + return SDValue(); + } + + bool ExtendIdx; + switch (IEltVT.getSimpleVT().SimpleTy) { + case MVT::i64: + ExtendIdx = false; + break; + case MVT::i32: + ExtendIdx = true; + break; + default: + return SDValue(); + } + + // Get the first index value. + auto IdxVal = Index->getOperand(0); + + // If the indices are 32b, we need to extend to 64b. + if (ExtendIdx) + IdxVal = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, IdxVal); + + // Split the mask into two parts and interleave with false + auto PFalse = SDValue(DAG.getMachineNode(AArch64::PFALSE, DL, PVT), 0); + auto PredL = DAG.getNode(AArch64ISD::ZIP1, DL, PVT, Mask, PFalse); + auto PredH = DAG.getNode(AArch64ISD::ZIP2, DL, PVT, Mask, PFalse); + + // Figure out the offsets needed for the two loads -- scale the + // first index up based on element size, then scale up a count of + // the number of elements in the vector. + auto EltCount = DAG.getVScale(DL, MVT::i64, + VT.getVectorNumElements() << ShiftAmt); + IdxVal = DAG.getNode(ISD::SHL, DL, BVT, IdxVal, + DAG.getConstant(ShiftAmt, DL, MVT::i64)); + + // Perform the actual loads, making sure we use the chain value + // from the first in the second. + BasePtr = DAG.getNode(ISD::ADD, DL, BVT, BasePtr, IdxVal); + auto LoadL = DAG.getMaskedLoad(VT, DL, Chain, BasePtr, PredL, + PassThrough, MemVT, + N->getMemOperand(), + N->getExtensionType()); + + BasePtr = DAG.getNode(ISD::ADD, DL, BVT, BasePtr, EltCount); + auto LoadH = DAG.getMaskedLoad(VT, DL, LoadL.getValue(1), BasePtr, + PredH, PassThrough, MemVT, + N->getMemOperand(), + N->getExtensionType()); + + // Combine the loaded lanes into the single vector we'd get from + // the original gather. + auto Res = DAG.getNode(AArch64ISD::UZP1, DL, VT, + LoadL.getValue(0), LoadH.getValue(0)); + + // Make sure we return both the loaded values and the chain from the + // last load. + return DAG.getMergeValues({ Res, LoadH.getValue(1) }, DL); + } + } + } + + return SDValue(); +} + +static SDValue performMSCATTERCombine(MaskedScatterSDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + SDLoc DL(N); + SelectionDAG &DAG = DCI.DAG; + + SDValue Index = N->getIndex(); + SDValue Scale = N->getScale(); + SDValue Chain = N->getChain(); + SDValue Data = N->getValue(); + SDValue Mask = N->getMask(); + SDValue BasePtr = N->getBasePtr(); + + EVT VT = Data.getValueType(); + EVT IVT = Index.getValueType(); + EVT PVT = Mask.getValueType(); + EVT BVT = BasePtr.getValueType(); + EVT IEltVT = IVT.getVectorElementType(); + EVT MemVT = N->getMemoryVT(); + ISD::MemIndexType IndexType = N->getIndexType(); + + if (DCI.isBeforeLegalize()) { + // SVE gather/scatter requires indices of i32/i64. Promote anything smaller + // prior to legalisation so the result can be split if required. + if ((IVT.getVectorElementType() == MVT::i8) || + (IVT.getVectorElementType() == MVT::i16)) { + EVT NewIVT = IVT.changeVectorElementType(MVT::i32); + if (N->isIndexSigned()) + Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIVT, Index); + else + Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIVT, Index); + + SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), + N->getMemoryVT(), DL, Ops, + N->getMemOperand(), N->isTruncatingStore(), + N->getIndexType()); + } + + // Some uses of MSCATTER naturally lead to illegal types. For example, + // unsigned indicies are typically zero extended because MSCATTER's + // default IndexType is SIGNED_SCALED. Unsigned offsets within nxv4i32 + // become signed offsets within nxv4i64. The latter is illegal and triggers + // splitvec style legalisation that's very difficult to undo. + // + // Here we catch such cases early and change MSCATTER's IndexType to allow + // the use of an Index that's more legalisation friendly. + if (FindMoreOptimalIndexType(BasePtr, Index, IndexType)) { + SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), + N->getMemoryVT(), DL, Ops, + N->getMemOperand(), N->isTruncatingStore(), + IndexType); + } + } else if (EnableSGToContiguousXForm) { + // After everything has been legalized, we want to recognize scatters with + // a constant stride of two and convert to two contiguous stores with + // shuffles and masking. This hurts a little bit on 128b, but is quite + // beneficial for greater register widths. + // + // We don't currently consider strides greater than two because that + // wouldn't fit nicely in 128b. A stride of four could work with 32b + // element types, but that should only be implemented if found to be + // common enough to warrant it. + // + // TODO: If we try this on a vector legalized from a wider type, things + // seem to go wrong; not sure why this is yet. For now, we bail out by + // checking whether the mask was extracted (SC-1425). + if (Index->getOpcode() == ISD::SERIES_VECTOR && + isa(Index->getOperand(1)) && + Mask->getOpcode() != ISD::EXTRACT_SUBVECTOR && + IndexType == ISD::SIGNED_SCALED) { + auto Step = cast(Index->getOperand(1))->getSExtValue(); + if (Step == 2) { + unsigned ShiftAmt; + + switch(MemVT.getSimpleVT().SimpleTy) { + case MVT::nxv2i64: + case MVT::nxv2f64: + ShiftAmt = 3; + break; + case MVT::nxv4i32: + case MVT::nxv4f32: + ShiftAmt = 2; + break; + default: + return SDValue(); + } + + bool ExtendIdx; + switch (IEltVT.getSimpleVT().SimpleTy) { + case MVT::i64: + ExtendIdx = false; + break; + case MVT::i32: + ExtendIdx = true; + break; + default: + return SDValue(); + } + + // Get the first index value. + auto IdxVal = Index->getOperand(0); + + // If the indices are 32b, we need to extend to 64b. + if (ExtendIdx) + IdxVal = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, IdxVal); + + // Split the mask into two parts and interleave with false + auto PFalse = SDValue(DAG.getMachineNode(AArch64::PFALSE, DL, PVT), 0); + auto PredL = DAG.getNode(AArch64ISD::ZIP1, DL, PVT, Mask, PFalse); + auto PredH = DAG.getNode(AArch64ISD::ZIP2, DL, PVT, Mask, PFalse); + + // Figure out the offsets needed for the two stores -- scale the + // first index up based on element size, then scale up a count of + // the number of elements in the vector. + auto EltCount = DAG.getVScale(DL, MVT::i64, + VT.getVectorNumElements() << ShiftAmt); + IdxVal = DAG.getNode(ISD::SHL, DL, BVT, IdxVal, + DAG.getConstant(ShiftAmt, DL, MVT::i64)); + + // As with the mask, split the data into two parts and interleave; + // Here we can just use the data itself for the other lanes, since + // the inactive lanes won't be stored. + auto DataL = DAG.getNode(AArch64ISD::ZIP1, DL, VT, Data, Data); + auto DataH = DAG.getNode(AArch64ISD::ZIP2, DL, VT, Data, Data); + + // Perform the actual stores, making sure we use the chain value + // from the first in the second. + BasePtr = DAG.getNode(ISD::ADD, DL, BVT, BasePtr, IdxVal); + auto StoreL = DAG.getMaskedStore(Chain, DL, DataL, BasePtr, PredL, + MemVT, N->getMemOperand(), + N->isTruncatingStore()); + + BasePtr = DAG.getNode(ISD::ADD, DL, BVT, BasePtr, EltCount); + auto StoreH = DAG.getMaskedStore(StoreL, DL, DataH, BasePtr, PredH, + MemVT, N->getMemOperand(), + N->isTruncatingStore()); + + return StoreH; + } + } + } + + return SDValue(); +} + +static SDValue performVSelectMinMaxRdxCombine(SDNode *N, SelectionDAG &DAG) { + // A vselect ((setcc op1, op2, gt), op2, op1) is doing a min selection, + // while a cc of 'lt' will do a max. + // There will be an AND as well if this is coming from vectorized code. + assert(N->getOpcode() == ISD::VSELECT && "Unexpected opcode"); + SDValue Pred = N->getOperand(0); + SDValue True = N->getOperand(1); + SDValue False = N->getOperand(2); + EVT ResVT = N->getValueType(0); + + if (Pred->getOpcode() != ISD::AND) + return SDValue(); + + SDValue PredOp1 = Pred->getOperand(0); + SDValue PredOp2 = Pred->getOperand(1); + // One of the operands should be a compare. + SDValue SetCC; + SDValue GP; + if (PredOp1->getOpcode() == ISD::SETCC) { + SetCC = PredOp1; + GP = PredOp2; + } else if (PredOp2->getOpcode() == ISD::SETCC) { + SetCC = PredOp2; + GP = PredOp1; + } else + return SDValue(); + + SDValue CmpOp1 = SetCC->getOperand(0); + SDValue CmpOp2 = SetCC->getOperand(1); + SDValue CC = SetCC->getOperand(2); + + // The compare operands must match the select sources. + if (False != CmpOp1 || True != CmpOp2) + return SDValue(); + + ISD::CondCode CondCode = cast(CC)->get(); + + assert(CmpOp1.getValueType().isVector() && "Unexpected vt"); + EVT ScalarVT = CmpOp1.getValueType().getVectorElementType(); + if (ScalarVT != MVT::f32 && ScalarVT != MVT::f64) + return SDValue(); + + bool IsMin = false; + bool NoNaN = DAG.getTarget().Options.NoNaNsFPMath; + switch (CondCode) { + case ISD::SETGT: + case ISD::SETGE: + IsMin = true; + break; + case ISD::SETLT: + case ISD::SETLE: + break; + default: + return SDValue(); + } + unsigned Opcode = IsMin ? AArch64ISD::FMIN_PRED : AArch64ISD::FMAX_PRED; + + // Use faster NM versions if we have NoNaNs. + if (NoNaN) { + if (Opcode == AArch64ISD::FMIN_PRED) + Opcode = AArch64ISD::FMINNM_PRED; + else + Opcode = AArch64ISD::FMAXNM_PRED; + } + return DAG.getNode(Opcode, SDLoc(N), ResVT, GP, False, True); +} + +// If the node is an extension from a legal SVE type to something wider, +// use HiOpcode and LoOpcode to extend each half individually, then +// concatenate them together. +void AArch64TargetLowering::ReplaceExtensionResults( + SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG, + unsigned HiOpcode, unsigned LoOpcode) const { + SDValue In = N->getOperand(0); + EVT InVT = In.getValueType(); + assert(InVT.isScalableVector() && "Can only lower WA vectors"); + if (!isTypeLegal(InVT)) + return; + + EVT InEltVT = InVT.getVectorElementType(); + auto EltCnt = InVT.getVectorElementCount(); + unsigned InEltBits = InEltVT.getSizeInBits(); + if (InEltBits != 8 && InEltBits != 16 && InEltBits != 32) + return; + + // The result must be at least twice as wide as the input in order for + // this to work. + EVT VT = N->getValueType(0); + EVT EltVT = VT.getVectorElementType(); + if (EltVT.getSizeInBits() < InEltBits * 2) + return; + + // Extend In to a double-width vector. + SDLoc dl(N); + EVT NewEltVT = EVT::getIntegerVT(*DAG.getContext(), InEltBits * 2); + EVT NewVT = EVT::getVectorVT(*DAG.getContext(), NewEltVT, EltCnt/2); + SDValue Lo = DAG.getNode(LoOpcode, dl, NewVT, In); + SDValue Hi = DAG.getNode(HiOpcode, dl, NewVT, In); + assert(isTypeLegal(NewVT) && "Extension result should be legal"); + + // If necessary, extend again using the original code. Such extensions + // will also need legalizing, but at least we're making forward progress. + NewVT = EVT::getVectorVT(*DAG.getContext(), EltVT, EltCnt/2); + Lo = DAG.getNode(N->getOpcode(), dl, NewVT, Lo); + Hi = DAG.getNode(N->getOpcode(), dl, NewVT, Hi); + Results.push_back(DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, Lo, Hi)); +} + +void AArch64TargetLowering::ReplaceExtractSubVectorResults( + SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const { + SDValue In = N->getOperand(0); + EVT InVT = In.getValueType(); + + // Common code will handle these just fine. + if (!InVT.isScalableVector() || !InVT.isInteger()) + return; + + SDLoc dl(N); + EVT VT = N->getValueType(0); + + if (!isTypeLegal(InVT)) { + // Bubble truncates to illegal types to the surface. + if (In->getOpcode() == ISD::TRUNCATE) { + EVT TruncOpVT = In->getOperand(0)->getValueType(0); + if (!isTypeLegal(TruncOpVT)) + return; + + EVT EltVT = TruncOpVT.getVectorElementType(); + EVT SubVecVT = VT.changeVectorElementType(EltVT); + + SDValue SubVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVecVT, + In->getOperand(0), N->getOperand(1)); + + Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, VT, SubVec)); + return; + } + + return; + } + + // The following checks bail if this is not a halving operation. + + if (InVT.getVectorNumElements() != (VT.getVectorNumElements()*2)) + return; + + auto *CIndex = dyn_cast(N->getOperand(1)); + if (!CIndex) + return; + + unsigned Index = CIndex->getZExtValue(); + if ((Index != 0) && (Index != VT.getVectorNumElements())) + return; + + unsigned Opcode = (Index == 0) ? AArch64ISD::UUNPKLO : AArch64ISD::UUNPKHI; + EVT ExtendedHalfVT = VT.widenIntegerVectorElementType(*DAG.getContext()); + + SDValue Half = DAG.getNode(Opcode, dl, ExtendedHalfVT, N->getOperand(0)); + Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, VT, Half)); +} + +void AArch64TargetLowering::ReplaceInsertSubVectorResults( + SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const { + SDLoc DL(N); + SDValue Vec0 = N->getOperand(0); + SDValue Vec1 = N->getOperand(1); + SDValue Idx = N->getOperand(2); + EVT VT = N->getValueType(0); + EVT Vec0VT = Vec0.getValueType(); + EVT Vec1VT = Vec1.getValueType(); + + if (!VT.isScalableVector() || !VT.isInteger()) + return; + + unsigned NumElts = Vec1VT.getVectorNumElements(); + + // Can only handle double width + if (Vec0VT.getVectorNumElements() != (Vec1VT.getVectorNumElements() * 2)) + return; + + // Can only handle upper/lower half + auto *CIdx = dyn_cast(Idx); + if (!CIdx) + return; + + unsigned IdxVal = CIdx->getZExtValue(); + + // Extract appropriate half of larger vector, then concat with smaller vector. + if (IdxVal == 0) { + SDValue HiVec0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Vec1VT, Vec0, + DAG.getConstant(NumElts, DL, + Idx.getValueType())); + SDValue ConcatVec = DAG.getNode(ISD::CONCAT_VECTORS, DL, + Vec0VT, Vec1, HiVec0); + Results.push_back(ConcatVec); + } else if (IdxVal == NumElts) { + SDValue LoVec0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Vec1VT, Vec0, + DAG.getConstant(0, DL, Idx.getValueType())); + SDValue ConcatVec = DAG.getNode(ISD::CONCAT_VECTORS, DL, + Vec0VT, LoVec0, Vec1); + Results.push_back(ConcatVec); + } + + return; +} + +// Lower illegal vector element insert into compare + select instructions. +// The types/operations that we create here are illegal and will be legalized +// separately. +void AArch64TargetLowering::ReplaceInsertVectorElementResults( + SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const { + SDLoc DL(N); + + // Get all operands + SDValue InVec = N->getOperand(0); + SDValue InElem = N->getOperand(1); + SDValue InIdx = N->getOperand(2); + + assert(InVec.getValueType().isScalableVector()); + + // Element type + EVT VT = InVec.getValueType(); + EVT EltVT = VT.getVectorElementType(); + + // #Elements #Bits + unsigned NumElts = VT.getVectorNumElements(); + unsigned EltBits = EltVT.getSizeInBits(); + + // Only support lowering for scalable types and index + if (isa(InIdx) || + NumElts * EltBits <= AArch64::SVEBitsPerBlock) + return; + + // Get types for splats + EVT SplatVT = VT.changeVectorElementTypeToInteger(); + EVT IntEltVT = SplatVT.getVectorElementType(); + EVT BoolVT = EVT::getIntegerVT(*DAG.getContext(), 1); + EVT PredVT = SplatVT.changeVectorElementType(BoolVT); + + // Create splats + SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, DL, SplatVT, + DAG.getNode(ISD::BITCAST, DL, IntEltVT, InElem)); + SDValue SplatIdx = DAG.getNode(ISD::SPLAT_VECTOR, DL, SplatVT, + DAG.getZExtOrTrunc(InIdx, DL, IntEltVT)); + + // Insert using compare mask and select + SDValue Zero = DAG.getConstant(0, DL, IntEltVT); + SDValue One = DAG.getConstant(1, DL, IntEltVT); + SDValue Seq = DAG.getNode(ISD::SERIES_VECTOR, DL, SplatVT, Zero, One); + SDValue Cmp = DAG.getNode(ISD::SETCC, DL, PredVT, + SplatIdx, Seq, DAG.getCondCode(ISD::SETEQ)); + SDValue Res = DAG.getNode(ISD::VSELECT, DL, SplatVT, Cmp, Splat, InVec); + Res = DAG.getNode(ISD::BITCAST, DL, InVec.getValueType(), Res); + + // Leave splitting of illegal types for seriesvector and vselect + // to LLVM for further processing. + Results.push_back(Res); +} +void AArch64TargetLowering::ReplaceMergeVecCpyResults(SDNode *N, + SmallVectorImpl &Results, SelectionDAG &DAG) const { + assert(N->getValueType(0).isScalableVector() && "Can only lower WA vectors"); + + SDValue Pred = N->getOperand(1); + SDValue Scalar = N->getOperand(2); + EVT ScalarVT = Scalar->getValueType(0); + + EVT VecLoVT, VecHiVT; + SDValue VecLo, VecHi; + SDLoc dl(N); + std::tie(VecLoVT, VecHiVT) = DAG.GetSplitDestVTs(N->getValueType(0)); + std::tie(VecLo, VecHi) = DAG.SplitVectorOperand(N, 0); + + EVT PredLoVT, PredHiVT; + SDValue PredLo, PredHi; + std::tie(PredLoVT, PredHiVT) = DAG.GetSplitDestVTs(Pred->getValueType(0)); + std::tie(PredLo, PredHi) = DAG.SplitVectorOperand(N, 1); + + SDVTList LoVTs = DAG.getVTList(VecLoVT, PredLoVT, ScalarVT); + SDValue LoOps[] = { VecLo, PredLo, Scalar }; + SDValue LoCpy = DAG.getNode(AArch64ISD::DUP_PRED, dl, LoVTs, LoOps); + + SDVTList HiVTs = DAG.getVTList(VecHiVT, PredHiVT, ScalarVT); + SDValue HiOps[] = { VecHi, PredHi, Scalar }; + SDValue HiCpy = DAG.getNode(AArch64ISD::DUP_PRED, dl, HiVTs, HiOps); + + SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, dl, N->getValueType(0), + LoCpy, HiCpy); + Results.push_back(Concat); +} + +void AArch64TargetLowering::ReplaceMaskedSpecLoadResults(SDNode *N, + SmallVectorImpl &Results, SelectionDAG &DAG) const { + EVT ValVT = N->getValueType(0); + EVT PredVT = N->getValueType(1); + assert(PredVT.isScalableVector() && "Can only lower WA vectors"); + assert(ValVT.isScalableVector() && "Can only lower WA vectors"); + + // not expecting a pass through value + if (N->getOperand(5).getOpcode() != ISD::UNDEF) + return; + + EVT EltVT = ValVT.getVectorElementType(); + unsigned NumElts = ValVT.getVectorNumElements(); + + EVT NewVT = ValVT; + if (EltVT.isInteger()) { + int NewEltBits = AArch64::SVEBitsPerBlock / NumElts; + EVT NewEltVT = EVT::getIntegerVT(*DAG.getContext(), NewEltBits); + NewVT = ValVT.changeVectorElementType(NewEltVT); + } + + if (!isTypeLegal(NewVT)) + return; + + SDLoc dl(N); + SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue); + SDValue SetFFR = DAG.getNode(AArch64ISD::SETFFR, dl, + NodeTys, N->getOperand(0)); + SDValue LoadOps[] = { SetFFR, // Chain in + N->getOperand(4), // GP + N->getOperand(2), // Address + DAG.getValueType(ValVT), // MemVT + SDValue(SetFFR.getNode(),1) }; // Glue in + SDVTList LoadVTs = DAG.getVTList(NewVT, MVT::Other); + SDValue Load = DAG.getNode(AArch64ISD::LDFF1, dl, LoadVTs, LoadOps); + SDValue LoadChain = SDValue(Load.getNode(), 1); + SDValue PredOps[] = { LoadChain, N->getOperand(4) }; // Chain in, GP + SDVTList PredVTs = DAG.getVTList(PredVT, MVT::Other); + SDValue FaultPred = DAG.getNode(AArch64ISD::RDFFR_PRED, dl, PredVTs, PredOps); + SDValue PredChain = SDValue(FaultPred.getNode(), 1); + + if (EltVT.isInteger()) + Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, ValVT, Load)); + else + Results.push_back(Load); + + Results.push_back(FaultPred); + Results.push_back(PredChain); + return; +} + +void AArch64TargetLowering::ReplaceVectorShuffleVarResults( + SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const { + EVT VT = N->getValueType(0); + assert(VT.isScalableVector() && "Can only lower WA vectors"); + EVT EltVT = VT.getVectorElementType(); + unsigned NumElts = VT.getVectorNumElements(); + unsigned EltBits = EltVT.getSizeInBits(); + if (NumElts * EltBits <= AArch64::SVEBitsPerBlock || + AArch64::SVEBitsPerBlock % EltBits != 0) + return; + + unsigned NewNumElts = AArch64::SVEBitsPerBlock / EltBits; + EVT NewVT = EVT::getVectorVT(*DAG.getContext(), EltVT, { NewNumElts, true }); + if (!isTypeLegal(NewVT)) + return; + + Results.push_back(LowerVECTOR_SHUFFLE_VAR(SDValue(N, 0), DAG, + NumElts / NewNumElts, NewVT)); +} + +void AArch64TargetLowering::ReplaceFP_EXTENDResults(SDNode *N, + SmallVectorImpl &Results, + SelectionDAG &DAG) const { + EVT VT = N->getValueType(0); + EVT InVT = N->getOperand(0)->getValueType(0); + + // Let normal code split the vector before we get involved. + if (!isTypeLegal(InVT)) + return; + + assert(VT.isScalableVector() && "Can only lower WA vectors"); + SDLoc DL(N); + SDValue Op = N->getOperand(0); + EVT SplitVT = VT.getHalfNumVectorElementsVT(*DAG.getContext()); + EVT SplitInVT = InVT.getHalfNumVectorElementsVT(*DAG.getContext()); + + SDValue Zip1 = DAG.getNode(AArch64ISD::ZIP1, DL, InVT, Op, Op); + SDValue Zip2 = DAG.getNode(AArch64ISD::ZIP2, DL, InVT, Op, Op); + + SDValue Lo = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, SplitInVT, Zip1); + SDValue Hi = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, SplitInVT, Zip2); + + SDValue FCvtLo = DAG.getNode(N->getOpcode(), DL, SplitVT, Lo); + SDValue FCvtHi = DAG.getNode(N->getOpcode(), DL, SplitVT, Hi); + + Results.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, FCvtLo, FCvtHi)); +} + +static SDValue +performFirstTrueTestVectorCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const AArch64Subtarget *Subtarget) { + assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT); + SelectionDAG &DAG = DCI.DAG; + + if (!Subtarget->hasSVE() || !DCI.isBeforeLegalize()) + return SDValue(); + + SDValue Op = N->getOperand(0); + EVT OpVT = Op.getValueType(); + + if (!OpVT.isScalableVector() || OpVT.getVectorElementType() != MVT::i1) + return SDValue(); + + auto *Idx = dyn_cast(N->getOperand(1)); + if (!Idx || Idx->getZExtValue() != 0) + return SDValue(); + + // Extracts of lane 0 for SVE can be expressed as PTEST(Op, FIRST) ? 1 : 0 + SDValue Pg = getPTrue(DAG, SDLoc(N), OpVT, AArch64SVEPredPattern::all); + return getPTest(DAG, N->getValueType(0), Pg, Op, AArch64CC::FIRST_ACTIVE); +} + +void AArch64TargetLowering::ReplaceVectorBITCASTResults(SDNode *N, + SmallVectorImpl &Results, + SelectionDAG &DAG) const { + SDLoc DL(N); + EVT VT = N->getValueType(0); + auto Src = N->getOperand(0); + + if (!VT.isScalableVector()) + return; + + if (!isTypeLegal(VT) && isTypeLegal(Src.getValueType())) { + assert(!VT.isFloatingPoint() && "Expected fp->int bitcast!"); + EVT ContainerVT = getNaturalIntSVETypeWithMatchingElementCount(VT); + auto Tmp = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, ContainerVT, Src); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Tmp)); + } +} + +static SDValue +performLastTrueTestVectorCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const AArch64Subtarget *Subtarget) { + assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT); + SelectionDAG &DAG = DCI.DAG; + + if (!Subtarget->hasSVE() || !DCI.isBeforeLegalize()) + return SDValue(); + + SDValue Op = N->getOperand(0); + EVT OpVT = Op.getValueType(); + + if (!OpVT.isScalableVector() || OpVT.getVectorElementType() != MVT::i1) + return SDValue(); + + // Idx == (vscale * NumEls) - 1 + + SDValue Idx = N->getOperand(1); + if (Idx.getOpcode() != ISD::ADD) + return SDValue(); + + SDValue VS = Idx.getOperand(0); + if (VS.getOpcode() != ISD::VSCALE) + return SDValue(); + + unsigned NumEls = OpVT.getVectorNumElements(); + if (cast(VS.getOperand(0))->getSExtValue() != NumEls) + return SDValue(); + + auto *CI = dyn_cast(Idx.getOperand(1)); + if (!CI || CI->getSExtValue() != -1) + return SDValue(); + + // Extracts of lane EC-1 for SVE can be expressed as PTEST(Op, LAST) ? 1 : 0 + SDValue Pg = getPTrue(DAG, SDLoc(N), OpVT, AArch64SVEPredPattern::all); + return getPTest(DAG, N->getValueType(0), Pg, Op, AArch64CC::LAST_ACTIVE); +} + +static SDValue performVScaleCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + assert(N->getOpcode() == ISD::VSCALE && "Expected ISD::VSCALE!"); + if (!DCI.isAfterLegalizeDAG()) + return SDValue(); + + // ISD::VSCALE has an embedded constant multiplier to simplify leagilisation. + // However, for values that don't map to VL based arithmetic instructions + // (including loads and stores with a *VL based offset) a multiply is + // selected. This means we miss out on logic that replaces such multiplies + // with shifts, chained adds, etc. Here we expand the likely cases to make + // the multiplication explicit and thus optimisable. + + int64_t MulImm = cast(N->getOperand(0))->getSExtValue(); + if (MulImm % 2) + return SDValue(); + + if ((MulImm < (-32 * 16)) || (MulImm > (31 * 16))) { + SDLoc DL(N); + SDValue One = DAG.getTargetConstant(1, DL, MVT::i32); + SDValue All = DAG.getTargetConstant(31, DL, MVT::i32); + + SDNode *Op; + if ((MulImm % 16) == 0) { + Op = DAG.getMachineNode(AArch64::RDVLI_XI, DL, MVT::i64, One); + MulImm = MulImm / 16; + } else if ((MulImm % 8) == 0) { + Op = DAG.getMachineNode(AArch64::CNTH_XPiI, DL, MVT::i64, All, One); + MulImm = MulImm / 8; + } else if ((MulImm % 4) == 0) { + Op = DAG.getMachineNode(AArch64::CNTW_XPiI, DL, MVT::i64, All, One); + MulImm = MulImm / 4; + } else { + assert((MulImm % 2) == 0); + Op = DAG.getMachineNode(AArch64::CNTD_XPiI, DL, MVT::i64, All, One); + MulImm = MulImm / 2; + } + + SDValue Cnt = SDValue(Op, 0); + SDValue CntImm = DAG.getConstant(MulImm, DL, MVT::i64); + return DAG.getNode(ISD::MUL, DL, MVT::i64, Cnt, CntImm); + } + + return SDValue(); +} + +// Here we perform late DAG transformations to make address generation more +// amenable for SVE load/store instruction selection. +static SDValue performSVEIndexedAddressingCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + if (!DCI.isAfterLegalizeDAG()) + return SDValue(); + + MVT VT = N->getSimpleValueType(0); + if (!VT.isScalarInteger()) + return SDValue(); + + if (N->getOpcode() != ISD::ADD) + return SDValue(); + + bool Skip = false; + MemSDNode* User = nullptr; + + // Only split VSCALE multiplication when used for address generation. + for (auto UI = N->use_begin(), UE = N->use_end(); UI != UE; ++UI) { + auto MemAccess = dyn_cast(*UI); + if (MemAccess && MemAccess->getMemoryVT().isScalableVector()) { + if (User == nullptr) + User = MemAccess; + + if (User->getMemoryVT() == MemAccess->getMemoryVT()) + continue; + } + + Skip = true; + break; + } + + if (Skip) + return SDValue(); + + EVT MemVT = User->getMemoryVT(); + unsigned MemEltSize = MemVT.getVectorElementType().getSizeInBits() / 8; + + // Bytes accesses can handle VSCALE multiplication as is. + if (MemEltSize == 1) + return SDValue(); + + // load(x + (vscale * C1)) -> load(x + (vscale * C2) * sizeof(elt)) + if ((N->getOperand(1).getOpcode() == ISD::MUL) && + (N->getOperand(1).getOperand(0).getOpcode() == ISD::VSCALE)) { + SDValue Mul = N->getOperand(1); + SDValue VS = N->getOperand(1).getOperand(0); + + int64_t MulImm = cast(VS.getOperand(0))->getSExtValue(); + if ((MulImm % MemEltSize) == 0) { + // Peel of part of the VSCALE multiplication that when combined with the + // add will match against an indexed load/store. + SDLoc DL(N); + SDValue Scale = DAG.getConstant(countTrailingZeros(MemEltSize), DL, + MVT::i64); + + SDValue NewVS = DAG.getVScale(DL, VT, MulImm / MemEltSize); + SDValue NewMul = DAG.getNode(ISD::MUL, DL, VT, NewVS, Mul.getOperand(1)); + SDValue NewShl = DAG.getNode(ISD::SHL, DL, VT, NewMul, Scale); + return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(0), NewShl); + } + } + + // for: y = (idx * sizeof(vec)) + // loads((x + y) + z) -> load((x + z) + y) + if ((N->getOperand(0).getOpcode() == ISD::ADD) && + (N->getOperand(0).getOperand(1).getOpcode() == ISD::VSCALE)) { + SDValue X = N->getOperand(0).getOperand(0); + SDValue Y = N->getOperand(0).getOperand(1); + SDValue Z = N->getOperand(1); + + SDValue XplusZ = DAG.getNode(ISD::ADD, SDLoc(N), VT, X, Z); + return DAG.getNode(ISD::ADD, SDLoc(N), VT, XplusZ, Y); + } + + return SDValue(); +} Index: lib/Target/AArch64/SVEInstrFormats.td =================================================================== --- lib/Target/AArch64/SVEInstrFormats.td +++ lib/Target/AArch64/SVEInstrFormats.td @@ -167,39 +167,39 @@ def SVEAddSubImmOperand32 : SVEShiftedImmOperand<32, "AddSub", "isSVEAddSubImm">; def SVEAddSubImmOperand64 : SVEShiftedImmOperand<64, "AddSub", "isSVEAddSubImm">; -class imm8_opt_lsl - : Operand, ImmLeaf { + : Operand, ImmLeaf { let EncoderMethod = "getImm8OptLsl"; let DecoderMethod = "DecodeImm8OptLsl<" # ElementWidth # ">"; let PrintMethod = "printImm8OptLsl<" # printType # ">"; let ParserMatchClass = OpndClass; - let MIOperandInfo = (ops i32imm, i32imm); + let MIOperandInfo = (ops Op, Op); } -def cpy_imm8_opt_lsl_i8 : imm8_opt_lsl<8, "int8_t", SVECpyImmOperand8, [{ +def cpy_imm8_opt_lsl_i8 : imm8_opt_lsl(Imm); }]>; -def cpy_imm8_opt_lsl_i16 : imm8_opt_lsl<16, "int16_t", SVECpyImmOperand16, [{ +def cpy_imm8_opt_lsl_i16 : imm8_opt_lsl(Imm); }]>; -def cpy_imm8_opt_lsl_i32 : imm8_opt_lsl<32, "int32_t", SVECpyImmOperand32, [{ +def cpy_imm8_opt_lsl_i32 : imm8_opt_lsl(Imm); }]>; -def cpy_imm8_opt_lsl_i64 : imm8_opt_lsl<64, "int64_t", SVECpyImmOperand64, [{ +def cpy_imm8_opt_lsl_i64 : imm8_opt_lsl(Imm); }]>; -def addsub_imm8_opt_lsl_i8 : imm8_opt_lsl<8, "uint8_t", SVEAddSubImmOperand8, [{ +def addsub_imm8_opt_lsl_i8 : imm8_opt_lsl(Imm); }]>; -def addsub_imm8_opt_lsl_i16 : imm8_opt_lsl<16, "uint16_t", SVEAddSubImmOperand16, [{ +def addsub_imm8_opt_lsl_i16 : imm8_opt_lsl(Imm); }]>; -def addsub_imm8_opt_lsl_i32 : imm8_opt_lsl<32, "uint32_t", SVEAddSubImmOperand32, [{ +def addsub_imm8_opt_lsl_i32 : imm8_opt_lsl(Imm); }]>; -def addsub_imm8_opt_lsl_i64 : imm8_opt_lsl<64, "uint64_t", SVEAddSubImmOperand64, [{ +def addsub_imm8_opt_lsl_i64 : imm8_opt_lsl(Imm); }]>; @@ -234,16 +234,21 @@ let DecoderMethod = "DecodeSVEIncDecImm"; } +// This allows i32 immediate extraction from i64 based arithmetic. +def sve_cnt_mul_imm : ComplexPattern">; +def sve_cnt_shl_imm : ComplexPattern">; + //===----------------------------------------------------------------------===// // SVE PTrue - These are used extensively throughout the pattern matching so // it's important we define them first. //===----------------------------------------------------------------------===// -class sve_int_ptrue sz8_64, bits<3> opc, string asm, PPRRegOp pprty> +class sve_int_ptrue sz8_64, bits<3> opc, string asm, PPRRegOp pprty, + ValueType vt, SDPatternOperator op> : I<(outs pprty:$Pd), (ins sve_pred_enum:$pattern), asm, "\t$Pd, $pattern", "", - []>, Sched<[]> { + [(set (vt pprty:$Pd), (op sve_pred_enum:$pattern))]>, Sched<[]> { bits<4> Pd; bits<5> pattern; let Inst{31-24} = 0b00100101; @@ -257,13 +262,15 @@ let Inst{3-0} = Pd; let Defs = !if(!eq (opc{0}, 1), [NZCV], []); + let ElementSize = pprty.ElementSize; + let isReMaterializable = 1; } -multiclass sve_int_ptrue opc, string asm> { - def _B : sve_int_ptrue<0b00, opc, asm, PPR8>; - def _H : sve_int_ptrue<0b01, opc, asm, PPR16>; - def _S : sve_int_ptrue<0b10, opc, asm, PPR32>; - def _D : sve_int_ptrue<0b11, opc, asm, PPR64>; +multiclass sve_int_ptrue opc, string asm, SDPatternOperator op> { + def _B : sve_int_ptrue<0b00, opc, asm, PPR8, nxv16i1, op>; + def _H : sve_int_ptrue<0b01, opc, asm, PPR16, nxv8i1, op>; + def _S : sve_int_ptrue<0b10, opc, asm, PPR32, nxv4i1, op>; + def _D : sve_int_ptrue<0b11, opc, asm, PPR64, nxv2i1, op>; def : InstAlias(NAME # _B) PPR8:$Pd, 0b11111), 1>; @@ -275,11 +282,182 @@ (!cast(NAME # _D) PPR64:$Pd, 0b11111), 1>; } +def SDT_AArch64PTrue : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVT<1, i32>]>; +def AArch64ptrue : SDNode<"AArch64ISD::PTRUE", SDT_AArch64PTrue>; + let Predicates = [HasSVE] in { - defm PTRUE : sve_int_ptrue<0b000, "ptrue">; - defm PTRUES : sve_int_ptrue<0b001, "ptrues">; + defm PTRUE : sve_int_ptrue<0b000, "ptrue", AArch64ptrue>; + defm PTRUES : sve_int_ptrue<0b001, "ptrues", null_frag>; } +//===----------------------------------------------------------------------===// +// SVE pattern match helpers. +//===----------------------------------------------------------------------===// + +class SVE_1_Op_Pat +: Pat<(vtd (op vt1:$Op1)), + (inst $Op1)>; + +class SVE_2_Op_Imm_Pat +: Pat<(vtd (op vt1:$Op1, (vt2 ImmTy:$Op2))), + (inst $Op1, ImmTy:$Op2)>; + +class SVE_2_Op_Pat +: Pat<(vtd (op vt1:$Op1, vt2:$Op2)), + (inst $Op1, $Op2)>; + +class SVE_3_Op_Imm_Pat +: Pat<(vtd (op vt1:$Op1, vt2:$Op2, (vt3 ImmTy:$Op3))), + (inst $Op1, $Op2, ImmTy:$Op3)>; + +class SVE_3_Op_Pat +: Pat<(vtd (op vt1:$Op1, vt2:$Op2, vt3:$Op3)), + (inst $Op1, $Op2, $Op3)>; + +class SVE_4_Op_Imm_Pat +: Pat<(vtd (op vt1:$Op1, vt2:$Op2, vt3:$Op3, (vt4 ImmTy:$Op4))), + (inst $Op1, $Op2, $Op3, ImmTy:$Op4)>; + +class SVE_4_Op_Pat +: Pat<(vtd (op vt1:$Op1, vt2:$Op2, vt3:$Op3, vt4:$Op4)), + (inst $Op1, $Op2, $Op3, $Op4)>; + +def SVEDup0 : ComplexPattern; +def SVEDup0Undef : ComplexPattern; + +let AddedComplexity = 1 in { +class SVE_3_Op_Pat_SelZero +: Pat<(vtd (vtd (op vt1:$Op1, (vselect vt1:$Op1, vt2:$Op2, (SVEDup0)), vt3:$Op3))), + (inst $Op1, $Op2, $Op3)>; + +class SVE_3_Op_Pat_Shift_Imm_SelZero +: Pat<(vtd (op vt1:$Op1, (vselect vt1:$Op1, vt2:$Op2, (SVEDup0)), (i32 (vt3:$Op3)))), + (inst $Op1, $Op2, vt3:$Op3)>; + +class SVE_4_Op_Pat_SelZero +: Pat<(vtd (op vt1:$Op1, (vselect vt1:$Op1, vt2:$Op2, (SVEDup0)), vt3:$Op3, vt4:$Op4)), + (inst $Op1, $Op2, $Op3, $Op4)>; +} + +// +// Common but less generic patterns. +// + +class SVE_1_Op_AllActive_Pat +: Pat<(vtd (op vt1:$Op1)), + (inst (IMPLICIT_DEF), (ptrue 31), $Op1)>; + +class SVE_2_Op_AllActive_Pat +: Pat<(vtd (op vt1:$Op1, vt2:$Op2)), + (inst (ptrue 31), $Op1, $Op2)>; + +// +// Instruction specific patterns. +// + +class SVE_Cmp_Pat0 +: Pat<(pt (setcc vt:$Zn, vt:$Zm, cc)), + (inst (ptrue 31), $Zn, $Zm)>; + +class SVE_Cmp_Pat1 +: Pat<(pt (and pt:$Pg, (setcc vt:$Zn, vt:$Zm, cc))), + (inst $Pg, $Zn, $Zm)>; + +// +// Pseudo -> Instruction mappings +// +def getSVEPseudoMap : InstrMapping { + let FilterClass = "SVEPseudo2Instr"; + let RowFields = ["PseudoName"]; + let ColFields = ["IsInstr"]; + let KeyCol = ["0"]; + let ValueCols = [["1"]]; +} + +class SVEPseudo2Instr { + string PseudoName = name; + bit IsInstr = instr; +} + +def getSVERevInstr : InstrMapping { + let FilterClass = "SVEInstr2Rev"; + let RowFields = ["InstrName"]; + let ColFields = ["IsOrig"]; + let KeyCol = ["1"]; + let ValueCols = [["0"]]; +} + +def getSVEOrigInstr : InstrMapping { + let FilterClass = "SVEInstr2Rev"; + let RowFields = ["InstrName"]; + let ColFields = ["IsOrig"]; + let KeyCol = ["0"]; + let ValueCols = [["1"]]; +} + +class SVEInstr2Rev { + string InstrName = !if(nameIsOrig, name, revname); + bit IsOrig = nameIsOrig; +} + +// +// Pseudos for destructive operands +// +let hasNoSchedulingInfo = 1 in { + class UnpredTwoOpImmPseudo + : SVEPseudo2Instr, + Pseudo<(outs zprty:$Zd), (ins zprty:$Zs1, immty:$imm), []> { + let FalseLanes = flags; + } + + class PredTwoOpImmPseudo + : SVEPseudo2Instr, + Pseudo<(outs zprty:$Zd), (ins PPR3bAny:$Pg, zprty:$Zs1, immty:$imm), []> { + let FalseLanes = flags; + } + + class PredTwoOpPseudo + : SVEPseudo2Instr, + Pseudo<(outs zprty:$Zd), (ins PPR3bAny:$Pg, zprty:$Zs1, zprty:$Zs2), []> { + let FalseLanes = flags; + } + + class PredTwoOpConstrainedPseudo + : SVEPseudo2Instr, + Pseudo<(outs zprty:$Zd), (ins PPR3bAny:$Pg, zprty:$Zs1, zprty:$Zs2), []> { + let Constraints = "$Zd = $Zs1"; + let FalseLanes = flags; + } + + class PredThreeOpPseudo + : SVEPseudo2Instr, + Pseudo<(outs zprty:$Zd), (ins PPR3bAny:$Pg, zprty:$Zs1, zprty:$Zs2, zprty:$Zs3), []> { + let FalseLanes = flags; + } +} //===----------------------------------------------------------------------===// // SVE Predicate Misc Group @@ -319,14 +497,15 @@ let Inst{4-0} = 0b00000; let Defs = [NZCV]; + let isCompare = 1; } class sve_int_pfirst_next sz8_64, bits<5> opc, string asm, - PPRRegOp pprty> + PPRRegOp pprty, ValueType vt, SDPatternOperator op> : I<(outs pprty:$Pdn), (ins PPRAny:$Pg, pprty:$_Pdn), asm, "\t$Pdn, $Pg, $_Pdn", "", - []>, Sched<[]> { + [(set (vt pprty:$Pdn), (op (vt PPRAny:$Pg), (vt pprty:$_Pdn)))]>, Sched<[]> { bits<4> Pdn; bits<4> Pg; let Inst{31-24} = 0b00100101; @@ -343,15 +522,15 @@ let Defs = [NZCV]; } -multiclass sve_int_pfirst opc, string asm> { - def : sve_int_pfirst_next<0b01, opc, asm, PPR8>; +multiclass sve_int_pfirst opc, string asm, SDPatternOperator op> { + def : sve_int_pfirst_next<0b01, opc, asm, PPR8, nxv16i1, op>; } -multiclass sve_int_pnext opc, string asm> { - def _B : sve_int_pfirst_next<0b00, opc, asm, PPR8>; - def _H : sve_int_pfirst_next<0b01, opc, asm, PPR16>; - def _S : sve_int_pfirst_next<0b10, opc, asm, PPR32>; - def _D : sve_int_pfirst_next<0b11, opc, asm, PPR64>; +multiclass sve_int_pnext opc, string asm, SDPatternOperator op> { + def _B : sve_int_pfirst_next<0b00, opc, asm, PPR8, nxv16i1, op>; + def _H : sve_int_pfirst_next<0b01, opc, asm, PPR16, nxv8i1, op>; + def _S : sve_int_pfirst_next<0b10, opc, asm, PPR32, nxv4i1, op>; + def _D : sve_int_pfirst_next<0b11, opc, asm, PPR64, nxv2i1, op>; } //===----------------------------------------------------------------------===// @@ -382,25 +561,58 @@ let Constraints = "$Rdn = $_Rdn"; } -multiclass sve_int_count_r_s32 opc, string asm> { +multiclass sve_int_count_r_s32 opc, string asm, + SDPatternOperator op> { def _B : sve_int_count_r<0b00, opc, asm, GPR64z, PPR8, GPR64as32>; def _H : sve_int_count_r<0b01, opc, asm, GPR64z, PPR16, GPR64as32>; def _S : sve_int_count_r<0b10, opc, asm, GPR64z, PPR32, GPR64as32>; def _D : sve_int_count_r<0b11, opc, asm, GPR64z, PPR64, GPR64as32>; + + // NOTE: Register allocation doesn't like tied operands of differing register + // class, hence the extra INSERT_SUBREG complication. + + def : Pat<(i32 (op GPR32:$Rn, (nxv16i1 PPRAny:$Pg))), + (EXTRACT_SUBREG (!cast(NAME # _B) PPRAny:$Pg, (INSERT_SUBREG (IMPLICIT_DEF), $Rn, sub_32)), sub_32)>; + def : Pat<(i32 (op GPR32:$Rn, (nxv8i1 PPRAny:$Pg))), + (EXTRACT_SUBREG (!cast(NAME # _H) PPRAny:$Pg, (INSERT_SUBREG (IMPLICIT_DEF), $Rn, sub_32)), sub_32)>; + def : Pat<(i32 (op GPR32:$Rn, (nxv4i1 PPRAny:$Pg))), + (EXTRACT_SUBREG (!cast(NAME # _S) PPRAny:$Pg, (INSERT_SUBREG (IMPLICIT_DEF), $Rn, sub_32)), sub_32)>; + def : Pat<(i32 (op GPR32:$Rn, (nxv2i1 PPRAny:$Pg))), + (EXTRACT_SUBREG (!cast(NAME # _D) PPRAny:$Pg, (INSERT_SUBREG (IMPLICIT_DEF), $Rn, sub_32)), sub_32)>; } -multiclass sve_int_count_r_u32 opc, string asm> { +multiclass sve_int_count_r_u32 opc, string asm, + SDPatternOperator op> { def _B : sve_int_count_r<0b00, opc, asm, GPR32z, PPR8, GPR32z>; def _H : sve_int_count_r<0b01, opc, asm, GPR32z, PPR16, GPR32z>; def _S : sve_int_count_r<0b10, opc, asm, GPR32z, PPR32, GPR32z>; def _D : sve_int_count_r<0b11, opc, asm, GPR32z, PPR64, GPR32z>; + + def : Pat<(i32 (op GPR32:$Rn, (nxv16i1 PPRAny:$Pg))), + (!cast(NAME # _B) PPRAny:$Pg, $Rn)>; + def : Pat<(i32 (op GPR32:$Rn, (nxv8i1 PPRAny:$Pg))), + (!cast(NAME # _H) PPRAny:$Pg, $Rn)>; + def : Pat<(i32 (op GPR32:$Rn, (nxv4i1 PPRAny:$Pg))), + (!cast(NAME # _S) PPRAny:$Pg, $Rn)>; + def : Pat<(i32 (op GPR32:$Rn, (nxv2i1 PPRAny:$Pg))), + (!cast(NAME # _D) PPRAny:$Pg, $Rn)>; } -multiclass sve_int_count_r_x64 opc, string asm> { +multiclass sve_int_count_r_x64 opc, string asm, + SDPatternOperator op = null_frag> { def _B : sve_int_count_r<0b00, opc, asm, GPR64z, PPR8, GPR64z>; def _H : sve_int_count_r<0b01, opc, asm, GPR64z, PPR16, GPR64z>; def _S : sve_int_count_r<0b10, opc, asm, GPR64z, PPR32, GPR64z>; def _D : sve_int_count_r<0b11, opc, asm, GPR64z, PPR64, GPR64z>; + + def : Pat<(i64 (op GPR64:$Rn, (nxv16i1 PPRAny:$Pg))), + (!cast(NAME # _B) PPRAny:$Pg, $Rn)>; + def : Pat<(i64 (op GPR64:$Rn, (nxv8i1 PPRAny:$Pg))), + (!cast(NAME # _H) PPRAny:$Pg, $Rn)>; + def : Pat<(i64 (op GPR64:$Rn, (nxv4i1 PPRAny:$Pg))), + (!cast(NAME # _S) PPRAny:$Pg, $Rn)>; + def : Pat<(i64 (op GPR64:$Rn, (nxv2i1 PPRAny:$Pg))), + (!cast(NAME # _D) PPRAny:$Pg, $Rn)>; } class sve_int_count_v sz8_64, bits<5> opc, string asm, @@ -421,14 +633,18 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; - let ElementSize = ElementSizeNone; + let DestructiveInstType = DestructiveOther; } -multiclass sve_int_count_v opc, string asm> { +multiclass sve_int_count_v opc, string asm, + SDPatternOperator op = null_frag> { def _H : sve_int_count_v<0b01, opc, asm, ZPR16>; def _S : sve_int_count_v<0b10, opc, asm, ZPR32>; def _D : sve_int_count_v<0b11, opc, asm, ZPR64>; + + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; } class sve_int_pcount_pred sz8_64, bits<4> opc, string asm, @@ -451,11 +667,23 @@ let Inst{4-0} = Rd; } -multiclass sve_int_pcount_pred opc, string asm> { +multiclass sve_int_pcount_pred opc, string asm, + SDPatternOperator int_op, + SDPatternOperator ir_op> { def _B : sve_int_pcount_pred<0b00, opc, asm, PPR8>; def _H : sve_int_pcount_pred<0b01, opc, asm, PPR16>; def _S : sve_int_pcount_pred<0b10, opc, asm, PPR32>; def _D : sve_int_pcount_pred<0b11, opc, asm, PPR64>; + + def : SVE_2_Op_Pat(NAME # _B)>; + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; + + def : Pat<(i64 (ir_op (nxv16i1 PPR:$Pn))), (!cast(NAME # _B) (PTRUE_B 31), PPR:$Pn)>; + def : Pat<(i64 (ir_op (nxv8i1 PPR:$Pn))), (!cast(NAME # _H) (PTRUE_H 31), PPR:$Pn)>; + def : Pat<(i64 (ir_op (nxv4i1 PPR:$Pn))), (!cast(NAME # _S) (PTRUE_S 31), PPR:$Pn)>; + def : Pat<(i64 (ir_op (nxv2i1 PPR:$Pn))), (!cast(NAME # _D) (PTRUE_D 31), PPR:$Pn)>; } //===----------------------------------------------------------------------===// @@ -478,15 +706,26 @@ let Inst{10} = opc{0}; let Inst{9-5} = pattern; let Inst{4-0} = Rd; + + let isReMaterializable = 1; } -multiclass sve_int_count opc, string asm> { +multiclass sve_int_count opc, string asm, SDPatternOperator op> { def NAME : sve_int_count; def : InstAlias(NAME) GPR64:$Rd, sve_pred_enum:$pattern, 1), 1>; def : InstAlias(NAME) GPR64:$Rd, 0b11111, 1), 2>; + + def : Pat<(i64 (mul (op sve_pred_enum:$pattern), (sve_cnt_mul_imm i32:$imm))), + (!cast(NAME) sve_pred_enum:$pattern, sve_incdec_imm:$imm)>; + + def : Pat<(i64 (shl (op sve_pred_enum:$pattern), (i64 (sve_cnt_shl_imm i32:$imm)))), + (!cast(NAME) sve_pred_enum:$pattern, sve_incdec_imm:$imm)>; + + def : Pat<(i64 (op sve_pred_enum:$pattern)), + (!cast(NAME) sve_pred_enum:$pattern, 1)>; } class sve_int_countvlv opc, string asm, ZPRRegOp zprty> @@ -508,17 +747,22 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; - let ElementSize = ElementSizeNone; + let DestructiveInstType = DestructiveOther; } -multiclass sve_int_countvlv opc, string asm, ZPRRegOp zprty> { +multiclass sve_int_countvlv opc, string asm, ZPRRegOp zprty, + SDPatternOperator op = null_frag, + ValueType vt = OtherVT> { def NAME : sve_int_countvlv; def : InstAlias(NAME) zprty:$Zdn, sve_pred_enum:$pattern, 1), 1>; def : InstAlias(NAME) zprty:$Zdn, 0b11111, 1), 2>; + + def : Pat<(vt (op (vt zprty:$Zn), (sve_pred_enum:$pattern), (sve_incdec_imm:$imm4))), + (!cast(NAME) $Zn, sve_pred_enum:$pattern, sve_incdec_imm:$imm4)>; + } class sve_int_pred_pattern_a opc, string asm> @@ -541,13 +785,39 @@ let Constraints = "$Rdn = $_Rdn"; } -multiclass sve_int_pred_pattern_a opc, string asm> { +multiclass sve_int_pred_pattern_a opc, string asm, + SDPatternOperator op, + SDPatternOperator opcnt> { def NAME : sve_int_pred_pattern_a; def : InstAlias(NAME) GPR64:$Rdn, sve_pred_enum:$pattern, 1), 1>; def : InstAlias(NAME) GPR64:$Rdn, 0b11111, 1), 2>; + + def : Pat<(i64 (op GPR64:$Rdn, (opcnt sve_pred_enum:$pattern))), + (!cast(NAME) GPR64:$Rdn, sve_pred_enum:$pattern, 1)>; + + def : Pat<(i64 (op GPR64:$Rdn, (mul (opcnt sve_pred_enum:$pattern), (sve_cnt_mul_imm i32:$imm)))), + (!cast(NAME) GPR64:$Rdn, sve_pred_enum:$pattern, $imm)>; + + def : Pat<(i64 (op GPR64:$Rdn, (shl (opcnt sve_pred_enum:$pattern), (i64 (sve_cnt_shl_imm i32:$imm))))), + (!cast(NAME) GPR64:$Rdn, sve_pred_enum:$pattern, $imm)>; + + def : Pat<(i32 (op GPR32:$Rdn, (i32 (trunc (opcnt (sve_pred_enum:$pattern)))))), + (i32 (EXTRACT_SUBREG (!cast(NAME) (INSERT_SUBREG (i64 (IMPLICIT_DEF)), + GPR32:$Rdn, sub_32), sve_pred_enum:$pattern, 1), + sub_32))>; + + def : Pat<(i32 (op GPR32:$Rdn, (mul (i32 (trunc (opcnt (sve_pred_enum:$pattern)))), (sve_cnt_mul_imm i32:$imm)))), + (i32 (EXTRACT_SUBREG (!cast(NAME) (INSERT_SUBREG (i64 (IMPLICIT_DEF)), + GPR32:$Rdn, sub_32), sve_pred_enum:$pattern, $imm), + sub_32))>; + + def : Pat<(i32 (op GPR32:$Rdn, (shl (i32 (trunc (opcnt (sve_pred_enum:$pattern)))), (i64 (sve_cnt_shl_imm i32:$imm))))), + (i32 (EXTRACT_SUBREG (!cast(NAME) (INSERT_SUBREG (i64 (IMPLICIT_DEF)), + GPR32:$Rdn, sub_32), sve_pred_enum:$pattern, $imm), + sub_32))>; } class sve_int_pred_pattern_b opc, string asm, RegisterOperand dt, @@ -577,31 +847,46 @@ let Constraints = "$Rdn = $_Rdn"; } -multiclass sve_int_pred_pattern_b_s32 opc, string asm> { +multiclass sve_int_pred_pattern_b_s32 opc, string asm, + SDPatternOperator op> { def NAME : sve_int_pred_pattern_b; def : InstAlias(NAME) GPR64z:$Rd, GPR64as32:$Rn, sve_pred_enum:$pattern, 1), 1>; def : InstAlias(NAME) GPR64z:$Rd, GPR64as32:$Rn, 0b11111, 1), 2>; + + // NOTE: Register allocation doesn't like tied operands of differing register + // class, hence the extra INSERT_SUBREG complication. + + def : Pat<(i32 (op GPR32:$Rn, (sve_pred_enum:$pattern), (sve_incdec_imm:$imm4))), + (EXTRACT_SUBREG (!cast(NAME) (INSERT_SUBREG (IMPLICIT_DEF), $Rn, sub_32), sve_pred_enum:$pattern, sve_incdec_imm:$imm4), sub_32)>; } -multiclass sve_int_pred_pattern_b_u32 opc, string asm> { +multiclass sve_int_pred_pattern_b_u32 opc, string asm, + SDPatternOperator op> { def NAME : sve_int_pred_pattern_b; def : InstAlias(NAME) GPR32z:$Rdn, sve_pred_enum:$pattern, 1), 1>; def : InstAlias(NAME) GPR32z:$Rdn, 0b11111, 1), 2>; + + def : Pat<(i32 (op GPR32:$Rn, (sve_pred_enum:$pattern), (sve_incdec_imm:$imm4))), + (!cast(NAME) $Rn, sve_pred_enum:$pattern, sve_incdec_imm:$imm4)>; } -multiclass sve_int_pred_pattern_b_x64 opc, string asm> { +multiclass sve_int_pred_pattern_b_x64 opc, string asm, + SDPatternOperator op> { def NAME : sve_int_pred_pattern_b; def : InstAlias(NAME) GPR64z:$Rdn, sve_pred_enum:$pattern, 1), 1>; def : InstAlias(NAME) GPR64z:$Rdn, 0b11111, 1), 2>; + + def : Pat<(i64 (op GPR64:$Rn, (sve_pred_enum:$pattern), (sve_incdec_imm:$imm4))), + (!cast(NAME) $Rn, sve_pred_enum:$pattern, sve_incdec_imm:$imm4)>; } @@ -701,8 +986,9 @@ (!cast(NAME # _Q) ZPR128:$Zd, FPR128asZPR:$Qn, 0), 2>; } -class sve_int_perm_tbl sz8_64, string asm, ZPRRegOp zprty, - RegisterOperand VecList> +class sve_int_perm_tbl sz8_64, bits<2> opc, string asm, + ZPRRegOp zprty, RegisterOperand VecList, + ValueType vt, SDPatternOperator op> : I<(outs zprty:$Zd), (ins VecList:$Zn, zprty:$Zm), asm, "\t$Zd, $Zn, $Zm", "", @@ -714,16 +1000,18 @@ let Inst{23-22} = sz8_64; let Inst{21} = 0b1; let Inst{20-16} = Zm; - let Inst{15-10} = 0b001100; + let Inst{15-13} = 0b001; + let Inst{12-11} = opc; + let Inst{10} = 0b0; let Inst{9-5} = Zn; let Inst{4-0} = Zd; } -multiclass sve_int_perm_tbl { - def _B : sve_int_perm_tbl<0b00, asm, ZPR8, Z_b>; - def _H : sve_int_perm_tbl<0b01, asm, ZPR16, Z_h>; - def _S : sve_int_perm_tbl<0b10, asm, ZPR32, Z_s>; - def _D : sve_int_perm_tbl<0b11, asm, ZPR64, Z_d>; +multiclass sve_int_perm_tbl { + def _B : sve_int_perm_tbl<0b00, 0b10, asm, ZPR8, Z_b, nxv16i8, op>; + def _H : sve_int_perm_tbl<0b01, 0b10, asm, ZPR16, Z_h, nxv8i16, op>; + def _S : sve_int_perm_tbl<0b10, 0b10, asm, ZPR32, Z_s, nxv4i32, op>; + def _D : sve_int_perm_tbl<0b11, 0b10, asm, ZPR64, Z_d, nxv2i64, op>; def : InstAlias(NAME # _B) ZPR8:$Zd, ZPR8:$Zn, ZPR8:$Zm), 0>; @@ -733,6 +1021,56 @@ (!cast(NAME # _S) ZPR32:$Zd, ZPR32:$Zn, ZPR32:$Zm), 0>; def : InstAlias(NAME # _D) ZPR64:$Zd, ZPR64:$Zn, ZPR64:$Zm), 0>; + + def : SVE_2_Op_Pat(NAME # _B)>; + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; +} + +multiclass sve2_int_perm_tbl { + def _B : sve_int_perm_tbl<0b00, 0b01, asm, ZPR8, ZZ_b, nxv16i8, null_frag>; + def _H : sve_int_perm_tbl<0b01, 0b01, asm, ZPR16, ZZ_h, nxv8i16, null_frag>; + def _S : sve_int_perm_tbl<0b10, 0b01, asm, ZPR32, ZZ_s, nxv4i32, null_frag>; + def _D : sve_int_perm_tbl<0b11, 0b01, asm, ZPR64, ZZ_d, nxv2i64, null_frag>; +} + +class sve2_int_perm_tbx sz8_64, string asm, ZPRRegOp zprty> +: I<(outs zprty:$Zd), (ins zprty:$_Zd, zprty:$Zn, zprty:$Zm), + asm, "\t$Zd, $Zn, $Zm", + "", + []>, Sched<[]> { + bits<5> Zd; + bits<5> Zm; + bits<5> Zn; + let Inst{31-24} = 0b00000101; + let Inst{23-22} = sz8_64; + let Inst{21} = 0b1; + let Inst{20-16} = Zm; + let Inst{15-10} = 0b001011; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; + + let Constraints = "$Zd = $_Zd"; +} + +multiclass sve2_int_perm_tbx { + def _B : sve2_int_perm_tbx<0b00, asm, ZPR8>; + def _H : sve2_int_perm_tbx<0b01, asm, ZPR16>; + def _S : sve2_int_perm_tbx<0b10, asm, ZPR32>; + def _D : sve2_int_perm_tbx<0b11, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } class sve_int_perm_reverse_z sz8_64, string asm, ZPRRegOp zprty> @@ -749,18 +1087,30 @@ let Inst{4-0} = Zd; } -multiclass sve_int_perm_reverse_z { +multiclass sve_int_perm_reverse_z { def _B : sve_int_perm_reverse_z<0b00, asm, ZPR8>; def _H : sve_int_perm_reverse_z<0b01, asm, ZPR16>; def _S : sve_int_perm_reverse_z<0b10, asm, ZPR32>; def _D : sve_int_perm_reverse_z<0b11, asm, ZPR64>; + + def : SVE_1_Op_Pat(NAME # _B)>; + def : SVE_1_Op_Pat(NAME # _H)>; + def : SVE_1_Op_Pat(NAME # _S)>; + def : SVE_1_Op_Pat(NAME # _D)>; + + def : SVE_1_Op_Pat(NAME # _H)>; + def : SVE_1_Op_Pat(NAME # _S)>; + def : SVE_1_Op_Pat(NAME # _D)>; + + def : SVE_1_Op_Pat(NAME # _D)>; } -class sve_int_perm_reverse_p sz8_64, string asm, PPRRegOp pprty> +class sve_int_perm_reverse_p sz8_64, string asm, PPRRegOp pprty, + ValueType vt, SDPatternOperator op> : I<(outs pprty:$Pd), (ins pprty:$Pn), asm, "\t$Pd, $Pn", "", - []>, Sched<[]> { + [(set pprty:$Pd, (vt (op (vt pprty:$Pn))))]>, Sched<[]> { bits<4> Pd; bits<4> Pn; let Inst{31-24} = 0b00000101; @@ -771,18 +1121,20 @@ let Inst{3-0} = Pd; } -multiclass sve_int_perm_reverse_p { - def _B : sve_int_perm_reverse_p<0b00, asm, PPR8>; - def _H : sve_int_perm_reverse_p<0b01, asm, PPR16>; - def _S : sve_int_perm_reverse_p<0b10, asm, PPR32>; - def _D : sve_int_perm_reverse_p<0b11, asm, PPR64>; +multiclass sve_int_perm_reverse_p { + def _B : sve_int_perm_reverse_p<0b00, asm, PPR8 , nxv16i1, op>; + def _H : sve_int_perm_reverse_p<0b01, asm, PPR16, nxv8i1, op>; + def _S : sve_int_perm_reverse_p<0b10, asm, PPR32, nxv4i1, op>; + def _D : sve_int_perm_reverse_p<0b11, asm, PPR64, nxv2i1, op>; } class sve_int_perm_unpk sz16_64, bits<2> opc, string asm, - ZPRRegOp zprty1, ZPRRegOp zprty2> + ZPRRegOp zprty1, ZPRRegOp zprty2, + ValueType out_vt, ValueType in_vt, SDPatternOperator op> : I<(outs zprty1:$Zd), (ins zprty2:$Zn), asm, "\t$Zd, $Zn", - "", []>, Sched<[]> { + "", + [(set zprty1:$Zd, (out_vt (op (in_vt zprty2:$Zn))))]>, Sched<[]> { bits<5> Zd; bits<5> Zn; let Inst{31-24} = 0b00000101; @@ -794,10 +1146,10 @@ let Inst{4-0} = Zd; } -multiclass sve_int_perm_unpk opc, string asm> { - def _H : sve_int_perm_unpk<0b01, opc, asm, ZPR16, ZPR8>; - def _S : sve_int_perm_unpk<0b10, opc, asm, ZPR32, ZPR16>; - def _D : sve_int_perm_unpk<0b11, opc, asm, ZPR64, ZPR32>; +multiclass sve_int_perm_unpk opc, string asm, SDPatternOperator op> { + def _H : sve_int_perm_unpk<0b01, opc, asm, ZPR16, ZPR8, nxv8i16, nxv16i8, op>; + def _S : sve_int_perm_unpk<0b10, opc, asm, ZPR32, ZPR16, nxv4i32, nxv8i16, op>; + def _D : sve_int_perm_unpk<0b11, opc, asm, ZPR64, ZPR32, nxv2i64, nxv4i32, op>; } class sve_int_perm_insrs sz8_64, string asm, ZPRRegOp zprty, @@ -815,15 +1167,19 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; - let ElementSize = ElementSizeNone; + let DestructiveInstType = DestructiveOther; } -multiclass sve_int_perm_insrs { +multiclass sve_int_perm_insrs { def _B : sve_int_perm_insrs<0b00, asm, ZPR8, GPR32>; def _H : sve_int_perm_insrs<0b01, asm, ZPR16, GPR32>; def _S : sve_int_perm_insrs<0b10, asm, ZPR32, GPR32>; def _D : sve_int_perm_insrs<0b11, asm, ZPR64, GPR64>; + + def : SVE_2_Op_Pat(NAME # _B)>; + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; } class sve_int_perm_insrv sz8_64, string asm, ZPRRegOp zprty, @@ -841,25 +1197,29 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; - let ElementSize = ElementSizeNone; + let DestructiveInstType = DestructiveOther; } -multiclass sve_int_perm_insrv { +multiclass sve_int_perm_insrv { def _B : sve_int_perm_insrv<0b00, asm, ZPR8, FPR8>; def _H : sve_int_perm_insrv<0b01, asm, ZPR16, FPR16>; def _S : sve_int_perm_insrv<0b10, asm, ZPR32, FPR32>; def _D : sve_int_perm_insrv<0b11, asm, ZPR64, FPR64>; + + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; } //===----------------------------------------------------------------------===// // SVE Permute - Extract Group //===----------------------------------------------------------------------===// -class sve_int_perm_extract_i +class sve_int_perm_extract_i : I<(outs ZPR8:$Zdn), (ins ZPR8:$_Zdn, ZPR8:$Zm, imm0_255:$imm8), asm, "\t$Zdn, $_Zdn, $Zm, $imm8", - "", []>, Sched<[]> { + "", + [(set ZPR8:$Zdn, (nxv16i8 (op (nxv16i8 ZPR8:$_Zdn), (nxv16i8 ZPR8:$Zm), (imm0_255:$imm8))))]>, Sched<[]> { bits<5> Zdn; bits<5> Zm; bits<8> imm8; @@ -871,8 +1231,22 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; - let ElementSize = ElementSizeNone; + let DestructiveInstType = DestructiveOther; +} + +class sve2_int_perm_extract_i_cons +: I<(outs ZPR8:$Zd), (ins ZZ_b:$Zn, imm0_255:$imm8), + asm, "\t$Zd, $Zn, $imm8", + "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<8> imm8; + let Inst{31-21} = 0b00000101011; + let Inst{20-16} = imm8{7-3}; + let Inst{15-13} = 0b000; + let Inst{12-10} = imm8{2-0}; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; } //===----------------------------------------------------------------------===// @@ -898,12 +1272,22 @@ let Inst{4-0} = Zd; } -multiclass sve_int_sel_vvv { +multiclass sve_int_sel_vvv { def _B : sve_int_sel_vvv<0b00, asm, ZPR8>; def _H : sve_int_sel_vvv<0b01, asm, ZPR16>; def _S : sve_int_sel_vvv<0b10, asm, ZPR32>; def _D : sve_int_sel_vvv<0b11, asm, ZPR64>; + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; + def : SVE_3_Op_Pat(NAME # _D)>; + def : InstAlias<"mov $Zd, $Pg/m, $Zn", (!cast(NAME # _B) ZPR8:$Zd, PPRAny:$Pg, ZPR8:$Zn, ZPR8:$Zd), 1>; def : InstAlias<"mov $Zd, $Pg/m, $Zn", @@ -947,15 +1331,27 @@ let Defs = !if(!eq (opc{2}, 1), [NZCV], []); } +multiclass sve_int_pred_log opc, string asm, + SDPatternOperator op = null_frag> { + def NAME : sve_int_pred_log; + + def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; +} + //===----------------------------------------------------------------------===// // SVE Logical Mask Immediate Group //===----------------------------------------------------------------------===// -class sve_int_log_imm opc, string asm> +class sve_int_log_imm opc, string asm, SDPatternOperator op> : I<(outs ZPR64:$Zdn), (ins ZPR64:$_Zdn, logical_imm64:$imms13), asm, "\t$Zdn, $_Zdn, $imms13", - "", []>, Sched<[]> { + "", + []>,//[(set ZPR64:$Zdn, (nxv2i64 (op (nxv2i64 ZPR64:$_Zdn), (logical_imm64:$imms13))))]>, + Sched<[]> { bits<5> Zdn; bits<13> imms13; let Inst{31-24} = 0b00000101; @@ -965,13 +1361,13 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; + let DestructiveInstType = DestructiveOther; let DecoderMethod = "DecodeSVELogicalImmInstruction"; - let DestructiveInstType = Destructive; - let ElementSize = ElementSizeNone; } -multiclass sve_int_log_imm opc, string asm, string alias> { - def NAME : sve_int_log_imm; +multiclass sve_int_log_imm opc, string asm, SDPatternOperator op, + string alias> { + def NAME : sve_int_log_imm; def : InstAlias(NAME) ZPR8:$Zdn, sve_logical_imm8:$imm), 4>; @@ -1029,10 +1425,11 @@ //===----------------------------------------------------------------------===// class sve_int_bin_cons_arit_0 sz8_64, bits<3> opc, string asm, - ZPRRegOp zprty> + ZPRRegOp zprty, ValueType vt, SDPatternOperator op> : I<(outs zprty:$Zd), (ins zprty:$Zn, zprty:$Zm), asm, "\t$Zd, $Zn, $Zm", - "", []>, Sched<[]> { + "", + [(set (vt zprty:$Zd), (op (vt zprty:$Zn), (vt zprty:$Zm)))]>, Sched<[]> { bits<5> Zd; bits<5> Zm; bits<5> Zn; @@ -1046,11 +1443,12 @@ let Inst{4-0} = Zd; } -multiclass sve_int_bin_cons_arit_0 opc, string asm> { - def _B : sve_int_bin_cons_arit_0<0b00, opc, asm, ZPR8>; - def _H : sve_int_bin_cons_arit_0<0b01, opc, asm, ZPR16>; - def _S : sve_int_bin_cons_arit_0<0b10, opc, asm, ZPR32>; - def _D : sve_int_bin_cons_arit_0<0b11, opc, asm, ZPR64>; +multiclass sve_int_bin_cons_arit_0 opc, string asm, + SDPatternOperator op> { + def _B : sve_int_bin_cons_arit_0<0b00, opc, asm, ZPR8, nxv16i8, op>; + def _H : sve_int_bin_cons_arit_0<0b01, opc, asm, ZPR16, nxv8i16, op>; + def _S : sve_int_bin_cons_arit_0<0b10, opc, asm, ZPR32, nxv4i32, op>; + def _D : sve_int_bin_cons_arit_0<0b11, opc, asm, ZPR64, nxv2i64, op>; } //===----------------------------------------------------------------------===// @@ -1078,14 +1476,26 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; let ElementSize = zprty.ElementSize; } -multiclass sve_fp_2op_i_p_zds opc, string asm, Operand imm_ty> { - def _H : sve_fp_2op_i_p_zds<0b01, opc, asm, ZPR16, imm_ty>; - def _S : sve_fp_2op_i_p_zds<0b10, opc, asm, ZPR32, imm_ty>; - def _D : sve_fp_2op_i_p_zds<0b11, opc, asm, ZPR64, imm_ty>; +multiclass sve_fp_2op_i_p_zds opc, string asm, string Ps, + Operand imm_ty> { + let DestructiveInstType = DestructiveBinaryImm in { + def _H : SVEPseudo2Instr, sve_fp_2op_i_p_zds<0b01, opc, asm, ZPR16, imm_ty>; + def _S : SVEPseudo2Instr, sve_fp_2op_i_p_zds<0b10, opc, asm, ZPR32, imm_ty>; + def _D : SVEPseudo2Instr, sve_fp_2op_i_p_zds<0b11, opc, asm, ZPR64, imm_ty>; + } +} + +multiclass sve_fp_2op_i_p_zds_zx { + def _UNDEF_H : PredTwoOpImmPseudo; + def _UNDEF_S : PredTwoOpImmPseudo; + def _UNDEF_D : PredTwoOpImmPseudo; + + def _ZERO_H : PredTwoOpImmPseudo; + def _ZERO_S : PredTwoOpImmPseudo; + def _ZERO_D : PredTwoOpImmPseudo; } class sve_fp_2op_p_zds sz, bits<4> opc, string asm, @@ -1107,18 +1517,67 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; let ElementSize = zprty.ElementSize; } -multiclass sve_fp_2op_p_zds opc, string asm> { - def _H : sve_fp_2op_p_zds<0b01, opc, asm, ZPR16>; - def _S : sve_fp_2op_p_zds<0b10, opc, asm, ZPR32>; - def _D : sve_fp_2op_p_zds<0b11, opc, asm, ZPR64>; +multiclass sve_fp_2op_p_zds opc, string asm, string Ps, + SDPatternOperator op, DestructiveInstTypeEnum flags, + string revname="", bit isOrig=0> { + let DestructiveInstType = flags in { + def _H : sve_fp_2op_p_zds<0b01, opc, asm, ZPR16>, + SVEPseudo2Instr, SVEInstr2Rev; + def _S : sve_fp_2op_p_zds<0b10, opc, asm, ZPR32>, + SVEPseudo2Instr, SVEInstr2Rev; + def _D : sve_fp_2op_p_zds<0b11, opc, asm, ZPR64>, + SVEPseudo2Instr, SVEInstr2Rev; + } + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } +multiclass sve_fp_2op_p_zds_fscale opc, string asm, + string Ps, SDPatternOperator op> { + let DestructiveInstType = DestructiveBinary in { + def _H : SVEPseudo2Instr, sve_fp_2op_p_zds<0b01, opc, asm, ZPR16>; + def _S : SVEPseudo2Instr, sve_fp_2op_p_zds<0b10, opc, asm, ZPR32>; + def _D : SVEPseudo2Instr, sve_fp_2op_p_zds<0b11, opc, asm, ZPR64>; + } + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + + +multiclass sve_fp_2op_p_zds_zx { + def _UNDEF_H : PredTwoOpPseudo; + def _UNDEF_S : PredTwoOpPseudo; + def _UNDEF_D : PredTwoOpPseudo; + + def _ZERO_H : PredTwoOpPseudo; + def _ZERO_S : PredTwoOpPseudo; + def _ZERO_D : PredTwoOpPseudo; + + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_H)>; + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_S)>; + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_D)>; +} + +multiclass sve_fp_2op_p_zds_fscale_zx { + def _ZERO_H : PredTwoOpConstrainedPseudo; + def _ZERO_S : PredTwoOpConstrainedPseudo; + def _ZERO_D : PredTwoOpConstrainedPseudo; + + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_H)>; + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_S)>; + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_D)>; +} + + class sve_fp_ftmad sz, string asm, ZPRRegOp zprty> -: I<(outs zprty:$Zdn), (ins zprty:$_Zdn, zprty:$Zm, imm0_7:$imm3), +: I<(outs zprty:$Zdn), (ins zprty:$_Zdn, zprty:$Zm, imm32_0_7:$imm3), asm, "\t$Zdn, $_Zdn, $Zm, $imm3", "", []>, Sched<[]> { @@ -1134,26 +1593,34 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; + let DestructiveInstType = DestructiveOther; let ElementSize = ElementSizeNone; } -multiclass sve_fp_ftmad { +multiclass sve_fp_ftmad { def _H : sve_fp_ftmad<0b01, asm, ZPR16>; def _S : sve_fp_ftmad<0b10, asm, ZPR32>; def _D : sve_fp_ftmad<0b11, asm, ZPR64>; -} + def : Pat<(nxv8f16 (op (nxv8f16 ZPR16:$Zn), (nxv8f16 ZPR16:$Zm), (i32 imm32_0_7:$imm))), + (!cast(NAME # _H) ZPR16:$Zn, ZPR16:$Zm, imm32_0_7:$imm)>; + def : Pat<(nxv4f32 (op (nxv4f32 ZPR32:$Zn), (nxv4f32 ZPR32:$Zm), (i32 imm32_0_7:$imm))), + (!cast(NAME # _S) ZPR32:$Zn, ZPR32:$Zm, imm32_0_7:$imm)>; + def : Pat<(nxv2f64 (op (nxv2f64 ZPR64:$Zn), (nxv2f64 ZPR64:$Zm), (i32 imm32_0_7:$imm))), + (!cast(NAME # _D) ZPR64:$Zn, ZPR64:$Zm, imm32_0_7:$imm)>; +} //===----------------------------------------------------------------------===// // SVE Floating Point Arithmetic - Unpredicated Group //===----------------------------------------------------------------------===// class sve_fp_3op_u_zd sz, bits<3> opc, string asm, - ZPRRegOp zprty> + ZPRRegOp zprty, + ValueType vt, ValueType vt2, SDPatternOperator op> : I<(outs zprty:$Zd), (ins zprty:$Zn, zprty:$Zm), asm, "\t$Zd, $Zn, $Zm", - "", []>, Sched<[]> { + "", + [(set (vt zprty:$Zd), (op (vt zprty:$Zn), (vt2 zprty:$Zm)))]>, Sched<[]> { bits<5> Zd; bits<5> Zm; bits<5> Zn; @@ -1167,10 +1634,17 @@ let Inst{4-0} = Zd; } -multiclass sve_fp_3op_u_zd opc, string asm> { - def _H : sve_fp_3op_u_zd<0b01, opc, asm, ZPR16>; - def _S : sve_fp_3op_u_zd<0b10, opc, asm, ZPR32>; - def _D : sve_fp_3op_u_zd<0b11, opc, asm, ZPR64>; +multiclass sve_fp_3op_u_zd opc, string asm, SDPatternOperator op> { + def _H : sve_fp_3op_u_zd<0b01, opc, asm, ZPR16, nxv8f16, nxv8f16, op>; + def _S : sve_fp_3op_u_zd<0b10, opc, asm, ZPR32, nxv4f32, nxv4f32, op>; + def _D : sve_fp_3op_u_zd<0b11, opc, asm, ZPR64, nxv2f64, nxv2f64, op>; +} + +multiclass sve_fp_3op_u_zd_ftsmul opc, string asm, + SDPatternOperator op> { + def _H : sve_fp_3op_u_zd<0b01, opc, asm, ZPR16, nxv8f16, nxv8i16, op>; + def _S : sve_fp_3op_u_zd<0b10, opc, asm, ZPR32, nxv4f32, nxv4i32, op>; + def _D : sve_fp_3op_u_zd<0b11, opc, asm, ZPR64, nxv2f64, nxv2i64, op>; } //===----------------------------------------------------------------------===// @@ -1197,14 +1671,24 @@ let Inst{4-0} = Zda; let Constraints = "$Zda = $_Zda"; - let DestructiveInstType = Destructive; let ElementSize = zprty.ElementSize; } -multiclass sve_fp_3op_p_zds_a opc, string asm> { - def _H : sve_fp_3op_p_zds_a<0b01, opc, asm, ZPR16>; - def _S : sve_fp_3op_p_zds_a<0b10, opc, asm, ZPR32>; - def _D : sve_fp_3op_p_zds_a<0b11, opc, asm, ZPR64>; +multiclass sve_fp_3op_p_zds_a opc, string asm, string Ps, + SDPatternOperator op, string revname="", + bit isOrig=0> { + let DestructiveInstType = DestructiveTernaryCommWithRev in { + def _H : sve_fp_3op_p_zds_a<0b01, opc, asm, ZPR16>, + SVEPseudo2Instr, SVEInstr2Rev; + def _S : sve_fp_3op_p_zds_a<0b10, opc, asm, ZPR32>, + SVEPseudo2Instr, SVEInstr2Rev; + def _D : sve_fp_3op_p_zds_a<0b11, opc, asm, ZPR64>, + SVEPseudo2Instr, SVEInstr2Rev; + } + + def : SVE_4_Op_Pat(NAME # _H)>; + def : SVE_4_Op_Pat(NAME # _S)>; + def : SVE_4_Op_Pat(NAME # _D)>; } class sve_fp_3op_p_zds_b sz, bits<2> opc, string asm, @@ -1228,16 +1712,47 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; + let DestructiveInstType = DestructiveOther; let ElementSize = zprty.ElementSize; } -multiclass sve_fp_3op_p_zds_b opc, string asm> { - def _H : sve_fp_3op_p_zds_b<0b01, opc, asm, ZPR16>; - def _S : sve_fp_3op_p_zds_b<0b10, opc, asm, ZPR32>; - def _D : sve_fp_3op_p_zds_b<0b11, opc, asm, ZPR64>; +multiclass sve_fp_3op_p_zds_b opc, string asm, SDPatternOperator op, + string revname, bit isOrig> { + def _H : sve_fp_3op_p_zds_b<0b01, opc, asm, ZPR16>, + SVEInstr2Rev; + def _S : sve_fp_3op_p_zds_b<0b10, opc, asm, ZPR32>, + SVEInstr2Rev; + def _D : sve_fp_3op_p_zds_b<0b11, opc, asm, ZPR64>, + SVEInstr2Rev; + + def : SVE_4_Op_Pat(NAME # _H)>; + def : SVE_4_Op_Pat(NAME # _S)>; + def : SVE_4_Op_Pat(NAME # _D)>; } +multiclass sve_fp_3op_p_zds_zx { + def _UNDEF_H : PredThreeOpPseudo; + def _UNDEF_S : PredThreeOpPseudo; + def _UNDEF_D : PredThreeOpPseudo; + + def _ZERO_H : PredThreeOpPseudo; + def _ZERO_S : PredThreeOpPseudo; + def _ZERO_D : PredThreeOpPseudo; + + def : SVE_4_Op_Pat_SelZero(NAME # _ZERO_H)>; + def : SVE_4_Op_Pat_SelZero(NAME # _ZERO_S)>; + def : SVE_4_Op_Pat_SelZero(NAME # _ZERO_D)>; + + // As above but with the accumulator in it's alternative position. + def : Pat<(nxv8f16 (rev_op nxv8i1:$Op1, (vselect nxv8i1:$Op1, nxv8f16:$Op2, (SVEDup0)), nxv8f16:$Op3, nxv8f16:$Op4)), + (!cast(NAME # _ZERO_H) $Op1, $Op4, $Op2, $Op3)>; + def : Pat<(nxv4f32 (rev_op nxv4i1:$Op1, (vselect nxv4i1:$Op1, nxv4f32:$Op2, (SVEDup0)), nxv4f32:$Op3, nxv4f32:$Op4)), + (!cast(NAME # _ZERO_S) $Op1, $Op4, $Op2, $Op3)>; + def : Pat<(nxv2f64 (rev_op nxv2i1:$Op1, (vselect nxv2i1:$Op1, nxv2f64:$Op2, (SVEDup0)), nxv2f64:$Op3, nxv2f64:$Op4)), + (!cast(NAME # _ZERO_D) $Op1, $Op4, $Op2, $Op3)>; +} + + //===----------------------------------------------------------------------===// // SVE Floating Point Multiply-Add - Indexed Group //===----------------------------------------------------------------------===// @@ -1258,30 +1773,37 @@ let Inst{4-0} = Zda; let Constraints = "$Zda = $_Zda"; - let DestructiveInstType = Destructive; - let ElementSize = ElementSizeNone; + let DestructiveInstType = DestructiveOther; } -multiclass sve_fp_fma_by_indexed_elem { - def _H : sve_fp_fma_by_indexed_elem<{0, ?}, opc, asm, ZPR16, ZPR3b16, VectorIndexH> { +multiclass sve_fp_fma_by_indexed_elem { + def _H : sve_fp_fma_by_indexed_elem<{0, ?}, opc, asm, ZPR16, ZPR3b16, VectorIndexH32b> { bits<3> Zm; bits<3> iop; let Inst{22} = iop{2}; let Inst{20-19} = iop{1-0}; let Inst{18-16} = Zm; } - def _S : sve_fp_fma_by_indexed_elem<0b10, opc, asm, ZPR32, ZPR3b32, VectorIndexS> { + def _S : sve_fp_fma_by_indexed_elem<0b10, opc, asm, ZPR32, ZPR3b32, VectorIndexS32b> { bits<3> Zm; bits<2> iop; let Inst{20-19} = iop; let Inst{18-16} = Zm; } - def _D : sve_fp_fma_by_indexed_elem<0b11, opc, asm, ZPR64, ZPR4b64, VectorIndexD> { + def _D : sve_fp_fma_by_indexed_elem<0b11, opc, asm, ZPR64, ZPR4b64, VectorIndexD32b> { bits<4> Zm; bit iop; let Inst{20} = iop; let Inst{19-16} = Zm; } + + def : Pat<(nxv8f16 (op nxv8f16:$Op1, nxv8f16:$Op2, nxv8f16:$Op3, (i32 VectorIndexH32b:$idx))), + (!cast(NAME # _H) $Op1, $Op2, $Op3, VectorIndexH32b:$idx)>; + def : Pat<(nxv4f32 (op nxv4f32:$Op1, nxv4f32:$Op2, nxv4f32:$Op3, (i32 VectorIndexS32b:$idx))), + (!cast(NAME # _S) $Op1, $Op2, $Op3, VectorIndexS32b:$idx)>; + def : Pat<(nxv2f64 (op nxv2f64:$Op1, nxv2f64:$Op2, nxv2f64:$Op3, (i32 VectorIndexD32b:$idx))), + (!cast(NAME # _D) $Op1, $Op2, $Op3, VectorIndexD32b:$idx)>; } @@ -1303,28 +1825,36 @@ let Inst{4-0} = Zd; } -multiclass sve_fp_fmul_by_indexed_elem { - def _H : sve_fp_fmul_by_indexed_elem<{0, ?}, asm, ZPR16, ZPR3b16, VectorIndexH> { +multiclass sve_fp_fmul_by_indexed_elem { + def _H : sve_fp_fmul_by_indexed_elem<{0, ?}, asm, ZPR16, ZPR3b16, VectorIndexH32b> { bits<3> Zm; bits<3> iop; let Inst{22} = iop{2}; let Inst{20-19} = iop{1-0}; let Inst{18-16} = Zm; } - def _S : sve_fp_fmul_by_indexed_elem<0b10, asm, ZPR32, ZPR3b32, VectorIndexS> { + def _S : sve_fp_fmul_by_indexed_elem<0b10, asm, ZPR32, ZPR3b32, VectorIndexS32b> { bits<3> Zm; bits<2> iop; let Inst{20-19} = iop; let Inst{18-16} = Zm; } - def _D : sve_fp_fmul_by_indexed_elem<0b11, asm, ZPR64, ZPR4b64, VectorIndexD> { + def _D : sve_fp_fmul_by_indexed_elem<0b11, asm, ZPR64, ZPR4b64, VectorIndexD32b> { bits<4> Zm; bit iop; let Inst{20} = iop; let Inst{19-16} = Zm; } + + def : Pat<(nxv8f16 (op nxv8f16:$Op1, nxv8f16:$Op2, (i32 VectorIndexH32b:$idx))), + (!cast(NAME # _H) $Op1, $Op2, VectorIndexH32b:$idx)>; + def : Pat<(nxv4f32 (op nxv4f32:$Op1, nxv4f32:$Op2, (i32 VectorIndexS32b:$idx))), + (!cast(NAME # _S) $Op1, $Op2, VectorIndexS32b:$idx)>; + def : Pat<(nxv2f64 (op nxv2f64:$Op1, nxv2f64:$Op2, (i32 VectorIndexD32b:$idx))), + (!cast(NAME # _D) $Op1, $Op2, VectorIndexD32b:$idx)>; } + //===----------------------------------------------------------------------===// // SVE Floating Point Complex Multiply-Add Group //===----------------------------------------------------------------------===// @@ -1350,14 +1880,21 @@ let Inst{4-0} = Zda; let Constraints = "$Zda = $_Zda"; - let DestructiveInstType = Destructive; + let DestructiveInstType = DestructiveOther; let ElementSize = zprty.ElementSize; } -multiclass sve_fp_fcmla { +multiclass sve_fp_fcmla { def _H : sve_fp_fcmla<0b01, asm, ZPR16>; def _S : sve_fp_fcmla<0b10, asm, ZPR32>; def _D : sve_fp_fcmla<0b11, asm, ZPR64>; + + def : Pat<(nxv8f16 (op nxv8i1:$Op1, nxv8f16:$Op2, nxv8f16:$Op3, nxv8f16:$Op4, (i32 complexrotateop:$imm))), + (!cast(NAME # _H) $Op1, $Op2, $Op3, $Op4, complexrotateop:$imm)>; + def : Pat<(nxv4f32 (op nxv4i1:$Op1, nxv4f32:$Op2, nxv4f32:$Op3, nxv4f32:$Op4, (i32 complexrotateop:$imm))), + (!cast(NAME # _S) $Op1, $Op2, $Op3, $Op4, complexrotateop:$imm)>; + def : Pat<(nxv2f64 (op nxv2i1:$Op1, nxv2f64:$Op2, nxv2f64:$Op3, nxv2f64:$Op4, (i32 complexrotateop:$imm))), + (!cast(NAME # _D) $Op1, $Op2, $Op3, $Op4, complexrotateop:$imm)>; } //===----------------------------------------------------------------------===// @@ -1383,25 +1920,30 @@ let Inst{4-0} = Zda; let Constraints = "$Zda = $_Zda"; - let DestructiveInstType = Destructive; - let ElementSize = ElementSizeNone; + let DestructiveInstType = DestructiveOther; } -multiclass sve_fp_fcmla_by_indexed_elem { - def _H : sve_fp_fcmla_by_indexed_elem<0b10, asm, ZPR16, ZPR3b16, VectorIndexS> { +multiclass sve_fp_fcmla_by_indexed_elem { + def _H : sve_fp_fcmla_by_indexed_elem<0b10, asm, ZPR16, ZPR3b16, VectorIndexS32b> { bits<3> Zm; bits<2> iop; let Inst{20-19} = iop; let Inst{18-16} = Zm; } - def _S : sve_fp_fcmla_by_indexed_elem<0b11, asm, ZPR32, ZPR4b32, VectorIndexD> { + def _S : sve_fp_fcmla_by_indexed_elem<0b11, asm, ZPR32, ZPR4b32, VectorIndexD32b> { bits<4> Zm; bits<1> iop; let Inst{20} = iop; let Inst{19-16} = Zm; } + + def : Pat<(nxv8f16 (op nxv8f16:$Op1, nxv8f16:$Op2, nxv8f16:$Op3, (i32 VectorIndexS32b:$idx), (i32 complexrotateop:$imm))), + (!cast(NAME # _H) $Op1, $Op2, $Op3, VectorIndexS32b:$idx, complexrotateop:$imm)>; + def : Pat<(nxv4f32 (op nxv4f32:$Op1, nxv4f32:$Op2, nxv4f32:$Op3, (i32 VectorIndexD32b:$idx), (i32 complexrotateop:$imm))), + (!cast(NAME # _S) $Op1, $Op2, $Op3, VectorIndexD32b:$idx, complexrotateop:$imm)>; } + //===----------------------------------------------------------------------===// // SVE Floating Point Complex Addition Group //===----------------------------------------------------------------------===// @@ -1426,14 +1968,186 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; + let DestructiveInstType = DestructiveOther; let ElementSize = zprty.ElementSize; } -multiclass sve_fp_fcadd { +multiclass sve_fp_fcadd { def _H : sve_fp_fcadd<0b01, asm, ZPR16>; def _S : sve_fp_fcadd<0b10, asm, ZPR32>; def _D : sve_fp_fcadd<0b11, asm, ZPR64>; + + def : Pat<(nxv8f16 (op nxv8i1:$Op1, nxv8f16:$Op2, nxv8f16:$Op3, (i32 complexrotateopodd:$imm))), + (!cast(NAME # _H) $Op1, $Op2, $Op3, complexrotateopodd:$imm)>; + def : Pat<(nxv4f32 (op nxv4i1:$Op1, nxv4f32:$Op2, nxv4f32:$Op3, (i32 complexrotateopodd:$imm))), + (!cast(NAME # _S) $Op1, $Op2, $Op3, complexrotateopodd:$imm)>; + def : Pat<(nxv2f64 (op nxv2i1:$Op1, nxv2f64:$Op2, nxv2f64:$Op3, (i32 complexrotateopodd:$imm))), + (!cast(NAME # _D) $Op1, $Op2, $Op3, complexrotateopodd:$imm)>; +} + + +//===----------------------------------------------------------------------===// +// SVE2 Floating Point Convert Group +//===----------------------------------------------------------------------===// + +class sve2_fp_convert_precision opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2> +: I<(outs zprty1:$Zd), (ins zprty1:$_Zd, PPR3bAny:$Pg, zprty2:$Zn), + asm, "\t$Zd, $Pg/m, $Zn", + "", + []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<3> Pg; + let Inst{31-24} = 0b01100100; + let Inst{23-22} = opc{3-2}; + let Inst{21-18} = 0b0010; + let Inst{17-16} = opc{1-0}; + let Inst{15-13} = 0b101; + let Inst{12-10} = Pg; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; + + let Constraints = "$Zd = $_Zd"; +} + +multiclass sve2_fp_convert_down_narrow { + def _StoH : sve2_fp_convert_precision<0b1000, asm, ZPR16, ZPR32>; + def _DtoS : sve2_fp_convert_precision<0b1110, asm, ZPR32, ZPR64>; + + def : SVE_3_Op_Pat(op # _f16f32), nxv8f16, nxv16i1, nxv4f32, !cast(NAME # _StoH)>; + def : SVE_3_Op_Pat(op # _f32f64), nxv4f32, nxv16i1, nxv2f64, !cast(NAME # _DtoS)>; +} + +multiclass sve2_fp_convert_up_long { + def _HtoS : sve2_fp_convert_precision<0b1001, asm, ZPR32, ZPR16>; + def _StoD : sve2_fp_convert_precision<0b1111, asm, ZPR64, ZPR32>; + + def : SVE_3_Op_Pat(op # _f32f16), nxv4f32, nxv16i1, nxv8f16, !cast(NAME # _HtoS)>; + def : SVE_3_Op_Pat(op # _f64f32), nxv2f64, nxv16i1, nxv4f32, !cast(NAME # _StoD)>; +} + +multiclass sve2_fp_convert_down_odd_rounding_top { + def _DtoS : sve2_fp_convert_precision<0b0010, asm, ZPR32, ZPR64>; + + def : SVE_3_Op_Pat(op # _f32f64), nxv4f32, nxv16i1, nxv2f64, !cast(NAME # _DtoS)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Floating Point Pairwise Group +//===----------------------------------------------------------------------===// + +class sve2_fp_pairwise_pred sz, bits<3> opc, string asm, + ZPRRegOp zprty> +: I<(outs zprty:$Zdn), (ins PPR3bAny:$Pg, zprty:$_Zdn, zprty:$Zm), + asm, "\t$Zdn, $Pg/m, $_Zdn, $Zm", + "", + []>, Sched<[]> { + bits<3> Pg; + bits<5> Zm; + bits<5> Zdn; + let Inst{31-24} = 0b01100100; + let Inst{23-22} = sz; + let Inst{21-19} = 0b010; + let Inst{18-16} = opc; + let Inst{15-13} = 0b100; + let Inst{12-10} = Pg; + let Inst{9-5} = Zm; + let Inst{4-0} = Zdn; + + let Constraints = "$Zdn = $_Zdn"; + let DestructiveInstType = DestructiveBinary; + let ElementSize = zprty.ElementSize; +} + +multiclass sve2_fp_pairwise_pred opc, string asm, string psName, + SDPatternOperator op> { + def _H : sve2_fp_pairwise_pred<0b01, opc, asm, ZPR16>, SVEPseudo2Instr; + def _S : sve2_fp_pairwise_pred<0b10, opc, asm, ZPR32>, SVEPseudo2Instr; + def _D : sve2_fp_pairwise_pred<0b11, opc, asm, ZPR64>, SVEPseudo2Instr; + + def _H_UNDEF : PredTwoOpConstrainedPseudo; + def _S_UNDEF : PredTwoOpConstrainedPseudo; + def _D_UNDEF : PredTwoOpConstrainedPseudo; + + def _H_ZERO : PredTwoOpConstrainedPseudo; + def _S_ZERO : PredTwoOpConstrainedPseudo; + def _D_ZERO : PredTwoOpConstrainedPseudo; + + def : SVE_3_Op_Pat_SelZero(NAME # _H_ZERO)>; + def : SVE_3_Op_Pat_SelZero(NAME # _S_ZERO)>; + def : SVE_3_Op_Pat_SelZero(NAME # _D_ZERO)>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Floating Point Widening Multiply-Add - Indexed Group +//===----------------------------------------------------------------------===// + +class sve2_fp_mla_long_by_indexed_elem opc, string asm> +: I<(outs ZPR32:$Zda), (ins ZPR32:$_Zda, ZPR16:$Zn, ZPR3b16:$Zm, + VectorIndexH32b:$iop), + asm, "\t$Zda, $Zn, $Zm$iop", + "", + []>, Sched<[]> { + bits<5> Zda; + bits<5> Zn; + bits<3> Zm; + bits<3> iop; + let Inst{31-21} = 0b01100100101; + let Inst{20-19} = iop{2-1}; + let Inst{18-16} = Zm; + let Inst{15-14} = 0b01; + let Inst{13} = opc{1}; + let Inst{12} = 0b0; + let Inst{11} = iop{0}; + let Inst{10} = opc{0}; + let Inst{9-5} = Zn; + let Inst{4-0} = Zda; + + let Constraints = "$Zda = $_Zda"; + let DestructiveInstType = DestructiveOther; + let ElementSize = ElementSizeNone; +} + +multiclass sve2_fp_mla_long_by_indexed_elem opc, string asm, + SDPatternOperator op> { + def NAME : sve2_fp_mla_long_by_indexed_elem; + def : SVE_4_Op_Imm_Pat(NAME)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Floating Point Widening Multiply-Add Group +//===----------------------------------------------------------------------===// + +class sve2_fp_mla_long opc, string asm> +: I<(outs ZPR32:$Zda), (ins ZPR32:$_Zda, ZPR16:$Zn, ZPR16:$Zm), + asm, "\t$Zda, $Zn, $Zm", + "", + []>, Sched<[]> { + bits<5> Zda; + bits<5> Zn; + bits<5> Zm; + let Inst{31-21} = 0b01100100101; + let Inst{20-16} = Zm; + let Inst{15-14} = 0b10; + let Inst{13} = opc{1}; + let Inst{12-11} = 0b00; + let Inst{10} = opc{0}; + let Inst{9-5} = Zn; + let Inst{4-0} = Zda; + + let Constraints = "$Zda = $_Zda"; + let DestructiveInstType = DestructiveOther; + let ElementSize = ElementSizeNone; +} + +multiclass sve2_fp_mla_long opc, string asm, SDPatternOperator op> { + def NAME : sve2_fp_mla_long; + def : SVE_3_Op_Pat(NAME)>; } //===----------------------------------------------------------------------===// @@ -1471,6 +2185,8 @@ let Inst{15-11} = 0b01010; let Inst{10-5} = imm6; let Inst{4-0} = Rd; + + let isReMaterializable = 1; } //===----------------------------------------------------------------------===// @@ -1496,11 +2212,22 @@ let Inst{4-0} = Zd; } -multiclass sve_int_perm_bin_perm_zz opc, string asm> { +multiclass sve_int_perm_bin_perm_zz opc, string asm, + SDPatternOperator op> { def _B : sve_int_perm_bin_perm_zz; def _H : sve_int_perm_bin_perm_zz; def _S : sve_int_perm_bin_perm_zz; def _D : sve_int_perm_bin_perm_zz; + + def : SVE_2_Op_Pat(NAME # _B)>; + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; + + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; } //===----------------------------------------------------------------------===// @@ -1508,7 +2235,7 @@ //===----------------------------------------------------------------------===// class sve_fp_2op_p_zd opc, string asm, RegisterOperand i_zprtype, - RegisterOperand o_zprtype, ElementSizeEnum size> + RegisterOperand o_zprtype, ElementSizeEnum Sz> : I<(outs o_zprtype:$Zd), (ins i_zprtype:$_Zd, PPR3bAny:$Pg, i_zprtype:$Zn), asm, "\t$Zd, $Pg/m, $Zn", "", @@ -1526,16 +2253,45 @@ let Inst{4-0} = Zd; let Constraints = "$Zd = $_Zd"; - let DestructiveInstType = Destructive; - let ElementSize = size; + let DestructiveInstType = DestructiveUnary; + let ElementSize = Sz; } -multiclass sve_fp_2op_p_zd_HSD opc, string asm> { +multiclass sve_fp_2op_p_zd opc, string asm, + RegisterOperand i_zprtype, + RegisterOperand o_zprtype, + SDPatternOperator op, ValueType vt1, + ValueType vt2, ValueType vt3, ElementSizeEnum Sz> { + def NAME : sve_fp_2op_p_zd; + + def : SVE_3_Op_Pat(NAME)>; +} + +multiclass sve_fp_2op_p_zd_HSD opc, string asm, SDPatternOperator op> { def _H : sve_fp_2op_p_zd<{ 0b01, opc }, asm, ZPR16, ZPR16, ElementSizeH>; def _S : sve_fp_2op_p_zd<{ 0b10, opc }, asm, ZPR32, ZPR32, ElementSizeS>; def _D : sve_fp_2op_p_zd<{ 0b11, opc }, asm, ZPR64, ZPR64, ElementSizeD>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } - + +multiclass sve2_fp_flogb { + def _H : sve_fp_2op_p_zd<0b0011010, asm, ZPR16, ZPR16, ElementSizeH>; + def _S : sve_fp_2op_p_zd<0b0011100, asm, ZPR32, ZPR32, ElementSizeS>; + def _D : sve_fp_2op_p_zd<0b0011110, asm, ZPR64, ZPR64, ElementSizeD>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +multiclass sve2_fp_convert_down_odd_rounding { + def _DtoS : sve_fp_2op_p_zd<0b0001010, asm, ZPR64, ZPR32, ElementSizeD>; + def : SVE_3_Op_Pat(op # _f32f64), nxv4f32, nxv16i1, nxv2f64, !cast(NAME # _DtoS)>; +} + //===----------------------------------------------------------------------===// // SVE Floating Point Unary Operations - Unpredicated Group //===----------------------------------------------------------------------===// @@ -1557,10 +2313,14 @@ let Inst{4-0} = Zd; } -multiclass sve_fp_2op_u_zd opc, string asm> { +multiclass sve_fp_2op_u_zd opc, string asm, SDPatternOperator op> { def _H : sve_fp_2op_u_zd<0b01, opc, asm, ZPR16>; def _S : sve_fp_2op_u_zd<0b10, opc, asm, ZPR32>; def _D : sve_fp_2op_u_zd<0b11, opc, asm, ZPR64>; + + def : SVE_1_Op_Pat(NAME # _H)>; + def : SVE_1_Op_Pat(NAME # _S)>; + def : SVE_1_Op_Pat(NAME # _D)>; } //===----------------------------------------------------------------------===// @@ -1585,42 +2345,112 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; let ElementSize = zprty.ElementSize; } -multiclass sve_int_bin_pred_log opc, string asm> { - def _B : sve_int_bin_pred_arit_log<0b00, 0b11, opc, asm, ZPR8>; - def _H : sve_int_bin_pred_arit_log<0b01, 0b11, opc, asm, ZPR16>; - def _S : sve_int_bin_pred_arit_log<0b10, 0b11, opc, asm, ZPR32>; - def _D : sve_int_bin_pred_arit_log<0b11, 0b11, opc, asm, ZPR64>; +multiclass sve_int_bin_pred_log opc, string asm, string Ps, + SDPatternOperator op, + DestructiveInstTypeEnum flags> { + let DestructiveInstType = flags in { + def _B : sve_int_bin_pred_arit_log<0b00, 0b11, opc, asm, ZPR8>, + SVEPseudo2Instr; + def _H : sve_int_bin_pred_arit_log<0b01, 0b11, opc, asm, ZPR16>, + SVEPseudo2Instr; + def _S : sve_int_bin_pred_arit_log<0b10, 0b11, opc, asm, ZPR32>, + SVEPseudo2Instr; + def _D : sve_int_bin_pred_arit_log<0b11, 0b11, opc, asm, ZPR64>, + SVEPseudo2Instr; + } + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } -multiclass sve_int_bin_pred_arit_0 opc, string asm> { - def _B : sve_int_bin_pred_arit_log<0b00, 0b00, opc, asm, ZPR8>; - def _H : sve_int_bin_pred_arit_log<0b01, 0b00, opc, asm, ZPR16>; - def _S : sve_int_bin_pred_arit_log<0b10, 0b00, opc, asm, ZPR32>; - def _D : sve_int_bin_pred_arit_log<0b11, 0b00, opc, asm, ZPR64>; +multiclass sve_int_bin_pred_arit_0 opc, string asm, string Ps, + SDPatternOperator op, + DestructiveInstTypeEnum flags, + string revname="", bit isOrig=0> { + let DestructiveInstType = flags in { + def _B : sve_int_bin_pred_arit_log<0b00, 0b00, opc, asm, ZPR8>, + SVEPseudo2Instr, SVEInstr2Rev; + def _H : sve_int_bin_pred_arit_log<0b01, 0b00, opc, asm, ZPR16>, + SVEPseudo2Instr, SVEInstr2Rev; + def _S : sve_int_bin_pred_arit_log<0b10, 0b00, opc, asm, ZPR32>, + SVEPseudo2Instr, SVEInstr2Rev; + def _D : sve_int_bin_pred_arit_log<0b11, 0b00, opc, asm, ZPR64>, + SVEPseudo2Instr, SVEInstr2Rev; + } + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } -multiclass sve_int_bin_pred_arit_1 opc, string asm> { - def _B : sve_int_bin_pred_arit_log<0b00, 0b01, opc, asm, ZPR8>; - def _H : sve_int_bin_pred_arit_log<0b01, 0b01, opc, asm, ZPR16>; - def _S : sve_int_bin_pred_arit_log<0b10, 0b01, opc, asm, ZPR32>; - def _D : sve_int_bin_pred_arit_log<0b11, 0b01, opc, asm, ZPR64>; +multiclass sve_int_bin_pred_arit_1 opc, string asm, string Ps, + SDPatternOperator op> { + let DestructiveInstType = DestructiveBinaryComm in { + def _B : sve_int_bin_pred_arit_log<0b00, 0b01, opc, asm, ZPR8>, + SVEPseudo2Instr; + def _H : sve_int_bin_pred_arit_log<0b01, 0b01, opc, asm, ZPR16>, + SVEPseudo2Instr; + def _S : sve_int_bin_pred_arit_log<0b10, 0b01, opc, asm, ZPR32>, + SVEPseudo2Instr; + def _D : sve_int_bin_pred_arit_log<0b11, 0b01, opc, asm, ZPR64>, + SVEPseudo2Instr; + } + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } -multiclass sve_int_bin_pred_arit_2 opc, string asm> { - def _B : sve_int_bin_pred_arit_log<0b00, 0b10, opc, asm, ZPR8>; - def _H : sve_int_bin_pred_arit_log<0b01, 0b10, opc, asm, ZPR16>; - def _S : sve_int_bin_pred_arit_log<0b10, 0b10, opc, asm, ZPR32>; - def _D : sve_int_bin_pred_arit_log<0b11, 0b10, opc, asm, ZPR64>; +multiclass sve_int_bin_pred_arit_2 opc, string asm, string Ps, + SDPatternOperator op> { + let DestructiveInstType = DestructiveBinaryComm in { + def _B : sve_int_bin_pred_arit_log<0b00, 0b10, opc, asm, ZPR8>, + SVEPseudo2Instr; + def _H : sve_int_bin_pred_arit_log<0b01, 0b10, opc, asm, ZPR16>, + SVEPseudo2Instr; + def _S : sve_int_bin_pred_arit_log<0b10, 0b10, opc, asm, ZPR32>, + SVEPseudo2Instr; + def _D : sve_int_bin_pred_arit_log<0b11, 0b10, opc, asm, ZPR64>, + SVEPseudo2Instr; + } + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } // Special case for divides which are not defined for 8b/16b elements. -multiclass sve_int_bin_pred_arit_2_div opc, string asm> { - def _S : sve_int_bin_pred_arit_log<0b10, 0b10, opc, asm, ZPR32>; - def _D : sve_int_bin_pred_arit_log<0b11, 0b10, opc, asm, ZPR64>; +multiclass sve_int_bin_pred_arit_2_div opc, string asm, string Ps, + SDPatternOperator op, string revname, + bit isOrig> { + let DestructiveInstType = DestructiveBinaryCommWithRev in { + def _S : sve_int_bin_pred_arit_log<0b10, 0b10, opc, asm, ZPR32>, + SVEPseudo2Instr, SVEInstr2Rev; + def _D : sve_int_bin_pred_arit_log<0b11, 0b10, opc, asm, ZPR64>, + SVEPseudo2Instr, SVEInstr2Rev; + } + + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +multiclass sve_int_bin_pred_arit_2_div_zx { + def _UNDEF_S : PredTwoOpPseudo; + def _UNDEF_D : PredTwoOpPseudo; + + def _ZERO_S : PredTwoOpPseudo; + def _ZERO_D : PredTwoOpPseudo; + + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_S)>; + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_D)>; } //===----------------------------------------------------------------------===// @@ -1648,15 +2478,47 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; + let DestructiveInstType = DestructiveOther; let ElementSize = zprty.ElementSize; } -multiclass sve_int_mladdsub_vvv_pred opc, string asm> { - def _B : sve_int_mladdsub_vvv_pred<0b00, opc, asm, ZPR8>; - def _H : sve_int_mladdsub_vvv_pred<0b01, opc, asm, ZPR16>; - def _S : sve_int_mladdsub_vvv_pred<0b10, opc, asm, ZPR32>; - def _D : sve_int_mladdsub_vvv_pred<0b11, opc, asm, ZPR64>; +multiclass sve_int_mladdsub_vvv_pred opc, string asm, + SDPatternOperator int_op, + SDPatternOperator acc_op, + string revname, bit isOrig> { + def _B : sve_int_mladdsub_vvv_pred<0b00, opc, asm, ZPR8>, + SVEInstr2Rev; + def _H : sve_int_mladdsub_vvv_pred<0b01, opc, asm, ZPR16>, + SVEInstr2Rev; + def _S : sve_int_mladdsub_vvv_pred<0b10, opc, asm, ZPR32>, + SVEInstr2Rev; + def _D : sve_int_mladdsub_vvv_pred<0b11, opc, asm, ZPR64>, + SVEInstr2Rev; + + def : SVE_4_Op_Pat(NAME # _B)>; + def : SVE_4_Op_Pat(NAME # _H)>; + def : SVE_4_Op_Pat(NAME # _S)>; + def : SVE_4_Op_Pat(NAME # _D)>; + + def : Pat<(nxv16i8 (vselect (nxv16i1 PPR:$Pg), + (acc_op (nxv16i8 ZPR:$Za), (mul (nxv16i8 ZPR:$Zdn), (nxv16i8 ZPR:$Zm))), + (nxv16i8 ZPR:$Zdn))), + (!cast(NAME # _B) PPR:$Pg, ZPR:$Zdn, ZPR:$Zm, ZPR:$Za)>; + + def : Pat<(nxv8i16 (vselect (nxv8i1 PPR:$Pg), + (acc_op (nxv8i16 ZPR:$Za), (mul (nxv8i16 ZPR:$Zdn), (nxv8i16 ZPR:$Zm))), + (nxv8i16 ZPR:$Zdn))), + (!cast(NAME # _H) PPR:$Pg, ZPR:$Zdn, ZPR:$Zm, ZPR:$Za)>; + + def : Pat<(nxv4i32 (vselect (nxv4i1 PPR:$Pg), + (acc_op (nxv4i32 ZPR:$Za), (mul (nxv4i32 ZPR:$Zdn), (nxv4i32 ZPR:$Zm))), + (nxv4i32 ZPR:$Zdn))), + (!cast(NAME # _S) PPR:$Pg, ZPR:$Zdn, ZPR:$Zm, ZPR:$Za)>; + + def : Pat<(nxv2i64 (vselect (nxv2i1 PPR:$Pg), + (acc_op (nxv2i64 ZPR:$Za), (mul (nxv2i64 ZPR:$Zdn), (nxv2i64 ZPR:$Zm))), + (nxv2i64 ZPR:$Zdn))), + (!cast(NAME # _D) PPR:$Pg, ZPR:$Zdn, ZPR:$Zm, ZPR:$Za)>; } class sve_int_mlas_vvv_pred sz8_64, bits<1> opc, string asm, @@ -1680,15 +2542,202 @@ let Inst{4-0} = Zda; let Constraints = "$Zda = $_Zda"; - let DestructiveInstType = Destructive; let ElementSize = zprty.ElementSize; } -multiclass sve_int_mlas_vvv_pred opc, string asm> { - def _B : sve_int_mlas_vvv_pred<0b00, opc, asm, ZPR8>; - def _H : sve_int_mlas_vvv_pred<0b01, opc, asm, ZPR16>; - def _S : sve_int_mlas_vvv_pred<0b10, opc, asm, ZPR32>; - def _D : sve_int_mlas_vvv_pred<0b11, opc, asm, ZPR64>; +multiclass sve_int_mlas_vvv_pred opc, string asm, string Ps, + SDPatternOperator int_op, + SDPatternOperator acc_op, + string revname, bit isOrig> { + let DestructiveInstType = DestructiveTernaryCommWithRev in { + def _B : sve_int_mlas_vvv_pred<0b00, opc, asm, ZPR8>, + SVEPseudo2Instr, + SVEInstr2Rev; + def _H : sve_int_mlas_vvv_pred<0b01, opc, asm, ZPR16>, + SVEPseudo2Instr, + SVEInstr2Rev; + def _S : sve_int_mlas_vvv_pred<0b10, opc, asm, ZPR32>, + SVEPseudo2Instr, + SVEInstr2Rev; + def _D : sve_int_mlas_vvv_pred<0b11, opc, asm, ZPR64>, + SVEPseudo2Instr, + SVEInstr2Rev; + } + + def : SVE_4_Op_Pat(NAME # _B)>; + def : SVE_4_Op_Pat(NAME # _H)>; + def : SVE_4_Op_Pat(NAME # _S)>; + def : SVE_4_Op_Pat(NAME # _D)>; + + def : Pat<(nxv16i8 (acc_op (nxv16i8 ZPR:$Zda), (vselect (nxv16i1 PPR:$Pg), + (mul (nxv16i8 ZPR:$Zn), (nxv16i8 ZPR:$Zm)), + (nxv16i8 (AArch64dup (i32 0)))))), + (!cast(NAME # _B) PPR:$Pg, ZPR:$Zda, ZPR:$Zn, ZPR:$Zm)>; + + def : Pat<(nxv8i16 (acc_op (nxv8i16 ZPR:$Zda), (vselect (nxv8i1 PPR:$Pg), + (mul (nxv8i16 ZPR:$Zn), (nxv8i16 ZPR:$Zm)), + (nxv8i16 (AArch64dup (i32 0)))))), + (!cast(NAME # _H) PPR:$Pg, ZPR:$Zda, ZPR:$Zn, ZPR:$Zm)>; + + def : Pat<(nxv4i32 (acc_op (nxv4i32 ZPR:$Zda), (vselect (nxv4i1 PPR:$Pg), + (mul (nxv4i32 ZPR:$Zn), (nxv4i32 ZPR:$Zm)), + (nxv4i32 (AArch64dup (i32 0)))))), + (!cast(NAME # _S) PPR:$Pg, ZPR:$Zda, ZPR:$Zn, ZPR:$Zm)>; + + def : Pat<(nxv2i64 (acc_op (nxv2i64 ZPR:$Zda), (vselect (nxv2i1 PPR:$Pg), + (mul (nxv2i64 ZPR:$Zn), (nxv2i64 ZPR:$Zm)), + (nxv2i64 (AArch64dup (i64 0)))))), + (!cast(NAME # _D) PPR:$Pg, ZPR:$Zda, ZPR:$Zn, ZPR:$Zm)>; +} + +multiclass sve_int_ternary_pred_zx { + def _UNDEF_B : PredThreeOpPseudo; + def _UNDEF_H : PredThreeOpPseudo; + def _UNDEF_S : PredThreeOpPseudo; + def _UNDEF_D : PredThreeOpPseudo; + + def _ZERO_B : PredThreeOpPseudo; + def _ZERO_H : PredThreeOpPseudo; + def _ZERO_S : PredThreeOpPseudo; + def _ZERO_D : PredThreeOpPseudo; + + def : SVE_4_Op_Pat_SelZero(NAME # _ZERO_B)>; + def : SVE_4_Op_Pat_SelZero(NAME # _ZERO_H)>; + def : SVE_4_Op_Pat_SelZero(NAME # _ZERO_S)>; + def : SVE_4_Op_Pat_SelZero(NAME # _ZERO_D)>; + + // As above but with the accumulator in it's alternative position. + def : Pat<(nxv16i8 (rev_op nxv16i1:$Op1, (vselect nxv16i1:$Op1, nxv16i8:$Op2, (SVEDup0)), nxv16i8:$Op3, nxv16i8:$Op4)), + (!cast(NAME # _ZERO_B) $Op1, $Op4, $Op2, $Op3)>; + def : Pat<(nxv8i16 (rev_op nxv8i1:$Op1, (vselect nxv8i1:$Op1, nxv8i16:$Op2, (SVEDup0)), nxv8i16:$Op3, nxv8i16:$Op4)), + (!cast(NAME # _ZERO_H) $Op1, $Op4, $Op2, $Op3)>; + def : Pat<(nxv4i32 (rev_op nxv4i1:$Op1, (vselect nxv4i1:$Op1, nxv4i32:$Op2, (SVEDup0)), nxv4i32:$Op3, nxv4i32:$Op4)), + (!cast(NAME # _ZERO_S) $Op1, $Op4, $Op2, $Op3)>; + def : Pat<(nxv2i64 (rev_op nxv2i1:$Op1, (vselect nxv2i1:$Op1, nxv2i64:$Op2, (SVEDup0)), nxv2i64:$Op3, nxv2i64:$Op4)), + (!cast(NAME # _ZERO_D) $Op1, $Op4, $Op2, $Op3)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Integer Multiply-Add - Unpredicated Group +//===----------------------------------------------------------------------===// + +class sve2_int_mla sz, bits<5> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2> +: I<(outs zprty1:$Zda), (ins zprty1:$_Zda, zprty2:$Zn, zprty2:$Zm), + asm, "\t$Zda, $Zn, $Zm", "", []>, Sched<[]> { + bits<5> Zda; + bits<5> Zn; + bits<5> Zm; + let Inst{31-24} = 0b01000100; + let Inst{23-22} = sz; + let Inst{21} = 0b0; + let Inst{20-16} = Zm; + let Inst{15} = 0b0; + let Inst{14-10} = opc; + let Inst{9-5} = Zn; + let Inst{4-0} = Zda; + + let Constraints = "$Zda = $_Zda"; + let DestructiveInstType = DestructiveOther; + let ElementSize = ElementSizeNone; +} + +multiclass sve2_int_mla { + def _B : sve2_int_mla<0b00, { 0b1110, S }, asm, ZPR8, ZPR8>; + def _H : sve2_int_mla<0b01, { 0b1110, S }, asm, ZPR16, ZPR16>; + def _S : sve2_int_mla<0b10, { 0b1110, S }, asm, ZPR32, ZPR32>; + def _D : sve2_int_mla<0b11, { 0b1110, S }, asm, ZPR64, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +multiclass sve2_int_mla_long opc, string asm, SDPatternOperator op> { + def _H : sve2_int_mla<0b01, opc, asm, ZPR16, ZPR8>; + def _S : sve2_int_mla<0b10, opc, asm, ZPR32, ZPR16>; + def _D : sve2_int_mla<0b11, opc, asm, ZPR64, ZPR32>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Integer Multiply-Add - Indexed Group +//===----------------------------------------------------------------------===// + +class sve2_int_mla_by_indexed_elem sz, bits<6> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2, + ZPRRegOp zprty3, Operand itype> +: I<(outs zprty1:$Zda), (ins zprty1:$_Zda, zprty2:$Zn, zprty3:$Zm, itype:$iop), + asm, "\t$Zda, $Zn, $Zm$iop", "", []>, Sched<[]> { + bits<5> Zda; + bits<5> Zn; + let Inst{31-24} = 0b01000100; + let Inst{23-22} = sz; + let Inst{21} = 0b1; + let Inst{15-10} = opc; + let Inst{9-5} = Zn; + let Inst{4-0} = Zda; + + let Constraints = "$Zda = $_Zda"; + let DestructiveInstType = DestructiveOther; + let ElementSize = ElementSizeNone; +} + +multiclass sve2_int_mla_by_indexed_elem opc, bit S, string asm, + SDPatternOperator op> { + def _H : sve2_int_mla_by_indexed_elem<{0, ?}, { 0b000, opc, S }, asm, ZPR16, ZPR16, ZPR3b16, VectorIndexH32b> { + bits<3> Zm; + bits<3> iop; + let Inst{22} = iop{2}; + let Inst{20-19} = iop{1-0}; + let Inst{18-16} = Zm; + } + def _S : sve2_int_mla_by_indexed_elem<0b10, { 0b000, opc, S }, asm, ZPR32, ZPR32, ZPR3b32, VectorIndexS32b> { + bits<3> Zm; + bits<2> iop; + let Inst{20-19} = iop; + let Inst{18-16} = Zm; + } + def _D : sve2_int_mla_by_indexed_elem<0b11, { 0b000, opc, S }, asm, ZPR64, ZPR64, ZPR4b64, VectorIndexD32b> { + bits<4> Zm; + bit iop; + let Inst{20} = iop; + let Inst{19-16} = Zm; + } + def : SVE_4_Op_Imm_Pat(NAME # _H)>; + def : SVE_4_Op_Imm_Pat(NAME # _S)>; + def : SVE_4_Op_Imm_Pat(NAME # _D)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Integer Multiply-Add Long - Indexed Group +//===----------------------------------------------------------------------===// + +multiclass sve2_int_mla_long_by_indexed_elem opc, string asm, + SDPatternOperator op> { + def _S : sve2_int_mla_by_indexed_elem<0b10, { opc{3}, 0b0, opc{2-1}, ?, opc{0} }, + asm, ZPR32, ZPR16, ZPR3b16, VectorIndexH32b> { + bits<3> Zm; + bits<3> iop; + let Inst{20-19} = iop{2-1}; + let Inst{18-16} = Zm; + let Inst{11} = iop{0}; + } + def _D : sve2_int_mla_by_indexed_elem<0b11, { opc{3}, 0b0, opc{2-1}, ?, opc{0} }, + asm, ZPR64, ZPR32, ZPR4b32, VectorIndexS32b> { + bits<4> Zm; + bits<2> iop; + let Inst{20} = iop{1}; + let Inst{19-16} = Zm; + let Inst{11} = iop{0}; + } + def : SVE_4_Op_Imm_Pat(NAME # _S)>; + def : SVE_4_Op_Imm_Pat(NAME # _D)>; } //===----------------------------------------------------------------------===// @@ -1712,13 +2761,15 @@ let Inst{4-0} = Zda; let Constraints = "$Zda = $_Zda"; - let DestructiveInstType = Destructive; - let ElementSize = zprty1.ElementSize; + let DestructiveInstType = DestructiveOther; } -multiclass sve_intx_dot { +multiclass sve_intx_dot { def _S : sve_intx_dot<0b0, opc, asm, ZPR32, ZPR8>; def _D : sve_intx_dot<0b1, opc, asm, ZPR64, ZPR16>; + + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } //===----------------------------------------------------------------------===// @@ -1742,23 +2793,971 @@ let Inst{4-0} = Zda; let Constraints = "$Zda = $_Zda"; - let DestructiveInstType = Destructive; - let ElementSize = ElementSizeNone; + let DestructiveInstType = DestructiveOther; } -multiclass sve_intx_dot_by_indexed_elem { - def _S : sve_intx_dot_by_indexed_elem<0b0, opc, asm, ZPR32, ZPR8, ZPR3b8, VectorIndexS> { +multiclass sve_intx_dot_by_indexed_elem { + def _S : sve_intx_dot_by_indexed_elem<0b0, opc, asm, ZPR32, ZPR8, ZPR3b8, VectorIndexS32b> { bits<2> iop; bits<3> Zm; let Inst{20-19} = iop; let Inst{18-16} = Zm; } - def _D : sve_intx_dot_by_indexed_elem<0b1, opc, asm, ZPR64, ZPR16, ZPR4b16, VectorIndexD> { + def _D : sve_intx_dot_by_indexed_elem<0b1, opc, asm, ZPR64, ZPR16, ZPR4b16, VectorIndexD32b> { bits<1> iop; bits<4> Zm; let Inst{20} = iop; let Inst{19-16} = Zm; } + + def : Pat<(nxv4i32 (op nxv4i32:$Op1, nxv16i8:$Op2, nxv16i8:$Op3, (i32 VectorIndexS32b:$idx))), + (!cast(NAME # _S) $Op1, $Op2, $Op3, VectorIndexS32b:$idx)>; + def : Pat<(nxv2i64 (op nxv2i64:$Op1, nxv8i16:$Op2, nxv8i16:$Op3, (i32 VectorIndexD32b:$idx))), + (!cast(NAME # _D) $Op1, $Op2, $Op3, VectorIndexD32b:$idx)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Complex Integer Dot Product Group +//===----------------------------------------------------------------------===// + +class sve2_complex_int_arith sz, bits<4> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2> +: I<(outs zprty1:$Zda), (ins zprty1:$_Zda, zprty2:$Zn, zprty2:$Zm, + complexrotateop:$rot), + asm, "\t$Zda, $Zn, $Zm, $rot", "", []>, Sched<[]> { + bits<5> Zda; + bits<5> Zn; + bits<5> Zm; + bits<2> rot; + let Inst{31-24} = 0b01000100; + let Inst{23-22} = sz; + let Inst{21} = 0b0; + let Inst{20-16} = Zm; + let Inst{15-12} = opc; + let Inst{11-10} = rot; + let Inst{9-5} = Zn; + let Inst{4-0} = Zda; + + let Constraints = "$Zda = $_Zda"; + let DestructiveInstType = DestructiveOther; + let ElementSize = ElementSizeNone; +} + +multiclass sve2_cintx_dot { + def _S : sve2_complex_int_arith<0b10, 0b0001, asm, ZPR32, ZPR8>; + def _D : sve2_complex_int_arith<0b11, 0b0001, asm, ZPR64, ZPR16>; + + def : Pat<(nxv4i32 (op (nxv4i32 ZPR32:$Op1), (nxv16i8 ZPR8:$Op2), (nxv16i8 ZPR8:$Op3), + (i32 complexrotateop:$imm))), + (!cast(NAME # "_S") ZPR32:$Op1, ZPR8:$Op2, ZPR8:$Op3, complexrotateop:$imm)>; + def : Pat<(nxv2i64 (op (nxv2i64 ZPR64:$Op1), (nxv8i16 ZPR16:$Op2), (nxv8i16 ZPR16:$Op3), + (i32 complexrotateop:$imm))), + (!cast(NAME # "_D") ZPR64:$Op1, ZPR16:$Op2, ZPR16:$Op3, complexrotateop:$imm)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Complex Multiply-Add Group +//===----------------------------------------------------------------------===// + +multiclass sve2_int_cmla { + def _B : sve2_complex_int_arith<0b00, { 0b001, opc }, asm, ZPR8, ZPR8>; + def _H : sve2_complex_int_arith<0b01, { 0b001, opc }, asm, ZPR16, ZPR16>; + def _S : sve2_complex_int_arith<0b10, { 0b001, opc }, asm, ZPR32, ZPR32>; + def _D : sve2_complex_int_arith<0b11, { 0b001, opc }, asm, ZPR64, ZPR64>; + + def : SVE_4_Op_Imm_Pat(NAME # _B)>; + def : SVE_4_Op_Imm_Pat(NAME # _H)>; + def : SVE_4_Op_Imm_Pat(NAME # _S)>; + def : SVE_4_Op_Imm_Pat(NAME # _D)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Complex Integer Dot Product - Indexed Group +//===----------------------------------------------------------------------===// + +class sve2_complex_int_arith_indexed sz, bits<4> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2, + ZPRRegOp zprty3, Operand itype> +: I<(outs zprty1:$Zda), (ins zprty1:$_Zda, zprty2:$Zn, zprty3:$Zm, itype:$iop, + complexrotateop:$rot), + asm, "\t$Zda, $Zn, $Zm$iop, $rot", "", []>, Sched<[]> { + bits<5> Zda; + bits<5> Zn; + bits<2> rot; + let Inst{31-24} = 0b01000100; + let Inst{23-22} = sz; + let Inst{21} = 0b1; + let Inst{15-12} = opc; + let Inst{11-10} = rot; + let Inst{9-5} = Zn; + let Inst{4-0} = Zda; + + let Constraints = "$Zda = $_Zda"; + let DestructiveInstType = DestructiveOther; + let ElementSize = ElementSizeNone; +} + +multiclass sve2_cintx_dot_by_indexed_elem { + def _S : sve2_complex_int_arith_indexed<0b10, 0b0100, asm, ZPR32, ZPR8, ZPR3b8, VectorIndexS32b> { + bits<2> iop; + bits<3> Zm; + let Inst{20-19} = iop; + let Inst{18-16} = Zm; + } + def _D : sve2_complex_int_arith_indexed<0b11, 0b0100, asm, ZPR64, ZPR16, ZPR4b16, VectorIndexD32b> { + bit iop; + bits<4> Zm; + let Inst{20} = iop; + let Inst{19-16} = Zm; + } + def : Pat<(nxv4i32 (op (nxv4i32 ZPR32:$Op1), (nxv16i8 ZPR8:$Op2), (nxv16i8 ZPR8:$Op3), + (i32 VectorIndexS32b:$idx), (i32 complexrotateop:$imm))), + (!cast(NAME # "_S") ZPR32:$Op1, ZPR8:$Op2, ZPR8:$Op3, VectorIndexS32b:$idx, complexrotateop:$imm)>; + def : Pat<(nxv2i64 (op (nxv2i64 ZPR64:$Op1), (nxv8i16 ZPR16:$Op2), (nxv8i16 ZPR16:$Op3), + (i32 VectorIndexD32b:$idx), (i32 complexrotateop:$imm))), + (!cast(NAME # "_D") ZPR64:$Op1, ZPR16:$Op2, ZPR16:$Op3, VectorIndexD32b:$idx, complexrotateop:$imm)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Complex Multiply-Add - Indexed Group +//===----------------------------------------------------------------------===// + +multiclass sve2_cmla_by_indexed_elem { + def _H : sve2_complex_int_arith_indexed<0b10, { 0b011, opc }, asm, ZPR16, ZPR16, ZPR3b16, VectorIndexS32b> { + bits<2> iop; + bits<3> Zm; + let Inst{20-19} = iop; + let Inst{18-16} = Zm; + } + def _S : sve2_complex_int_arith_indexed<0b11, { 0b011, opc }, asm, ZPR32, ZPR32, ZPR4b32, VectorIndexD32b> { + bit iop; + bits<4> Zm; + let Inst{20} = iop; + let Inst{19-16} = Zm; + } + def : Pat<(nxv8i16 (op (nxv8i16 ZPR16:$Op1), (nxv8i16 ZPR16:$Op2), (nxv8i16 ZPR16:$Op3), + (i32 VectorIndexS32b:$idx), (i32 complexrotateop:$imm))), + (!cast(NAME # "_H") ZPR16:$Op1, ZPR16:$Op2, ZPR16:$Op3, VectorIndexS32b:$idx, complexrotateop:$imm)>; + def : Pat<(nxv4i32 (op (nxv4i32 ZPR32:$Op1), (nxv4i32 ZPR32:$Op2), (nxv4i32 ZPR32:$Op3), + (i32 VectorIndexD32b:$idx), (i32 complexrotateop:$imm))), + (!cast(NAME # "_S") ZPR32:$Op1, ZPR32:$Op2, ZPR32:$Op3, VectorIndexD32b:$idx, complexrotateop:$imm)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Integer Multiply - Unpredicated Group +//===----------------------------------------------------------------------===// + +class sve2_int_mul sz, bits<3> opc, string asm, ZPRRegOp zprty> +: I<(outs zprty:$Zd), (ins zprty:$Zn, zprty:$Zm), + asm, "\t$Zd, $Zn, $Zm", "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zm; + bits<5> Zn; + let Inst{31-24} = 0b00000100; + let Inst{23-22} = sz; + let Inst{21} = 0b1; + let Inst{20-16} = Zm; + let Inst{15-13} = 0b011; + let Inst{12-10} = opc; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +multiclass sve2_int_mul_single sz, bits<3> opc, string asm, + SDPatternOperator op, ZPRRegOp zprty, + ValueType vt> { + def NAME : sve2_int_mul; + def : SVE_2_Op_Pat(NAME)>; +} + +multiclass sve2_int_mul opc, string asm, + SDPatternOperator op = null_frag> { + defm _B : sve2_int_mul_single<0b00, opc, asm, op, ZPR8, nxv16i8>; + defm _H : sve2_int_mul_single<0b01, opc, asm, op, ZPR16, nxv8i16>; + defm _S : sve2_int_mul_single<0b10, opc, asm, op, ZPR32, nxv4i32>; + defm _D : sve2_int_mul_single<0b11, opc, asm, op, ZPR64, nxv2i64>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Integer Multiply - Indexed Group +//===----------------------------------------------------------------------===// + +class sve2_int_mul_by_indexed_elem sz, bits<4> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2, + ZPRRegOp zprty3, Operand itype> +: I<(outs zprty1:$Zd), (ins zprty2:$Zn, zprty3:$Zm, itype:$iop), + asm, "\t$Zd, $Zn, $Zm$iop", "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + let Inst{31-24} = 0b01000100; + let Inst{23-22} = sz; + let Inst{21} = 0b1; + let Inst{15-14} = 0b11; + let Inst{13-10} = opc; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +multiclass sve2_int_mul_by_indexed_elem opc, string asm, + SDPatternOperator op> { + def _H : sve2_int_mul_by_indexed_elem<{0, ?}, opc, asm, ZPR16, ZPR16, ZPR3b16, VectorIndexH32b> { + bits<3> Zm; + bits<3> iop; + let Inst{22} = iop{2}; + let Inst{20-19} = iop{1-0}; + let Inst{18-16} = Zm; + } + def _S : sve2_int_mul_by_indexed_elem<0b10, opc, asm, ZPR32, ZPR32, ZPR3b32, VectorIndexS32b> { + bits<3> Zm; + bits<2> iop; + let Inst{20-19} = iop; + let Inst{18-16} = Zm; + } + def _D : sve2_int_mul_by_indexed_elem<0b11, opc, asm, ZPR64, ZPR64, ZPR4b64, VectorIndexD32b> { + bits<4> Zm; + bit iop; + let Inst{20} = iop; + let Inst{19-16} = Zm; + } + def : SVE_3_Op_Imm_Pat(NAME # _H)>; + def : SVE_3_Op_Imm_Pat(NAME # _S)>; + def : SVE_3_Op_Imm_Pat(NAME # _D)>; +} + +multiclass sve2_int_mul_long_by_indexed_elem opc, string asm, SDPatternOperator op> { + def _S : sve2_int_mul_by_indexed_elem<0b10, { opc{2-1}, ?, opc{0} }, asm, + ZPR32, ZPR16, ZPR3b16, VectorIndexH32b> { + bits<3> Zm; + bits<3> iop; + let Inst{20-19} = iop{2-1}; + let Inst{18-16} = Zm; + let Inst{11} = iop{0}; + } + def _D : sve2_int_mul_by_indexed_elem<0b11, { opc{2-1}, ?, opc{0} }, asm, + ZPR64, ZPR32, ZPR4b32, VectorIndexS32b> { + bits<4> Zm; + bits<2> iop; + let Inst{20} = iop{1}; + let Inst{19-16} = Zm; + let Inst{11} = iop{0}; + } + def : SVE_3_Op_Imm_Pat(NAME # _S)>; + def : SVE_3_Op_Imm_Pat(NAME # _D)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Integer - Predicated Group +//===----------------------------------------------------------------------===// + +class sve2_int_arith_pred sz, bits<6> opc, string asm, + ZPRRegOp zprty> +: I<(outs zprty:$Zdn), (ins PPR3bAny:$Pg, zprty:$_Zdn, zprty:$Zm), + asm, "\t$Zdn, $Pg/m, $_Zdn, $Zm", "", []>, Sched<[]> { + bits<3> Pg; + bits<5> Zm; + bits<5> Zdn; + let Inst{31-24} = 0b01000100; + let Inst{23-22} = sz; + let Inst{21-20} = 0b01; + let Inst{20-16} = opc{5-1}; + let Inst{15-14} = 0b10; + let Inst{13} = opc{0}; + let Inst{12-10} = Pg; + let Inst{9-5} = Zm; + let Inst{4-0} = Zdn; + + let Constraints = "$Zdn = $_Zdn"; + let DestructiveInstType = DestructiveBinary; + let ElementSize = zprty.ElementSize; +} + +multiclass sve2_int_arith_pred opc, string asm, string psName, + SDPatternOperator op = null_frag> { + def _B : sve2_int_arith_pred<0b00, opc, asm, ZPR8>, SVEPseudo2Instr; + def _H : sve2_int_arith_pred<0b01, opc, asm, ZPR16>, SVEPseudo2Instr; + def _S : sve2_int_arith_pred<0b10, opc, asm, ZPR32>, SVEPseudo2Instr; + def _D : sve2_int_arith_pred<0b11, opc, asm, ZPR64>, SVEPseudo2Instr; + + def _B_UNDEF : PredTwoOpConstrainedPseudo; + def _H_UNDEF : PredTwoOpConstrainedPseudo; + def _S_UNDEF : PredTwoOpConstrainedPseudo; + def _D_UNDEF : PredTwoOpConstrainedPseudo; + + def _B_ZERO : PredTwoOpConstrainedPseudo; + def _H_ZERO : PredTwoOpConstrainedPseudo; + def _S_ZERO : PredTwoOpConstrainedPseudo; + def _D_ZERO : PredTwoOpConstrainedPseudo; + + def : SVE_3_Op_Pat_SelZero(NAME # _B_ZERO)>; + def : SVE_3_Op_Pat_SelZero(NAME # _H_ZERO)>; + def : SVE_3_Op_Pat_SelZero(NAME # _S_ZERO)>; + def : SVE_3_Op_Pat_SelZero(NAME # _D_ZERO)>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +class sve2_int_sadd_long_accum_pairwise sz, bit U, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2> +: I<(outs zprty1:$Zda), (ins PPR3bAny:$Pg, zprty1:$_Zda, zprty2:$Zn), + asm, "\t$Zda, $Pg/m, $Zn", "", []>, Sched<[]> { + bits<3> Pg; + bits<5> Zn; + bits<5> Zda; + let Inst{31-24} = 0b01000100; + let Inst{23-22} = sz; + let Inst{21-17} = 0b00010; + let Inst{16} = U; + let Inst{15-13} = 0b101; + let Inst{12-10} = Pg; + let Inst{9-5} = Zn; + let Inst{4-0} = Zda; + + let Constraints = "$Zda = $_Zda"; + let DestructiveInstType = DestructiveBinary; + let ElementSize = zprty1.ElementSize; +} + +multiclass sve2_int_sadd_long_accum_pairwise { + def _H : sve2_int_sadd_long_accum_pairwise<0b01, U, asm, ZPR16, ZPR8>, SVEPseudo2Instr; + def _S : sve2_int_sadd_long_accum_pairwise<0b10, U, asm, ZPR32, ZPR16>, SVEPseudo2Instr; + def _D : sve2_int_sadd_long_accum_pairwise<0b11, U, asm, ZPR64, ZPR32>, SVEPseudo2Instr; + + def _H_UNDEF : PredTwoOpConstrainedPseudo; + def _S_UNDEF : PredTwoOpConstrainedPseudo; + def _D_UNDEF : PredTwoOpConstrainedPseudo; + + def _H_ZERO : PredTwoOpConstrainedPseudo; + def _S_ZERO : PredTwoOpConstrainedPseudo; + def _D_ZERO : PredTwoOpConstrainedPseudo; + + def : SVE_3_Op_Pat_SelZero(NAME # _H_ZERO)>; + def : SVE_3_Op_Pat_SelZero(NAME # _S_ZERO)>; + def : SVE_3_Op_Pat_SelZero(NAME # _D_ZERO)>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +class sve2_int_un_pred_arit sz, bit Q, bits<2> opc, + string asm, ZPRRegOp zprty> +: I<(outs zprty:$Zd), (ins zprty:$_Zd, PPR3bAny:$Pg, zprty:$Zn), + asm, "\t$Zd, $Pg/m, $Zn", + "", + []>, Sched<[]> { + bits<3> Pg; + bits<5> Zd; + bits<5> Zn; + let Inst{31-24} = 0b01000100; + let Inst{23-22} = sz; + let Inst{21-20} = 0b00; + let Inst{19} = Q; + let Inst{18} = 0b0; + let Inst{17-16} = opc; + let Inst{15-13} = 0b101; + let Inst{12-10} = Pg; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; + + let Constraints = "$Zd = $_Zd"; + let DestructiveInstType = DestructiveUnary; + let ElementSize = zprty.ElementSize; +} + +multiclass sve2_int_un_pred_arit_s opc, string asm, + SDPatternOperator op> { + def _S : sve2_int_un_pred_arit<0b10, opc{2}, opc{1-0}, asm, ZPR32>; + def : SVE_3_Op_Pat(NAME # _S)>; +} + +multiclass sve2_int_un_pred_arit opc, string asm, SDPatternOperator op> { + def _B : sve2_int_un_pred_arit<0b00, opc{2}, opc{1-0}, asm, ZPR8>; + def _H : sve2_int_un_pred_arit<0b01, opc{2}, opc{1-0}, asm, ZPR16>; + def _S : sve2_int_un_pred_arit<0b10, opc{2}, opc{1-0}, asm, ZPR32>; + def _D : sve2_int_un_pred_arit<0b11, opc{2}, opc{1-0}, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Widening Integer Arithmetic Group +//===----------------------------------------------------------------------===// + +class sve2_wide_int_arith sz, bits<5> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2, ZPRRegOp zprty3> +: I<(outs zprty1:$Zd), (ins zprty2:$Zn, zprty3:$Zm), + asm, "\t$Zd, $Zn, $Zm", "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<5> Zm; + let Inst{31-24} = 0b01000101; + let Inst{23-22} = sz; + let Inst{21} = 0b0; + let Inst{20-16} = Zm; + let Inst{15} = 0b0; + let Inst{14-10} = opc; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +multiclass sve2_wide_int_arith_pmul sz, bits<5> opc, string asm, + SDPatternOperator op> { + def NAME :sve2_wide_int_arith; + // To avoid using 128 bit elements in the IR, pattern below works with + // llvm intrinsics with _pair suffix, to reflect that + // _Q is implemented as a pair of _D. + def : SVE_2_Op_Pat(NAME)>; +} + +multiclass sve2_wide_int_arith_long opc, string asm, + SDPatternOperator op> { + def _H : sve2_wide_int_arith<0b01, opc, asm, ZPR16, ZPR8, ZPR8>; + def _S : sve2_wide_int_arith<0b10, opc, asm, ZPR32, ZPR16, ZPR16>; + def _D : sve2_wide_int_arith<0b11, opc, asm, ZPR64, ZPR32, ZPR32>; + + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; +} + +multiclass sve2_wide_int_arith_wide opc, string asm, + SDPatternOperator op> { + def _H : sve2_wide_int_arith<0b01, { 0b10, opc }, asm, ZPR16, ZPR16, ZPR8>; + def _S : sve2_wide_int_arith<0b10, { 0b10, opc }, asm, ZPR32, ZPR32, ZPR16>; + def _D : sve2_wide_int_arith<0b11, { 0b10, opc }, asm, ZPR64, ZPR64, ZPR32>; + + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; +} + +multiclass sve2_pmul_long opc, string asm, SDPatternOperator op> { + def _H : sve2_wide_int_arith<0b01, {0b1101, opc}, asm, ZPR16, ZPR8, ZPR8>; + def _D : sve2_wide_int_arith<0b11, {0b1101, opc}, asm, ZPR64, ZPR32, ZPR32>; + + // To avoid using 128 bit elements in the IR, patterns below work with + // llvm intrinsics with _pair suffix, to reflect that + // _H is implemented as a pair of _B and _D is implemented as a pair of _S. + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _D)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Misc Group +//===----------------------------------------------------------------------===// + +class sve2_misc sz, bits<4> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2> +: I<(outs zprty1:$Zd), (ins zprty2:$Zn, zprty2:$Zm), + asm, "\t$Zd, $Zn, $Zm", "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<5> Zm; + let Inst{31-24} = 0b01000101; + let Inst{23-22} = sz; + let Inst{21} = 0b0; + let Inst{20-16} = Zm; + let Inst{15-14} = 0b10; + let Inst{13-10} = opc; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +multiclass sve2_misc_bitwise opc, string asm, SDPatternOperator op> { + def _B : sve2_misc<0b00, opc, asm, ZPR8, ZPR8>; + def _H : sve2_misc<0b01, opc, asm, ZPR16, ZPR16>; + def _S : sve2_misc<0b10, opc, asm, ZPR32, ZPR32>; + def _D : sve2_misc<0b11, opc, asm, ZPR64, ZPR64>; + + def : SVE_2_Op_Pat(NAME # _B)>; + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; +} + +multiclass sve2_misc_int_addsub_long_interleaved opc, string asm, + SDPatternOperator op> { + def _H : sve2_misc<0b01, { 0b00, opc }, asm, ZPR16, ZPR8>; + def _S : sve2_misc<0b10, { 0b00, opc }, asm, ZPR32, ZPR16>; + def _D : sve2_misc<0b11, { 0b00, opc }, asm, ZPR64, ZPR32>; + + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; +} + +class sve2_bitwise_xor_interleaved sz, bits<1> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2> +: I<(outs zprty1:$Zd), (ins zprty1:$_Zd, zprty2:$Zn, zprty2:$Zm), + asm, "\t$Zd, $Zn, $Zm", "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<5> Zm; + let Inst{31-24} = 0b01000101; + let Inst{23-22} = sz; + let Inst{21} = 0b0; + let Inst{20-16} = Zm; + let Inst{15-11} = 0b10010; + let Inst{10} = opc; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; + + let Constraints = "$Zd = $_Zd"; + let DestructiveInstType = DestructiveOther; + let ElementSize = ElementSizeNone; +} + +multiclass sve2_bitwise_xor_interleaved { + def _B : sve2_bitwise_xor_interleaved<0b00, opc, asm, ZPR8, ZPR8>; + def _H : sve2_bitwise_xor_interleaved<0b01, opc, asm, ZPR16, ZPR16>; + def _S : sve2_bitwise_xor_interleaved<0b10, opc, asm, ZPR32, ZPR32>; + def _D : sve2_bitwise_xor_interleaved<0b11, opc, asm, ZPR64, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +class sve2_bitwise_shift_left_long tsz8_64, bits<2> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2, + Operand immtype> +: I<(outs zprty1:$Zd), (ins zprty2:$Zn, immtype:$imm), + asm, "\t$Zd, $Zn, $imm", + "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<5> imm; + let Inst{31-23} = 0b010001010; + let Inst{22} = tsz8_64{2}; + let Inst{21} = 0b0; + let Inst{20-19} = tsz8_64{1-0}; + let Inst{18-16} = imm{2-0}; // imm3 + let Inst{15-12} = 0b1010; + let Inst{11-10} = opc; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +multiclass sve2_bitwise_shift_left_long opc, string asm, + SDPatternOperator op> { + def _H : sve2_bitwise_shift_left_long<{0,0,1}, opc, asm, + ZPR16, ZPR8, vecshiftL8>; + def _S : sve2_bitwise_shift_left_long<{0,1,?}, opc, asm, + ZPR32, ZPR16, vecshiftL16> { + let Inst{19} = imm{3}; + } + def _D : sve2_bitwise_shift_left_long<{1,?,?}, opc, asm, + ZPR64, ZPR32, vecshiftL32> { + let Inst{20-19} = imm{4-3}; + } + def : SVE_2_Op_Imm_Pat(NAME # _H)>; + def : SVE_2_Op_Imm_Pat(NAME # _S)>; + def : SVE_2_Op_Imm_Pat(NAME # _D)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Accumulate Group +//===----------------------------------------------------------------------===// + +class sve2_int_bin_shift_imm tsz8_64, bit opc, string asm, + ZPRRegOp zprty, Operand immtype> +: I<(outs zprty:$Zd), (ins zprty:$_Zd, zprty:$Zn, immtype:$imm), + asm, "\t$Zd, $Zn, $imm", + "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<6> imm; + let Inst{31-24} = 0b01000101; + let Inst{23-22} = tsz8_64{3-2}; + let Inst{21} = 0b0; + let Inst{20-19} = tsz8_64{1-0}; + let Inst{18-16} = imm{2-0}; // imm3 + let Inst{15-11} = 0b11110; + let Inst{10} = opc; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; + + let Constraints = "$Zd = $_Zd"; +} + +multiclass sve2_int_bin_shift_imm_left { + def _B : sve2_int_bin_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftL8>; + def _H : sve2_int_bin_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftL16> { + let Inst{19} = imm{3}; + } + def _S : sve2_int_bin_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftL32> { + let Inst{20-19} = imm{4-3}; + } + def _D : sve2_int_bin_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftL64> { + let Inst{22} = imm{5}; + let Inst{20-19} = imm{4-3}; + } + + def : SVE_3_Op_Imm_Pat(NAME # _B)>; + def : SVE_3_Op_Imm_Pat(NAME # _H)>; + def : SVE_3_Op_Imm_Pat(NAME # _S)>; + def : SVE_3_Op_Imm_Pat(NAME # _D)>; +} + +multiclass sve2_int_bin_shift_imm_right { + def _B : sve2_int_bin_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftR8>; + def _H : sve2_int_bin_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftR16> { + let Inst{19} = imm{3}; + } + def _S : sve2_int_bin_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftR32> { + let Inst{20-19} = imm{4-3}; + } + def _D : sve2_int_bin_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftR64> { + let Inst{22} = imm{5}; + let Inst{20-19} = imm{4-3}; + } + def : SVE_3_Op_Imm_Pat(NAME # _B)>; + def : SVE_3_Op_Imm_Pat(NAME # _H)>; + def : SVE_3_Op_Imm_Pat(NAME # _S)>; + def : SVE_3_Op_Imm_Pat(NAME # _D)>; +} + +class sve2_int_bin_accum_shift_imm tsz8_64, bits<2> opc, + string asm, ZPRRegOp zprty, + Operand immtype> +: I<(outs zprty:$Zda), (ins zprty:$_Zda, zprty:$Zn, immtype:$imm), + asm, "\t$Zda, $Zn, $imm", + "", []>, Sched<[]> { + bits<5> Zda; + bits<5> Zn; + bits<6> imm; + let Inst{31-24} = 0b01000101; + let Inst{23-22} = tsz8_64{3-2}; + let Inst{21} = 0b0; + let Inst{20-19} = tsz8_64{1-0}; + let Inst{18-16} = imm{2-0}; // imm3 + let Inst{15-12} = 0b1110; + let Inst{11-10} = opc; + let Inst{9-5} = Zn; + let Inst{4-0} = Zda; + + let Constraints = "$Zda = $_Zda"; + let DestructiveInstType = DestructiveOther; + let ElementSize = ElementSizeNone; +} + +multiclass sve2_int_bin_accum_shift_imm_right opc, string asm, + SDPatternOperator op> { + def _B : sve2_int_bin_accum_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftR8>; + def _H : sve2_int_bin_accum_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftR16> { + let Inst{19} = imm{3}; + } + def _S : sve2_int_bin_accum_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftR32> { + let Inst{20-19} = imm{4-3}; + } + def _D : sve2_int_bin_accum_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftR64> { + let Inst{22} = imm{5}; + let Inst{20-19} = imm{4-3}; + } + def : SVE_3_Op_Imm_Pat(NAME # _B)>; + def : SVE_3_Op_Imm_Pat(NAME # _H)>; + def : SVE_3_Op_Imm_Pat(NAME # _S)>; + def : SVE_3_Op_Imm_Pat(NAME # _D)>; +} + +class sve2_int_cadd sz, bit opc, string asm, ZPRRegOp zprty> +: I<(outs zprty:$Zdn), (ins zprty:$_Zdn, zprty:$Zm, complexrotateopodd:$rot), + asm, "\t$Zdn, $_Zdn, $Zm, $rot", "", []>, Sched<[]> { + bits<5> Zdn; + bits<5> Zm; + bit rot; + let Inst{31-24} = 0b01000101; + let Inst{23-22} = sz; + let Inst{21-17} = 0b00000; + let Inst{16} = opc; + let Inst{15-11} = 0b11011; + let Inst{10} = rot; + let Inst{9-5} = Zm; + let Inst{4-0} = Zdn; + + let Constraints = "$Zdn = $_Zdn"; + let DestructiveInstType = DestructiveOther; + let ElementSize = ElementSizeNone; +} + +multiclass sve2_int_cadd { + def _B : sve2_int_cadd<0b00, opc, asm, ZPR8>; + def _H : sve2_int_cadd<0b01, opc, asm, ZPR16>; + def _S : sve2_int_cadd<0b10, opc, asm, ZPR32>; + def _D : sve2_int_cadd<0b11, opc, asm, ZPR64>; + + def : SVE_3_Op_Imm_Pat(NAME # _B)>; + def : SVE_3_Op_Imm_Pat(NAME # _H)>; + def : SVE_3_Op_Imm_Pat(NAME # _S)>; + def : SVE_3_Op_Imm_Pat(NAME # _D)>; +} + +class sve2_int_absdiff_accum sz, bits<4> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2> +: I<(outs zprty1:$Zda), (ins zprty1:$_Zda, zprty2:$Zn, zprty2:$Zm), + asm, "\t$Zda, $Zn, $Zm", "", []>, Sched<[]> { + bits<5> Zda; + bits<5> Zn; + bits<5> Zm; + let Inst{31-24} = 0b01000101; + let Inst{23-22} = sz; + let Inst{21} = 0b0; + let Inst{20-16} = Zm; + let Inst{15-14} = 0b11; + let Inst{13-10} = opc; + let Inst{9-5} = Zn; + let Inst{4-0} = Zda; + + let Constraints = "$Zda = $_Zda"; + let DestructiveInstType = DestructiveOther; + let ElementSize = ElementSizeNone; +} + +multiclass sve2_int_absdiff_accum { + def _B : sve2_int_absdiff_accum<0b00, { 0b111, opc }, asm, ZPR8, ZPR8>; + def _H : sve2_int_absdiff_accum<0b01, { 0b111, opc }, asm, ZPR16, ZPR16>; + def _S : sve2_int_absdiff_accum<0b10, { 0b111, opc }, asm, ZPR32, ZPR32>; + def _D : sve2_int_absdiff_accum<0b11, { 0b111, opc }, asm, ZPR64, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +multiclass sve2_int_absdiff_accum_long opc, string asm, + SDPatternOperator op> { + def _H : sve2_int_absdiff_accum<0b01, { 0b00, opc }, asm, ZPR16, ZPR8>; + def _S : sve2_int_absdiff_accum<0b10, { 0b00, opc }, asm, ZPR32, ZPR16>; + def _D : sve2_int_absdiff_accum<0b11, { 0b00, opc }, asm, ZPR64, ZPR32>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +multiclass sve2_int_addsub_long_carry opc, string asm, + SDPatternOperator op> { + def _S : sve2_int_absdiff_accum<{ opc{1}, 0b0 }, { 0b010, opc{0} }, asm, + ZPR32, ZPR32>; + def _D : sve2_int_absdiff_accum<{ opc{1}, 0b1 }, { 0b010, opc{0} }, asm, + ZPR64, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Narrowing Group +//===----------------------------------------------------------------------===// + +class sve2_int_bin_shift_imm_narrow_bottom tsz8_64, bits<3> opc, + string asm, ZPRRegOp zprty1, + ZPRRegOp zprty2, Operand immtype> +: I<(outs zprty1:$Zd), (ins zprty2:$Zn, immtype:$imm), + asm, "\t$Zd, $Zn, $imm", + "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<5> imm; + let Inst{31-23} = 0b010001010; + let Inst{22} = tsz8_64{2}; + let Inst{21} = 0b1; + let Inst{20-19} = tsz8_64{1-0}; + let Inst{18-16} = imm{2-0}; // imm3 + let Inst{15-14} = 0b00; + let Inst{13-11} = opc; + let Inst{10} = 0b0; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +multiclass sve2_int_bin_shift_imm_right_narrow_bottom opc, string asm, + SDPatternOperator op> { + def _B : sve2_int_bin_shift_imm_narrow_bottom<{0,0,1}, opc, asm, ZPR8, ZPR16, + vecshiftR8>; + def _H : sve2_int_bin_shift_imm_narrow_bottom<{0,1,?}, opc, asm, ZPR16, ZPR32, + vecshiftR16> { + let Inst{19} = imm{3}; + } + def _S : sve2_int_bin_shift_imm_narrow_bottom<{1,?,?}, opc, asm, ZPR32, ZPR64, + vecshiftR32> { + let Inst{20-19} = imm{4-3}; + } + def : SVE_2_Op_Imm_Pat(NAME # _B)>; + def : SVE_2_Op_Imm_Pat(NAME # _H)>; + def : SVE_2_Op_Imm_Pat(NAME # _S)>; +} + +class sve2_int_bin_shift_imm_narrow_top tsz8_64, bits<3> opc, + string asm, ZPRRegOp zprty1, + ZPRRegOp zprty2, Operand immtype> +: I<(outs zprty1:$Zd), (ins zprty1:$_Zd, zprty2:$Zn, immtype:$imm), + asm, "\t$Zd, $Zn, $imm", + "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<5> imm; + let Inst{31-23} = 0b010001010; + let Inst{22} = tsz8_64{2}; + let Inst{21} = 0b1; + let Inst{20-19} = tsz8_64{1-0}; + let Inst{18-16} = imm{2-0}; // imm3 + let Inst{15-14} = 0b00; + let Inst{13-11} = opc; + let Inst{10} = 0b1; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; + + let Constraints = "$Zd = $_Zd"; +} + +multiclass sve2_int_bin_shift_imm_right_narrow_top opc, string asm, + SDPatternOperator op> { + def _B : sve2_int_bin_shift_imm_narrow_top<{0,0,1}, opc, asm, ZPR8, ZPR16, + vecshiftR8>; + def _H : sve2_int_bin_shift_imm_narrow_top<{0,1,?}, opc, asm, ZPR16, ZPR32, + vecshiftR16> { + let Inst{19} = imm{3}; + } + def _S : sve2_int_bin_shift_imm_narrow_top<{1,?,?}, opc, asm, ZPR32, ZPR64, + vecshiftR32> { + let Inst{20-19} = imm{4-3}; + } + def : SVE_3_Op_Imm_Pat(NAME # _B)>; + def : SVE_3_Op_Imm_Pat(NAME # _H)>; + def : SVE_3_Op_Imm_Pat(NAME # _S)>; +} + +class sve2_int_addsub_narrow_high_bottom sz, bits<2> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2> +: I<(outs zprty1:$Zd), (ins zprty2:$Zn, zprty2:$Zm), + asm, "\t$Zd, $Zn, $Zm", "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<5> Zm; + let Inst{31-24} = 0b01000101; + let Inst{23-22} = sz; + let Inst{21} = 0b1; + let Inst{20-16} = Zm; + let Inst{15-13} = 0b011; + let Inst{12-11} = opc; // S, R + let Inst{10} = 0b0; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +multiclass sve2_int_addsub_narrow_high_bottom opc, string asm, + SDPatternOperator op> { + def _B : sve2_int_addsub_narrow_high_bottom<0b01, opc, asm, ZPR8, ZPR16>; + def _H : sve2_int_addsub_narrow_high_bottom<0b10, opc, asm, ZPR16, ZPR32>; + def _S : sve2_int_addsub_narrow_high_bottom<0b11, opc, asm, ZPR32, ZPR64>; + + def : SVE_2_Op_Pat(NAME # _B)>; + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; +} + +class sve2_int_addsub_narrow_high_top sz, bits<2> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2> +: I<(outs zprty1:$Zd), (ins zprty1:$_Zd, zprty2:$Zn, zprty2:$Zm), + asm, "\t$Zd, $Zn, $Zm", "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<5> Zm; + let Inst{31-24} = 0b01000101; + let Inst{23-22} = sz; + let Inst{21} = 0b1; + let Inst{20-16} = Zm; + let Inst{15-13} = 0b011; + let Inst{12-11} = opc; // S, R + let Inst{10} = 0b1; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; + + let Constraints = "$Zd = $_Zd"; +} + +multiclass sve2_int_addsub_narrow_high_top opc, string asm, + SDPatternOperator op> { + def _B : sve2_int_addsub_narrow_high_top<0b01, opc, asm, ZPR8, ZPR16>; + def _H : sve2_int_addsub_narrow_high_top<0b10, opc, asm, ZPR16, ZPR32>; + def _S : sve2_int_addsub_narrow_high_top<0b11, opc, asm, ZPR32, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; +} + +class sve2_int_sat_extract_narrow_bottom tsz8_64, bits<2> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2> +: I<(outs zprty1:$Zd), (ins zprty2:$Zn), + asm, "\t$Zd, $Zn", "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + let Inst{31-23} = 0b010001010; + let Inst{22} = tsz8_64{2}; + let Inst{21} = 0b1; + let Inst{20-19} = tsz8_64{1-0}; + let Inst{18-13} = 0b000010; + let Inst{12-11} = opc; + let Inst{10} = 0b0; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +multiclass sve2_int_sat_extract_narrow_bottom opc, string asm, + SDPatternOperator op> { + def _B : sve2_int_sat_extract_narrow_bottom<0b001, opc, asm, ZPR8, ZPR16>; + def _H : sve2_int_sat_extract_narrow_bottom<0b010, opc, asm, ZPR16, ZPR32>; + def _S : sve2_int_sat_extract_narrow_bottom<0b100, opc, asm, ZPR32, ZPR64>; + + def : SVE_1_Op_Pat(NAME # _B)>; + def : SVE_1_Op_Pat(NAME # _H)>; + def : SVE_1_Op_Pat(NAME # _S)>; +} + +class sve2_int_sat_extract_narrow_top tsz8_64, bits<2> opc, string asm, + ZPRRegOp zprty1, ZPRRegOp zprty2> +: I<(outs zprty1:$Zd), (ins zprty1:$_Zd, zprty2:$Zn), + asm, "\t$Zd, $Zn", "", []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + let Inst{31-23} = 0b010001010; + let Inst{22} = tsz8_64{2}; + let Inst{21} = 0b1; + let Inst{20-19} = tsz8_64{1-0}; + let Inst{18-13} = 0b000010; + let Inst{12-11} = opc; + let Inst{10} = 0b1; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; + + let Constraints = "$Zd = $_Zd"; +} + +multiclass sve2_int_sat_extract_narrow_top opc, string asm, + SDPatternOperator op> { + def _B : sve2_int_sat_extract_narrow_top<0b001, opc, asm, ZPR8, ZPR16>; + def _H : sve2_int_sat_extract_narrow_top<0b010, opc, asm, ZPR16, ZPR32>; + def _S : sve2_int_sat_extract_narrow_top<0b100, opc, asm, ZPR32, ZPR64>; + + def : SVE_2_Op_Pat(NAME # _B)>; + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; } //===----------------------------------------------------------------------===// @@ -1785,43 +3784,72 @@ let Inst{4-0} = Zd; let Constraints = "$Zd = $_Zd"; - let DestructiveInstType = Destructive; + let DestructiveInstType = DestructiveUnary; let ElementSize = zprty.ElementSize; } -multiclass sve_int_un_pred_arit_0 opc, string asm> { +multiclass sve_int_un_pred_arit_0 opc, string asm, + SDPatternOperator op> { def _B : sve_int_un_pred_arit<0b00, { opc, 0b0 }, asm, ZPR8>; def _H : sve_int_un_pred_arit<0b01, { opc, 0b0 }, asm, ZPR16>; def _S : sve_int_un_pred_arit<0b10, { opc, 0b0 }, asm, ZPR32>; def _D : sve_int_un_pred_arit<0b11, { opc, 0b0 }, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } -multiclass sve_int_un_pred_arit_0_h opc, string asm> { +multiclass sve_int_un_pred_arit_0_h opc, string asm, + SDPatternOperator op> { def _H : sve_int_un_pred_arit<0b01, { opc, 0b0 }, asm, ZPR16>; def _S : sve_int_un_pred_arit<0b10, { opc, 0b0 }, asm, ZPR32>; def _D : sve_int_un_pred_arit<0b11, { opc, 0b0 }, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } -multiclass sve_int_un_pred_arit_0_w opc, string asm> { +multiclass sve_int_un_pred_arit_0_w opc, string asm, + SDPatternOperator op> { def _S : sve_int_un_pred_arit<0b10, { opc, 0b0 }, asm, ZPR32>; def _D : sve_int_un_pred_arit<0b11, { opc, 0b0 }, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } -multiclass sve_int_un_pred_arit_0_d opc, string asm> { +multiclass sve_int_un_pred_arit_0_d opc, string asm, + SDPatternOperator op> { def _D : sve_int_un_pred_arit<0b11, { opc, 0b0 }, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _D)>; } -multiclass sve_int_un_pred_arit_1 opc, string asm> { +multiclass sve_int_un_pred_arit_1 opc, string asm, + SDPatternOperator op> { def _B : sve_int_un_pred_arit<0b00, { opc, 0b1 }, asm, ZPR8>; def _H : sve_int_un_pred_arit<0b01, { opc, 0b1 }, asm, ZPR16>; def _S : sve_int_un_pred_arit<0b10, { opc, 0b1 }, asm, ZPR32>; def _D : sve_int_un_pred_arit<0b11, { opc, 0b1 }, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } -multiclass sve_int_un_pred_arit_1_fp opc, string asm> { +multiclass sve_int_un_pred_arit_1_fp opc, string asm, + SDPatternOperator op> { def _H : sve_int_un_pred_arit<0b01, { opc, 0b1 }, asm, ZPR16>; def _S : sve_int_un_pred_arit<0b10, { opc, 0b1 }, asm, ZPR32>; def _D : sve_int_un_pred_arit<0b11, { opc, 0b1 }, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } //===----------------------------------------------------------------------===// @@ -1917,15 +3945,25 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; - let ElementSize = ElementSizeNone; + let DestructiveInstType = DestructiveBinaryShImmUnpred; } -multiclass sve_int_arith_imm0 opc, string asm> { - def _B : sve_int_arith_imm0<0b00, opc, asm, ZPR8, addsub_imm8_opt_lsl_i8>; - def _H : sve_int_arith_imm0<0b01, opc, asm, ZPR16, addsub_imm8_opt_lsl_i16>; - def _S : sve_int_arith_imm0<0b10, opc, asm, ZPR32, addsub_imm8_opt_lsl_i32>; - def _D : sve_int_arith_imm0<0b11, opc, asm, ZPR64, addsub_imm8_opt_lsl_i64>; +multiclass sve_int_arith_imm0 opc, string Ps, string asm> { + def _B : sve_int_arith_imm0<0b00, opc, asm, ZPR8, addsub_imm8_opt_lsl_i8>, + SVEPseudo2Instr; + def _H : sve_int_arith_imm0<0b01, opc, asm, ZPR16, addsub_imm8_opt_lsl_i16>, + SVEPseudo2Instr; + def _S : sve_int_arith_imm0<0b10, opc, asm, ZPR32, addsub_imm8_opt_lsl_i32>, + SVEPseudo2Instr; + def _D : sve_int_arith_imm0<0b11, opc, asm, ZPR64, addsub_imm8_opt_lsl_i64>, + SVEPseudo2Instr; +} + +multiclass sve_int_arith_imm0_zzi { + def _B : UnpredTwoOpImmPseudo; + def _H : UnpredTwoOpImmPseudo; + def _S : UnpredTwoOpImmPseudo; + def _D : UnpredTwoOpImmPseudo; } class sve_int_arith_imm sz8_64, bits<6> opc, string asm, @@ -1944,8 +3982,7 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; - let ElementSize = ElementSizeNone; + let DestructiveInstType = DestructiveOther; } multiclass sve_int_arith_imm1 opc, string asm, Operand immtype> { @@ -1962,6 +3999,7 @@ def _D : sve_int_arith_imm<0b11, 0b110000, asm, ZPR64, simm8>; } + //===----------------------------------------------------------------------===// // SVE Bitwise Logical - Unpredicated Group //===----------------------------------------------------------------------===// @@ -1983,6 +4021,90 @@ let Inst{4-0} = Zd; } +multiclass sve_int_bin_cons_log opc, string asm> { + def NAME : sve_int_bin_cons_log; + + def : InstAlias(NAME) ZPR8:$Zd, ZPR8:$Zn, ZPR8:$Zm), 1>; + def : InstAlias(NAME) ZPR16:$Zd, ZPR16:$Zn, ZPR16:$Zm), 1>; + def : InstAlias(NAME) ZPR32:$Zd, ZPR32:$Zn, ZPR32:$Zm), 1>; +} + +class sve2_int_bitwise_ternary_op_d opc, string asm> +: I<(outs ZPR64:$Zdn), (ins ZPR64:$_Zdn, ZPR64:$Zm, ZPR64:$Zk), + asm, "\t$Zdn, $_Zdn, $Zm, $Zk", + "", + []>, Sched<[]> { + bits<5> Zdn; + bits<5> Zk; + bits<5> Zm; + let Inst{31-24} = 0b00000100; + let Inst{23-22} = opc{2-1}; + let Inst{21} = 0b1; + let Inst{20-16} = Zm; + let Inst{15-11} = 0b00111; + let Inst{10} = opc{0}; + let Inst{9-5} = Zk; + let Inst{4-0} = Zdn; + + let Constraints = "$Zdn = $_Zdn"; + let DestructiveInstType = DestructiveOther; + let ElementSize = ElementSizeNone; +} + +multiclass sve2_int_bitwise_ternary_op opc, string asm> { + def NAME : sve2_int_bitwise_ternary_op_d; + + def : InstAlias(NAME) ZPR8:$Zdn, ZPR8:$Zm, ZPR8:$Zk), 1>; + def : InstAlias(NAME) ZPR16:$Zdn, ZPR16:$Zm, ZPR16:$Zk), 1>; + def : InstAlias(NAME) ZPR32:$Zdn, ZPR32:$Zm, ZPR32:$Zk), 1>; +} + +class sve2_int_rotate_right_imm tsz8_64, string asm, + ZPRRegOp zprty, Operand immtype> +: I<(outs zprty:$Zdn), (ins zprty:$_Zdn, zprty:$Zm, immtype:$imm), + asm, "\t$Zdn, $_Zdn, $Zm, $imm", + "", + []>, Sched<[]> { + bits<5> Zdn; + bits<5> Zm; + bits<6> imm; + let Inst{31-24} = 0b00000100; + let Inst{23-22} = tsz8_64{3-2}; + let Inst{21} = 0b1; + let Inst{20-19} = tsz8_64{1-0}; + let Inst{18-16} = imm{2-0}; // imm3 + let Inst{15-10} = 0b001101; + let Inst{9-5} = Zm; + let Inst{4-0} = Zdn; + + let Constraints = "$Zdn = $_Zdn"; + let DestructiveInstType = DestructiveOther; + let ElementSize = ElementSizeNone; +} + +multiclass sve2_int_rotate_right_imm { + def _B : sve2_int_rotate_right_imm<{0,0,0,1}, asm, ZPR8, vecshiftR8>; + def _H : sve2_int_rotate_right_imm<{0,0,1,?}, asm, ZPR16, vecshiftR16> { + let Inst{19} = imm{3}; + } + def _S : sve2_int_rotate_right_imm<{0,1,?,?}, asm, ZPR32, vecshiftR32> { + let Inst{20-19} = imm{4-3}; + } + def _D : sve2_int_rotate_right_imm<{1,?,?,?}, asm, ZPR64, vecshiftR64> { + let Inst{22} = imm{5}; + let Inst{20-19} = imm{4-3}; + } + def : SVE_3_Op_Imm_Pat(NAME # _B)>; + def : SVE_3_Op_Imm_Pat(NAME # _H)>; + def : SVE_3_Op_Imm_Pat(NAME # _S)>; + def : SVE_3_Op_Imm_Pat(NAME # _D)>; +} //===----------------------------------------------------------------------===// // SVE Integer Wide Immediate - Predicated Group @@ -2006,7 +4128,7 @@ let Inst{4-0} = Zd; let Constraints = "$Zd = $_Zd"; - let DestructiveInstType = Destructive; + let DestructiveInstType = DestructiveOther; let ElementSize = zprty.ElementSize; } @@ -2041,7 +4163,7 @@ let Inst{12-5} = imm{7-0}; // imm8 let Inst{4-0} = Zd; - let DestructiveInstType = Destructive; + let DestructiveInstType = DestructiveOther; let ElementSize = zprty.ElementSize; } @@ -2113,25 +4235,51 @@ let Inst{3-0} = Pd; let Defs = [NZCV]; + let ElementSize = pprty.ElementSize; + let isPTestLike = 1; } -multiclass sve_int_cmp_0 opc, string asm> { +multiclass sve_int_cmp_0 opc, string asm, SDPatternOperator op, + CondCode cc> { def _B : sve_int_cmp<0b0, 0b00, opc, asm, PPR8, ZPR8, ZPR8>; def _H : sve_int_cmp<0b0, 0b01, opc, asm, PPR16, ZPR16, ZPR16>; def _S : sve_int_cmp<0b0, 0b10, opc, asm, PPR32, ZPR32, ZPR32>; def _D : sve_int_cmp<0b0, 0b11, opc, asm, PPR64, ZPR64, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; + + def : SVE_Cmp_Pat0(NAME # _B), PTRUE_B>; + def : SVE_Cmp_Pat0(NAME # _H), PTRUE_H>; + def : SVE_Cmp_Pat0(NAME # _S), PTRUE_S>; + def : SVE_Cmp_Pat0(NAME # _D), PTRUE_D>; + + def : SVE_Cmp_Pat1(NAME # _B)>; + def : SVE_Cmp_Pat1(NAME # _H)>; + def : SVE_Cmp_Pat1(NAME # _S)>; + def : SVE_Cmp_Pat1(NAME # _D)>; } -multiclass sve_int_cmp_0_wide opc, string asm> { +multiclass sve_int_cmp_0_wide opc, string asm, SDPatternOperator op> { def _B : sve_int_cmp<0b0, 0b00, opc, asm, PPR8, ZPR8, ZPR64>; def _H : sve_int_cmp<0b0, 0b01, opc, asm, PPR16, ZPR16, ZPR64>; def _S : sve_int_cmp<0b0, 0b10, opc, asm, PPR32, ZPR32, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; } -multiclass sve_int_cmp_1_wide opc, string asm> { +multiclass sve_int_cmp_1_wide opc, string asm, SDPatternOperator op> { def _B : sve_int_cmp<0b1, 0b00, opc, asm, PPR8, ZPR8, ZPR64>; def _H : sve_int_cmp<0b1, 0b01, opc, asm, PPR16, ZPR16, ZPR64>; def _S : sve_int_cmp<0b1, 0b10, opc, asm, PPR32, ZPR32, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; } @@ -2163,13 +4311,71 @@ let Inst{3-0} = Pd; let Defs = [NZCV]; + let ElementSize = pprty.ElementSize; + let isPTestLike = 1; } -multiclass sve_int_scmp_vi opc, string asm> { +multiclass sve_int_scmp_vi opc, string asm, CondCode cc, + SDPatternOperator op = null_frag, + SDPatternOperator inv_op = null_frag> { def _B : sve_int_scmp_vi<0b00, opc, asm, PPR8, ZPR8, simm5_32b>; def _H : sve_int_scmp_vi<0b01, opc, asm, PPR16, ZPR16, simm5_32b>; def _S : sve_int_scmp_vi<0b10, opc, asm, PPR32, ZPR32, simm5_32b>; def _D : sve_int_scmp_vi<0b11, opc, asm, PPR64, ZPR64, simm5_64b>; + + // IR version + def : Pat<(nxv16i1 (setcc (nxv16i8 ZPR:$Zs1), + (nxv16i8 (AArch64dup (simm5_32b:$imm))), + cc)), + (!cast(NAME # "_B") (PTRUE_B 31), ZPR:$Zs1, simm5_32b:$imm)>; + def : Pat<(nxv8i1 (setcc (nxv8i16 ZPR:$Zs1), + (nxv8i16 (AArch64dup (simm5_32b:$imm))), + cc)), + (!cast(NAME # "_H") (PTRUE_H 31), ZPR:$Zs1, simm5_32b:$imm)>; + def : Pat<(nxv4i1 (setcc (nxv4i32 ZPR:$Zs1), + (nxv4i32 (AArch64dup (simm5_32b:$imm))), + cc)), + (!cast(NAME # "_S") (PTRUE_S 31), ZPR:$Zs1, simm5_32b:$imm)>; + def : Pat<(nxv2i1 (setcc (nxv2i64 ZPR:$Zs1), + (nxv2i64 (AArch64dup (simm5_64b:$imm))), + cc)), + (!cast(NAME # "_D") (PTRUE_D 31), ZPR:$Zs1, simm5_64b:$imm)>; + + // Intrinsic version + def : Pat<(nxv16i1 (op (nxv16i1 PPR_3b:$Pg), + (nxv16i8 ZPR:$Zs1), + (nxv16i8 (AArch64dup (simm5_32b:$imm))))), + (!cast(NAME # "_B") PPR_3b:$Pg, ZPR:$Zs1, simm5_32b:$imm)>; + def : Pat<(nxv8i1 (op (nxv8i1 PPR_3b:$Pg), + (nxv8i16 ZPR:$Zs1), + (nxv8i16 (AArch64dup (simm5_32b:$imm))))), + (!cast(NAME # "_H") PPR_3b:$Pg, ZPR:$Zs1, simm5_32b:$imm)>; + def : Pat<(nxv4i1 (op (nxv4i1 PPR_3b:$Pg), + (nxv4i32 ZPR:$Zs1), + (nxv4i32 (AArch64dup (simm5_32b:$imm))))), + (!cast(NAME # "_S") PPR_3b:$Pg, ZPR:$Zs1, simm5_32b:$imm)>; + def : Pat<(nxv2i1 (op (nxv2i1 PPR_3b:$Pg), + (nxv2i64 ZPR:$Zs1), + (nxv2i64 (AArch64dup (simm5_64b:$imm))))), + (!cast(NAME # "_D") PPR_3b:$Pg, ZPR:$Zs1, simm5_64b:$imm)>; + + // Inverted intrinsic version + def : Pat<(nxv16i1 (inv_op (nxv16i1 PPR_3b:$Pg), + (nxv16i8 (AArch64dup (simm5_32b:$imm))), + (nxv16i8 ZPR:$Zs1))), + (!cast(NAME # "_B") PPR_3b:$Pg, ZPR:$Zs1, simm5_32b:$imm)>; + def : Pat<(nxv8i1 (inv_op (nxv8i1 PPR_3b:$Pg), + (nxv8i16 (AArch64dup (simm5_32b:$imm))), + (nxv8i16 ZPR:$Zs1))), + (!cast(NAME # "_H") PPR_3b:$Pg, ZPR:$Zs1, simm5_32b:$imm)>; + def : Pat<(nxv4i1 (inv_op (nxv4i1 PPR_3b:$Pg), + (nxv4i32 (AArch64dup (simm5_32b:$imm))), + (nxv4i32 ZPR:$Zs1))), + (!cast(NAME # "_S") PPR_3b:$Pg, ZPR:$Zs1, simm5_32b:$imm)>; + def : Pat<(nxv2i1 (inv_op (nxv2i1 PPR_3b:$Pg), + (nxv2i64 (AArch64dup (simm5_64b:$imm))), + (nxv2i64 ZPR:$Zs1))), + (!cast(NAME # "_D") PPR_3b:$Pg, ZPR:$Zs1, simm5_64b:$imm)>; } @@ -2198,13 +4404,71 @@ let Inst{3-0} = Pd; let Defs = [NZCV]; + let ElementSize = pprty.ElementSize; + let isPTestLike = 1; } -multiclass sve_int_ucmp_vi opc, string asm> { +multiclass sve_int_ucmp_vi opc, string asm, CondCode cc, + SDPatternOperator op = null_frag, + SDPatternOperator inv_op = null_frag> { def _B : sve_int_ucmp_vi<0b00, opc, asm, PPR8, ZPR8, imm0_127>; def _H : sve_int_ucmp_vi<0b01, opc, asm, PPR16, ZPR16, imm0_127>; def _S : sve_int_ucmp_vi<0b10, opc, asm, PPR32, ZPR32, imm0_127>; - def _D : sve_int_ucmp_vi<0b11, opc, asm, PPR64, ZPR64, imm0_127>; + def _D : sve_int_ucmp_vi<0b11, opc, asm, PPR64, ZPR64, imm0_127_64b>; + + // IR version + def : Pat<(nxv16i1 (setcc (nxv16i8 ZPR:$Zs1), + (nxv16i8 (AArch64dup (imm0_127:$imm))), + cc)), + (!cast(NAME # "_B") (PTRUE_B 31), ZPR:$Zs1, imm0_127:$imm)>; + def : Pat<(nxv8i1 (setcc (nxv8i16 ZPR:$Zs1), + (nxv8i16 (AArch64dup (imm0_127:$imm))), + cc)), + (!cast(NAME # "_H") (PTRUE_H 31), ZPR:$Zs1, imm0_127:$imm)>; + def : Pat<(nxv4i1 (setcc (nxv4i32 ZPR:$Zs1), + (nxv4i32 (AArch64dup (imm0_127:$imm))), + cc)), + (!cast(NAME # "_S") (PTRUE_S 31), ZPR:$Zs1, imm0_127:$imm)>; + def : Pat<(nxv2i1 (setcc (nxv2i64 ZPR:$Zs1), + (nxv2i64 (AArch64dup (imm0_127_64b:$imm))), + cc)), + (!cast(NAME # "_D") (PTRUE_D 31), ZPR:$Zs1, imm0_127_64b:$imm)>; + + // Intrinsic version + def : Pat<(nxv16i1 (op (nxv16i1 PPR_3b:$Pg), + (nxv16i8 ZPR:$Zs1), + (nxv16i8 (AArch64dup (imm0_127:$imm))))), + (!cast(NAME # "_B") PPR_3b:$Pg, ZPR:$Zs1, imm0_127:$imm)>; + def : Pat<(nxv8i1 (op (nxv8i1 PPR_3b:$Pg), + (nxv8i16 ZPR:$Zs1), + (nxv8i16 (AArch64dup (imm0_127:$imm))))), + (!cast(NAME # "_H") PPR_3b:$Pg, ZPR:$Zs1, imm0_127:$imm)>; + def : Pat<(nxv4i1 (op (nxv4i1 PPR_3b:$Pg), + (nxv4i32 ZPR:$Zs1), + (nxv4i32 (AArch64dup (imm0_127:$imm))))), + (!cast(NAME # "_S") PPR_3b:$Pg, ZPR:$Zs1, imm0_127:$imm)>; + def : Pat<(nxv2i1 (op (nxv2i1 PPR_3b:$Pg), + (nxv2i64 ZPR:$Zs1), + (nxv2i64 (AArch64dup (imm0_127_64b:$imm))))), + (!cast(NAME # "_D") PPR_3b:$Pg, ZPR:$Zs1, imm0_127_64b:$imm)>; + + // Inverted intrinsic version + def : Pat<(nxv16i1 (inv_op (nxv16i1 PPR_3b:$Pg), + (nxv16i8 (AArch64dup (imm0_127:$imm))), + (nxv16i8 ZPR:$Zs1))), + (!cast(NAME # "_B") PPR_3b:$Pg, ZPR:$Zs1, imm0_127:$imm)>; + def : Pat<(nxv8i1 (inv_op (nxv8i1 PPR_3b:$Pg), + (nxv8i16 (AArch64dup (imm0_127:$imm))), + (nxv8i16 ZPR:$Zs1))), + (!cast(NAME # "_H") PPR_3b:$Pg, ZPR:$Zs1, imm0_127:$imm)>; + def : Pat<(nxv4i1 (inv_op (nxv4i1 PPR_3b:$Pg), + (nxv4i32 (AArch64dup (imm0_127:$imm))), + (nxv4i32 ZPR:$Zs1))), + (!cast(NAME # "_S") PPR_3b:$Pg, ZPR:$Zs1, imm0_127:$imm)>; + def : Pat<(nxv2i1 (inv_op (nxv2i1 PPR_3b:$Pg), + (nxv2i64 (AArch64dup (imm0_127_64b:$imm))), + (nxv2i64 ZPR:$Zs1))), + (!cast(NAME # "_D") PPR_3b:$Pg, ZPR:$Zs1, imm0_127_64b:$imm)>; } @@ -2232,10 +4496,12 @@ } class sve_int_while_rr sz8_64, bits<4> opc, string asm, - RegisterClass gprty, PPRRegOp pprty> + RegisterClass gprty, PPRRegOp pprty, + ValueType vt, SDPatternOperator op> : I<(outs pprty:$Pd), (ins gprty:$Rn, gprty:$Rm), asm, "\t$Pd, $Rn, $Rm", - "", []>, Sched<[]> { + "", + [(set (vt pprty:$Pd), (op gprty:$Rn, gprty:$Rm))]>, Sched<[]> { bits<4> Pd; bits<5> Rm; bits<5> Rn; @@ -2250,22 +4516,57 @@ let Inst{3-0} = Pd; let Defs = [NZCV]; + let ElementSize = pprty.ElementSize; + let isWhile = 1; } -multiclass sve_int_while4_rr opc, string asm> { - def _B : sve_int_while_rr<0b00, { 0, opc }, asm, GPR32, PPR8>; - def _H : sve_int_while_rr<0b01, { 0, opc }, asm, GPR32, PPR16>; - def _S : sve_int_while_rr<0b10, { 0, opc }, asm, GPR32, PPR32>; - def _D : sve_int_while_rr<0b11, { 0, opc }, asm, GPR32, PPR64>; +multiclass sve_int_while4_rr opc, string asm, SDPatternOperator op> { + def _B : sve_int_while_rr<0b00, { 0, opc }, asm, GPR32, PPR8, nxv16i1, op>; + def _H : sve_int_while_rr<0b01, { 0, opc }, asm, GPR32, PPR16, nxv8i1, op>; + def _S : sve_int_while_rr<0b10, { 0, opc }, asm, GPR32, PPR32, nxv4i1, op>; + def _D : sve_int_while_rr<0b11, { 0, opc }, asm, GPR32, PPR64, nxv2i1, op>; } -multiclass sve_int_while8_rr opc, string asm> { - def _B : sve_int_while_rr<0b00, { 1, opc }, asm, GPR64, PPR8>; - def _H : sve_int_while_rr<0b01, { 1, opc }, asm, GPR64, PPR16>; - def _S : sve_int_while_rr<0b10, { 1, opc }, asm, GPR64, PPR32>; - def _D : sve_int_while_rr<0b11, { 1, opc }, asm, GPR64, PPR64>; +multiclass sve_int_while8_rr opc, string asm, SDPatternOperator op> { + def _B : sve_int_while_rr<0b00, { 1, opc }, asm, GPR64, PPR8, nxv16i1, op>; + def _H : sve_int_while_rr<0b01, { 1, opc }, asm, GPR64, PPR16, nxv8i1, op>; + def _S : sve_int_while_rr<0b10, { 1, opc }, asm, GPR64, PPR32, nxv4i1, op>; + def _D : sve_int_while_rr<0b11, { 1, opc }, asm, GPR64, PPR64, nxv2i1, op>; } +class sve2_int_while_rr sz8_64, bits<1> rw, string asm, + PPRRegOp pprty> +: I<(outs pprty:$Pd), (ins GPR64:$Rn, GPR64:$Rm), + asm, "\t$Pd, $Rn, $Rm", + "", []>, Sched<[]> { + bits<4> Pd; + bits<5> Rm; + bits<5> Rn; + let Inst{31-24} = 0b00100101; + let Inst{23-22} = sz8_64; + let Inst{21} = 0b1; + let Inst{20-16} = Rm; + let Inst{15-10} = 0b001100; + let Inst{9-5} = Rn; + let Inst{4} = rw; + let Inst{3-0} = Pd; + + let Defs = [NZCV]; + let ElementSize = pprty.ElementSize; + let isWhile = 1; +} + +multiclass sve2_int_while_rr rw, string asm, string op> { + def _B : sve2_int_while_rr<0b00, rw, asm, PPR8>; + def _H : sve2_int_while_rr<0b01, rw, asm, PPR16>; + def _S : sve2_int_while_rr<0b10, rw, asm, PPR32>; + def _D : sve2_int_while_rr<0b11, rw, asm, PPR64>; + + def : SVE_2_Op_Pat(op # _b), i64, i64, !cast(NAME # _B)>; + def : SVE_2_Op_Pat(op # _h), i64, i64, !cast(NAME # _H)>; + def : SVE_2_Op_Pat(op # _s), i64, i64, !cast(NAME # _S)>; + def : SVE_2_Op_Pat(op # _d), i64, i64, !cast(NAME # _D)>; +} //===----------------------------------------------------------------------===// // SVE Floating Point Fast Reduction Group @@ -2290,10 +4591,15 @@ let Inst{4-0} = Vd; } -multiclass sve_fp_fast_red opc, string asm> { +multiclass sve_fp_fast_red opc, string asm, SDPatternOperator op> { def _H : sve_fp_fast_red<0b01, opc, asm, ZPR16, FPR16>; def _S : sve_fp_fast_red<0b10, opc, asm, ZPR32, FPR32>; def _D : sve_fp_fast_red<0b11, opc, asm, ZPR64, FPR64>; + + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; } @@ -2323,10 +4629,15 @@ let Constraints = "$Vdn = $_Vdn"; } -multiclass sve_fp_2op_p_vd opc, string asm> { +multiclass sve_fp_2op_p_vd opc, string asm, SDPatternOperator op> { def _H : sve_fp_2op_p_vd<0b01, opc, asm, ZPR16, FPR16>; def _S : sve_fp_2op_p_vd<0b10, opc, asm, ZPR32, FPR32>; def _D : sve_fp_2op_p_vd<0b11, opc, asm, ZPR64, FPR64>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } //===----------------------------------------------------------------------===// @@ -2356,10 +4667,22 @@ let Inst{3-0} = Pd; } -multiclass sve_fp_3op_p_pd opc, string asm> { +multiclass sve_fp_3op_p_pd opc, string asm, SDPatternOperator int_op, + SDPatternOperator ir_op = null_frag> { def _H : sve_fp_3op_p_pd<0b01, opc, asm, PPR16, ZPR16>; def _S : sve_fp_3op_p_pd<0b10, opc, asm, PPR32, ZPR32>; def _D : sve_fp_3op_p_pd<0b11, opc, asm, PPR64, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; + + def : SVE_2_Op_AllActive_Pat(NAME # _H), PTRUE_H>; + def : SVE_2_Op_AllActive_Pat(NAME # _H), PTRUE_S>; + def : SVE_2_Op_AllActive_Pat(NAME # _H), PTRUE_D>; + def : SVE_2_Op_AllActive_Pat(NAME # _S), PTRUE_S>; + def : SVE_2_Op_AllActive_Pat(NAME # _S), PTRUE_D>; + def : SVE_2_Op_AllActive_Pat(NAME # _D), PTRUE_D>; } @@ -2387,10 +4710,64 @@ let Inst{3-0} = Pd; } -multiclass sve_fp_2op_p_pd opc, string asm> { +multiclass sve_fp_2op_p_pd opc, string asm, + SDPatternOperator int_op = null_frag, + SDPatternOperator ir_op = null_frag, + SDPatternOperator inv_int_op = null_frag, + SDPatternOperator inv_ir_op = null_frag> { def _H : sve_fp_2op_p_pd<0b01, opc, asm, PPR16, ZPR16>; def _S : sve_fp_2op_p_pd<0b10, opc, asm, PPR32, ZPR32>; def _D : sve_fp_2op_p_pd<0b11, opc, asm, PPR64, ZPR64>; + + // Intrinsics + def : Pat<(nxv8i1 (int_op (nxv8i1 PPR_3b:$Pg), + (nxv8f16 ZPR:$Zs1), + (nxv8f16 (AArch64dup (f16 fpimm0))))), + (!cast(NAME # "_H") PPR_3b:$Pg, ZPR:$Zs1)>; + def : Pat<(nxv4i1 (int_op (nxv4i1 PPR_3b:$Pg), + (nxv4f32 ZPR:$Zs1), + (nxv4f32 (AArch64dup (f32 fpimm0))))), + (!cast(NAME # "_S") PPR_3b:$Pg, ZPR:$Zs1)>; + def : Pat<(nxv2i1 (int_op (nxv2i1 PPR_3b:$Pg), + (nxv2f64 ZPR:$Zs1), + (nxv2f64 (AArch64dup (f64 fpimm0))))), + (!cast(NAME # "_D") PPR_3b:$Pg, ZPR:$Zs1)>; + + // IR + def : Pat<(nxv8i1 (ir_op (nxv8f16 ZPR:$Zs1), + (nxv8f16 (AArch64dup (f16 fpimm0))))), + (!cast(NAME # "_H") (PTRUE_H 31), ZPR:$Zs1)>; + def : Pat<(nxv4i1 (ir_op (nxv4f32 ZPR:$Zs1), + (nxv4f32 (AArch64dup (f32 fpimm0))))), + (!cast(NAME # "_S") (PTRUE_S 31), ZPR:$Zs1)>; + def : Pat<(nxv2i1 (ir_op (nxv2f64 ZPR:$Zs1), + (nxv2f64 (AArch64dup (f64 fpimm0))))), + (!cast(NAME # "_D") (PTRUE_D 31), ZPR:$Zs1)>; + + // Inverted Intrinsics (e.g. LT -> GE) + def : Pat<(nxv8i1 (inv_int_op (nxv8i1 PPR_3b:$Pg), + (nxv8f16 (AArch64dup (f16 fpimm0))), + (nxv8f16 ZPR:$Zs1))), + (!cast(NAME # "_H") PPR_3b:$Pg, ZPR:$Zs1)>; + def : Pat<(nxv4i1 (inv_int_op (nxv4i1 PPR_3b:$Pg), + (nxv4f32 (AArch64dup (f32 fpimm0))), + (nxv4f32 ZPR:$Zs1))), + (!cast(NAME # "_S") PPR_3b:$Pg, ZPR:$Zs1)>; + def : Pat<(nxv2i1 (inv_int_op (nxv2i1 PPR_3b:$Pg), + (nxv2f64 (AArch64dup (f64 fpimm0))), + (nxv2f64 ZPR:$Zs1))), + (!cast(NAME # "_D") PPR_3b:$Pg, ZPR:$Zs1)>; + + // Inverted IR + def : Pat<(nxv8i1 (inv_ir_op (nxv8f16 (AArch64dup (f16 fpimm0))), + (nxv8f16 ZPR:$Zs1))), + (!cast(NAME # "_H") (PTRUE_H 31), ZPR:$Zs1)>; + def : Pat<(nxv4i1 (inv_ir_op (nxv4f32 (AArch64dup (f32 fpimm0))), + (nxv4f32 ZPR:$Zs1))), + (!cast(NAME # "_S") (PTRUE_S 31), ZPR:$Zs1)>; + def : Pat<(nxv2i1 (inv_ir_op (nxv2f64 (AArch64dup (f64 fpimm0))), + (nxv2f64 ZPR:$Zs1))), + (!cast(NAME # "_D") (PTRUE_D 31), ZPR:$Zs1)>; } @@ -2399,10 +4776,11 @@ //===----------------------------------------------------------------------===// class sve_int_index_ii sz8_64, string asm, ZPRRegOp zprty, - Operand imm_ty> + Operand imm_ty, ValueType vt, SDPatternOperator op> : I<(outs zprty:$Zd), (ins imm_ty:$imm5, imm_ty:$imm5b), asm, "\t$Zd, $imm5, $imm5b", - "", []>, Sched<[]> { + "", + [(set (vt zprty:$Zd), (op imm_ty:$imm5, imm_ty:$imm5b))]>, Sched<[]> { bits<5> Zd; bits<5> imm5; bits<5> imm5b; @@ -2413,20 +4791,24 @@ let Inst{15-10} = 0b010000; let Inst{9-5} = imm5; let Inst{4-0} = Zd; + + let isReMaterializable = 1; } -multiclass sve_int_index_ii { - def _B : sve_int_index_ii<0b00, asm, ZPR8, simm5_32b>; - def _H : sve_int_index_ii<0b01, asm, ZPR16, simm5_32b>; - def _S : sve_int_index_ii<0b10, asm, ZPR32, simm5_32b>; - def _D : sve_int_index_ii<0b11, asm, ZPR64, simm5_64b>; +multiclass sve_int_index_ii { + def _B : sve_int_index_ii<0b00, asm, ZPR8, simm5_32b, nxv16i8, op>; + def _H : sve_int_index_ii<0b01, asm, ZPR16, simm5_32b, nxv8i16, op>; + def _S : sve_int_index_ii<0b10, asm, ZPR32, simm5_32b, nxv4i32, op>; + def _D : sve_int_index_ii<0b11, asm, ZPR64, simm5_64b, nxv2i64, op>; } class sve_int_index_ir sz8_64, string asm, ZPRRegOp zprty, - RegisterClass srcRegType, Operand imm_ty> + RegisterClass srcRegType, Operand imm_ty, ValueType vt, + SDPatternOperator op> : I<(outs zprty:$Zd), (ins imm_ty:$imm5, srcRegType:$Rm), asm, "\t$Zd, $imm5, $Rm", - "", []>, Sched<[]> { + "", + [(set (vt zprty:$Zd), (op imm_ty:$imm5, srcRegType:$Rm))]>, Sched<[]> { bits<5> Rm; bits<5> Zd; bits<5> imm5; @@ -2439,18 +4821,20 @@ let Inst{4-0} = Zd; } -multiclass sve_int_index_ir { - def _B : sve_int_index_ir<0b00, asm, ZPR8, GPR32, simm5_32b>; - def _H : sve_int_index_ir<0b01, asm, ZPR16, GPR32, simm5_32b>; - def _S : sve_int_index_ir<0b10, asm, ZPR32, GPR32, simm5_32b>; - def _D : sve_int_index_ir<0b11, asm, ZPR64, GPR64, simm5_64b>; +multiclass sve_int_index_ir { + def _B : sve_int_index_ir<0b00, asm, ZPR8, GPR32, simm5_32b, nxv16i8, op>; + def _H : sve_int_index_ir<0b01, asm, ZPR16, GPR32, simm5_32b, nxv8i16, op>; + def _S : sve_int_index_ir<0b10, asm, ZPR32, GPR32, simm5_32b, nxv4i32, op>; + def _D : sve_int_index_ir<0b11, asm, ZPR64, GPR64, simm5_64b, nxv2i64, op>; } class sve_int_index_ri sz8_64, string asm, ZPRRegOp zprty, - RegisterClass srcRegType, Operand imm_ty> + RegisterClass srcRegType, Operand imm_ty, ValueType vt, + SDPatternOperator op> : I<(outs zprty:$Zd), (ins srcRegType:$Rn, imm_ty:$imm5), asm, "\t$Zd, $Rn, $imm5", - "", []>, Sched<[]> { + "", + [(set (vt zprty:$Zd), (op srcRegType:$Rn, imm_ty:$imm5))]>, Sched<[]> { bits<5> Rn; bits<5> Zd; bits<5> imm5; @@ -2463,18 +4847,20 @@ let Inst{4-0} = Zd; } -multiclass sve_int_index_ri { - def _B : sve_int_index_ri<0b00, asm, ZPR8, GPR32, simm5_32b>; - def _H : sve_int_index_ri<0b01, asm, ZPR16, GPR32, simm5_32b>; - def _S : sve_int_index_ri<0b10, asm, ZPR32, GPR32, simm5_32b>; - def _D : sve_int_index_ri<0b11, asm, ZPR64, GPR64, simm5_64b>; +multiclass sve_int_index_ri { + def _B : sve_int_index_ri<0b00, asm, ZPR8, GPR32, simm5_32b, nxv16i8, op>; + def _H : sve_int_index_ri<0b01, asm, ZPR16, GPR32, simm5_32b, nxv8i16, op>; + def _S : sve_int_index_ri<0b10, asm, ZPR32, GPR32, simm5_32b, nxv4i32, op>; + def _D : sve_int_index_ri<0b11, asm, ZPR64, GPR64, simm5_64b, nxv2i64, op>; } class sve_int_index_rr sz8_64, string asm, ZPRRegOp zprty, - RegisterClass srcRegType> + RegisterClass srcRegType, ValueType vt, + SDPatternOperator op> : I<(outs zprty:$Zd), (ins srcRegType:$Rn, srcRegType:$Rm), asm, "\t$Zd, $Rn, $Rm", - "", []>, Sched<[]> { + "", + [(set (vt zprty:$Zd), (op srcRegType:$Rn, srcRegType:$Rm))]>, Sched<[]> { bits<5> Zd; bits<5> Rm; bits<5> Rn; @@ -2487,19 +4873,21 @@ let Inst{4-0} = Zd; } -multiclass sve_int_index_rr { - def _B : sve_int_index_rr<0b00, asm, ZPR8, GPR32>; - def _H : sve_int_index_rr<0b01, asm, ZPR16, GPR32>; - def _S : sve_int_index_rr<0b10, asm, ZPR32, GPR32>; - def _D : sve_int_index_rr<0b11, asm, ZPR64, GPR64>; +multiclass sve_int_index_rr { + def _B : sve_int_index_rr<0b00, asm, ZPR8, GPR32, nxv16i8, op>; + def _H : sve_int_index_rr<0b01, asm, ZPR16, GPR32, nxv8i16, op>; + def _S : sve_int_index_rr<0b10, asm, ZPR32, GPR32, nxv4i32, op>; + def _D : sve_int_index_rr<0b11, asm, ZPR64, GPR64, nxv2i64, op>; } + + // //===----------------------------------------------------------------------===// // SVE Bitwise Shift - Predicated Group //===----------------------------------------------------------------------===// -class sve_int_bin_pred_shift_imm tsz8_64, bits<3> opc, string asm, - ZPRRegOp zprty, Operand immtype, - ElementSizeEnum size> + +class sve_int_bin_pred_shift_imm tsz8_64, bits<4> opc, string asm, + ZPRRegOp zprty, Operand immtype> : I<(outs zprty:$Zdn), (ins PPR3bAny:$Pg, zprty:$_Zdn, immtype:$imm), asm, "\t$Zdn, $Pg/m, $_Zdn, $imm", "", @@ -2509,8 +4897,8 @@ bits<6> imm; let Inst{31-24} = 0b00000100; let Inst{23-22} = tsz8_64{3-2}; - let Inst{21-19} = 0b000; - let Inst{18-16} = opc; + let Inst{21-20} = 0b00; + let Inst{19-16} = opc; let Inst{15-13} = 0b100; let Inst{12-10} = Pg; let Inst{9-8} = tsz8_64{1-0}; @@ -2518,44 +4906,107 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; - let ElementSize = size; + let DestructiveInstType = DestructiveOther; + let ElementSize = zprty.ElementSize; } -multiclass sve_int_bin_pred_shift_imm_left opc, string asm> { - def _B : sve_int_bin_pred_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftL8, - ElementSizeB>; - def _H : sve_int_bin_pred_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftL16, - ElementSizeH> { +multiclass sve_int_bin_pred_shift_imm_left opc, string asm> { + def _B : sve_int_bin_pred_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftL8>; + def _H : sve_int_bin_pred_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftL16> { let Inst{8} = imm{3}; } - def _S : sve_int_bin_pred_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftL32, - ElementSizeS> { + def _S : sve_int_bin_pred_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftL32> { let Inst{9-8} = imm{4-3}; } - def _D : sve_int_bin_pred_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftL64, - ElementSizeD> { + def _D : sve_int_bin_pred_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftL64> { let Inst{22} = imm{5}; let Inst{9-8} = imm{4-3}; } } -multiclass sve_int_bin_pred_shift_imm_right opc, string asm> { - def _B : sve_int_bin_pred_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftR8, - ElementSizeB>; - def _H : sve_int_bin_pred_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftR16, - ElementSizeH> { +multiclass sve2_int_bin_pred_shift_imm_left opc, string asm, + string psName, + SDPatternOperator op> { + + let DestructiveInstType = DestructiveBinaryImm in { + def _B : SVEPseudo2Instr, sve_int_bin_pred_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftL8>; + def _H : SVEPseudo2Instr, + sve_int_bin_pred_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftL16> { + let Inst{8} = imm{3}; + } + def _S : SVEPseudo2Instr, + sve_int_bin_pred_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftL32> { + let Inst{9-8} = imm{4-3}; + } + def _D : SVEPseudo2Instr, + sve_int_bin_pred_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftL64> { + let Inst{22} = imm{5}; + let Inst{9-8} = imm{4-3}; + } + } + + def _B_Z_UNDEF : PredTwoOpImmPseudo; + def _H_Z_UNDEF : PredTwoOpImmPseudo; + def _S_Z_UNDEF : PredTwoOpImmPseudo; + def _D_Z_UNDEF : PredTwoOpImmPseudo; + + def _B_Z_ZERO : PredTwoOpImmPseudo; + def _H_Z_ZERO : PredTwoOpImmPseudo; + def _S_Z_ZERO : PredTwoOpImmPseudo; + def _D_Z_ZERO : PredTwoOpImmPseudo; + + def : SVE_3_Op_Pat_Shift_Imm_SelZero(NAME # _B_Z_ZERO)>; + def : SVE_3_Op_Pat_Shift_Imm_SelZero(NAME # _H_Z_ZERO)>; + def : SVE_3_Op_Pat_Shift_Imm_SelZero(NAME # _S_Z_ZERO)>; + def : SVE_3_Op_Pat_Shift_Imm_SelZero(NAME # _D_Z_ZERO)>; + + def : SVE_3_Op_Imm_Pat(NAME # _B)>; + def : SVE_3_Op_Imm_Pat(NAME # _H)>; + def : SVE_3_Op_Imm_Pat(NAME # _S)>; + def : SVE_3_Op_Imm_Pat(NAME # _D)>; + +} + +multiclass sve_int_bin_pred_shift_imm_right opc, string asm, string Ps, + SDPatternOperator op = null_frag> { + let DestructiveInstType = DestructiveBinaryImm in { + def _B : SVEPseudo2Instr, + sve_int_bin_pred_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftR8>; + def _H : SVEPseudo2Instr, + sve_int_bin_pred_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftR16> { let Inst{8} = imm{3}; } - def _S : sve_int_bin_pred_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftR32, - ElementSizeS> { + def _S : SVEPseudo2Instr, + sve_int_bin_pred_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftR32> { let Inst{9-8} = imm{4-3}; } - def _D : sve_int_bin_pred_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftR64, - ElementSizeD> { + def _D : SVEPseudo2Instr, + sve_int_bin_pred_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftR64> { let Inst{22} = imm{5}; let Inst{9-8} = imm{4-3}; } + } + + def : Pat<(nxv16i8 (op (nxv16i1 PPR3bAny:$Pg), (nxv16i8 ZPR8:$Zn), (i32 vecshiftR8:$imm))), + (!cast(NAME # _B) PPR3bAny:$Pg, ZPR8:$Zn, vecshiftR8:$imm)>; + def : Pat<(nxv8i16 (op (nxv8i1 PPR3bAny:$Pg), (nxv8i16 ZPR16:$Zn), (i32 vecshiftR16:$imm))), + (!cast(NAME # _H) PPR3bAny:$Pg, ZPR16:$Zn, vecshiftR16:$imm)>; + def : Pat<(nxv4i32 (op (nxv4i1 PPR3bAny:$Pg), (nxv4i32 ZPR32:$Zn), (i32 vecshiftR32:$imm))), + (!cast(NAME # _S) PPR3bAny:$Pg, ZPR32:$Zn, vecshiftR32:$imm)>; + def : Pat<(nxv2i64 (op (nxv2i1 PPR3bAny:$Pg), (nxv2i64 ZPR64:$Zn), (i32 vecshiftR64:$imm))), + (!cast(NAME # _D) PPR3bAny:$Pg, ZPR64:$Zn, vecshiftR64:$imm)>; +} + +multiclass sve_int_bin_pred_shift_0_right_zx { + def _ZERO_B : PredTwoOpImmPseudo; + def _ZERO_H : PredTwoOpImmPseudo; + def _ZERO_S : PredTwoOpImmPseudo; + def _ZERO_D : PredTwoOpImmPseudo; + + def : SVE_3_Op_Pat_Shift_Imm_SelZero(NAME # _ZERO_B)>; + def : SVE_3_Op_Pat_Shift_Imm_SelZero(NAME # _ZERO_H)>; + def : SVE_3_Op_Pat_Shift_Imm_SelZero(NAME # _ZERO_S)>; + def : SVE_3_Op_Pat_Shift_Imm_SelZero(NAME # _ZERO_D)>; } class sve_int_bin_pred_shift sz8_64, bit wide, bits<3> opc, @@ -2578,23 +5029,72 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; let ElementSize = zprty.ElementSize; } -multiclass sve_int_bin_pred_shift opc, string asm> { - def _B : sve_int_bin_pred_shift<0b00, 0b0, opc, asm, ZPR8, ZPR8>; - def _H : sve_int_bin_pred_shift<0b01, 0b0, opc, asm, ZPR16, ZPR16>; - def _S : sve_int_bin_pred_shift<0b10, 0b0, opc, asm, ZPR32, ZPR32>; - def _D : sve_int_bin_pred_shift<0b11, 0b0, opc, asm, ZPR64, ZPR64>; +multiclass sve_int_bin_pred_shift opc, string asm, string Ps, + SDPatternOperator op, string revname, bit isOrig> { + let DestructiveInstType = DestructiveBinaryCommWithRev in { + def _B : sve_int_bin_pred_shift<0b00, 0b0, opc, asm, ZPR8, ZPR8>, + SVEPseudo2Instr, SVEInstr2Rev; + def _H : sve_int_bin_pred_shift<0b01, 0b0, opc, asm, ZPR16, ZPR16>, + SVEPseudo2Instr, SVEInstr2Rev; + def _S : sve_int_bin_pred_shift<0b10, 0b0, opc, asm, ZPR32, ZPR32>, + SVEPseudo2Instr, SVEInstr2Rev; + def _D : sve_int_bin_pred_shift<0b11, 0b0, opc, asm, ZPR64, ZPR64>, + SVEPseudo2Instr, SVEInstr2Rev; + } + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } -multiclass sve_int_bin_pred_shift_wide opc, string asm> { +multiclass sve_int_bin_pred_noncomm_zx { + def _ZERO_B : PredTwoOpConstrainedPseudo; + def _ZERO_H : PredTwoOpConstrainedPseudo; + def _ZERO_S : PredTwoOpConstrainedPseudo; + def _ZERO_D : PredTwoOpConstrainedPseudo; + + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_B)>; + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_H)>; + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_S)>; + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_D)>; +} + +multiclass sve_int_bin_pred_zx { + def _UNDEF_B : PredTwoOpPseudo; + def _UNDEF_H : PredTwoOpPseudo; + def _UNDEF_S : PredTwoOpPseudo; + def _UNDEF_D : PredTwoOpPseudo; + + def _ZERO_B : PredTwoOpPseudo; + def _ZERO_H : PredTwoOpPseudo; + def _ZERO_S : PredTwoOpPseudo; + def _ZERO_D : PredTwoOpPseudo; + + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_B)>; + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_H)>; + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_S)>; + def : SVE_3_Op_Pat_SelZero(NAME # _ZERO_D)>; +} + + +multiclass sve_int_bin_pred_shift_wide opc, string asm, + SDPatternOperator op> { + let DestructiveInstType = DestructiveOther in { def _B : sve_int_bin_pred_shift<0b00, 0b1, opc, asm, ZPR8, ZPR64>; def _H : sve_int_bin_pred_shift<0b01, 0b1, opc, asm, ZPR16, ZPR64>; def _S : sve_int_bin_pred_shift<0b10, 0b1, opc, asm, ZPR32, ZPR64>; + } + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; } + //===----------------------------------------------------------------------===// // SVE Shift - Unpredicated Group //===----------------------------------------------------------------------===// @@ -2625,10 +5125,12 @@ } class sve_int_bin_cons_shift_imm tsz8_64, bits<2> opc, string asm, - ZPRRegOp zprty, Operand immtype> + ZPRRegOp zprty, Operand immtype, ValueType vt, + SDPatternOperator op> : I<(outs zprty:$Zd), (ins zprty:$Zn, immtype:$imm), asm, "\t$Zd, $Zn, $imm", - "", []>, Sched<[]> { + "", + [(set (vt zprty:$Zd), (op (vt zprty:$Zn), immtype:$imm))]>, Sched<[]> { bits<5> Zd; bits<5> Zn; bits<6> imm; @@ -2643,33 +5145,37 @@ let Inst{4-0} = Zd; } -multiclass sve_int_bin_cons_shift_imm_left opc, string asm> { - def _B : sve_int_bin_cons_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftL8>; - def _H : sve_int_bin_cons_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftL16> { +multiclass sve_int_bin_cons_shift_imm_left opc, string asm, + SDPatternOperator op> { + def _B : sve_int_bin_cons_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftL8, nxv16i8, op>; + def _H : sve_int_bin_cons_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftL16, nxv8i16, op> { let Inst{19} = imm{3}; } - def _S : sve_int_bin_cons_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftL32> { + def _S : sve_int_bin_cons_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftL32, nxv4i32, op> { let Inst{20-19} = imm{4-3}; } - def _D : sve_int_bin_cons_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftL64> { + def _D : sve_int_bin_cons_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftL64, nxv2i64, op> { let Inst{22} = imm{5}; let Inst{20-19} = imm{4-3}; } } -multiclass sve_int_bin_cons_shift_imm_right opc, string asm> { - def _B : sve_int_bin_cons_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftR8>; - def _H : sve_int_bin_cons_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftR16> { +multiclass sve_int_bin_cons_shift_imm_right opc, string asm, + SDPatternOperator op> { + def _B : sve_int_bin_cons_shift_imm<{0,0,0,1}, opc, asm, ZPR8, vecshiftR8, nxv16i8, op>; + def _H : sve_int_bin_cons_shift_imm<{0,0,1,?}, opc, asm, ZPR16, vecshiftR16, nxv8i16, op> { let Inst{19} = imm{3}; } - def _S : sve_int_bin_cons_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftR32> { + def _S : sve_int_bin_cons_shift_imm<{0,1,?,?}, opc, asm, ZPR32, vecshiftR32, nxv4i32, op> { let Inst{20-19} = imm{4-3}; } - def _D : sve_int_bin_cons_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftR64> { + def _D : sve_int_bin_cons_shift_imm<{1,?,?,?}, opc, asm, ZPR64, vecshiftR64, nxv2i64, op> { let Inst{22} = imm{5}; let Inst{20-19} = imm{4-3}; } } + + //===----------------------------------------------------------------------===// // SVE Memory - Store Group //===----------------------------------------------------------------------===// @@ -2856,6 +5362,45 @@ (!cast(NAME) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, gprty:$Rm), 0>; } +class sve2_mem_sstnt_vs_base opc, string asm, + RegisterOperand listty, ZPRRegOp zprty> +: I<(outs), (ins listty:$Zt, PPR3bAny:$Pg, zprty:$Zn, GPR64:$Rm), + asm, "\t$Zt, $Pg, [$Zn, $Rm]", + "", + []>, Sched<[]> { + bits<3> Pg; + bits<5> Rm; + bits<5> Zn; + bits<5> Zt; + let Inst{31-25} = 0b1110010; + let Inst{24-22} = opc; + let Inst{21} = 0b0; + let Inst{20-16} = Rm; + let Inst{15-13} = 0b001; + let Inst{12-10} = Pg; + let Inst{9-5} = Zn; + let Inst{4-0} = Zt; + + let mayStore = 1; +} + +multiclass sve2_mem_sstnt_vs opc, string asm, + RegisterOperand listty, ZPRRegOp zprty, + SDPatternOperator op, ValueType vt1, ValueType vt2, + ValueType vt3> { + def _REAL : sve2_mem_sstnt_vs_base; + + def : InstAlias(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, zprty:$Zn, GPR64:$Rm), 0>; + def : InstAlias(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, zprty:$Zn, XZR), 0>; + def : InstAlias(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, zprty:$Zn, XZR), 1>; + + def : Pat <(op (vt1 zprty:$Zt), (vt2 PPR3bAny:$Pg), (i64 GPR64:$Rm), (vt1 zprty:$Zn), vt3), + (!cast(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, zprty:$Zn, GPR64:$Rm)>; +} + class sve_mem_sst_sv opc, bit xs, bit scaled, string asm, RegisterOperand VecList, RegisterOperand zprext> : I<(outs), (ins VecList:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, zprext:$Zm), @@ -3038,11 +5583,12 @@ //===----------------------------------------------------------------------===// class sve_int_perm_bin_perm_pp opc, bits<2> sz8_64, string asm, - PPRRegOp pprty> + PPRRegOp pprty, ValueType vt, + SDPatternOperator op> : I<(outs pprty:$Pd), (ins pprty:$Pn, pprty:$Pm), asm, "\t$Pd, $Pn, $Pm", "", - []>, Sched<[]> { + [(set (vt pprty:$Pd), (op (vt pprty:$Pn), (vt pprty:$Pm)))]>, Sched<[]> { bits<4> Pd; bits<4> Pm; bits<4> Pn; @@ -3058,11 +5604,12 @@ let Inst{3-0} = Pd; } -multiclass sve_int_perm_bin_perm_pp opc, string asm> { - def _B : sve_int_perm_bin_perm_pp; - def _H : sve_int_perm_bin_perm_pp; - def _S : sve_int_perm_bin_perm_pp; - def _D : sve_int_perm_bin_perm_pp; +multiclass sve_int_perm_bin_perm_pp opc, string asm, + SDPatternOperator op> { + def _B : sve_int_perm_bin_perm_pp; + def _H : sve_int_perm_bin_perm_pp; + def _S : sve_int_perm_bin_perm_pp; + def _D : sve_int_perm_bin_perm_pp; } class sve_int_perm_punpk @@ -3080,6 +5627,14 @@ let Inst{3-0} = Pd; } +multiclass sve_int_perm_punpk { + def NAME : sve_int_perm_punpk; + + def : SVE_1_Op_Pat(NAME)>; + def : SVE_1_Op_Pat(NAME)>; + def : SVE_1_Op_Pat(NAME)>; +} + class sve_int_rdffr_pred : I<(outs PPR8:$Pd), (ins PPRAny:$Pg), asm, "\t$Pd, $Pg/z", @@ -3098,6 +5653,17 @@ let Uses = [FFR]; } +multiclass sve_int_rdffr_pred { + def _REAL : sve_int_rdffr_pred; + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1 in { + def "" : Pseudo<(outs PPR8:$Pd), (ins PPRAny:$Pg), [(set (nxv16i1 PPR8:$Pd), (op (nxv16i1 PPRAny:$Pg)))]>, + PseudoInstExpansion<(!cast(NAME # _REAL) PPR8:$Pd, PPRAny:$Pg)>; + } +} + class sve_int_rdffr_unpred : I< (outs PPR8:$Pd), (ins), asm, "\t$Pd", @@ -3110,11 +5676,22 @@ let Uses = [FFR]; } -class sve_int_wrffr +multiclass sve_int_rdffr_unpred { + def _REAL : sve_int_rdffr_unpred; + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1 in { + def "" : Pseudo<(outs PPR8:$Pd), (ins), [(set (nxv16i1 PPR8:$Pd), (op))]>, + PseudoInstExpansion<(!cast(NAME # _REAL) PPR8:$Pd)>; + } +} + +class sve_int_wrffr : I<(outs), (ins PPR8:$Pn), asm, "\t$Pn", "", - []>, Sched<[]> { + [(op (nxv16i1 PPR8:$Pn))]>, Sched<[]> { bits<4> Pn; let Inst{31-9} = 0b00100101001010001001000; let Inst{8-5} = Pn; @@ -3124,11 +5701,11 @@ let Defs = [FFR]; } -class sve_int_setffr +class sve_int_setffr : I<(outs), (ins), asm, "", "", - []>, Sched<[]> { + [(op)]>, Sched<[]> { let Inst{31-0} = 0b00100101001011001001000000000000; let hasSideEffects = 1; @@ -3160,11 +5737,16 @@ let Constraints = "$Rdn = $_Rdn"; } -multiclass sve_int_perm_clast_rz { +multiclass sve_int_perm_clast_rz { def _B : sve_int_perm_clast_rz<0b00, ab, asm, ZPR8, GPR32>; def _H : sve_int_perm_clast_rz<0b01, ab, asm, ZPR16, GPR32>; def _S : sve_int_perm_clast_rz<0b10, ab, asm, ZPR32, GPR32>; def _D : sve_int_perm_clast_rz<0b11, ab, asm, ZPR64, GPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } class sve_int_perm_clast_vz sz8_64, bit ab, string asm, @@ -3188,11 +5770,15 @@ let Constraints = "$Vdn = $_Vdn"; } -multiclass sve_int_perm_clast_vz { +multiclass sve_int_perm_clast_vz { def _B : sve_int_perm_clast_vz<0b00, ab, asm, ZPR8, FPR8>; def _H : sve_int_perm_clast_vz<0b01, ab, asm, ZPR16, FPR16>; def _S : sve_int_perm_clast_vz<0b10, ab, asm, ZPR32, FPR32>; def _D : sve_int_perm_clast_vz<0b11, ab, asm, ZPR64, FPR64>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } class sve_int_perm_clast_zz sz8_64, bit ab, string asm, @@ -3214,15 +5800,23 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; - let ElementSize = ElementSizeNone; + let DestructiveInstType = DestructiveOther; } -multiclass sve_int_perm_clast_zz { +multiclass sve_int_perm_clast_zz { def _B : sve_int_perm_clast_zz<0b00, ab, asm, ZPR8>; def _H : sve_int_perm_clast_zz<0b01, ab, asm, ZPR16>; def _S : sve_int_perm_clast_zz<0b10, ab, asm, ZPR32>; def _D : sve_int_perm_clast_zz<0b11, ab, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } class sve_int_perm_last_r sz8_64, bit ab, string asm, @@ -3244,11 +5838,16 @@ let Inst{4-0} = Rd; } -multiclass sve_int_perm_last_r { +multiclass sve_int_perm_last_r { def _B : sve_int_perm_last_r<0b00, ab, asm, ZPR8, GPR32>; def _H : sve_int_perm_last_r<0b01, ab, asm, ZPR16, GPR32>; def _S : sve_int_perm_last_r<0b10, ab, asm, ZPR32, GPR32>; def _D : sve_int_perm_last_r<0b11, ab, asm, ZPR64, GPR64>; + + def : SVE_2_Op_Pat(NAME # _B)>; + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; } class sve_int_perm_last_v sz8_64, bit ab, string asm, @@ -3270,11 +5869,16 @@ let Inst{4-0} = Vd; } -multiclass sve_int_perm_last_v { +multiclass sve_int_perm_last_v { def _B : sve_int_perm_last_v<0b00, ab, asm, ZPR8, FPR8>; def _H : sve_int_perm_last_v<0b01, ab, asm, ZPR16, FPR16>; def _S : sve_int_perm_last_v<0b10, ab, asm, ZPR32, FPR32>; def _D : sve_int_perm_last_v<0b11, ab, asm, ZPR64, FPR64>; + + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; } class sve_int_perm_splice sz8_64, string asm, ZPRRegOp zprty> @@ -3293,15 +5897,47 @@ let Inst{4-0} = Zdn; let Constraints = "$Zdn = $_Zdn"; - let DestructiveInstType = Destructive; - let ElementSize = ElementSizeNone; + let DestructiveInstType = DestructiveOther; } -multiclass sve_int_perm_splice { +multiclass sve_int_perm_splice { def _B : sve_int_perm_splice<0b00, asm, ZPR8>; def _H : sve_int_perm_splice<0b01, asm, ZPR16>; def _S : sve_int_perm_splice<0b10, asm, ZPR32>; def _D : sve_int_perm_splice<0b11, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +class sve2_int_perm_splice_cons sz8_64, string asm, + ZPRRegOp zprty, RegisterOperand VecList> +: I<(outs zprty:$Zd), (ins PPR3bAny:$Pg, VecList:$Zn), + asm, "\t$Zd, $Pg, $Zn", + "", + []>, Sched<[]> { + bits<3> Pg; + bits<5> Zn; + bits<5> Zd; + let Inst{31-24} = 0b00000101; + let Inst{23-22} = sz8_64; + let Inst{21-13} = 0b101101100; + let Inst{12-10} = Pg; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +multiclass sve2_int_perm_splice_cons { + def _B : sve2_int_perm_splice_cons<0b00, asm, ZPR8, ZZ_b>; + def _H : sve2_int_perm_splice_cons<0b01, asm, ZPR16, ZZ_h>; + def _S : sve2_int_perm_splice_cons<0b10, asm, ZPR32, ZZ_s>; + def _D : sve2_int_perm_splice_cons<0b11, asm, ZPR64, ZZ_d>; } class sve_int_perm_rev sz8_64, bits<2> opc, string asm, @@ -3323,30 +5959,50 @@ let Inst{4-0} = Zd; let Constraints = "$Zd = $_Zd"; - let DestructiveInstType = Destructive; + let DestructiveInstType = DestructiveUnary; let ElementSize = zprty.ElementSize; } -multiclass sve_int_perm_rev_rbit { +multiclass sve_int_perm_rev_rbit { def _B : sve_int_perm_rev<0b00, 0b11, asm, ZPR8>; def _H : sve_int_perm_rev<0b01, 0b11, asm, ZPR16>; def _S : sve_int_perm_rev<0b10, 0b11, asm, ZPR32>; def _D : sve_int_perm_rev<0b11, 0b11, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } -multiclass sve_int_perm_rev_revb { +multiclass sve_int_perm_rev_revb { def _H : sve_int_perm_rev<0b01, 0b00, asm, ZPR16>; def _S : sve_int_perm_rev<0b10, 0b00, asm, ZPR32>; def _D : sve_int_perm_rev<0b11, 0b00, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; + + def : SVE_1_Op_AllActive_Pat(NAME # _H), PTRUE_H>; + def : SVE_1_Op_AllActive_Pat(NAME # _S), PTRUE_S>; + def : SVE_1_Op_AllActive_Pat(NAME # _D), PTRUE_D>; } -multiclass sve_int_perm_rev_revh { +multiclass sve_int_perm_rev_revh { def _S : sve_int_perm_rev<0b10, 0b01, asm, ZPR32>; def _D : sve_int_perm_rev<0b11, 0b01, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } -multiclass sve_int_perm_rev_revw { +multiclass sve_int_perm_rev_revw { def _D : sve_int_perm_rev<0b11, 0b10, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _D)>; } class sve_int_perm_cpy_r sz8_64, string asm, ZPRRegOp zprty, @@ -3366,11 +6022,11 @@ let Inst{4-0} = Zd; let Constraints = "$Zd = $_Zd"; - let DestructiveInstType = Destructive; + let DestructiveInstType = DestructiveUnary; let ElementSize = zprty.ElementSize; } -multiclass sve_int_perm_cpy_r { +multiclass sve_int_perm_cpy_r { def _B : sve_int_perm_cpy_r<0b00, asm, ZPR8, GPR32sp>; def _H : sve_int_perm_cpy_r<0b01, asm, ZPR16, GPR32sp>; def _S : sve_int_perm_cpy_r<0b10, asm, ZPR32, GPR32sp>; @@ -3384,6 +6040,11 @@ (!cast(NAME # _S) ZPR32:$Zd, PPR3bAny:$Pg, GPR32sp:$Rn), 1>; def : InstAlias<"mov $Zd, $Pg/m, $Rn", (!cast(NAME # _D) ZPR64:$Zd, PPR3bAny:$Pg, GPR64sp:$Rn), 1>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } class sve_int_perm_cpy_v sz8_64, string asm, ZPRRegOp zprty, @@ -3403,11 +6064,11 @@ let Inst{4-0} = Zd; let Constraints = "$Zd = $_Zd"; - let DestructiveInstType = Destructive; + let DestructiveInstType = DestructiveUnary; let ElementSize = zprty.ElementSize; } -multiclass sve_int_perm_cpy_v { +multiclass sve_int_perm_cpy_v { def _B : sve_int_perm_cpy_v<0b00, asm, ZPR8, FPR8>; def _H : sve_int_perm_cpy_v<0b01, asm, ZPR16, FPR16>; def _S : sve_int_perm_cpy_v<0b10, asm, ZPR32, FPR32>; @@ -3421,6 +6082,11 @@ (!cast(NAME # _S) ZPR32:$Zd, PPR3bAny:$Pg, FPR32:$Vn), 1>; def : InstAlias<"mov $Zd, $Pg/m, $Vn", (!cast(NAME # _D) ZPR64:$Zd, PPR3bAny:$Pg, FPR64:$Vn), 1>; + + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } class sve_int_perm_compact @@ -3439,9 +6105,14 @@ let Inst{4-0} = Zd; } -multiclass sve_int_perm_compact { +multiclass sve_int_perm_compact { def _S : sve_int_perm_compact<0b0, asm, ZPR32>; def _D : sve_int_perm_compact<0b1, asm, ZPR64>; + + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; + def : SVE_2_Op_Pat(NAME # _D)>; } @@ -3483,6 +6154,13 @@ (!cast(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), 0>; def : InstAlias(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, 0), 1>; + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1, mayLoad = 1 in { + def "" : Pseudo<(outs listty:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), []>, + PseudoInstExpansion<(!cast(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4)>; + } } multiclass sve_mem_cld_si dtype, string asm, RegisterOperand listty, @@ -3691,6 +6369,13 @@ def : InstAlias(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, XZR), 0>; + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1 in { + def "" : Pseudo<(outs listty:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, gprty:$Rm), []>, + PseudoInstExpansion<(!cast(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, gprty:$Rm)>; + } } multiclass sve_mem_cldnf_si dtype, string asm, RegisterOperand listty, @@ -3750,6 +6435,7 @@ let mayLoad = 1; } + //===----------------------------------------------------------------------===// // SVE Memory - 32-bit Gather and Unsized Contiguous Group //===----------------------------------------------------------------------===// @@ -3783,8 +6469,11 @@ } multiclass sve_mem_32b_gld_sv_32_scaled opc, string asm, + SDPatternOperator sxtw_op, + SDPatternOperator uxtw_op, RegisterOperand sxtw_opnd, - RegisterOperand uxtw_opnd> { + RegisterOperand uxtw_opnd, + ValueType vt> { def _UXTW_SCALED_REAL : sve_mem_32b_gld_sv; def _SXTW_SCALED_REAL : sve_mem_32b_gld_sv; @@ -3792,11 +6481,28 @@ (!cast(NAME # _UXTW_SCALED_REAL) ZPR32:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, uxtw_opnd:$Zm), 0>; def : InstAlias(NAME # _SXTW_SCALED_REAL) ZPR32:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, sxtw_opnd:$Zm), 0>; + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1 in { + def _UXTW_SCALED : Pseudo<(outs Z_s:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, uxtw_opnd:$Zm), []>, + PseudoInstExpansion<(!cast(NAME # _UXTW_SCALED_REAL) Z_s:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, uxtw_opnd:$Zm)>; + def _SXTW_SCALED : Pseudo<(outs Z_s:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, sxtw_opnd:$Zm), []>, + PseudoInstExpansion<(!cast(NAME # _SXTW_SCALED_REAL) Z_s:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, sxtw_opnd:$Zm)>; + } + + def : Pat<(nxv4i32 (uxtw_op (nxv4i1 PPR:$gp), GPR64sp:$base, (nxv4i32 ZPR:$indices), vt)), + (!cast(NAME # _UXTW_SCALED) PPR:$gp, GPR64sp:$base, ZPR:$indices)>; + def : Pat<(nxv4i32 (sxtw_op (nxv4i1 PPR:$gp), GPR64sp:$base, (nxv4i32 ZPR:$indices), vt)), + (!cast(NAME # _SXTW_SCALED) PPR:$gp, GPR64sp:$base, ZPR:$indices)>; } multiclass sve_mem_32b_gld_vs_32_unscaled opc, string asm, + SDPatternOperator sxtw_op, + SDPatternOperator uxtw_op, RegisterOperand sxtw_opnd, - RegisterOperand uxtw_opnd> { + RegisterOperand uxtw_opnd, + ValueType vt> { def _UXTW_REAL : sve_mem_32b_gld_sv; def _SXTW_REAL : sve_mem_32b_gld_sv; @@ -3804,6 +6510,21 @@ (!cast(NAME # _UXTW_REAL) ZPR32:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, uxtw_opnd:$Zm), 0>; def : InstAlias(NAME # _SXTW_REAL) ZPR32:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, sxtw_opnd:$Zm), 0>; + + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1 in { + def _UXTW : Pseudo<(outs Z_s:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, uxtw_opnd:$Zm), []>, + PseudoInstExpansion<(!cast(NAME # _UXTW_REAL) Z_s:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, uxtw_opnd:$Zm)>; + def _SXTW : Pseudo<(outs Z_s:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, sxtw_opnd:$Zm), []>, + PseudoInstExpansion<(!cast(NAME # _SXTW_REAL) Z_s:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, sxtw_opnd:$Zm)>; + } + + def : Pat<(nxv4i32 (uxtw_op (nxv4i1 PPR:$gp), GPR64sp:$base, (nxv4i32 ZPR:$offsets), vt)), + (!cast(NAME # _UXTW) PPR:$gp, GPR64sp:$base, ZPR:$offsets)>; + def : Pat<(nxv4i32 (sxtw_op (nxv4i1 PPR:$gp), GPR64sp:$base, (nxv4i32 ZPR:$offsets), vt)), + (!cast(NAME # _SXTW) PPR:$gp, GPR64sp:$base, ZPR:$offsets)>; } @@ -3831,7 +6552,8 @@ let Uses = !if(!eq(opc{0}, 1), [FFR], []); } -multiclass sve_mem_32b_gld_vi_32_ptrs opc, string asm, Operand imm_ty> { +multiclass sve_mem_32b_gld_vi_32_ptrs opc, string asm, Operand imm_ty, + SDPatternOperator op, ValueType vt> { def _IMM_REAL : sve_mem_32b_gld_vi; def : InstAlias(NAME # _IMM_REAL) ZPR32:$Zt, PPR3bAny:$Pg, ZPR32:$Zn, imm_ty:$imm5), 0>; def : InstAlias(NAME # _IMM_REAL) Z_s:$Zt, PPR3bAny:$Pg, ZPR32:$Zn, 0), 1>; + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1 in { + def _IMM : Pseudo<(outs Z_s:$Zt), (ins PPR3bAny:$Pg, ZPR32:$Zn, imm_ty:$imm5), []>, + PseudoInstExpansion<(!cast(NAME # _IMM_REAL) Z_s:$Zt, PPR3bAny:$Pg, ZPR32:$Zn, imm_ty:$imm5)>; + } + + def : Pat<(nxv4i32 (op (nxv4i1 PPR:$gp), imm_ty:$index, (nxv4i32 ZPR:$ptrs), vt)), + (!cast(NAME # _IMM) PPR:$gp, ZPR:$ptrs, imm_ty:$index)>; } class sve_mem_prfm_si msz, string asm> @@ -3919,10 +6651,18 @@ } multiclass sve_mem_32b_prfm_sv_scaled msz, string asm, + SDPatternOperator sxtw_op, + SDPatternOperator uxtw_op, RegisterOperand sxtw_opnd, - RegisterOperand uxtw_opnd> { + RegisterOperand uxtw_opnd, + ValueType vt> { def _UXTW_SCALED : sve_mem_32b_prfm_sv; def _SXTW_SCALED : sve_mem_32b_prfm_sv; + + def : Pat<(uxtw_op (nxv4i1 PPR:$gp), GPR64sp:$base, (nxv4i32 ZPR:$indices), (i32 sve_prfop:$prfop), vt), + (!cast(NAME # _UXTW_SCALED) sve_prfop:$prfop, PPR:$gp, GPR64sp:$base, ZPR:$indices)>; + def : Pat<(sxtw_op (nxv4i1 PPR:$gp), GPR64sp:$base, (nxv4i32 ZPR:$indices), (i32 sve_prfop:$prfop), vt), + (!cast(NAME # _SXTW_SCALED) sve_prfop:$prfop, PPR:$gp, GPR64sp:$base, ZPR:$indices)>; } class sve_mem_32b_prfm_vi msz, string asm, Operand imm_ty> @@ -3945,9 +6685,13 @@ let Inst{3-0} = prfop; } -multiclass sve_mem_32b_prfm_vi msz, string asm, Operand imm_ty> { +multiclass sve_mem_32b_prfm_vi msz, string asm, Operand imm_ty, + SDPatternOperator prefetch, ValueType vt> { def NAME : sve_mem_32b_prfm_vi; + def : Pat<(prefetch (nxv4i1 PPR:$gp), (i64 imm_ty:$imm5), (nxv4i32 ZPR:$indices), (i32 sve_prfop:$prfop), vt), + (!cast(NAME) sve_prfop:$prfop, PPR:$gp, ZPR:$indices, imm_ty:$imm5)>; + def : InstAlias(NAME) sve_prfop:$prfop, PPR3bAny:$Pg, ZPR32:$Zn, 0), 1>; } @@ -4003,6 +6747,49 @@ (!cast(NAME) PPRAny:$Pt, GPR64sp:$Rn, 0), 1>; } +class sve2_mem_gldnt_vs_base opc, dag iops, string asm, + RegisterOperand VecList> +: I<(outs VecList:$Zt), iops, + asm, "\t$Zt, $Pg/z, [$Zn, $Rm]", + "", + []>, Sched<[]> { + bits<3> Pg; + bits<5> Rm; + bits<5> Zn; + bits<5> Zt; + let Inst{31} = 0b1; + let Inst{30} = opc{4}; + let Inst{29-25} = 0b00010; + let Inst{24-23} = opc{3-2}; + let Inst{22-21} = 0b00; + let Inst{20-16} = Rm; + let Inst{15} = 0b1; + let Inst{14-13} = opc{1-0}; + let Inst{12-10} = Pg; + let Inst{9-5} = Zn; + let Inst{4-0} = Zt; + + let mayLoad = 1; +} + +multiclass sve2_mem_gldnt_vs opc, string asm, + RegisterOperand listty, ZPRRegOp zprty, + SDPatternOperator op, ValueType vt1, + ValueType vt2, ValueType vt3> { + def _REAL : sve2_mem_gldnt_vs_base; + + def : InstAlias(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, zprty:$Zn, GPR64:$Rm), 0>; + def : InstAlias(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, zprty:$Zn, XZR), 0>; + def : InstAlias(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, zprty:$Zn, XZR), 1>; + + def : Pat <(vt1 (op (vt2 PPR3bAny:$Pg), (i64 GPR64:$Rm) , (vt1 zprty:$Zd), vt3)), + (!cast(NAME # _REAL) PPR3bAny:$Pg, zprty:$Zd, GPR64:$Rm)>; +} + //===----------------------------------------------------------------------===// // SVE Memory - 64-bit Gather Group //===----------------------------------------------------------------------===// @@ -4037,8 +6824,10 @@ } multiclass sve_mem_64b_gld_sv_32_scaled opc, string asm, + SDPatternOperator op, RegisterOperand sxtw_opnd, - RegisterOperand uxtw_opnd> { + RegisterOperand uxtw_opnd, + ValueType vt> { def _UXTW_SCALED_REAL : sve_mem_64b_gld_sv; def _SXTW_SCALED_REAL : sve_mem_64b_gld_sv; @@ -4046,11 +6835,27 @@ (!cast(NAME # _UXTW_SCALED_REAL) ZPR64:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, uxtw_opnd:$Zm), 0>; def : InstAlias(NAME # _SXTW_SCALED_REAL) ZPR64:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, sxtw_opnd:$Zm), 0>; + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1 in { + def _UXTW_SCALED : Pseudo<(outs Z_d:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, uxtw_opnd:$Zm), []>, + PseudoInstExpansion<(!cast(NAME # _UXTW_SCALED_REAL) Z_d:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, uxtw_opnd:$Zm)>; + def _SXTW_SCALED : Pseudo<(outs Z_d:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, sxtw_opnd:$Zm), []>, + PseudoInstExpansion<(!cast(NAME # _SXTW_SCALED_REAL) Z_d:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, sxtw_opnd:$Zm)>; + } + + def : Pat<(nxv2i64 (op (nxv2i1 PPR:$gp), GPR64sp:$base, (and (nxv2i64 ZPR:$indices), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))), vt)), + (!cast(NAME # _UXTW_SCALED) PPR:$gp, GPR64sp:$base, ZPR:$indices)>; + def : Pat<(nxv2i64 (op (nxv2i1 PPR:$gp), GPR64sp:$base, (sext_inreg (nxv2i64 ZPR:$indices), nxv2i32), vt)), + (!cast(NAME # _SXTW_SCALED) PPR:$gp, GPR64sp:$base, ZPR:$indices)>; } multiclass sve_mem_64b_gld_vs_32_unscaled opc, string asm, + SDPatternOperator op, RegisterOperand sxtw_opnd, - RegisterOperand uxtw_opnd> { + RegisterOperand uxtw_opnd, + ValueType vt> { def _UXTW_REAL : sve_mem_64b_gld_sv; def _SXTW_REAL : sve_mem_64b_gld_sv; @@ -4058,21 +6863,57 @@ (!cast(NAME # _UXTW_REAL) ZPR64:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, uxtw_opnd:$Zm), 0>; def : InstAlias(NAME # _SXTW_REAL) ZPR64:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, sxtw_opnd:$Zm), 0>; + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1 in { + def _UXTW : Pseudo<(outs Z_d:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, uxtw_opnd:$Zm), []>, + PseudoInstExpansion<(!cast(NAME # _UXTW_REAL) Z_d:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, uxtw_opnd:$Zm)>; + def _SXTW : Pseudo<(outs Z_d:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, sxtw_opnd:$Zm), []>, + PseudoInstExpansion<(!cast(NAME # _SXTW_REAL) Z_d:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, sxtw_opnd:$Zm)>; + } + + def : Pat<(nxv2i64 (op (nxv2i1 PPR:$gp), GPR64sp:$base, (and (nxv2i64 ZPR:$offsets), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))), vt)), + (!cast(NAME # _UXTW) PPR:$gp, GPR64sp:$base, ZPR:$offsets)>; + def : Pat<(nxv2i64 (op (nxv2i1 PPR:$gp), GPR64sp:$base, (sext_inreg (nxv2i64 ZPR:$offsets), nxv2i32), vt)), + (!cast(NAME # _SXTW) PPR:$gp, GPR64sp:$base, ZPR:$offsets)>; } multiclass sve_mem_64b_gld_sv2_64_scaled opc, string asm, - RegisterOperand zprext> { + SDPatternOperator op, + RegisterOperand zprext, ValueType vt> { def _SCALED_REAL : sve_mem_64b_gld_sv; def : InstAlias(NAME # _SCALED_REAL) ZPR64:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, zprext:$Zm), 0>; + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1 in { + def _SCALED : Pseudo<(outs Z_d:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, zprext:$Zm), []>, + PseudoInstExpansion<(!cast(NAME # _SCALED_REAL) Z_d:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, zprext:$Zm)>; + } + + def : Pat<(nxv2i64 (op (nxv2i1 PPR:$gp), GPR64sp:$base, (nxv2i64 ZPR:$indices), vt)), + (!cast(NAME # _SCALED) PPR:$gp, GPR64sp:$base, ZPR:$indices)>; } -multiclass sve_mem_64b_gld_vs2_64_unscaled opc, string asm> { +multiclass sve_mem_64b_gld_vs2_64_unscaled opc, string asm, + SDPatternOperator op, ValueType vt> { def _REAL : sve_mem_64b_gld_sv; def : InstAlias(NAME # _REAL) ZPR64:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, ZPR64ExtLSL8:$Zm), 0>; + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1 in { + def "" : Pseudo<(outs Z_d:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, ZPR64ExtLSL8:$Zm), []>, + PseudoInstExpansion<(!cast(NAME # _REAL) Z_d:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, ZPR64ExtLSL8:$Zm)>; + } + + def : Pat<(nxv2i64 (op (nxv2i1 PPR:$gp), GPR64sp:$base, (nxv2i64 ZPR:$offsets), vt)), + (!cast(NAME) PPR:$gp, GPR64sp:$base, ZPR:$offsets)>; } class sve_mem_64b_gld_vi opc, string asm, Operand imm_ty> @@ -4099,7 +6940,8 @@ let Uses = !if(!eq(opc{0}, 1), [FFR], []); } -multiclass sve_mem_64b_gld_vi_64_ptrs opc, string asm, Operand imm_ty> { +multiclass sve_mem_64b_gld_vi_64_ptrs opc, string asm, Operand imm_ty, + SDPatternOperator op, ValueType vt> { def _IMM_REAL : sve_mem_64b_gld_vi; def : InstAlias(NAME # _IMM_REAL) ZPR64:$Zt, PPR3bAny:$Pg, ZPR64:$Zn, imm_ty:$imm5), 0>; def : InstAlias(NAME # _IMM_REAL) Z_d:$Zt, PPR3bAny:$Pg, ZPR64:$Zn, 0), 1>; + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1 in { + def _IMM : Pseudo<(outs Z_d:$Zt), (ins PPR3bAny:$Pg, ZPR64:$Zn, imm_ty:$imm5), []>, + PseudoInstExpansion<(!cast(NAME # _IMM_REAL) Z_d:$Zt, PPR3bAny:$Pg, ZPR64:$Zn, imm_ty:$imm5)>; + } + + def : Pat<(nxv2i64 (op (nxv2i1 PPR:$gp), imm_ty:$index, (nxv2i64 ZPR:$ptrs), vt)), + (!cast(NAME # _IMM) PPR:$gp, ZPR:$ptrs, imm_ty:$index)>; } // bit lsl is '0' if the offsets are extended (uxtw/sxtw), '1' if shifted (lsl) @@ -4135,18 +6987,31 @@ let hasSideEffects = 1; } -multiclass sve_mem_64b_prfm_sv_ext_scaled msz, string asm, - RegisterOperand sxtw_opnd, - RegisterOperand uxtw_opnd> { +multiclass sve_mem_64b_prfm_ext_scaled msz, string asm, + SDPatternOperator sxtw_op, + SDPatternOperator uxtw_op, + RegisterOperand sxtw_opnd, + RegisterOperand uxtw_opnd, + ValueType vt> { def _UXTW_SCALED : sve_mem_64b_prfm_sv; def _SXTW_SCALED : sve_mem_64b_prfm_sv; + + def : Pat<(uxtw_op (nxv2i1 PPR:$gp), GPR64sp:$base, (nxv2i64 ZPR:$indices), (i32 sve_prfop:$prfop), vt), + (!cast(NAME # _UXTW_SCALED) sve_prfop:$prfop, PPR:$gp, GPR64sp:$base, ZPR:$indices)>; + def : Pat<(sxtw_op (nxv2i1 PPR:$gp), GPR64sp:$base, (nxv2i64 ZPR:$indices), (i32 sve_prfop:$prfop), vt), + (!cast(NAME # _SXTW_SCALED) sve_prfop:$prfop, PPR:$gp, GPR64sp:$base, ZPR:$indices)>; } multiclass sve_mem_64b_prfm_sv_lsl_scaled msz, string asm, - RegisterOperand zprext> { + SDPatternOperator op, + RegisterOperand zprext, ValueType vt> { def NAME : sve_mem_64b_prfm_sv; -} + def : Pat<(op (nxv2i1 PPR:$gp), GPR64sp:$base, (nxv2i64 ZPR:$indices), + (i32 sve_prfop:$prfop), vt), + (!cast(NAME) sve_prfop:$prfop, PPR:$gp, GPR64sp:$base, + ZPR:$indices)>; +} class sve_mem_64b_prfm_vi msz, string asm, Operand imm_ty> : I<(outs), (ins sve_prfop:$prfop, PPR3bAny:$Pg, ZPR64:$Zn, imm_ty:$imm5), @@ -4170,9 +7035,13 @@ let hasSideEffects = 1; } -multiclass sve_mem_64b_prfm_vi msz, string asm, Operand imm_ty> { +multiclass sve_mem_64b_prfm_vi msz, string asm, Operand imm_ty, + SDPatternOperator prefetch, ValueType vt> { def NAME : sve_mem_64b_prfm_vi; + def : Pat<(prefetch (nxv2i1 PPR:$gp), (i64 imm_ty:$imm5), (nxv2i64 ZPR:$indices), (i32 sve_prfop:$prfop), vt), + (!cast(NAME) sve_prfop:$prfop, PPR:$gp, ZPR:$indices, imm_ty:$imm5)>; + def : InstAlias(NAME) sve_prfop:$prfop, PPR3bAny:$Pg, ZPR64:$Zn, 0), 1>; } @@ -4206,6 +7075,8 @@ def _1 : sve_int_bin_cons_misc_0_a; def _2 : sve_int_bin_cons_misc_0_a; def _3 : sve_int_bin_cons_misc_0_a; + + def "" : Pseudo<(outs ZPR64:$Zd), (ins ZPR64:$Zn, ZPR64:$Zm, imm0_3:$msz), []>; } multiclass sve_int_bin_cons_misc_0_a_sxtw opc, string asm> { @@ -4213,6 +7084,8 @@ def _1 : sve_int_bin_cons_misc_0_a; def _2 : sve_int_bin_cons_misc_0_a; def _3 : sve_int_bin_cons_misc_0_a; + + def "" : Pseudo<(outs ZPR64:$Zd), (ins ZPR64:$Zn, ZPR64:$Zm, imm0_3:$msz), []>; } multiclass sve_int_bin_cons_misc_0_a_32_lsl opc, string asm> { @@ -4220,6 +7093,8 @@ def _1 : sve_int_bin_cons_misc_0_a; def _2 : sve_int_bin_cons_misc_0_a; def _3 : sve_int_bin_cons_misc_0_a; + + def "" : Pseudo<(outs ZPR32:$Zd), (ins ZPR32:$Zn, ZPR32:$Zm, imm0_3:$msz), []>; } multiclass sve_int_bin_cons_misc_0_a_64_lsl opc, string asm> { @@ -4227,8 +7102,9 @@ def _1 : sve_int_bin_cons_misc_0_a; def _2 : sve_int_bin_cons_misc_0_a; def _3 : sve_int_bin_cons_misc_0_a; -} + def "" : Pseudo<(outs ZPR64:$Zd), (ins ZPR64:$Zn, ZPR64:$Zm, imm0_3:$msz), []>; +} //===----------------------------------------------------------------------===// // SVE Integer Misc - Unpredicated Group @@ -4251,10 +7127,14 @@ let Inst{4-0} = Zd; } -multiclass sve_int_bin_cons_misc_0_b { +multiclass sve_int_bin_cons_misc_0_b { def _H : sve_int_bin_cons_misc_0_b<0b01, asm, ZPR16>; def _S : sve_int_bin_cons_misc_0_b<0b10, asm, ZPR32>; def _D : sve_int_bin_cons_misc_0_b<0b11, asm, ZPR64>; + + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; } class sve_int_bin_cons_misc_0_c opc, string asm, ZPRRegOp zprty> @@ -4298,31 +7178,69 @@ let Inst{4-0} = Vd; } -multiclass sve_int_reduce_0_saddv opc, string asm> { +multiclass sve_int_reduce_0_saddv opc, string asm, + SDPatternOperator op> { def _B : sve_int_reduce<0b00, 0b00, opc, asm, ZPR8, FPR64>; def _H : sve_int_reduce<0b01, 0b00, opc, asm, ZPR16, FPR64>; def _S : sve_int_reduce<0b10, 0b00, opc, asm, ZPR32, FPR64>; + + def : Pat<(v2i64 (op (nxv16i1 PPR3bAny:$Pg), (nxv16i8 ZPR8:$Zn))), + (INSERT_SUBREG (v2i64 (IMPLICIT_DEF)), (!cast(NAME#_B) PPR3bAny:$Pg, ZPR8:$Zn), dsub)>; + def : Pat<(v2i64 (op (nxv8i1 PPR3bAny:$Pg), (nxv8i16 ZPR16:$Zn))), + (INSERT_SUBREG (v2i64 (IMPLICIT_DEF)), (!cast(NAME#_H) PPR3bAny:$Pg, ZPR16:$Zn), dsub)>; + def : Pat<(v2i64 (op (nxv4i1 PPR3bAny:$Pg), (nxv4i32 ZPR32:$Zn))), + (INSERT_SUBREG (v2i64 (IMPLICIT_DEF)), (!cast(NAME#_S) PPR3bAny:$Pg, ZPR32:$Zn), dsub)>; } -multiclass sve_int_reduce_0_uaddv opc, string asm> { +multiclass sve_int_reduce_0_uaddv opc, string asm, + SDPatternOperator op> { def _B : sve_int_reduce<0b00, 0b00, opc, asm, ZPR8, FPR64>; def _H : sve_int_reduce<0b01, 0b00, opc, asm, ZPR16, FPR64>; def _S : sve_int_reduce<0b10, 0b00, opc, asm, ZPR32, FPR64>; def _D : sve_int_reduce<0b11, 0b00, opc, asm, ZPR64, FPR64>; + + def : Pat<(v2i64 (op (nxv16i1 PPR3bAny:$Pg), (nxv16i8 ZPR8:$Zn))), + (INSERT_SUBREG (v2i64 (IMPLICIT_DEF)), (!cast(NAME#_B) PPR3bAny:$Pg, ZPR8:$Zn), dsub)>; + def : Pat<(v2i64 (op (nxv8i1 PPR3bAny:$Pg), (nxv8i16 ZPR16:$Zn))), + (INSERT_SUBREG (v2i64 (IMPLICIT_DEF)), (!cast(NAME#_H) PPR3bAny:$Pg, ZPR16:$Zn), dsub)>; + def : Pat<(v2i64 (op (nxv4i1 PPR3bAny:$Pg), (nxv4i32 ZPR32:$Zn))), + (INSERT_SUBREG (v2i64 (IMPLICIT_DEF)), (!cast(NAME#_S) PPR3bAny:$Pg, ZPR32:$Zn), dsub)>; + def : Pat<(v2i64 (op (nxv2i1 PPR3bAny:$Pg), (nxv2i64 ZPR64:$Zn))), + (INSERT_SUBREG (v2i64 (IMPLICIT_DEF)), (!cast(NAME#_D) PPR3bAny:$Pg, ZPR64:$Zn), dsub)>; } -multiclass sve_int_reduce_1 opc, string asm> { +multiclass sve_int_reduce_1 opc, string asm, + SDPatternOperator op> { def _B : sve_int_reduce<0b00, 0b01, opc, asm, ZPR8, FPR8>; def _H : sve_int_reduce<0b01, 0b01, opc, asm, ZPR16, FPR16>; def _S : sve_int_reduce<0b10, 0b01, opc, asm, ZPR32, FPR32>; def _D : sve_int_reduce<0b11, 0b01, opc, asm, ZPR64, FPR64>; + + def : Pat<(v16i8 (op (nxv16i1 PPR3bAny:$Pg), (nxv16i8 ZPR8:$Zn))), + (INSERT_SUBREG (v16i8 (IMPLICIT_DEF)), (!cast(NAME#_B) PPR3bAny:$Pg, ZPR8:$Zn), bsub)>; + def : Pat<(v8i16 (op (nxv8i1 PPR3bAny:$Pg), (nxv8i16 ZPR16:$Zn))), + (INSERT_SUBREG (v8i16 (IMPLICIT_DEF)), (!cast(NAME#_H) PPR3bAny:$Pg, ZPR16:$Zn), hsub)>; + def : Pat<(v4i32 (op (nxv4i1 PPR3bAny:$Pg), (nxv4i32 ZPR32:$Zn))), + (INSERT_SUBREG (v4i32 (IMPLICIT_DEF)), (!cast(NAME#_S) PPR3bAny:$Pg, ZPR32:$Zn), ssub)>; + def : Pat<(v2i64 (op (nxv2i1 PPR3bAny:$Pg), (nxv2i64 ZPR64:$Zn))), + (INSERT_SUBREG (v2i64 (IMPLICIT_DEF)), (!cast(NAME#_D) PPR3bAny:$Pg, ZPR64:$Zn), dsub)>; } -multiclass sve_int_reduce_2 opc, string asm> { +multiclass sve_int_reduce_2 opc, string asm, + SDPatternOperator op> { def _B : sve_int_reduce<0b00, 0b11, opc, asm, ZPR8, FPR8>; def _H : sve_int_reduce<0b01, 0b11, opc, asm, ZPR16, FPR16>; def _S : sve_int_reduce<0b10, 0b11, opc, asm, ZPR32, FPR32>; def _D : sve_int_reduce<0b11, 0b11, opc, asm, ZPR64, FPR64>; + + def : Pat<(v16i8 (op (nxv16i1 PPR3bAny:$Pg), (nxv16i8 ZPR8:$Zn))), + (INSERT_SUBREG (v16i8 (IMPLICIT_DEF)), (!cast(NAME#_B) PPR3bAny:$Pg, ZPR8:$Zn), bsub)>; + def : Pat<(v8i16 (op (nxv8i1 PPR3bAny:$Pg), (nxv8i16 ZPR16:$Zn))), + (INSERT_SUBREG (v8i16 (IMPLICIT_DEF)), (!cast(NAME#_H) PPR3bAny:$Pg, ZPR16:$Zn), hsub)>; + def : Pat<(v4i32 (op (nxv4i1 PPR3bAny:$Pg), (nxv4i32 ZPR32:$Zn))), + (INSERT_SUBREG (v4i32 (IMPLICIT_DEF)), (!cast(NAME#_S) PPR3bAny:$Pg, ZPR32:$Zn), ssub)>; + def : Pat<(v2i64 (op (nxv2i1 PPR3bAny:$Pg), (nxv2i64 ZPR64:$Zn))), + (INSERT_SUBREG (v2i64 (IMPLICIT_DEF)), (!cast(NAME#_D) PPR3bAny:$Pg, ZPR64:$Zn), dsub)>; } class sve_int_movprfx_pred sz8_32, bits<3> opc, string asm, @@ -4398,6 +7316,15 @@ let Defs = !if(!eq (opc{1}, 1), [NZCV], []); } +multiclass sve_int_brkp opc, string asm, SDPatternOperator op> { + def NAME : sve_int_brkp; + + def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; +} + //===----------------------------------------------------------------------===// // SVE Partition Break Group @@ -4424,6 +7351,15 @@ let Defs = !if(!eq (S, 0b1), [NZCV], []); } +multiclass sve_int_brkn opc, string asm, SDPatternOperator op> { + def NAME : sve_int_brkn; + + def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; +} + class sve_int_break opc, string asm, string suffix, dag iops> : I<(outs PPR8:$Pd), iops, asm, "\t$Pd, $Pg"#suffix#", $Pn", @@ -4446,11 +7382,168 @@ } -multiclass sve_int_break_m opc, string asm> { +multiclass sve_int_break_m opc, string asm, SDPatternOperator op> { def NAME : sve_int_break; + + def : SVE_3_Op_Pat(NAME)>; } -multiclass sve_int_break_z opc, string asm> { +multiclass sve_int_break_z opc, string asm, SDPatternOperator op> { def NAME : sve_int_break; + + def : SVE_2_Op_Pat(NAME)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 String Processing Group +//===----------------------------------------------------------------------===// + +class sve2_char_match +: I<(outs pprty:$Pd), (ins PPR3bAny:$Pg, zprty:$Zn, zprty:$Zm), + asm, "\t$Pd, $Pg/z, $Zn, $Zm", + "", + []>, Sched<[]> { + bits<4> Pd; + bits<3> Pg; + bits<5> Zm; + bits<5> Zn; + let Inst{31-23} = 0b010001010; + let Inst{22} = sz; + let Inst{21} = 0b1; + let Inst{20-16} = Zm; + let Inst{15-13} = 0b100; + let Inst{12-10} = Pg; + let Inst{9-5} = Zn; + let Inst{4} = opc; + let Inst{3-0} = Pd; + + let Defs = [NZCV]; +} + +multiclass sve2_char_match { + def _B : sve2_char_match<0b0, opc, asm, PPR8, ZPR8>; + def _H : sve2_char_match<0b1, opc, asm, PPR16, ZPR16>; + + def : SVE_3_Op_Pat(NAME # _B)>; + def : SVE_3_Op_Pat(NAME # _H)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Histogram Computation - Segment Group +//===----------------------------------------------------------------------===// + +class sve2_hist_gen_segment +: I<(outs ZPR8:$Zd), (ins ZPR8:$Zn, ZPR8:$Zm), + asm, "\t$Zd, $Zn, $Zm", + "", + [(set nxv16i8:$Zd, (op nxv16i8:$Zn, nxv16i8:$Zm))]>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<5> Zm; + let Inst{31-21} = 0b01000101001; + let Inst{20-16} = Zm; + let Inst{15-10} = 0b101000; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +//===----------------------------------------------------------------------===// +// SVE2 Histogram Computation - Vector Group +//===----------------------------------------------------------------------===// + +class sve2_hist_gen_vector +: I<(outs zprty:$Zd), (ins PPR3bAny:$Pg, zprty:$Zn, zprty:$Zm), + asm, "\t$Zd, $Pg/z, $Zn, $Zm", + "", + []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<3> Pg; + bits<5> Zm; + let Inst{31-23} = 0b010001011; + let Inst{22} = sz; + let Inst{21} = 0b1; + let Inst{20-16} = Zm; + let Inst{15-13} = 0b110; + let Inst{12-10} = Pg; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +multiclass sve2_hist_gen_vector { + def _S : sve2_hist_gen_vector<0b0, asm, ZPR32>; + def _D : sve2_hist_gen_vector<0b1, asm, ZPR64>; + + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; +} + +//===----------------------------------------------------------------------===// +// SVE2 Crypto Extensions Group +//===----------------------------------------------------------------------===// + +class sve2_crypto_cons_bin_op +: I<(outs zprty:$Zd), (ins zprty:$Zn, zprty:$Zm), + asm, "\t$Zd, $Zn, $Zm", + "", + []>, Sched<[]> { + bits<5> Zd; + bits<5> Zn; + bits<5> Zm; + let Inst{31-21} = 0b01000101001; + let Inst{20-16} = Zm; + let Inst{15-11} = 0b11110; + let Inst{10} = opc; + let Inst{9-5} = Zn; + let Inst{4-0} = Zd; +} + +multiclass sve2_crypto_cons_bin_op { + def NAME : sve2_crypto_cons_bin_op; + def : SVE_2_Op_Pat(NAME)>; +} + +class sve2_crypto_des_bin_op opc, string asm, ZPRRegOp zprty> +: I<(outs zprty:$Zdn), (ins zprty:$_Zdn, zprty:$Zm), + asm, "\t$Zdn, $_Zdn, $Zm", + "", + []>, Sched<[]> { + bits<5> Zdn; + bits<5> Zm; + let Inst{31-17} = 0b010001010010001; + let Inst{16} = opc{1}; + let Inst{15-11} = 0b11100; + let Inst{10} = opc{0}; + let Inst{9-5} = Zm; + let Inst{4-0} = Zdn; + + let Constraints = "$Zdn = $_Zdn"; +} + +multiclass sve2_crypto_des_bin_op opc, string asm, ZPRRegOp zprty, + SDPatternOperator op, ValueType vt> { + def NAME : sve2_crypto_des_bin_op; + def : SVE_2_Op_Pat(NAME)>; +} + +class sve2_crypto_unary_op +: I<(outs ZPR8:$Zdn), (ins ZPR8:$_Zdn), + asm, "\t$Zdn, $_Zdn", + "", + []>, Sched<[]> { + bits<5> Zdn; + let Inst{31-11} = 0b010001010010000011100; + let Inst{10} = opc; + let Inst{9-5} = 0b00000; + let Inst{4-0} = Zdn; + + let Constraints = "$Zdn = $_Zdn"; +} + +multiclass sve2_crypto_unary_op { + def NAME : sve2_crypto_unary_op; + def : Pat <(nxv16i8 (op (nxv16i8 ZPR8:$Op1))), (!cast(NAME) ZPR8:$Op1)>; } Index: lib/Target/AArch64/SVEIntrinsicOpts.cpp =================================================================== --- /dev/null +++ lib/Target/AArch64/SVEIntrinsicOpts.cpp @@ -0,0 +1,290 @@ +//===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Performs general IR level optimizations on SVE intrinsics. +// +// The two main goals are: +// 1) Replace constant intrinsics with IR constants (e.g. ptrue all). +// 2) Remove reinterpret intrinsics that typically block optimisations +// (e.g. constant propagation). +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "Utils/AArch64BaseInfo.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "sve-intrinsicopts" + +namespace llvm { + void initializeSVEIntrinsicOptsPass(PassRegistry &); +} + +namespace { +struct SVEIntrinsicOpts : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + SVEIntrinsicOpts() : FunctionPass(ID) { + initializeSVEIntrinsicOptsPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; +private: + static bool isReinterpretFromSVBool(const Value *V); + static bool isReinterpretToSVBool(const Value *V); + + static bool optimizeBlock(BasicBlock *BB); + static bool optimizeIntrinsic(Instruction *I); + + static bool optimizeCnt(IntrinsicInst *I, unsigned Scale); + static bool optimizePTest(IntrinsicInst *I); + static bool optimizePTrue(IntrinsicInst *I); + static bool optimizeReinterprets(IntrinsicInst *I); + + static bool processPhiNode(Instruction *I); + + DominatorTree *DT; +}; +} // end anonymous namespace + +void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.setPreservesCFG(); +} + +char SVEIntrinsicOpts::ID = 0; +static const char *name = "SVE intrinsics optimizations"; +INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass); +INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false) + +namespace llvm { +FunctionPass *createSVEIntrinsicOptsPass() { return new SVEIntrinsicOpts(); } +} + +/// Returns true if V is a cast from (aka svbool_t). +bool SVEIntrinsicOpts::isReinterpretToSVBool(const Value *V) { + const IntrinsicInst *I = dyn_cast(V); + if (!I) + return false; + + return I->getIntrinsicID() == Intrinsic::aarch64_sve_reinterpret_bool_b; +} + +/// Returns true if V is a cast to (aka svbool_t). +bool SVEIntrinsicOpts::isReinterpretFromSVBool(const Value *V) { + const IntrinsicInst *I = dyn_cast(V); + if (!I) + return false; + + unsigned ID = I->getIntrinsicID(); + if (ID != Intrinsic::aarch64_sve_reinterpret_bool_h && + ID != Intrinsic::aarch64_sve_reinterpret_bool_w && + ID != Intrinsic::aarch64_sve_reinterpret_bool_d) + return false; + + return true; +} + +/// The function will remove redundant reinterprets casting in the presence +/// of the control flow +bool SVEIntrinsicOpts::processPhiNode(Instruction *X) { + + SmallVector Worklist; + auto RequiredType = X->getType(); + + auto *PN = dyn_cast(X->getOperand(0)); + if (!PN) + return false; + + // Don't create a new PHI unless we can remove the old one. + if (!PN->hasOneUse()) + return false; + + for (Value *IncValPhi : PN->incoming_values()) { + auto *IncValPhiInst = dyn_cast(IncValPhi); + if (!IncValPhiInst) + return false; + + if (!isReinterpretToSVBool(IncValPhiInst)) + return false; + + Value *SourceVal = IncValPhiInst->getOperand(0); + if (RequiredType != SourceVal->getType()) + return false; + } + + LLVMContext &C1 = PN->getContext(); + IRBuilder<> builder(C1); + builder.SetInsertPoint(PN); + PHINode *NPN = builder.CreatePHI(RequiredType, PN->getNumIncomingValues()); + Worklist.push_back(PN); + + for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { + auto *Reinterpret = cast(PN->getIncomingValue(i)); + NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(i)); + Worklist.push_back(Reinterpret); + } + + // Cleanup Phi Node and reinterprets + X->replaceAllUsesWith(NPN); + X->eraseFromParent(); + + for (auto &I : Worklist) + if (I->use_empty()) + I->eraseFromParent(); + + return true; +} + +// Replace an intrinsic call to cnt[bhwd] with a scaled vscale constant. +bool SVEIntrinsicOpts::optimizeCnt(IntrinsicInst *I, unsigned Scale) { + if (cast(I->getArgOperand(0))->getZExtValue() != 31) + return false; + + Type *CntTy = I->getType(); + Constant *CScale = ConstantInt::get(CntTy, Scale); + I->replaceAllUsesWith(ConstantExpr::getMul(VScale::get(CntTy), CScale)); + I->eraseFromParent(); + return true; +} + +bool SVEIntrinsicOpts::optimizePTest(IntrinsicInst *I) { + IntrinsicInst* Op1 = dyn_cast(I->getArgOperand(0)); + IntrinsicInst* Op2 = dyn_cast(I->getArgOperand(1)); + + if (Op1 && Op2 && + Op1->getIntrinsicID() == Intrinsic::aarch64_sve_reinterpret_bool_b && + Op2->getIntrinsicID() == Intrinsic::aarch64_sve_reinterpret_bool_b && + Op1->getArgOperand(0)->getType() == Op2->getArgOperand(0)->getType()) { + Value *Ops[] = { Op1->getArgOperand(0), Op2->getArgOperand(0) }; + Type *Tys[] = { Op1->getArgOperand(0)->getType() }; + Module *M = I->getParent()->getParent()->getParent(); + + auto Fn = Intrinsic::getDeclaration(M, I->getIntrinsicID(), Tys); + auto CI = CallInst::Create(Fn, Ops, I->getName(), I); + + I->replaceAllUsesWith(CI); + I->eraseFromParent(); + if (Op1->use_empty()) + Op1->eraseFromParent(); + if (Op2->use_empty()) + Op2->eraseFromParent(); + + return true; + } + + return false; +} + +// Replace "ptrue_ all" with ConstantInt(true). +bool SVEIntrinsicOpts::optimizePTrue(IntrinsicInst *I) { + if (cast(I->getArgOperand(0))->getZExtValue() != 31) + return false; + + I->replaceAllUsesWith(ConstantInt::getTrue(I->getType())); + I->eraseFromParent(); + return true; +} + +bool SVEIntrinsicOpts::optimizeReinterprets(IntrinsicInst *I) { + assert(isReinterpretFromSVBool(I)); + + // If the reinterpret instruction operand is a PHI Node + if (isa(I->getArgOperand(0))) + return processPhiNode(I); + + // If we have a reinterpret intrinsic I of type A which is converting from + // another reinterpret Y of type B, and the source type of Y is A, then we can + // elide away both reinterprets if there are no other users of Y. + IntrinsicInst *Y = dyn_cast(I->getArgOperand(0)); + if (!Y) + return false; + if (isReinterpretToSVBool(Y)) + return false; + + Value *SourceVal = Y->getArgOperand(0); + if (I->getType() != SourceVal->getType()) + return false; + + I->replaceAllUsesWith(SourceVal); + I->eraseFromParent(); + if (Y->use_empty()) + Y->eraseFromParent(); + + return true; +} + +bool SVEIntrinsicOpts::optimizeIntrinsic(Instruction *I) { + IntrinsicInst *IntrI = dyn_cast(I); + if (!IntrI) + return false; + + switch (IntrI->getIntrinsicID()) { + case Intrinsic::aarch64_sve_cntb: + return optimizeCnt(IntrI, 16); + case Intrinsic::aarch64_sve_cnth: + return optimizeCnt(IntrI, 8); + case Intrinsic::aarch64_sve_cntw: + return optimizeCnt(IntrI, 4); + case Intrinsic::aarch64_sve_cntd: + return optimizeCnt(IntrI, 2); + case Intrinsic::aarch64_sve_reinterpret_bool_b: + // The reinterprets are clang specific, which never connects them like this. + assert(!isReinterpretFromSVBool(IntrI->getArgOperand(0))); + return false; + case Intrinsic::aarch64_sve_reinterpret_bool_h: + case Intrinsic::aarch64_sve_reinterpret_bool_w: + case Intrinsic::aarch64_sve_reinterpret_bool_d: + return optimizeReinterprets(IntrI); + case Intrinsic::aarch64_sve_ptrue: + return optimizePTrue(IntrI); + case Intrinsic::aarch64_sve_ptest_any: + case Intrinsic::aarch64_sve_ptest_first: + case Intrinsic::aarch64_sve_ptest_last: + return optimizePTest(IntrI); + default: + return false; + } +} + +bool SVEIntrinsicOpts::optimizeBlock(BasicBlock *BB) { + bool Changed = false; + for (auto II = BB->begin(), IE = BB->end(); II != IE;) { + Instruction *I = &(*II); + II = std::next(II); + Changed |= optimizeIntrinsic(I); + } + return Changed; +} + +bool SVEIntrinsicOpts::runOnFunction(Function &F) { + DT = &getAnalysis().getDomTree(); + bool Changed = false; + + // Traverse the DT with an rpo walk so we see defs before uses, allowing + // simplification to be done incrementally. + BasicBlock *Root = DT->getRoot(); + ReversePostOrderTraversal RPOT(Root); + for (auto I = RPOT.begin(), E = RPOT.end(); I != E; ++I) { + Changed |= optimizeBlock(*I); + } + + return Changed; +} Index: lib/Target/AArch64/SVEPostVectorize.cpp =================================================================== --- /dev/null +++ lib/Target/AArch64/SVEPostVectorize.cpp @@ -0,0 +1,356 @@ +//===- SVEPostVectorize - A SVE Loops Optimizer -----------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass looks for idioms that can be expressed more efficiently by SVE +// intrinsics. The main focus is the IR produced by the loop vectoriser when +// its using scalable vectors. The goal of the loop vectorizer is to represent +// scalable vectors in a generic way. However, cases exists whereby SVE +// supports instructions that are specifically intended to handle common +// vectorization features and it's our job to recongnise them within the generic +// IR and construct a suitable replacement. +// +// An important case is the calculation of predicates used by a vectorized loop. +// In most cases the resulting IR uses types that are difficult for the code +// generator. For example, as the original scalar induction variable is often +// an i64, vector types of are not uncommon. +// +// However, SVE has a WHILE instruction that allows induction based predicates +// to be calculated from the original scalar variables, thus removing the need +// to handle such unfriendly vector types. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "sve-postvec" + +STATISTIC(NumVisited, "Number of loops visited."); +STATISTIC(NumOptimized, "Number of loops optimized."); +STATISTIC(NumWhileConversions, "Number of while intrinsics introduced."); + +namespace llvm { + void initializeSVEPostVectorizePass(PassRegistry &); +} + +namespace { +struct SVEPostVectorize : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + SVEPostVectorize() : FunctionPass(ID) { + initializeSVEPostVectorizePass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + this->F = &F; + LI = &getAnalysis().getLoopInfo(); + TTI = &getAnalysis().getTTI(F); + + bool Changed = false; + for (auto I = LI->begin(), IE = LI->end(); I != IE; ++I) + // Traverse loop nest in post-order so sub-loops are processed first. + for (auto L = po_begin(*I), LE = po_end(*I); L != LE; ++L) + Changed |= runOnLoop(*L); + + VisitedBlocks.clear(); + return Changed; + } + + bool runOnLoop(Loop *L); + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequiredID(LoopSimplifyID); + AU.addRequired(); + AU.addRequired(); + } + +private: + Function *F; + LoopInfo *LI; + TargetTransformInfo *TTI; + SmallPtrSet VisitedBlocks; + + bool Optimize_WhileConversion(BasicBlock *BB); + bool Transform_StructAddressing(BasicBlock *BB); + + Instruction *CreateWhile(Intrinsic::ID IntID, Type *Ty, Value *Op1, + Value *Op2); + Instruction* ConvertToWhile(Value* V); +}; +} + +char SVEPostVectorize::ID = 0; +static const char *name = "SVE Post Vectorisation"; +INITIALIZE_PASS_BEGIN(SVEPostVectorize, DEBUG_TYPE, name, false, false) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_END(SVEPostVectorize, DEBUG_TYPE, name, false, false) + +namespace llvm { +FunctionPass *createSVEPostVectorizePass() {return new SVEPostVectorize();} +} + +bool SVEPostVectorize::runOnLoop(Loop *L) { + bool LoopOptimized = false; + // Look for idioms introduced by the loop vectorizer that can be expressed + // more efficiently using SVE specific intrinsics. + + SmallVector Worklist; + for (auto BB = L->block_begin(), BE = L->block_end(); BB != BE; ++BB) + if (!VisitedBlocks.count(*BB)) + Worklist.push_back(*BB); + if (auto PreHeader = L->getLoopPreheader()) + if (!VisitedBlocks.count(PreHeader)) + Worklist.push_back(PreHeader); + + while (!Worklist.empty()) { + // First block should be preheader if it exists. + auto BB = Worklist.pop_back_val(); + LoopOptimized |= Optimize_WhileConversion(BB); + LoopOptimized |= Transform_StructAddressing(BB); + VisitedBlocks.insert(BB); + } + + ++NumVisited; + if (LoopOptimized) + ++NumOptimized; + + return LoopOptimized; +} + +/// Create call to specified WHILE intrinsic. +/// +Instruction* SVEPostVectorize::CreateWhile(Intrinsic::ID IntID, Type* Ty, + Value* Op1, Value* Op2) { + SmallVector Types = { Ty, Op1->getType() }; + SmallVector Args { Op1, Op2 }; + + Function *Intrinsic = Intrinsic::getDeclaration(F->getParent(), IntID, Types); + return CallInst::Create(Intrinsic->getFunctionType(), Intrinsic, Args); +} + +/// If V matches the semantics of an llvm.aarch64.sve.while## intrinsic a +/// suitable call is returned, nullptr otherwise. +/// +Instruction* SVEPostVectorize::ConvertToWhile(Value* V) { + Value *Series, *Splat; + ICmpInst::Predicate Pred; + + if (!match(V, m_ICmp(Pred, m_Value(Series), m_Value(Splat)))) + return nullptr; + + // Canonicalise to less based comparisions. + if ((Pred == ICmpInst::ICMP_SGE) || (Pred == ICmpInst::ICMP_SGT) || + (Pred == ICmpInst::ICMP_UGE) || (Pred == ICmpInst::ICMP_UGT)) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(Series, Splat); + } + + Intrinsic::ID IntID; + switch (Pred) { + default: return nullptr; + case ICmpInst::ICMP_SLE: IntID = Intrinsic::aarch64_sve_whilele; break; + case ICmpInst::ICMP_SLT: IntID = Intrinsic::aarch64_sve_whilelt; break; + case ICmpInst::ICMP_ULE: IntID = Intrinsic::aarch64_sve_whilels; break; + case ICmpInst::ICMP_ULT: IntID = Intrinsic::aarch64_sve_whilelo; break; + } + + Value *Start, *End; + if (!match(Splat, m_SplatVector(m_Value(End)))) + return nullptr; + + // 0, 1, 2, ... + if (isa(Series)) + Start = Constant::getNullValue(End->getType()); + // n, n+1, n+2 ... + else if (ICmpInst::isSigned(Pred) && + !match(Series, m_NSWAdd(m_SplatVector(m_Value(Start)), + m_StepVector()))) + return nullptr; + else if (ICmpInst::isUnsigned(Pred) && + !match(Series, m_NUWAdd(m_SplatVector(m_Value(Start)), + m_StepVector()))) + return nullptr; + + return CreateWhile(IntID, V->getType(), Start, End); +} + +/// Attempt to use llvm.aarch64.sve.while## when calculating predicate vectors. +/// +bool SVEPostVectorize::Optimize_WhileConversion(BasicBlock *BB) { + bool LoopChanged = false; + + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e;) { + Instruction *I = &*it++; + + if (auto W = ConvertToWhile(I)) { + LLVM_DEBUG(dbgs() << "SVE: Replaced " << *I << " by " << *W << "\n"); + W->insertBefore(I); + W->takeName(I); + I->replaceAllUsesWith(W); + I->eraseFromParent(); + + ++NumWhileConversions; + LoopChanged = true; + continue; + } + + // Special case to catch constant predicate creation. + if (auto PHI = dyn_cast(I)) { + for (unsigned i = 0, e = PHI->getNumIncomingValues(); i != e; ++i) { + auto *Op = PHI->getIncomingValue(i); + auto *BB = PHI->getIncomingBlock(i); + + if (!isa(Op)) + continue; + + if (auto W = ConvertToWhile(Op)) { + LLVM_DEBUG(dbgs() << "SVE: Replaced " << *Op << " by " << *W << "\n"); + W->insertBefore(BB->getTerminator()); + PHI->setIncomingValue(i, W); + + ++NumWhileConversions; + LoopChanged = true; + continue; + } + } + } + } + + return LoopChanged; +} + +/// Transform aggregate type addressing used by gathers/scatters to instead use +/// the first member address. This canonicalization is needed to optimize more +/// cases in the gather/scatter interleave lowering pass. +/// +bool SVEPostVectorize::Transform_StructAddressing(BasicBlock *BB) { + bool LoopChanged = false; + // We look for patterns where a bitcast of a vector of pointers to aggregate + // types is being fed by a GEP, with the GEP using a vector of offsets of + // generated by a seriesvector. + + SmallVector ProcessedInsts; + + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + Instruction *I = &*it; + Value *BC, *GEPVal; + if (!(match(I, m_Intrinsic(m_Value(BC))) || + match(I, m_Intrinsic(m_Value(), + m_Value(BC))))) + continue; + + if (!(match(BC, m_BitCast(m_Value(GEPVal))) && + isa(GEPVal))) + continue; + + auto *GEP = cast(GEPVal); + if (!isa(GEP->getPointerOperand())) + continue; + Instruction *Offsets = dyn_cast(GEP->getOperand(1)); + Value *SVBase; + ConstantInt *SVStep; + if (!Offsets || + !match(Offsets, m_SeriesVector(m_Value(SVBase), m_ConstantInt(SVStep)))) + continue; + + // If the seriesvector step isn't 1, then this needs to be a real + // gather/scatter. + if (SVStep->getZExtValue() != 1) + continue; + + auto BCInst = cast(BC); + Type *SrcEltTy = + BCInst->getSrcTy()->getVectorElementType()->getPointerElementType(); + if (!SrcEltTy->isAggregateType()) + continue; + + // The struct elements must not be themselves structs. + auto IsAggTy = [](Type *Ty) { return Ty->isAggregateType(); }; + if (auto STy = dyn_cast(SrcEltTy)) { + if (std::any_of(STy->element_begin(), STy->element_end(), IsAggTy)) + continue; + } + + // The size of the destination type must divide that of the aggregate type. + auto DL = &F->getParent()->getDataLayout(); + Type *DstEltTy = + BCInst->getDestTy()->getVectorElementType()->getPointerElementType(); + unsigned DestSize = DL->getTypeStoreSize(DstEltTy); + unsigned AggSize = DL->getTypeStoreSize(SrcEltTy); + if (AggSize % DestSize) { + LLVM_DEBUG(dbgs() << "SVE: Can't optimize gather/scatter. Aggregate size not" + " multiple of load/store elt size"); + continue; + } + + // Create a new seriesvec to scale the base offset by the effective stride. + unsigned Stride = AggSize / DestSize; + IRBuilder<> B(cast(Offsets)); + // We need to create the same kind of scaling as non-aggregate + // gathers/scatters in the loop, in order for LSR to create a single new + // induction variable for all of them. + auto NewBase = B.CreateShl( + SVBase, ConstantInt::get(SVBase->getType(), Log2_64(Stride))); + B.SetInsertPoint(I); + auto NewOffsets = B.CreateSeriesVector( + cast(Offsets->getType())->getElementCount(), NewBase, + ConstantInt::get(SVStep->getType(), Stride)); + + // The base address for the gather/scatter should be another GEP, which + // we need to index further to get the element address. + auto BaseGEP = cast(GEP->getPointerOperand()); + SmallVector Indices; + for (auto II = BaseGEP->idx_begin(), IE = BaseGEP->idx_end(); II != IE; + ++II) + Indices.push_back(*II); + Indices.push_back(B.getInt32(0)); + auto NewBaseGEP = GetElementPtrInst::Create( + nullptr, BaseGEP->getPointerOperand(), Indices, "basegep", BaseGEP); + NewBaseGEP->setDebugLoc(BaseGEP->getDebugLoc()); + + auto NewGEP = GetElementPtrInst::Create(nullptr, NewBaseGEP, {NewOffsets}); + NewGEP->setDebugLoc(GEP->getDebugLoc()); + B.Insert(NewGEP); + auto NewBC = B.CreateBitCast(NewGEP, BC->getType()); + BC->replaceAllUsesWith(NewBC); + + ProcessedInsts.push_back(BCInst); + ProcessedInsts.push_back(GEP); + ProcessedInsts.push_back(BaseGEP); + + LoopChanged = true; + } + + for (auto &I : ProcessedInsts) { + if (I->getNumUses() == 0) + I->eraseFromParent(); + } + return LoopChanged; +} Index: lib/Target/AArch64/Utils/AArch64BaseInfo.h =================================================================== --- lib/Target/AArch64/Utils/AArch64BaseInfo.h +++ lib/Target/AArch64/Utils/AArch64BaseInfo.h @@ -208,7 +208,13 @@ AL = 0xe, // Always (unconditional) Always (unconditional) NV = 0xf, // Always (unconditional) Always (unconditional) // Note the NV exists purely to disassemble 0b1111. Execution is "always". - Invalid + Invalid, + + // Common aliases used for SVE. + ANY_ACTIVE = NE, // (!Z) + FIRST_ACTIVE = MI, // ( N) + LAST_ACTIVE = LO, // (!C) + NONE_ACTIVE = EQ // ( Z) }; inline static const char *getCondCodeName(CondCode Code) { @@ -563,6 +569,18 @@ }; } // end namespace AArch64II +namespace AArch64 { +// The number of bits in a SVE register is architecturally defined +// to be a multiple of this value. If has this number of bits, +// a vector can be stored in a SVE register without any +// redundant bits. If has this number of bits divided by P, +// a vector is stored in a SVE register by placing index i +// in index i*P of a vector. The other elements of the +// vector (such as index 1) are undefined. +const unsigned SVEBitsPerBlock = 128; +const unsigned SVEMaxBitsPerVector = 2048; +} // end namespace AArch64 + } // end namespace llvm #endif Index: lib/Target/AMDGPU/AMDGPUISelLowering.cpp =================================================================== --- lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -223,6 +223,7 @@ setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v16f32, MVT::v16f16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand); @@ -275,6 +276,7 @@ setTruncStoreAction(MVT::v2f32, MVT::v2f16, Expand); setTruncStoreAction(MVT::v4f32, MVT::v4f16, Expand); setTruncStoreAction(MVT::v8f32, MVT::v8f16, Expand); + setTruncStoreAction(MVT::v16f32, MVT::v16f16, Expand); setTruncStoreAction(MVT::f64, MVT::f16, Expand); setTruncStoreAction(MVT::f64, MVT::f32, Expand); @@ -933,12 +935,12 @@ // the correct memory offsets. SmallVector ValueVTs; - SmallVector Offsets; - ComputeValueVTs(*this, DL, BaseArgTy, ValueVTs, &Offsets, ArgOffset); + SmallVector Offsets; + ComputeValueVTs(*this, DL, BaseArgTy, ValueVTs, &Offsets, {ArgOffset,0}); for (unsigned Value = 0, NumValues = ValueVTs.size(); Value != NumValues; ++Value) { - uint64_t BasePartOffset = Offsets[Value]; + uint64_t BasePartOffset = Offsets[Value].UnscaledBytes; EVT ArgVT = ValueVTs[Value]; EVT MemVT = ArgVT; Index: lib/Target/AMDGPU/SIFrameLowering.cpp =================================================================== --- lib/Target/AMDGPU/SIFrameLowering.cpp +++ lib/Target/AMDGPU/SIFrameLowering.cpp @@ -703,10 +703,10 @@ } } } - - FuncInfo->removeSGPRToVGPRFrameIndices(MFI); } + FuncInfo->removeSGPRToVGPRFrameIndices(MFI); + // FIXME: The other checks should be redundant with allStackObjectsAreDead, // but currently hasNonSpillStackObjects is set only from source // allocas. Stack temps produced from legalization are not counted currently. Index: lib/Target/AMDGPU/SIMachineFunctionInfo.cpp =================================================================== --- lib/Target/AMDGPU/SIMachineFunctionInfo.cpp +++ lib/Target/AMDGPU/SIMachineFunctionInfo.cpp @@ -308,6 +308,11 @@ void SIMachineFunctionInfo::removeSGPRToVGPRFrameIndices(MachineFrameInfo &MFI) { for (auto &R : SGPRToVGPRSpills) MFI.RemoveStackObject(R.first); + // All other SPGRs must be allocated on the default stack, so reset + // the stack ID. + for (unsigned i = MFI.getObjectIndexBegin(), e = MFI.getObjectIndexEnd(); + i != e; ++i) + MFI.setStackID(i, 0); } Index: lib/Target/ARM/ARMBaseInstrInfo.cpp =================================================================== --- lib/Target/ARM/ARMBaseInstrInfo.cpp +++ lib/Target/ARM/ARMBaseInstrInfo.cpp @@ -1172,8 +1172,14 @@ unsigned ARMBaseInstrInfo::isStoreToStackSlotPostFE(const MachineInstr &MI, int &FrameIndex) const { - const MachineMemOperand *Dummy; - return MI.mayStore() && hasStoreToStackSlot(MI, Dummy, FrameIndex); + SmallVector Accesses; + if (MI.mayStore() && hasStoreToStackSlot(MI, Accesses)) { + FrameIndex = + cast(Accesses.front()->getPseudoValue()) + ->getFrameIndex(); + return true; + } + return false; } void ARMBaseInstrInfo:: @@ -1386,8 +1392,14 @@ unsigned ARMBaseInstrInfo::isLoadFromStackSlotPostFE(const MachineInstr &MI, int &FrameIndex) const { - const MachineMemOperand *Dummy; - return MI.mayLoad() && hasLoadFromStackSlot(MI, Dummy, FrameIndex); + SmallVector Accesses; + if (MI.mayLoad() && hasLoadFromStackSlot(MI, Accesses)) { + FrameIndex = + cast(Accesses.front()->getPseudoValue()) + ->getFrameIndex(); + return true; + } + return false; } /// Expands MEMCPY to either LDMIA/STMIA or LDMIA_UPD/STMID_UPD Index: lib/Target/ARM/ARMCallLowering.cpp =================================================================== --- lib/Target/ARM/ARMCallLowering.cpp +++ lib/Target/ARM/ARMCallLowering.cpp @@ -193,8 +193,8 @@ const Function &F = MF.getFunction(); SmallVector SplitVTs; - SmallVector Offsets; - ComputeValueVTs(TLI, DL, OrigArg.Ty, SplitVTs, &Offsets, 0); + SmallVector Offsets; + ComputeValueVTs(TLI, DL, OrigArg.Ty, SplitVTs, &Offsets, {0,0}); if (SplitVTs.size() == 1) { // Even if there is no splitting to do, we still want to replace the @@ -231,7 +231,8 @@ } for (unsigned i = 0; i < Offsets.size(); ++i) - PerformArgSplit(SplitArgs[FirstRegIdx + i].Reg, Offsets[i] * 8); + PerformArgSplit(SplitArgs[FirstRegIdx + i].Reg, + Offsets[i].UnscaledBytes * 8); } /// Lower the return value for the already existing \p Ret. This assumes that Index: lib/Target/ARM/ARMParallelDSP.cpp =================================================================== --- lib/Target/ARM/ARMParallelDSP.cpp +++ lib/Target/ARM/ARMParallelDSP.cpp @@ -21,6 +21,7 @@ #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/NoFolder.h" #include "llvm/Transforms/Scalar.h" @@ -111,6 +112,7 @@ ScalarEvolution *SE; AliasAnalysis *AA; TargetLibraryInfo *TLI; + TargetTransformInfo *TTI; DominatorTree *DT; LoopInfo *LI; Loop *L; @@ -141,6 +143,7 @@ AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); @@ -153,6 +156,8 @@ SE = &getAnalysis().getSE(); AA = &getAnalysis().getAAResults(); TLI = &getAnalysis().getTLI(); + TTI = &getAnalysis().getTTI( + *(L->getHeader()->getParent())); DT = &getAnalysis().getDomTree(); LI = &getAnalysis().getLoopInfo(); auto &TPC = getAnalysis(); @@ -189,7 +194,7 @@ return false; } - LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI); + LoopAccessInfo LAI(L, SE, TLI, TTI, AA, DT, LI); bool Changes = false; LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n\n"); @@ -406,7 +411,7 @@ } static void MatchReductions(Function &F, Loop *TheLoop, BasicBlock *Header, - ReductionList &Reductions) { + ReductionList &Reductions, ScalarEvolution* SE) { RecurrenceDescriptor RecDesc; const bool HasFnNoNaNAttr = F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; @@ -426,7 +431,8 @@ const bool IsReduction = RecurrenceDescriptor::AddReductionVar(&Phi, RecurrenceDescriptor::RK_IntegerAdd, - TheLoop, HasFnNoNaNAttr, RecDesc); + TheLoop, HasFnNoNaNAttr, SE, + RecDesc, false); if (!IsReduction) continue; @@ -596,7 +602,7 @@ bool Changed = false; ReductionList Reductions; - MatchReductions(F, L, Header, Reductions); + MatchReductions(F, L, Header, Reductions, SE); for (auto &R : Reductions) { OpChainList MACCandidates; Index: lib/Target/Hexagon/HexagonInstrInfo.h =================================================================== --- lib/Target/Hexagon/HexagonInstrInfo.h +++ lib/Target/Hexagon/HexagonInstrInfo.h @@ -69,16 +69,16 @@ /// Check if the instruction or the bundle of instructions has /// load from stack slots. Return the frameindex and machine memory operand /// if true. - bool hasLoadFromStackSlot(const MachineInstr &MI, - const MachineMemOperand *&MMO, - int &FrameIndex) const override; + bool hasLoadFromStackSlot( + const MachineInstr &MI, + SmallVectorImpl &Accesses) const override; /// Check if the instruction or the bundle of instructions has /// store to stack slots. Return the frameindex and machine memory operand /// if true. - bool hasStoreToStackSlot(const MachineInstr &MI, - const MachineMemOperand *&MMO, - int &FrameIndex) const override; + bool hasStoreToStackSlot( + const MachineInstr &MI, + SmallVectorImpl &Accesses) const override; /// Analyze the branching code at the end of MBB, returning /// true if it cannot be understood (e.g. it's a switch dispatch or isn't Index: lib/Target/Hexagon/HexagonInstrInfo.cpp =================================================================== --- lib/Target/Hexagon/HexagonInstrInfo.cpp +++ lib/Target/Hexagon/HexagonInstrInfo.cpp @@ -335,37 +335,37 @@ /// This function checks if the instruction or bundle of instructions /// has load from stack slot and returns frameindex and machine memory /// operand of that instruction if true. -bool HexagonInstrInfo::hasLoadFromStackSlot(const MachineInstr &MI, - const MachineMemOperand *&MMO, - int &FrameIndex) const { +bool HexagonInstrInfo::hasLoadFromStackSlot( + const MachineInstr &MI, + SmallVectorImpl &Accesses) const { if (MI.isBundle()) { const MachineBasicBlock *MBB = MI.getParent(); MachineBasicBlock::const_instr_iterator MII = MI.getIterator(); for (++MII; MII != MBB->instr_end() && MII->isInsideBundle(); ++MII) - if (TargetInstrInfo::hasLoadFromStackSlot(*MII, MMO, FrameIndex)) + if (TargetInstrInfo::hasLoadFromStackSlot(*MII, Accesses)) return true; return false; } - return TargetInstrInfo::hasLoadFromStackSlot(MI, MMO, FrameIndex); + return TargetInstrInfo::hasLoadFromStackSlot(MI, Accesses); } /// This function checks if the instruction or bundle of instructions /// has store to stack slot and returns frameindex and machine memory /// operand of that instruction if true. -bool HexagonInstrInfo::hasStoreToStackSlot(const MachineInstr &MI, - const MachineMemOperand *&MMO, - int &FrameIndex) const { +bool HexagonInstrInfo::hasStoreToStackSlot( + const MachineInstr &MI, + SmallVectorImpl &Accesses) const { if (MI.isBundle()) { const MachineBasicBlock *MBB = MI.getParent(); MachineBasicBlock::const_instr_iterator MII = MI.getIterator(); for (++MII; MII != MBB->instr_end() && MII->isInsideBundle(); ++MII) - if (TargetInstrInfo::hasStoreToStackSlot(*MII, MMO, FrameIndex)) + if (TargetInstrInfo::hasStoreToStackSlot(*MII, Accesses)) return true; return false; } - return TargetInstrInfo::hasStoreToStackSlot(MI, MMO, FrameIndex); + return TargetInstrInfo::hasStoreToStackSlot(MI, Accesses); } /// This function can analyze one/two way branching only and should (mostly) be Index: lib/Target/Hexagon/HexagonTargetTransformInfo.h =================================================================== --- lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -101,10 +101,11 @@ unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract); unsigned getOperandsScalarizationOverhead(ArrayRef Args, - unsigned VF); + VectorType::ElementCount VF); unsigned getCallInstrCost(Function *F, Type *RetTy, ArrayRef Tys); unsigned getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, - ArrayRef Args, FastMathFlags FMF, unsigned VF); + ArrayRef Args, FastMathFlags FMF, + VectorType::ElementCount VF = VectorType::SingleElement()); unsigned getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, ArrayRef Tys, FastMathFlags FMF, unsigned ScalarizationCostPassed = UINT_MAX); Index: lib/Target/Hexagon/HexagonTargetTransformInfo.cpp =================================================================== --- lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -120,7 +120,7 @@ } unsigned HexagonTTIImpl::getOperandsScalarizationOverhead( - ArrayRef Args, unsigned VF) { + ArrayRef Args, VectorType::ElementCount VF) { return BaseT::getOperandsScalarizationOverhead(Args, VF); } @@ -130,7 +130,7 @@ } unsigned HexagonTTIImpl::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, - ArrayRef Args, FastMathFlags FMF, unsigned VF) { + ArrayRef Args, FastMathFlags FMF, VectorType::ElementCount VF) { return BaseT::getIntrinsicInstrCost(ID, RetTy, Args, FMF, VF); } Index: lib/Target/Lanai/LanaiInstrInfo.cpp =================================================================== --- lib/Target/Lanai/LanaiInstrInfo.cpp +++ lib/Target/Lanai/LanaiInstrInfo.cpp @@ -733,8 +733,13 @@ if ((Reg = isLoadFromStackSlot(MI, FrameIndex))) return Reg; // Check for post-frame index elimination operations - const MachineMemOperand *Dummy; - return hasLoadFromStackSlot(MI, Dummy, FrameIndex); + SmallVector Accesses; + if (hasLoadFromStackSlot(MI, Accesses)){ + FrameIndex = + cast(Accesses.front()->getPseudoValue()) + ->getFrameIndex(); + return 1; + } } return 0; } Index: lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- lib/Target/NVPTX/NVPTXISelLowering.cpp +++ lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -165,7 +165,7 @@ SmallVectorImpl *Offsets = nullptr, uint64_t StartingOffset = 0) { SmallVector TempVTs; - SmallVector TempOffsets; + SmallVector TempOffsets; // Special case for i128 - decompose to (i64, i64) if (Ty->isIntegerTy(128)) { @@ -180,10 +180,10 @@ return; } - ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, StartingOffset); + ComputeValueVTs(TLI, DL, Ty, TempVTs, &TempOffsets, {StartingOffset, 0}); for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) { EVT VT = TempVTs[i]; - uint64_t Off = TempOffsets[i]; + uint64_t Off = TempOffsets[i].UnscaledBytes; // Split vectors into individual elements, except for v2f16, which // we will pass as a single scalar. if (VT.isVector()) { Index: lib/Target/SystemZ/AsmParser/SystemZAsmParser.cpp =================================================================== --- lib/Target/SystemZ/AsmParser/SystemZAsmParser.cpp +++ lib/Target/SystemZ/AsmParser/SystemZAsmParser.cpp @@ -349,7 +349,7 @@ bool isVR128() const { return isReg(VR128Reg); } bool isAR32() const { return isReg(AR32Reg); } bool isCR64() const { return isReg(CR64Reg); } - bool isAnyReg() const { return (isReg() || isImm(0, 15)); } + bool isAnyReg() const override { return (isReg() || isImm(0, 15)); } bool isBDAddr32Disp12() const { return isMemDisp12(BDMem, ADDR32Reg); } bool isBDAddr32Disp20() const { return isMemDisp20(BDMem, ADDR32Reg); } bool isBDAddr64Disp12() const { return isMemDisp12(BDMem, ADDR64Reg); } Index: lib/Target/SystemZ/SystemZOperators.td =================================================================== --- lib/Target/SystemZ/SystemZOperators.td +++ lib/Target/SystemZ/SystemZOperators.td @@ -498,21 +498,6 @@ return cast(N)->getMemoryVT() == MVT::i32; }]>; -// Extending loads in which the extension type can be unsigned. -def azextload : PatFrag<(ops node:$ptr), (unindexedload node:$ptr), [{ - unsigned Type = cast(N)->getExtensionType(); - return Type == ISD::EXTLOAD || Type == ISD::ZEXTLOAD; -}]>; -def azextloadi8 : PatFrag<(ops node:$ptr), (azextload node:$ptr), [{ - return cast(N)->getMemoryVT() == MVT::i8; -}]>; -def azextloadi16 : PatFrag<(ops node:$ptr), (azextload node:$ptr), [{ - return cast(N)->getMemoryVT() == MVT::i16; -}]>; -def azextloadi32 : PatFrag<(ops node:$ptr), (azextload node:$ptr), [{ - return cast(N)->getMemoryVT() == MVT::i32; -}]>; - // Extending loads in which the extension type doesn't matter. def anyextload : PatFrag<(ops node:$ptr), (unindexedload node:$ptr), [{ return cast(N)->getExtensionType() != ISD::NON_EXTLOAD; Index: lib/Target/TargetMachine.cpp =================================================================== --- lib/Target/TargetMachine.cpp +++ lib/Target/TargetMachine.cpp @@ -81,6 +81,16 @@ Options.FPDenormalMode = FPDenormal::PositiveZero; else Options.FPDenormalMode = DefaultOptions.FPDenormalMode; + + StringRef FPContract = F.getFnAttribute("fp-contract").getValueAsString(); + if (FPContract == "off") + // Preserve any contraction performed by the front-end. (Strict performs + // splitting of the muladd instrinsic in the backend.) + Options.AllowFPOpFusion = llvm::FPOpFusion::Standard; + else if (FPContract == "on") + Options.AllowFPOpFusion = llvm::FPOpFusion::Standard; + else if (FPContract == "fast") + Options.AllowFPOpFusion = llvm::FPOpFusion::Fast; } /// Returns the code generation relocation model. The choices are static, PIC, Index: lib/Target/X86/X86CallLowering.cpp =================================================================== --- lib/Target/X86/X86CallLowering.cpp +++ lib/Target/X86/X86CallLowering.cpp @@ -62,8 +62,8 @@ LLVMContext &Context = OrigArg.Ty->getContext(); SmallVector SplitVTs; - SmallVector Offsets; - ComputeValueVTs(TLI, DL, OrigArg.Ty, SplitVTs, &Offsets, 0); + SmallVector Offsets; + ComputeValueVTs(TLI, DL, OrigArg.Ty, SplitVTs, &Offsets, {0,0}); if (SplitVTs.size() != 1) { // TODO: support struct/array split Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -25056,7 +25056,8 @@ DAG.getConstant(0, dl, MVT::v2i1)); SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale}; return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), N->getMemoryVT(), dl, - Ops, N->getMemOperand()); + Ops, N->getMemOperand(), N->isTruncatingStore(), + N->getIndexType()); } MVT IndexVT = Index.getSimpleValueType(); @@ -25876,7 +25877,9 @@ Index, Gather->getScale() }; SDValue Res = DAG.getMaskedGather(DAG.getVTList(MVT::v4i32, MVT::Other), Gather->getMemoryVT(), dl, Ops, - Gather->getMemOperand()); + Gather->getMemOperand(), + Gather->getExtensionType(), + Gather->getIndexType()); SDValue Chain = Res.getValue(1); if (getTypeAction(*DAG.getContext(), MVT::v2i32) != TypeWidenVector) Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v2i32, Res, Index: lib/Target/X86/X86InstrInfo.cpp =================================================================== --- lib/Target/X86/X86InstrInfo.cpp +++ lib/Target/X86/X86InstrInfo.cpp @@ -411,8 +411,13 @@ if ((Reg = isLoadFromStackSlot(MI, FrameIndex))) return Reg; // Check for post-frame index elimination operations - const MachineMemOperand *Dummy; - return hasLoadFromStackSlot(MI, Dummy, FrameIndex); + SmallVector Accesses; + if (hasLoadFromStackSlot(MI, Accesses)) { + FrameIndex = + cast(Accesses.front()->getPseudoValue()) + ->getFrameIndex(); + return 1; + } } return 0; } @@ -441,8 +446,13 @@ if ((Reg = isStoreToStackSlot(MI, FrameIndex))) return Reg; // Check for post-frame index elimination operations - const MachineMemOperand *Dummy; - return hasStoreToStackSlot(MI, Dummy, FrameIndex); + SmallVector Accesses; + if (hasStoreToStackSlot(MI, Accesses)) { + FrameIndex = + cast(Accesses.front()->getPseudoValue()) + ->getFrameIndex(); + return 1; + } } return 0; } Index: lib/Target/X86/X86TargetTransformInfo.h =================================================================== --- lib/Target/X86/X86TargetTransformInfo.h +++ lib/Target/X86/X86TargetTransformInfo.h @@ -90,8 +90,8 @@ ArrayRef Tys, FastMathFlags FMF, unsigned ScalarizationCostPassed = UINT_MAX); int getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, - ArrayRef Args, FastMathFlags FMF, - unsigned VF = 1); + ArrayRef Args, FastMathFlags FMF, + VectorType::ElementCount VF = VectorType::SingleElement()); int getArithmeticReductionCost(unsigned Opcode, Type *Ty, bool IsPairwiseForm); Index: lib/Target/X86/X86TargetTransformInfo.cpp =================================================================== --- lib/Target/X86/X86TargetTransformInfo.cpp +++ lib/Target/X86/X86TargetTransformInfo.cpp @@ -1869,7 +1869,8 @@ } int X86TTIImpl::getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, - ArrayRef Args, FastMathFlags FMF, unsigned VF) { + ArrayRef Args, FastMathFlags FMF, + VectorType::ElementCount VF) { return BaseT::getIntrinsicInstrCost(IID, RetTy, Args, FMF, VF); } Index: lib/Transforms/IPO/PartialInlining.cpp =================================================================== --- lib/Transforms/IPO/PartialInlining.cpp +++ lib/Transforms/IPO/PartialInlining.cpp @@ -834,42 +834,41 @@ int PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB) { int InlineCost = 0; const DataLayout &DL = BB->getParent()->getParent()->getDataLayout(); - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { - if (isa(I)) - continue; - - switch (I->getOpcode()) { + for (Instruction &I : BB->instructionsWithoutDebug()) { + // Skip free instructions. + switch (I.getOpcode()) { case Instruction::BitCast: case Instruction::PtrToInt: case Instruction::IntToPtr: case Instruction::Alloca: + case Instruction::PHI: continue; case Instruction::GetElementPtr: - if (cast(I)->hasAllZeroIndices()) + if (cast(&I)->hasAllZeroIndices()) continue; break; default: break; } - IntrinsicInst *IntrInst = dyn_cast(I); + IntrinsicInst *IntrInst = dyn_cast(&I); if (IntrInst) { if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start || IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) continue; } - if (CallInst *CI = dyn_cast(I)) { + if (CallInst *CI = dyn_cast(&I)) { InlineCost += getCallsiteCost(CallSite(CI), DL); continue; } - if (InvokeInst *II = dyn_cast(I)) { + if (InvokeInst *II = dyn_cast(&I)) { InlineCost += getCallsiteCost(CallSite(II), DL); continue; } - if (SwitchInst *SI = dyn_cast(I)) { + if (SwitchInst *SI = dyn_cast(&I)) { InlineCost += (SI->getNumCases() + 1) * InlineConstants::InstrCost; continue; } @@ -1448,16 +1447,6 @@ if (CurrFunc->use_empty()) continue; - bool Recursive = false; - for (User *U : CurrFunc->users()) - if (Instruction *I = dyn_cast(U)) - if (I->getParent()->getParent() == CurrFunc) { - Recursive = true; - break; - } - if (Recursive) - continue; - std::pair Result = unswitchFunction(CurrFunc); if (Result.second) Worklist.push_back(Result.second); Index: lib/Transforms/IPO/PassManagerBuilder.cpp =================================================================== --- lib/Transforms/IPO/PassManagerBuilder.cpp +++ lib/Transforms/IPO/PassManagerBuilder.cpp @@ -46,7 +46,7 @@ using namespace llvm; static cl::opt - RunPartialInlining("enable-partial-inlining", cl::init(false), cl::Hidden, + RunPartialInlining("enable-partial-inlining", cl::init(true), cl::Hidden, cl::ZeroOrMore, cl::desc("Run Partial inlinining pass")); static cl::opt @@ -54,13 +54,35 @@ cl::desc("Run the Loop vectorization passes")); static cl::opt +UseSVEVectorizer("use-sve-vectorizer", cl::init(false), cl::Hidden, + cl::desc("Use the SVE vectorizer instead of community")); + +static cl::opt +RemoveSwitchBlocks("remove-switch-blocks", cl::init(true), cl::Hidden, + cl::desc("Convert switch blocks into a branch sequence " + "prior to vectorization.")); + +static cl::opt +RunSearchLoopVectorization("vectorize-search-loops", cl::init(false), cl::Hidden, + cl::desc("Run search loop vectorizer")); + +static cl::opt +BOSCC("insert-superword-control-flow", cl::init(false), cl::Hidden, + cl::desc("Run the 'Branch On Superword Condition Codes' (BOSCC) pass.")); + +static cl::opt RunSLPVectorization("vectorize-slp", cl::Hidden, cl::desc("Run the SLP vectorization passes")); static cl::opt + GVNAutovecAware("gvn-autovec-aware", cl::init(true), cl::Hidden, + cl::desc("Prevent GVN to PRE certain loads that introduce " + "a dependence that cannot be vectorized.")); + +static cl::opt UseGVNAfterVectorization("use-gvn-after-vectorization", - cl::init(false), cl::Hidden, - cl::desc("Run GVN instead of Early CSE after vectorization passes")); + cl::init(true), cl::Hidden, + cl::desc("Run GVN after vectorization passes")); static cl::opt ExtraVectorizerPasses( "extra-vectorizer-passes", cl::init(false), cl::Hidden, @@ -121,6 +143,10 @@ "enable-loop-versioning-licm", cl::init(false), cl::Hidden, cl::desc("Enable the experimental Loop Versioning LICM pass")); +static cl::opt UseLoopSpeculativeBoundsChecking( + "enable-loop-speculative-bounds-checking", cl::init(true), cl::Hidden, + cl::desc("Enable experimental loop speculative bounds checking pass")); + static cl::opt DisablePreInliner("disable-preinline", cl::init(false), cl::Hidden, cl::desc("Disable pre-instrumentation inliner")); @@ -148,10 +174,23 @@ cl::desc("Enable the simple loop unswitch pass. Also enables independent " "cleanup passes integrated into the loop pass manager pipeline.")); +static cl::opt EnableLoopExprTreeFactoring( + "enable-loop-expr-tree-factoring", cl::init(true), cl::Hidden, + cl::desc("Enable pass that rewrites add/mul expression trees by factoring " + "out common multiplies")); + static cl::opt EnableGVNSink( "enable-gvn-sink", cl::init(false), cl::Hidden, cl::desc("Enable the GVN sinking pass (default = off)")); +// This value determines the point at which we stop removing switch statements +// before the vectorizer pass. Removing switch blocks and replacing them with +// compares and branches allows architectures that support predication to +// vectorize. This value was chosen initially because it was needed to +// vectorise a TSVC loop, however this value can be tweaked over time if higher +// numbers are found to improve performance. +static const int RemoveSwitchCaseThreshold = 4; + PassManagerBuilder::PassManagerBuilder() { OptLevel = 2; SizeLevel = 0; @@ -160,6 +199,7 @@ DisableUnrollLoops = false; SLPVectorize = RunSLPVectorization; LoopVectorize = RunLoopVectorization; + SearchLoopVectorize = RunSearchLoopVectorization; RerollLoops = RunLoopRerolling; NewGVN = RunNewGVN; DisableGVNLoadPRE = false; @@ -379,8 +419,10 @@ if (OptLevel > 1) { MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds + MPM.add(createLoopRewriteGEPsPass()); // Provide more LoadPRE opportunities MPM.add(NewGVN ? createNewGVNPass() - : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies + : createGVNPass(DisableGVNLoadPRE, // Remove redundancies + GVNAutovecAware)); } MPM.add(createMemCpyOptPass()); // Remove memcpy / form memset MPM.add(createSCCPPass()); // Constant prop with SCCP @@ -397,6 +439,12 @@ MPM.add(createJumpThreadingPass()); // Thread jumps MPM.add(createCorrelatedValuePropagationPass()); MPM.add(createDeadStoreEliminationPass()); // Delete dead stores + if (EnableLoopExprTreeFactoring && OptLevel > 2) { + MPM.add(createLICMPass()); + MPM.add(createLoopExprTreeFactoringPass()); + MPM.add(createAggressiveDCEPass()); // Delete dead instructions + MPM.add(createEarlyCSEPass()); + } MPM.add(createLICMPass()); addExtensionsToPM(EP_ScalarOptimizerLate, MPM); @@ -491,8 +539,10 @@ addExtensionsToPM(EP_ModuleOptimizerEarly, MPM); - if (OptLevel > 2) + if (Inliner) { MPM.add(createCallSiteSplittingPass()); + MPM.add(NewGVN ? createNewGVNPass() : createGVNPass(true)); + } MPM.add(createIPSCCPPass()); // IP SCCP MPM.add(createCalledValuePropagationPass()); @@ -627,7 +677,24 @@ // llvm.loop.distribute=true or when -enable-loop-distribute is specified. MPM.add(createLoopDistributePass()); - MPM.add(createLoopVectorizePass(DisableUnrollLoops, LoopVectorize)); + // TODO: Decide if this is the best place to run this pass.... + if (UseLoopSpeculativeBoundsChecking) + MPM.add(createLoopSpeculativeBoundsCheckPass()); + + if (RemoveSwitchBlocks) + MPM.add(createCFGSimplificationPass(1, false, false, true, false, nullptr, + RemoveSwitchCaseThreshold)); + + if (UseSVEVectorizer) + MPM.add(createSVELoopVectorizePass(DisableUnrollLoops, LoopVectorize)); + else + MPM.add(createLoopVectorizePass(DisableUnrollLoops, LoopVectorize)); + + if (SearchLoopVectorize) + MPM.add(createSearchLoopVectorizePass(DisableUnrollLoops, LoopVectorize)); + + if (BOSCC) + MPM.add(createBOSCCPass()); // Eliminate loads by forwarding stores from the previous iteration to loads // of the current iteration. @@ -652,8 +719,12 @@ MPM.add(createLICMPass()); MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget)); MPM.add(createCFGSimplificationPass()); + MPM.add(NewGVN ? createNewGVNPass() + : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies. addInstructionCombiningPass(MPM); - } + } else if (OptLevel > 1 && UseGVNAfterVectorization) + MPM.add(NewGVN ? createNewGVNPass() + : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies. // Cleanup after loop vectorization, etc. Simplification passes like CVP and // GVN, loop transforms, and others have already run, so it's now better to @@ -670,6 +741,8 @@ } addExtensionsToPM(EP_Peephole, MPM); + // MERGE MPM.add(createLateCFGSimplificationPass()); // Switches to lookup tables + MPM.add(createSeparateInvariantsFromGepOffsetPass()); addInstructionCombiningPass(MPM); if (!DisableUnrollLoops) { @@ -839,7 +912,8 @@ PM.add(createLICMPass()); // Hoist loop invariants. PM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds. PM.add(NewGVN ? createNewGVNPass() - : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies. + : createGVNPass(DisableGVNLoadPRE, // Remove redundancies. + GVNAutovecAware)); PM.add(createMemCpyOptPass()); // Remove dead memcpys. // Nuke dead stores. @@ -853,7 +927,16 @@ if (!DisableUnrollLoops) PM.add(createSimpleLoopUnrollPass(OptLevel)); // Unroll small loops - PM.add(createLoopVectorizePass(true, LoopVectorize)); + + if (RemoveSwitchBlocks) + PM.add(createCFGSimplificationPass(1, false, false, true, false, nullptr, + RemoveSwitchCaseThreshold)); + + if (UseSVEVectorizer) + PM.add(createSVELoopVectorizePass(true, LoopVectorize)); + else + PM.add(createLoopVectorizePass(true, LoopVectorize)); + // The vectorizer may have significantly shortened a loop body; unroll again. if (!DisableUnrollLoops) PM.add(createLoopUnrollPass(OptLevel)); @@ -864,6 +947,9 @@ addInstructionCombiningPass(PM); // Initial cleanup PM.add(createCFGSimplificationPass()); // if-convert PM.add(createSCCPPass()); // Propagate exposed constants + if (OptLevel > 1 && UseGVNAfterVectorization) + PM.add(NewGVN ? createNewGVNPass() + : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies. addInstructionCombiningPass(PM); // Clean up again PM.add(createBitTrackingDCEPass()); Index: lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1371,6 +1371,36 @@ I.setOperand(1, B); return &I; } + // splat(A) + (splat(B) + C) ==> splat(A+B) + C + { + bool HasNSW = I.hasNoSignedWrap(); + bool HasNUW = I.hasNoUnsignedWrap(); + Value *A = nullptr, *B = nullptr, *C = nullptr; + + if (match(LHS, m_SplatVector(m_Value(A)))) { + if (match(RHS, m_Add(m_SplatVector(m_Value(B)), m_Value(C))) || + match(RHS, m_Add(m_Value(C), m_SplatVector(m_Value(B))))) { + // It's safe to propagate the wrap flags because stepvector[0] == 0 + auto NewAdd = Builder.CreateAdd(A, B, "", HasNUW, HasNSW); + + auto EC = cast(I.getType())->getElementCount(); + auto SplatNewAdd = Builder.CreateVectorSplat(EC, NewAdd); + return BinaryOperator::CreateAdd(SplatNewAdd, C); + } + } + + if (match(RHS, m_SplatVector(m_Value(A)))) { + if (match(LHS, m_Add(m_SplatVector(m_Value(B)), m_Value(C))) || + match(LHS, m_Add(m_Value(C), m_SplatVector(m_Value(B))))) { + // It's safe to propagate the wrap flags because stepvector[0] == 0 + auto NewAdd = Builder.CreateAdd(A, B, "", HasNUW, HasNSW); + + auto EC = cast(I.getType())->getElementCount(); + auto SplatNewAdd = Builder.CreateVectorSplat(EC, NewAdd); + return BinaryOperator::CreateAdd(SplatNewAdd, C); + } + } + } // TODO(jingyue): Consider willNotOverflowSignedAdd and // willNotOverflowUnsignedAdd to reduce the number of invocations of @@ -1798,6 +1828,26 @@ if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) return replaceInstUsesWith(I, Res); + + // (splat(A) + B) - splat(C) ==> splat(A-C) + B + { + bool HasNSW = I.hasNoSignedWrap(); + bool HasNUW = I.hasNoUnsignedWrap(); + Value *A = nullptr, *B = nullptr, *C = nullptr; + + if (match(Op1, m_SplatVector(m_Value(C)))) { + if (match(Op0, m_Add(m_SplatVector(m_Value(A)), m_Value(B))) || + match(Op0, m_Add(m_Value(B), m_SplatVector(m_Value(A))))) { + // It's safe to propagate the wrap flags because stepvector[0] == 0 + auto NewSub = Builder.CreateSub(A, C, "", HasNUW, HasNSW); + + auto EC = cast(I.getType())->getElementCount(); + auto SplatNewSub = Builder.CreateVectorSplat(EC, NewSub); + return BinaryOperator::CreateAdd(SplatNewSub, B); + } + } + } + // Canonicalize a shifty way to code absolute value to the common pattern. // There are 2 potential commuted variants. // We're relying on the fact that we only do this transform when the shift has Index: lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCalls.cpp +++ lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" @@ -363,7 +364,7 @@ // Get a constant vector of the same type as the first operand. auto ShiftAmt = ConstantInt::get(SVT, Count.zextOrTrunc(BitWidth)); - auto ShiftVec = Builder.CreateVectorSplat(VWidth, ShiftAmt); + auto ShiftVec = Builder.CreateVectorSplat({VWidth, false}, ShiftAmt); if (ShiftLeft) return Builder.CreateShl(Vec, ShiftVec); @@ -1215,7 +1216,10 @@ return false; if (ConstMask->isAllOnesValue() || isa(ConstMask)) return true; - for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E; + auto *VTy = dyn_cast(Mask->getType()); + if (!VTy || VTy->isScalable()) + return false; + for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E; ++I) { if (auto *MaskElt = ConstMask->getAggregateElement(I)) if (MaskElt->isAllOnesValue() || isa(MaskElt)) @@ -1244,7 +1248,7 @@ return nullptr; // If the mask is all zeros, this instruction does nothing. - if (ConstMask->isNullValue()) + if (ConstMask->isNullValue() || isa(ConstMask)) return IC.eraseInstFromFunction(II); // If the mask is all ones, this is a plain vector store of the 1st argument. @@ -1964,11 +1968,21 @@ } break; } - case Intrinsic::masked_load: + case Intrinsic::masked_load: { if (Value *SimplifiedMaskedOp = simplifyMaskedLoad(*II, Builder)) return replaceInstUsesWith(CI, SimplifiedMaskedOp); + BasicBlock::iterator BBI(II); + bool IsLoadCSE = false; + if (Value *AvailableVal = FindAvailablePtrMaskedLoadStore( + II->getOperand(0), II->getOperand(2), II->getOperand(3), + II->getType(), II->isAtomic(), II->getParent(), BBI, + DefMaxInstsToScan, AA, &IsLoadCSE, 0)) + return replaceInstUsesWith(CI, AvailableVal); break; + } case Intrinsic::masked_store: + if (clearRedundantStore(cast(II))) + return nullptr; return simplifyMaskedStore(*II, *this); case Intrinsic::masked_gather: return simplifyMaskedGather(*II, *this); Index: lib/Transforms/InstCombine/InstCombineCasts.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCasts.cpp +++ lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -93,6 +93,10 @@ Type *CastElTy = PTy->getElementType(); if (!AllocElTy->isSized() || !CastElTy->isSized()) return nullptr; + // TODO: Should be able to disqualify this via size queries. + if (isa(CastElTy) && CastElTy->getVectorIsScalable()) + return nullptr; + unsigned AllocElTyAlign = DL.getABITypeAlignment(AllocElTy); unsigned CastElTyAlign = DL.getABITypeAlignment(CastElTy); if (CastElTyAlign < AllocElTyAlign) return nullptr; @@ -624,7 +628,8 @@ InstCombiner::BuilderTy &Builder) { auto *Shuf = dyn_cast(Trunc.getOperand(0)); if (Shuf && Shuf->hasOneUse() && isa(Shuf->getOperand(1)) && - Shuf->getMask()->getSplatValue() && + !Shuf->getType()->getVectorIsScalable() && + cast(Shuf->getMask())->getSplatValue() && Shuf->getType() == Shuf->getOperand(0)->getType()) { // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Undef, SplatMask Constant *NarrowUndef = UndefValue::get(Trunc.getType()); @@ -800,6 +805,34 @@ } } + Constant *C; + + // Transform "trunc (and X, cst)" -> "and (trunc X), trunc_cst" + if (isa(SrcTy) && + match(Src, m_And(m_Value(A), m_Constant(C)))) { + Value *NewTrunc = Builder.CreateTrunc(A, DestTy, A->getName() + ".tr"); + return BinaryOperator::CreateAnd(NewTrunc, + ConstantExpr::getTrunc(C, DestTy)); + } + + // Transform "trunc (add X, cst)" -> "add (trunc X), trunc_cst" + if (isa(SrcTy) && + match(Src, m_Add(m_Value(A), m_Constant(C)))) { + Value *NewTrunc = Builder.CreateTrunc(A, DestTy, A->getName() + ".tr"); + return BinaryOperator::CreateAdd(NewTrunc, + ConstantExpr::getTrunc(C, DestTy)); + } + + // Transform "trunc (splat X)" -> "splat (trunc X)" + if (Src->hasOneUse() && match(Src, m_SplatVector(m_Value(A)))) { + auto DestVTy = dyn_cast(DestTy); + Type *DestEltTy = DestVTy->getElementType(); + Value *NewTrunc = Builder.CreateTrunc(A, DestEltTy, A->getName() + ".tr"); + Value *Splat = Builder.CreateVectorSplat(DestVTy->getElementCount(), + NewTrunc); + return new BitCastInst(Splat, DestVTy); + } + if (Instruction *I = foldVecTruncToExtElt(CI, *this)) return I; @@ -1725,8 +1758,8 @@ if (CI.getOperand(0)->getType()->getScalarSizeInBits() != DL.getPointerSizeInBits(AS)) { Type *Ty = DL.getIntPtrType(CI.getContext(), AS); - if (CI.getType()->isVectorTy()) // Handle vectors of pointers. - Ty = VectorType::get(Ty, CI.getType()->getVectorNumElements()); + if (auto *CITy = dyn_cast(CI.getType())) + Ty = VectorType::get(Ty, CITy->getElementCount()); Value *P = Builder.CreateZExtOrTrunc(CI.getOperand(0), Ty); return new IntToPtrInst(P, CI.getType()); @@ -1735,6 +1768,71 @@ if (Instruction *I = commonCastTransforms(CI)) return I; + // Convert vector pointer arithmetic into a GetElementPtr. + if (CI.getType()->isVectorTy()) { + Value *Ptr, *Offsets; + + if (match(CI.getOperand(0), + m_Add(m_Value(Offsets), + m_PtrToInt(m_SplatVector(m_Value(Ptr)))))) + /* match */; + else if (match(CI.getOperand(0), + m_Add(m_PtrToInt(m_SplatVector(m_Value(Ptr))), + m_Value(Offsets)))) + /* match */; + else if (match(CI.getOperand(0), + m_Add(m_PtrToInt(m_Value(Ptr)), + m_SplatVector(m_Value(Offsets))))) + /* match */; + else if (match(CI.getOperand(0), + m_Add(m_SplatVector(m_Value(Offsets)), + m_PtrToInt(m_Value(Ptr))))) + /* match */; + else + return nullptr; + + if (Ptr->getType()->isVectorTy() != Offsets->getType()->isVectorTy()) { + Type *Ty = CI.getType()->getScalarType()->getPointerElementType(); + + Type *PtrTy = CI.getType(); // vector_of_ptrs + if (!Ptr->getType()->isVectorTy()) + PtrTy = PtrTy->getScalarType(); // ptr + + // Bytes don't require scaling. + if (DL.getTypeAllocSize(Ty) == 1) { + Ptr = Builder.CreateBitCast(Ptr, PtrTy); + return GetElementPtrInst::Create(Ty, Ptr, Offsets); + } else { + Value *Indices; + const APInt *Scale; + const APInt AllocSize(64, DL.getTypeAllocSize(Ty)); + + // ptr + vector_of_indices + if (match(Offsets, m_Mul(m_Value(Indices), + m_SplatVector(m_APInt(Scale))))) { + if (*Scale == AllocSize) { + Ptr = Builder.CreateBitCast(Ptr, PtrTy); + return GetElementPtrInst::Create(Ty, Ptr, Indices); + } + // ptr + vector_of_indices + } else if (match(Offsets, m_Shl(m_Value(Indices), + m_SplatVector(m_APInt(Scale))))) { + if (*Scale == AllocSize.exactLogBase2()) { + Ptr = Builder.CreateBitCast(Ptr, PtrTy); + return GetElementPtrInst::Create(Ty, Ptr, Indices); + } + } + // vector_of_ptrs + index + else if (match(Offsets, m_Shl(m_Value(Indices), m_APInt(Scale)))) { + if (*Scale == AllocSize.exactLogBase2()) { + Ptr = Builder.CreateBitCast(Ptr, PtrTy); + return GetElementPtrInst::Create(Ty, Ptr, Indices); + } + } + } + } + } + return nullptr; } @@ -2008,7 +2106,7 @@ if (!VectorType::isValidElementType(DestType)) return nullptr; - unsigned NumElts = ExtElt->getVectorOperandType()->getNumElements(); + auto NumElts = ExtElt->getVectorOperandType()->getElementCount(); auto *NewVecType = VectorType::get(DestType, NumElts); auto *NewBC = IC.Builder.CreateBitCast(ExtElt->getVectorOperand(), NewVecType, "bc"); Index: lib/Transforms/InstCombine/InstCombineCompares.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCompares.cpp +++ lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -909,7 +909,9 @@ } // If all indices are the same, just compare the base pointers. - if (IndicesTheSame) + if (IndicesTheSame && + (GEPLHS->getOperand(0)->getType()->isVectorTy() == + GEPLHS->getType()->isVectorTy())) return new ICmpInst(Cond, GEPLHS->getOperand(0), GEPRHS->getOperand(0)); // If we're comparing GEPs with two base pointers that only differ in type @@ -5262,12 +5264,21 @@ } break; case Instruction::Call: { - if (!RHSC->isNullValue()) - break; - CallInst *CI = cast(LHSI); Intrinsic::ID IID = getIntrinsicForCallSite(CI, &TLI); - if (IID != Intrinsic::fabs) + + if (IID == Intrinsic::sqrt) { + FastMathFlags FMF = I.getFastMathFlags(); + ConstantFP *RHSF = dyn_cast(RHSC); + + // fcmp sqrt(x),C --> fcmp x,C*C ; When signs and NaNs are preserved. + if (RHSF && !RHSF->isNegative() && FMF.noNaNs()) + return new FCmpInst(I.getPredicate(), CI->getArgOperand(0), + ConstantExpr::getFMul(RHSC, RHSC)); + break; + } + + if ((IID != Intrinsic::fabs) || !RHSC->isNullValue()) break; // Various optimization for fabs compared with zero. Index: lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- lib/Transforms/InstCombine/InstCombineInternal.h +++ lib/Transforms/InstCombine/InstCombineInternal.h @@ -53,6 +53,10 @@ class CallSite; class DataLayout; class DominatorTree; +class TargetLibraryInfo; +class DbgDeclareInst; +class MemIntrinsic; +class MemSetInst; class GEPOperator; class GlobalVariable; class LoopInfo; @@ -314,7 +318,7 @@ LoopInfo *LI) : Worklist(Worklist), Builder(Builder), MinimizeSize(MinimizeSize), ExpensiveCombines(ExpensiveCombines), AA(AA), AC(AC), TLI(TLI), DT(DT), - DL(DL), SQ(DL, &TLI, &DT, &AC), ORE(ORE), LI(LI) {} + DL(DL), SQ(DL, &TLI, &DT, &AC), ORE(ORE), LI(LI), MadeIRChange(false) {} /// Run the combiner over the entire worklist until it is empty. /// @@ -409,6 +413,12 @@ Instruction *visitVAStartInst(VAStartInst &I); Instruction *visitVACopyInst(VACopyInst &I); + /// Try to clear store instruction I if it is redundant, + /// and possibly other redundant stores it may find when + /// scanning back from I. + /// \return true if I has been removed. + bool clearRedundantStore(Instruction *I); + /// Specify what to return for unhandled instructions. Instruction *visitInstruction(Instruction &I) { return nullptr; } @@ -522,7 +532,9 @@ Value *EmitGEPOffset(User *GEP); Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); + Instruction *scalarizeGEP(GetElementPtrInst *GEP, unsigned Index); Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef Mask); + Value *FindScalarElement(Value *V, unsigned EltNo); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); Instruction *narrowBinOp(TruncInst &Trunc); Instruction *narrowMaskedBinOp(BinaryOperator &And); Index: lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -629,6 +629,7 @@ // Do not perform canonicalization if minmax pattern is found (to avoid // infinite loop). if (!Ty->isIntegerTy() && Ty->isSized() && + DL.isTypeStoreSizeKnown(Ty) && DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) && DL.getTypeStoreSizeInBits(Ty) == DL.getTypeSizeInBits(Ty) && !DL.isNonIntegralPointerType(Ty) && @@ -1345,6 +1346,72 @@ return false; } +bool InstCombiner::clearRedundantStore(Instruction *SI) { + Value *StrPtr = nullptr, *StrMask = nullptr, *StrVal = nullptr; + if (!match(SI, m_AnyStore(m_Value(StrVal), m_Value(StrPtr), m_Value(), + m_Value(StrMask)))) + return false; + + bool isAllOnesMask = false; + if (auto *MC = dyn_cast(StrMask)) + isAllOnesMask = MC->isAllOnesValue(); + + BasicBlock::iterator BBI(SI); + for (unsigned ScanInsts = DefMaxInstsToScan; + BBI != SI->getParent()->begin() && ScanInsts; --ScanInsts) { + --BBI; + // Don't count debug info directives, lest they affect codegen, + // and we skip pointer-to-pointer bitcasts, which are NOPs. + if (isa(BBI) || + (isa(BBI) && BBI->getType()->isPointerTy())) { + ScanInsts++; + continue; + } + + if (Instruction *Inst = dyn_cast(BBI)) { + Value *InstVal = nullptr, *InstMask = nullptr, *InstPtr = nullptr, + *InstPass = nullptr; + if (match(Inst, m_AnyStore(m_Value(InstVal), m_Value(InstPtr), m_Value(), + m_Value(InstMask)))) { + // Prev store isn't volatile, and stores to the same location? + if (equivalentAddressValues(InstPtr, StrPtr) && + (InstMask == StrMask || isAllOnesMask) && + (!isa(Inst) || cast(Inst)->isUnordered())) { + ++BBI; + this->eraseInstFromFunction(*Inst); + continue; + } + break; + } + + // If this is a load, we have to stop. However, if the loaded value is + // from the pointer we're loading and is producing the pointer we're + // storing, then *this* store is dead (X = load P; store X -> P). + if (match(Inst, m_AnyLoad(m_Value(InstPtr), m_Value(), m_Value(InstMask), + m_Value(InstPass)))) { + if (Inst == StrVal && equivalentAddressValues(StrPtr, InstPtr) && + isa(InstPass) && StrMask == InstMask) { + if (StoreInst *StoreInstruction = dyn_cast(SI)) + assert(StoreInstruction->isUnordered() && + "can't eliminate ordering operation"); + ++NumDeadStore; + this->eraseInstFromFunction(*SI); + return true; + } + + // Otherwise, this is a load from some other location. Stores before it + // may not be dead. + break; + } + } + + // Don't skip over loads, throws or things that can modify memory. + if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory() || BBI->mayThrow()) + break; + } + return false; +} + /// Converts store (bitcast (load (bitcast (select ...)))) to /// store (load (select ...)), where select is minmax: /// select ((cmp load V1, load V2), V1, V2). @@ -1435,51 +1502,8 @@ } } - // Do really simple DSE, to catch cases where there are several consecutive - // stores to the same location, separated by a few arithmetic operations. This - // situation often occurs with bitfield accesses. - BasicBlock::iterator BBI(SI); - for (unsigned ScanInsts = 6; BBI != SI.getParent()->begin() && ScanInsts; - --ScanInsts) { - --BBI; - // Don't count debug info directives, lest they affect codegen, - // and we skip pointer-to-pointer bitcasts, which are NOPs. - if (isa(BBI) || - (isa(BBI) && BBI->getType()->isPointerTy())) { - ScanInsts++; - continue; - } - - if (StoreInst *PrevSI = dyn_cast(BBI)) { - // Prev store isn't volatile, and stores to the same location? - if (PrevSI->isUnordered() && equivalentAddressValues(PrevSI->getOperand(1), - SI.getOperand(1))) { - ++NumDeadStore; - ++BBI; - eraseInstFromFunction(*PrevSI); - continue; - } - break; - } - - // If this is a load, we have to stop. However, if the loaded value is from - // the pointer we're loading and is producing the pointer we're storing, - // then *this* store is dead (X = load P; store X -> P). - if (LoadInst *LI = dyn_cast(BBI)) { - if (LI == Val && equivalentAddressValues(LI->getOperand(0), Ptr)) { - assert(SI.isUnordered() && "can't eliminate ordering operation"); - return eraseInstFromFunction(SI); - } - - // Otherwise, this is a load from some other location. Stores before it - // may not be dead. - break; - } - - // Don't skip over loads, throws or things that can modify memory. - if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory() || BBI->mayThrow()) - break; - } + if (clearRedundantStore(dyn_cast(&SI))) + return nullptr; // store X, null -> turns into 'unreachable' in SimplifyCFG // store X, GEP(null, Y) -> turns into 'unreachable' in SimplifyCFG @@ -1499,7 +1523,7 @@ // If this store is the last instruction in the basic block (possibly // excepting debug info instructions), and if the block ends with an // unconditional branch, try to move it to the successor block. - BBI = SI.getIterator(); + BasicBlock::iterator BBI = SI.getIterator(); do { ++BBI; } while (isa(BBI) || Index: lib/Transforms/InstCombine/InstCombineMulDivRem.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -108,6 +108,18 @@ if (!Ty->isVectorTy()) return nullptr; + VectorType *VTy = cast(Ty); + if (VTy->isScalable()) { + Constant *CS = C->getSplatValue(); + if (CS && match(CS, m_APInt(IVal)) && IVal->isPowerOf2()) + return ConstantVector::getSplat( + VTy->getElementCount(), + ConstantInt::get(VTy->getScalarType(), IVal->logBase2())); + return nullptr; + } + + assert(!VTy->isScalable() && "Expected only fixed-width vectors"); + SmallVector Elts; for (unsigned I = 0, E = Ty->getVectorNumElements(); I != E; ++I) { Constant *Elt = C->getAggregateElement(I); @@ -443,6 +455,10 @@ if (match(Op0, m_OneUse(m_FNeg(m_Value(X))))) return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op1, &I), &I); + if (auto *VTy = dyn_cast(I.getType())) + if (VTy->isScalable()) + return nullptr; + // Sink negation: Y * -X --> -(X * Y) if (match(Op1, m_OneUse(m_FNeg(m_Value(X))))) return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op0, &I), &I); @@ -1190,6 +1206,10 @@ if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; + if (auto *VTy = dyn_cast(I.getType())) + if (VTy->isScalable()) + return nullptr; + if (isa(Op1)) if (SelectInst *SI = dyn_cast(Op0)) if (Instruction *R = FoldOpIntoSelect(I, SI)) Index: lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineSelect.cpp +++ lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -343,6 +343,14 @@ return nullptr; } + // If the select condition is a vector, the operands of the original select's + // operands also must be vectors. This may not be the case for getelementptr + // for example. + if (SI.getCondition()->getType()->isVectorTy() && + (!OtherOpT->getType()->isVectorTy() || + !OtherOpF->getType()->isVectorTy())) + return nullptr; + // If we reach here, they do have operations in common. Value *NewSI = Builder.CreateSelect(SI.getCondition(), OtherOpT, OtherOpF, SI.getName() + ".v", &SI); @@ -1539,14 +1547,15 @@ CmpInst::Predicate Pred; if (match(CondVal, m_OneUse(m_ICmp(Pred, m_Value(), m_Value()))) && !isCanonicalPredicate(Pred)) { - // Swap true/false values and condition. - CmpInst *Cond = cast(CondVal); - Cond->setPredicate(CmpInst::getInversePredicate(Pred)); - SI.setOperand(1, FalseVal); - SI.setOperand(2, TrueVal); - SI.swapProfMetadata(); - Worklist.Add(Cond); - return &SI; + if (auto *Cond = dyn_cast(CondVal)) { + // Swap true/false values and condition. + Cond->setPredicate(CmpInst::getInversePredicate(Pred)); + SI.setOperand(1, FalseVal); + SI.setOperand(2, TrueVal); + SI.swapProfMetadata(); + Worklist.Add(Cond); + return &SI; + } } if (SelType->isIntOrIntVectorTy(1) && @@ -1941,6 +1950,15 @@ return &SI; } + // TODO: m_AllOnes needs to support scalable vectors + Value *InvCondVal; + if (match(CondVal, m_Xor(m_Value(InvCondVal), m_SplatVector(m_AllOnes())))) { + SI.setOperand(0, InvCondVal); + SI.setOperand(1, FalseVal); + SI.setOperand(2, TrueVal); + return &SI; + } + if (VectorType *VecTy = dyn_cast(SelType)) { unsigned VWidth = VecTy->getNumElements(); APInt UndefElts(VWidth, 0); Index: lib/Transforms/InstCombine/InstCombineShifts.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineShifts.cpp +++ lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -446,7 +446,7 @@ APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); if (VectorType *VT = dyn_cast(X->getType())) - Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); + Mask = ConstantVector::getSplat(VT->getElementCount(), Mask); return BinaryOperator::CreateAnd(X, Mask); } @@ -481,7 +481,7 @@ APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); if (VectorType *VT = dyn_cast(X->getType())) - Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); + Mask = ConstantVector::getSplat(VT->getElementCount(), Mask); return BinaryOperator::CreateAnd(X, Mask); } Index: lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -976,13 +976,12 @@ getIntrinsicInfoTableEntries(IID, Table); ArrayRef TableRef = Table; + // Validate function argument and return types, extracting overloaded types + // along the way. FunctionType *FTy = II->getCalledFunction()->getFunctionType(); SmallVector OverloadTys; - Intrinsic::matchIntrinsicType(FTy->getReturnType(), TableRef, OverloadTys); - for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) - Intrinsic::matchIntrinsicType(FTy->getParamType(i), TableRef, OverloadTys); + Intrinsic::matchIntrinsicSignature(FTy, TableRef, OverloadTys); - // Get the new return type overload of the intrinsic. Module *M = II->getParent()->getParent()->getParent(); Type *EltTy = II->getType()->getVectorElementType(); Type *NewTy = (NewNumElts == 1) ? EltTy : VectorType::get(EltTy, NewNumElts); @@ -1035,6 +1034,10 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, unsigned Depth) { + // Cannot track elements when you don't know how many there are. + if (cast(V->getType())->isScalable()) + return nullptr; + unsigned VWidth = V->getType()->getVectorNumElements(); APInt EltMask(APInt::getAllOnesValue(VWidth)); assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); @@ -1154,17 +1157,18 @@ unsigned LHSVWidth = Shuffle->getOperand(0)->getType()->getVectorNumElements(); APInt LeftDemanded(LHSVWidth, 0), RightDemanded(LHSVWidth, 0); + SmallVector Mask; + if (!Shuffle->getShuffleMask(Mask)) + return nullptr; for (unsigned i = 0; i < VWidth; i++) { - if (DemandedElts[i]) { - unsigned MaskVal = Shuffle->getMaskValue(i); - if (MaskVal != -1u) { - assert(MaskVal < LHSVWidth * 2 && - "shufflevector mask index out of range!"); - if (MaskVal < LHSVWidth) - LeftDemanded.setBit(MaskVal); - else - RightDemanded.setBit(MaskVal - LHSVWidth); - } + if (DemandedElts[i] && Mask[i] != -1) { + unsigned MaskVal = Mask[i]; + assert(MaskVal < LHSVWidth * 2 && + "shufflevector mask index out of range!"); + if (MaskVal < LHSVWidth) + LeftDemanded.setBit(MaskVal); + else + RightDemanded.setBit(MaskVal - LHSVWidth); } } @@ -1184,7 +1188,7 @@ bool LHSUniform = true; bool RHSUniform = true; for (unsigned i = 0; i < VWidth; i++) { - unsigned MaskVal = Shuffle->getMaskValue(i); + unsigned MaskVal = Mask[i]; if (MaskVal == -1u) { UndefElts.setBit(i); } else if (!DemandedElts[i]) { @@ -1252,7 +1256,7 @@ Elts.push_back(UndefValue::get(Type::getInt32Ty(I->getContext()))); else Elts.push_back(ConstantInt::get(Type::getInt32Ty(I->getContext()), - Shuffle->getMaskValue(i))); + Mask[i])); } I->setOperand(2, ConstantVector::get(Elts)); MadeChange = true; Index: lib/Transforms/InstCombine/InstCombineVectorOps.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -82,9 +82,72 @@ cheapToScalarize(CI->getOperand(1), isConstant))) return true; + if (match(I, m_SplatVector(m_Value()))) + return true; + return false; } +/// FindScalarElement - Given a vector and an element number, see if the scalar +/// value is already around as a register, for example if it were inserted then +/// extracted from the vector. +Value *InstCombiner::FindScalarElement(Value *V, unsigned EltNo) { + assert(V->getType()->isVectorTy() && "Not looking at a vector?"); + VectorType *VTy = cast(V->getType()); + unsigned Width = VTy->getNumElements(); + if (EltNo >= Width) // Out of range access. + return UndefValue::get(VTy->getElementType()); + + if (Constant *C = dyn_cast(V)) + return C->getAggregateElement(EltNo); + + if (InsertElementInst *III = dyn_cast(V)) { + // If this is an insert to a variable element, we don't know what it is. + if (!isa(III->getOperand(2))) + return nullptr; + unsigned IIElt = cast(III->getOperand(2))->getZExtValue(); + + // If this is an insert to the element we are looking for, return the + // inserted value. + if (EltNo == IIElt) + return III->getOperand(1); + + // Otherwise, the insertelement doesn't modify the value, recurse on its + // vector input. + return FindScalarElement(III->getOperand(0), EltNo); + } + + if (ShuffleVectorInst *SVI = dyn_cast(V)) { + unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); + int InEl; + if (SVI->getMaskValue(EltNo, InEl)) { + if (InEl < 0) + return UndefValue::get(VTy->getElementType()); + if (InEl < (int)LHSWidth) + return FindScalarElement(SVI->getOperand(0), InEl); + return FindScalarElement(SVI->getOperand(1), InEl - LHSWidth); + } + } + + // Extract the initial value from a numerical series. + if (EltNo == 0) { + Value *Start; + if (match(V, m_SeriesVector(m_Value(Start), m_Value()))) + return Start; + } + + // Extract a value from a vector add operation with a constant zero. + Value *Val = nullptr; Constant *Con = nullptr; + if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) { + auto ConElt = Con->getAggregateElement(EltNo); + if (ConElt && ConElt->isNullValue()) + return FindScalarElement(Val, EltNo); + } + + // Otherwise, we don't know. + return nullptr; +} + // If we have a PHI node with a vector type that is only used to feed // itself and be an operand of extractelement at a constant location, // try to replace the PHI of the vector type with a PHI of a scalar type. @@ -166,6 +229,29 @@ return &EI; } +Instruction *InstCombiner::scalarizeGEP(GetElementPtrInst *GEP, + unsigned Index) { + if (!GEP->getType()->isVectorTy()) + return nullptr; + + SmallVector Elts; + for (unsigned i = 0; i < GEP->getNumOperands(); ++i) { + Value *Op = GEP->getOperand(i); + if (Op->getType()->isVectorTy()) { + if (Value *Elt = FindScalarElement(Op, Index)) + Elts.push_back(Elt); + } else { + Elts.push_back(GEP->getOperand(i)); + } + } + // If any of the calls to FindScalarElement failed, this test will fail + if (Elts.size() != GEP->getNumOperands()) + return nullptr; + + return GetElementPtrInst::Create(nullptr, Elts[0], + makeArrayRef(Elts).slice(1)); +} + Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { if (Value *V = SimplifyExtractElementInst(EI.getVectorOperand(), EI.getIndexOperand(), @@ -181,13 +267,19 @@ // If extracting a specified index from the vector, see if we can recursively // find a previously computed scalar that was inserted into the vector. if (ConstantInt *IdxC = dyn_cast(EI.getOperand(1))) { - unsigned VectorWidth = EI.getVectorOperandType()->getNumElements(); + unsigned IndexVal = IdxC->getZExtValue(); + auto VecTy = EI.getVectorOperandType(); + unsigned VectorWidth = VecTy->getNumElements(); + + // Do not fold scalable vectors if the index is bigger than the vector + // width. + if (VecTy->isScalable() && IndexVal >= VectorWidth) + return nullptr; // InstSimplify should handle cases where the index is invalid. if (!IdxC->getValue().ule(VectorWidth)) return nullptr; - unsigned IndexVal = IdxC->getZExtValue(); // This instruction only demands the single element from the input vector. // If the input vector has a single use, simplify it based on this use @@ -208,9 +300,15 @@ // it. In this case, we will end up needing to bitcast the scalars. if (BitCastInst *BCI = dyn_cast(EI.getOperand(0))) { if (VectorType *VT = dyn_cast(BCI->getOperand(0)->getType())) - if (VT->getNumElements() == VectorWidth) + if (VT->getNumElements() == VectorWidth) { if (Value *Elt = findScalarElement(BCI->getOperand(0), IndexVal)) return new BitCastInst(Elt, EI.getType()); + else if (auto *GEP = dyn_cast(BCI->getOperand(0))) + if (Instruction *scalarGEP = scalarizeGEP(GEP, IndexVal)) { + Builder.Insert(scalarGEP); + return new BitCastInst(scalarGEP, EI.getType()); + } + } } // If there's a vector PHI feeding a scalar use through this extractelement @@ -220,8 +318,17 @@ if (scalarPHI) return scalarPHI; } - } + if (match(EI.getOperand(0), m_SeriesVector(m_Value(), m_Value()))) + if (auto Elt = FindScalarElement(EI.getOperand(0), IndexVal)) + return new BitCastInst(Elt, EI.getType()); + + // Replace an extract of a vector GEP with a scalar GEP if possible + if (auto *GEP = dyn_cast(EI.getOperand(0))) + if (Instruction *ReplacementGEP = scalarizeGEP(GEP, IndexVal)) + return ReplacementGEP; + + } if (Instruction *I = dyn_cast(EI.getOperand(0))) { // Push extractelement into predecessor operation if legal and // profitable to do so. @@ -252,23 +359,25 @@ // If this is extracting an element from a shufflevector, figure out where // it came from and extract from the appropriate input element instead. if (ConstantInt *Elt = dyn_cast(EI.getOperand(1))) { - int SrcIdx = SVI->getMaskValue(Elt->getZExtValue()); - Value *Src; - unsigned LHSWidth = - SVI->getOperand(0)->getType()->getVectorNumElements(); + int SrcIdx; + if (SVI->getMaskValue(Elt->getZExtValue(), SrcIdx)) { + Value *Src; + unsigned LHSWidth = + SVI->getOperand(0)->getType()->getVectorNumElements(); - if (SrcIdx < 0) - return replaceInstUsesWith(EI, UndefValue::get(EI.getType())); - if (SrcIdx < (int)LHSWidth) - Src = SVI->getOperand(0); - else { - SrcIdx -= LHSWidth; - Src = SVI->getOperand(1); + if (SrcIdx < 0) + return replaceInstUsesWith(EI, UndefValue::get(EI.getType())); + if (SrcIdx < (int)LHSWidth) + Src = SVI->getOperand(0); + else { + SrcIdx -= LHSWidth; + Src = SVI->getOperand(1); + } + Type *Int32Ty = Type::getInt32Ty(EI.getContext()); + return ExtractElementInst::Create(Src, + ConstantInt::get(Int32Ty, + SrcIdx, false)); } - Type *Int32Ty = Type::getInt32Ty(EI.getContext()); - return ExtractElementInst::Create(Src, - ConstantInt::get(Int32Ty, - SrcIdx, false)); } } else if (CastInst *CI = dyn_cast(I)) { // Canonicalize extractelement(cast) -> cast(extractelement). @@ -453,6 +562,7 @@ Value *PermittedRHS, InstCombiner &IC) { assert(V->getType()->isVectorTy() && "Invalid shuffle!"); + assert(!cast(V->getType())->isScalable() && "Invalid vector!"); unsigned NumElts = V->getType()->getVectorNumElements(); if (isa(V)) { @@ -578,8 +688,9 @@ // Each mask element must be undefined or choose a vector element from one of // the source operands without crossing vector lanes. for (int i = 0; i != MaskSize; ++i) { - int Elt = Shuf.getMaskValue(i); - if (Elt != -1 && Elt != i && Elt != i + VecSize) + int Elt; + if (!Shuf.getMaskValue(i, Elt) || + (Elt != -1 && Elt != i && Elt != i + VecSize)) return false; } @@ -599,6 +710,10 @@ VectorType *VT = cast(InsElt.getType()); int NumElements = VT->getNumElements(); + // Skip scalable vectors, since we don't know the number of elements + if (VT->isScalable()) + return nullptr; + // Do not try to do this for a one-element vector, since that's a nop, // and will cause an inf-loop. if (NumElements == 1) @@ -713,7 +828,7 @@ // mask vector with the insertelt index plus the length of the vector // (because the constant vector operand of a shuffle is always the 2nd // operand). - Constant *Mask = Shuf->getMask(); + Constant *Mask = cast(Shuf->getMask()); unsigned NumElts = Mask->getType()->getVectorNumElements(); SmallVector NewShufElts(NumElts); SmallVector NewMaskElts(NumElts); @@ -802,10 +917,13 @@ cast(EI->getOperand(1))->getZExtValue(); unsigned InsertedIdx = cast(IdxOp)->getZExtValue(); - if (ExtractedIdx >= NumExtractVectorElts) // Out of range extract. + // Out of range extract. + if (ExtractedIdx >= NumExtractVectorElts && + !EI->getVectorOperand()->getType()->getVectorIsScalable()) return replaceInstUsesWith(IE, VecOp); - if (InsertedIdx >= NumInsertVectorElts) // Out of range insert. + if (InsertedIdx >= NumInsertVectorElts && + !VecOp->getType()->getVectorIsScalable()) // Out of range insert. return replaceInstUsesWith(IE, UndefValue::get(IE.getType())); // If we are extracting a value from a vector, then inserting it right @@ -813,6 +931,15 @@ if (EI->getOperand(0) == VecOp && ExtractedIdx == InsertedIdx) return replaceInstUsesWith(IE, VecOp); + // Turning back into scalable vectors is not supported + // because algorithm used in collectShuffleElements assumes fixed + // width vectors. It probably could be supported in some special cases + // for scalable vectors, e.g. to create shuffles with 2nd source + // being undef. + if (EI->getVectorOperand()->getType()->getVectorIsScalable() || + VecOp->getType()->getVectorIsScalable()) + return nullptr; + // If this insertelement isn't used by some other insertelement, turn it // (and any insertelements it points to), into one big shuffle. if (!IE.hasOneUse() || !isa(IE.user_back())) { @@ -1210,7 +1337,10 @@ // Example: shuf (mul X, {-1,-2,-3,-4}), X, {0,5,6,3} --> mul X, {-1,1,1,-4} // Example: shuf X, (add X, {-1,-2,-3,-4}), {0,1,6,7} --> add X, {0,0,-3,-4} // The existing binop constant vector remains in the same operand position. - Constant *Mask = Shuf.getMask(); + Constant *Mask = dyn_cast(Shuf.getMask()); + if (!Mask) + return nullptr; + Constant *NewC = Op0IsBinop ? ConstantExpr::getShuffleVector(C, IdC, Mask) : ConstantExpr::getShuffleVector(IdC, C, Mask); @@ -1289,7 +1419,10 @@ BinaryOperator::BinaryOps BOpc = Opc0; // Select the constant elements needed for the single binop. - Constant *Mask = Shuf.getMask(); + Constant *Mask = dyn_cast(Shuf.getMask()); + if (!Mask) + return nullptr; + Constant *NewC = ConstantExpr::getShuffleVector(C0, C1, Mask); // We are moving a binop after a shuffle. When a shuffle has an undefined @@ -1353,9 +1486,25 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *LHS = SVI.getOperand(0); Value *RHS = SVI.getOperand(1); - SmallVector Mask = SVI.getShuffleMask(); Type *Int32Ty = Type::getInt32Ty(SVI.getContext()); + if (SVI.getType()->getVectorIsScalable()) { + Value *Val = nullptr; + Value *Mask = SVI.getOperand(2); + + // Back to back shuffles of a single vector using the same decrementing mask + // are redundant. + if (match(Mask, m_SeriesVector(m_Value(), m_ConstantInt<-1>())) && + // Combine back to back element swaps. + match(LHS, m_ShuffleVector(m_Value(Val), m_Undef(), m_Specific(Mask)))&& + isa(RHS)) + return replaceInstUsesWith(SVI, Val); + } + + SmallVector Mask; + if (!SVI.getShuffleMask(Mask)) + return nullptr; + if (auto *V = SimplifyShuffleVectorInst( LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI))) return replaceInstUsesWith(SVI, V); @@ -1411,7 +1560,9 @@ MadeChange = true; } - if (VWidth == LHSWidth) { + if ((VWidth == LHSWidth) && + (SVI.getType()->isScalable() == + cast(LHS->getType())->isScalable())) { // Analyze the shuffle, are the LHS or RHS and identity shuffles? bool isLHSID, isRHSID; recognizeIdentityMask(Mask, isLHSID, isRHSID); @@ -1610,11 +1761,12 @@ return MadeChange ? &SVI : nullptr; SmallVector LHSMask; + if (newLHS != LHS && !LHSShuffle->getShuffleMask(LHSMask)) + return MadeChange ? &SVI : nullptr; + SmallVector RHSMask; - if (newLHS != LHS) - LHSMask = LHSShuffle->getShuffleMask(); - if (RHSShuffle && newRHS != RHS) - RHSMask = RHSShuffle->getShuffleMask(); + if (RHSShuffle && newRHS != RHS && !RHSShuffle->getShuffleMask(RHSMask)) + return MadeChange ? &SVI : nullptr; unsigned newLHSWidth = (newLHS != LHS) ? LHSOp0Width : LHSWidth; SmallVector newMask; Index: lib/Transforms/InstCombine/InstructionCombining.cpp =================================================================== --- lib/Transforms/InstCombine/InstructionCombining.cpp +++ lib/Transforms/InstCombine/InstructionCombining.cpp @@ -186,6 +186,46 @@ return shouldChangeType(FromWidth, ToWidth); } +static bool vscaleNoWrap(Instruction::BinaryOps Opcode, Value *B, Value *C) { + assert(Opcode == Instruction::Add || Opcode == Instruction::Sub); + // sometimes we may hit cases when we want to combine + // following instructions + // i32 %idx + // %idx.next = add nuw i32 %idx, mul (i32 vscale, i32 c1) + // %add = add nuw i32 %idx.next, mul (i32 vscale, i32 c2) + // + // into: + // %add = add nuw i32 %idx, mul (i32 vscale, add (i32 c1, i32 c2)) + // + // in case of sub: + // %idx.next = sub nuw i32 %idx, mul (i32 vscale, i32 c1) + // %sub = sub nuw i32 %idx.next, mul (i32 vscale, i32 c2) + // + // 0 < mul (i32 vscale, i32 c1) <= idx <= MAX_UINT + // 0 < mul (i32 vscale, i32 c2) <= idx.next <= MAX_UINT + // so + // mul (i32 vscale, i32 c1) <= mul (i32 vscale, i32 c2) + // + // so it is safe to set NUW in: + // %sub = sub nuw i32 %idx, mul (i32 vscale, sub (i32 c1, i32 c2)) + // + // the resulting %add/sub should maintain nsw flag + // as long as c1 and c2 are both non negative constants + + if (match(B, m_Mul(m_VScale(), m_NonNegative())) && + match(C, m_Mul(m_VScale(), m_NonNegative()))) + return true; + + // similar situation to one above + // i32 %idx + // %idx.next = add nuw i32 %idx, i32 vscale + // %add = add nuw i32 %idx.next, i32 c1 + if ((match(B, m_VScale()) && match(C, m_NonNegative())) || + (match(C, m_VScale()) && match(B, m_NonNegative()))) + return true; + return false; +} + // Return true, if No Signed Wrap should be maintained for I. // The No Signed Wrap flag can be kept if the operation "B (I.getOpcode) C", // where both B and C should be ConstantInts, results in a constant that does @@ -202,6 +242,9 @@ return false; const APInt *BVal, *CVal; + if (isa(C) && match(C, m_Zero())) + return true; + if (!match(B, m_APInt(BVal)) || !match(C, m_APInt(CVal))) return false; @@ -214,6 +257,51 @@ return !Overflow; } +// Return true, if No Unsigned Wrap should be maintained for I. +// The No Unsigned Wrap flag can be kept if the operation "B (I.getOpcode) C", +// where both B and C should be ConstantInts, results in a constant that does +// not overflow. This function only handles the Add and Sub opcodes. For +// all other opcodes, the function conservatively returns false. +static bool MaintainNoUnsignedWrap(BinaryOperator &I, Value *B, Value *C) { + OverflowingBinaryOperator *OBO = dyn_cast(&I); + if (!OBO || !OBO->hasNoUnsignedWrap()) { + return false; + } + + // We reason about Add and Sub Only. + Instruction::BinaryOps Opcode = I.getOpcode(); + if (Opcode != Instruction::Add && + Opcode != Instruction::Sub) { + return false; + } + + ConstantInt *CB = dyn_cast(B); + ConstantInt *CC = dyn_cast(C); + + if (CC && CC->isNullValue()) { + return true; + } + + if (vscaleNoWrap(Opcode, B, C)) + return true; + + if (!CB || !CC) { + return false; + } + + const APInt &BVal = CB->getValue(); + const APInt &CVal = CC->getValue(); + bool Overflow = false; + + if (Opcode == Instruction::Add) { + (void)BVal.uadd_ov(CVal, Overflow); + } else { + (void)BVal.usub_ov(CVal, Overflow); + } + + return !Overflow; +} + /// Conservatively clears subclassOptionalData after a reassociation or /// commutation. We preserve fast-math flags when applicable as they can be /// preserved. @@ -318,14 +406,20 @@ // It simplifies to V. Form "A op V". I.setOperand(0, A); I.setOperand(1, V); + + bool NSW = MaintainNoSignedWrap(I, B, C) && + (!Op0 || (isa(Op0) && Op0->hasNoSignedWrap())); + bool NUW = MaintainNoUnsignedWrap(I, B, C) && + (!Op0 || (isa(Op0) && Op0->hasNoUnsignedWrap())); + // Conservatively clear the optional flags, since they may not be // preserved by the reassociation. - if (MaintainNoSignedWrap(I, B, C) && - (!Op0 || (isa(Op0) && Op0->hasNoSignedWrap()))) { + if (NSW || NUW) { // Note: this is only valid because SimplifyBinOp doesn't look at // the operands to Op0. I.clearSubclassOptionalData(); - I.setHasNoSignedWrap(true); + I.setHasNoSignedWrap(NSW); + I.setHasNoUnsignedWrap(NUW); } else { ClearSubclassDataAfterReassociation(I); } @@ -1384,23 +1478,41 @@ return createBinOpShuffle(V1, V2, Mask); } + // If both arguments of a binary operation are inserts, which use the same + // source vector and element index, it is worth moving the insert after + // the binary operation: + // Op(insert(v, e1, m), insert(v, e2, m)) -> insert(v, Op(e1, e2), m) + Value *LInsElem, *LInsIdx, *RInsElem, *RInsIdx; + if (match(LHS, m_InsertElement(m_Undef(), m_Value(LInsElem), + m_Value(LInsIdx))) && + match(RHS, m_InsertElement(m_Undef(), m_Value(RInsElem), + m_Value(RInsIdx)))) { + if (LInsIdx == RInsIdx) { + Value *NewBO = Builder.CreateBinOp(Inst.getOpcode(), LInsElem, RInsElem); + if (auto *BO = dyn_cast(NewBO)) + BO->copyIRFlags(&Inst); + Value *Undef = UndefValue::get(LHS->getType()); + return InsertElementInst::Create(Undef, NewBO, LInsIdx); + } + } + // If one argument is a shuffle within one vector and the other is a constant, // try moving the shuffle after the binary operation. This canonicalization // intends to move shuffles closer to other shuffles and binops closer to // other binops, so they can be folded. It may also enable demanded elements // transforms. Constant *C; + SmallVector ShMask; if (match(&Inst, m_c_BinOp( m_OneUse(m_ShuffleVector(m_Value(V1), m_Undef(), m_Constant(Mask))), m_Constant(C))) && - V1->getType() == Inst.getType()) { + V1->getType() == Inst.getType() && + ShuffleVectorInst::getShuffleMask(Mask, ShMask)) { // Find constant NewC that has property: // shuffle(NewC, ShMask) = C // If such constant does not exist (example: ShMask=<0,0> and C=<1,2>) // reorder is not possible. A 1-to-1 mapping is not required. Example: // ShMask = <1,1,2,2> and C = <5,5,6,6> --> NewC = - SmallVector ShMask; - ShuffleVectorInst::getShuffleMask(Mask, ShMask); SmallVector NewVecC(VWidth, UndefValue::get(C->getType()->getScalarType())); bool MayChange = true; @@ -1437,6 +1549,12 @@ } Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { + // TODO: This is excessive but much of the following does not work for width + // agnostic vectors. + if (GEP.getType()->isVectorTy() && + cast(GEP.getType())->isScalable()) + return nullptr; + SmallVector Ops(GEP.op_begin(), GEP.op_end()); Type *GEPType = GEP.getType(); Type *GEPEltType = GEP.getSourceElementType(); @@ -1714,6 +1832,48 @@ GEP.getName()); } + if (LI && (GEP.getNumIndices() == 1) && !GEP.getType()->isVectorTy()) { + auto *BB = GEP.getParent(); + auto *L = LI->getLoopFor(BB); + auto *Idx = dyn_cast(GEP.getOperand(1)); + + // Try to reassociate loop invariant index calculations to enable LICM. + if (L && Idx && (Idx->getOpcode() == Instruction::Add)) { + Value *Ptr = GEP.getOperand(0); + Value *InvIdx = Idx->getOperand(0); + Value *NonInvIdx = Idx->getOperand(1); + + if (!L->isLoopInvariant(InvIdx)) + std::swap(InvIdx, NonInvIdx); + + if (L->isLoopInvariant(InvIdx) && !L->isLoopInvariant(NonInvIdx) && + L->isLoopInvariant(Ptr)) { + // Ensure Idx can be eliminated. + auto IsDead = [BB,L] (User *U) { + auto *G = dyn_cast(U); + return G && (G->getNumIndices() == 1) && (G->getParent() == BB) && + L->isLoopInvariant(G->getOperand(0)); + }; + + if (Idx->hasOneUse() || + std::all_of(Idx->user_begin(), Idx->user_end(), IsDead)) { + // rewrite: + // %idx = add i64 %invariant, %indvars.iv + // %gep = getelementptr i32, i32* %ptr, i64 %idx + // as: + // %newptr = getelementptr i32, i32* %ptr, i64 %invariant + // %newgep = getelementptr i32, i32* %newptr, i64 %indvars.iv + auto *NewPtr = GetElementPtrInst::Create(GEP.getResultElementType(), + Ptr, InvIdx, "", &GEP); + auto *NewGEP = GetElementPtrInst::Create(GEP.getResultElementType(), + NewPtr, NonInvIdx); + NewGEP->setIsInBounds(GEP.isInBounds()); + return NewGEP; + } + } + } + } + if (GEP.getNumIndices() == 1) { unsigned AS = GEP.getPointerAddressSpace(); if (GEP.getOperand(1)->getType()->getScalarSizeInBits() == @@ -2336,12 +2496,13 @@ if (match(&BI, m_Br(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), TrueDest, FalseDest)) && !isCanonicalPredicate(Pred)) { - // Swap destinations and condition. - CmpInst *Cond = cast(BI.getCondition()); - Cond->setPredicate(CmpInst::getInversePredicate(Pred)); - BI.swapSuccessors(); - Worklist.Add(Cond); - return &BI; + if (auto *Cond = dyn_cast(BI.getCondition())) { + // Swap destinations and condition. + Cond->setPredicate(CmpInst::getInversePredicate(Pred)); + BI.swapSuccessors(); + Worklist.Add(Cond); + return &BI; + } } return nullptr; Index: lib/Transforms/Scalar/CMakeLists.txt =================================================================== --- lib/Transforms/Scalar/CMakeLists.txt +++ lib/Transforms/Scalar/CMakeLists.txt @@ -24,6 +24,7 @@ JumpThreading.cpp LICM.cpp LoopAccessAnalysisPrinter.cpp + LoopExprTreeFactoring.cpp LoopSink.cpp LoopDeletion.cpp LoopDataPrefetch.cpp @@ -35,8 +36,10 @@ LoopPassManager.cpp LoopPredication.cpp LoopRerollPass.cpp + LoopRewriteGEPs.cpp LoopRotation.cpp LoopSimplifyCFG.cpp + LoopSpeculativeBoundsCheck.cpp LoopStrengthReduce.cpp LoopUnrollPass.cpp LoopUnrollAndJamPass.cpp @@ -60,6 +63,7 @@ Scalar.cpp Scalarizer.cpp SeparateConstOffsetFromGEP.cpp + SeparateInvariantsFromGepOffset.cpp SimpleLoopUnswitch.cpp SimplifyCFGPass.cpp Sink.cpp Index: lib/Transforms/Scalar/CallSiteSplitting.cpp =================================================================== --- lib/Transforms/Scalar/CallSiteSplitting.cpp +++ lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -136,7 +136,8 @@ CmpInst::Predicate Pred; Value *Cond = BI->getCondition(); - if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant()))) + if (!isa(Cond) || + !match(Cond, m_ICmp(Pred, m_Value(), m_Constant()))) return; ICmpInst *Cmp = cast(Cond); @@ -494,7 +495,9 @@ continue; Function *Callee = CS.getCalledFunction(); - if (!Callee || Callee->isDeclaration()) + // Downstream note: our PreInlinerTransform pass applied to all functions, + // so we do not skip declaration here, as upstream does. + if (!Callee) continue; // Successful musttail call-site splits result in erased CI and erased BB. Index: lib/Transforms/Scalar/GVN.cpp =================================================================== --- lib/Transforms/Scalar/GVN.cpp +++ lib/Transforms/Scalar/GVN.cpp @@ -94,9 +94,12 @@ STATISTIC(NumGVNEqProp, "Number of equalities propagated"); STATISTIC(NumPRELoad, "Number of loads PRE'd"); +static cl::opt ForceDisablePREsWhichIntroduceArtificialLoopDep( + "force-disable-pre-loopdep", cl::init(false), cl::Hidden); static cl::opt EnablePRE("enable-pre", cl::init(true), cl::Hidden); static cl::opt EnableLoadPRE("enable-load-pre", cl::init(true)); +static cl::opt EnableMemDep("enable-gvn-memdep", cl::init(true)); // Maximum allowed recursion depth. static cl::opt @@ -392,18 +395,13 @@ uint32_t e = assignExpNewValueNum(exp).first; valueNumbering[C] = e; return e; - } else if (AA->onlyReadsMemory(C)) { + } else if (MD && AA->onlyReadsMemory(C)) { Expression exp = createExpr(C); auto ValNum = assignExpNewValueNum(exp); if (ValNum.second) { valueNumbering[C] = ValNum.first; return ValNum.first; } - if (!MD) { - uint32_t e = assignExpNewValueNum(exp).first; - valueNumbering[C] = e; - return e; - } MemDepResult local_dep = MD->getDependency(C); @@ -1059,6 +1057,25 @@ BasicBlock *TmpBB = LoadBB; bool IsSafeToSpeculativelyExecute = isSafeToSpeculativelyExecute(LI); + // If loop information is available and GVN is run before loop-vectorization, + // make sure not to break vectorization opportunities by PRE'ing a partially + // redundant load with a load from a previous iteration that does not dominate + // the partially redundant load, as this case is not supported by + // LoopVectorize's handling of firstOrderRecurrences (that are not + // reductions/inductions). + if (DisablePREsWhichIntroduceArtificialLoopDep && Loops && + !LI->getType()->isVectorTy() && ValuesPerBlock.size()) + if (const Loop *L = Loops->getLoopFor(LI->getParent())) + for (auto &V : ValuesPerBlock) { + if (!V.AV.isCoercedLoadValue()) + continue; + LoadInst *OtherLI = V.AV.getCoercedLoadValue(); + if (L != Loops->getLoopFor(V.BB)) + continue; + if (OtherLI != LI && !OI->dominates(OtherLI, LI)) + return false; + } + // Check that there is no implicit control flow instructions above our load in // its block. If there is an instruction that doesn't always pass the // execution to the following instruction, then moving through it may become @@ -2018,6 +2035,7 @@ DT = &RunDT; VN.setDomTree(DT); TLI = &RunTLI; + Loops = LI; VN.setAliasAnalysis(&RunAA); MD = RunMD; OrderedInstructions OrderedInstrs(DT); @@ -2613,8 +2631,12 @@ public: static char ID; // Pass identification, replacement for typeid - explicit GVNLegacyPass(bool NoLoads = false) - : FunctionPass(ID), NoLoads(NoLoads) { + explicit GVNLegacyPass( + bool NoMemDepAnalysis = !EnableMemDep, + bool DisablePREsWhichIntroduceArtificialLoopDep = false) + : FunctionPass(ID), NoMemDepAnalysis(NoMemDepAnalysis), + DisablePREsWhichIntroduceArtificialLoopDep( + DisablePREsWhichIntroduceArtificialLoopDep) { initializeGVNLegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -2623,13 +2645,15 @@ return false; auto *LIWP = getAnalysisIfAvailable(); - + Impl.DisablePREsWhichIntroduceArtificialLoopDep = + DisablePREsWhichIntroduceArtificialLoopDep || + ForceDisablePREsWhichIntroduceArtificialLoopDep; return Impl.runImpl( F, getAnalysis().getAssumptionCache(F), getAnalysis().getDomTree(), getAnalysis().getTLI(), getAnalysis().getAAResults(), - NoLoads ? nullptr + NoMemDepAnalysis ? nullptr : &getAnalysis().getMemDep(), LIWP ? &LIWP->getLoopInfo() : nullptr, &getAnalysis().getORE()); @@ -2639,7 +2663,7 @@ AU.addRequired(); AU.addRequired(); AU.addRequired(); - if (!NoLoads) + if (!NoMemDepAnalysis) AU.addRequired(); AU.addRequired(); @@ -2650,7 +2674,8 @@ } private: - bool NoLoads; + bool NoMemDepAnalysis; + bool DisablePREsWhichIntroduceArtificialLoopDep; GVN Impl; }; @@ -2667,6 +2692,9 @@ INITIALIZE_PASS_END(GVNLegacyPass, "gvn", "Global Value Numbering", false, false) // The public interface to this file... -FunctionPass *llvm::createGVNPass(bool NoLoads) { - return new GVNLegacyPass(NoLoads); +FunctionPass * +llvm::createGVNPass(bool NoMemDepAnalysis, + bool DisablePREsWhichIntroduceArtificialLoopDep) { + return new GVNLegacyPass(NoMemDepAnalysis, + DisablePREsWhichIntroduceArtificialLoopDep); } Index: lib/Transforms/Scalar/IndVarSimplify.cpp =================================================================== --- lib/Transforms/Scalar/IndVarSimplify.cpp +++ lib/Transforms/Scalar/IndVarSimplify.cpp @@ -949,6 +949,7 @@ SmallVectorImpl &DeadInsts; SmallPtrSet Widened; + DenseMap WideMap; SmallVector NarrowIVUsers; enum ExtendKind { ZeroExtended, SignExtended, Unknown }; @@ -1377,12 +1378,19 @@ // Stop traversing the def-use chain at inner-loop phis or post-loop phis. if (PHINode *UsePhi = dyn_cast(DU.NarrowUse)) { if (LI->getLoopFor(UsePhi->getParent()) != L) { - // For LCSSA phis, sink the truncate outside the loop. - // After SimplifyCFG most loop exit targets have a single predecessor. - // Otherwise fall back to a truncate within the loop. - if (UsePhi->getNumOperands() != 1) - truncateIVUse(DU, DT, LI); - else { + SmallVector, 4> WideIncVals; + + // Look through the incoming values for the narrow use phi. If we have + // an existing wide value for all of them (including the current def + // being considered), then we can continue. + if(llvm::all_of(UsePhi->blocks(), [&](BasicBlock *BB) { + auto *I = dyn_cast(UsePhi->getIncomingValueForBlock(BB)); + if (I == nullptr || WideMap.count(I) == 0) + return false; + + WideIncVals.push_back({WideMap[I], BB}); + return true; + })) { // Widening the PHI requires us to insert a trunc. The logical place // for this trunc is in the same BB as the PHI. This is not possible if // the BB is terminated by a catchswitch. @@ -1392,14 +1400,18 @@ PHINode *WidePhi = PHINode::Create(DU.WideDef->getType(), 1, UsePhi->getName() + ".wide", UsePhi); - WidePhi->addIncoming(DU.WideDef, UsePhi->getIncomingBlock(0)); + for (auto &WIV : WideIncVals) + WidePhi->addIncoming(WIV.first, WIV.second); IRBuilder<> Builder(&*WidePhi->getParent()->getFirstInsertionPt()); Value *Trunc = Builder.CreateTrunc(WidePhi, DU.NarrowDef->getType()); UsePhi->replaceAllUsesWith(Trunc); DeadInsts.emplace_back(UsePhi); LLVM_DEBUG(dbgs() << "INDVARS: Widen lcssa phi " << *UsePhi << " to " << *WidePhi << "\n"); - } + } else + // Otherwise fall back to a truncate within the loop. + truncateIVUse(DU, DT, LI); + return nullptr; } } @@ -1608,6 +1620,7 @@ Widened.insert(OrigPhi); pushNarrowIVUsers(OrigPhi, WidePhi); + WideMap.insert({OrigPhi, WidePhi}); while (!NarrowIVUsers.empty()) { NarrowIVDefUse DU = NarrowIVUsers.pop_back_val(); @@ -1616,9 +1629,12 @@ // use_iterator across it. Instruction *WideUse = widenIVUse(DU, Rewriter); + // Record mapping of narrow to wide use // Follow all def-use edges from the previous narrow use. - if (WideUse) + if (WideUse) { + WideMap.insert({DU.NarrowUse, WideUse}); pushNarrowIVUsers(DU.NarrowUse, WideUse); + } // widenIVUse may have removed the def-use edge. if (DU.NarrowDef->use_empty()) Index: lib/Transforms/Scalar/LICM.cpp =================================================================== --- lib/Transforms/Scalar/LICM.cpp +++ lib/Transforms/Scalar/LICM.cpp @@ -641,7 +641,7 @@ // writes to this memory in the loop, we can hoist or sink. if (AliasAnalysis::onlyAccessesArgPointees(Behavior)) { for (Value *Op : CI->arg_operands()) - if (Op->getType()->isPointerTy() && + if (Op->getType()->isPtrOrPtrVectorTy() && pointerInvalidatedByLoop(Op, MemoryLocation::UnknownSize, AAMDNodes(), CurAST)) return false; Index: lib/Transforms/Scalar/LoopExprTreeFactoring.cpp =================================================================== --- /dev/null +++ lib/Transforms/Scalar/LoopExprTreeFactoring.cpp @@ -0,0 +1,433 @@ +//===- LoopExprTreeFactoring.cpp ------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass rewrites fadd/fmul trees into mathematically simplified form if the +// correct fast-math attributes are set. +// +// For example: +// (c0 * x) + (c1 * x) + (c2 * x) +// <=> (c0 + c1 + c2) * x +// +// (c0 * x) + (c0 * c1 * x) + (c0 * c2 * x) +// <=> c0 * ((1 + c1 + c2) * x) +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/LoopUtils.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "loop-factor-expr-trees" + +static cl::opt IgnoreCostModel( + "loop-factor-ignore-cost-model", cl::init(false), cl::Hidden, + cl::desc("Force factorisation even when not beneficial.")); + +namespace { +// We keep a map of 'ordered values', per DFS-order, to keep the +// algorithm deterministic since we are using various unordered data +// structures like DenseMap. +DenseMap OrderedValues; +} + +typedef std::function MatchFunc; + +/// Accmulator tree of multiplicands (Value*) and addends (AccumTree). +// Given the following tree hierarchy: +// T = ( {v1,v2}, {T2, T3} ) +// T2 = ( {}, {v3} ) +// T3 = ( {}, {v4} ) +// The generated expression is: +// (v1*v2) * (v3 + v4) +class AccumTree { + /// The multiplicands that are common to all subtrees + DenseMap CommonMultiplicands; + + /// A list of subtrees (Addends) + SmallVector, 2> Addends; + +public: + AccumTree() {} + + AccumTree(ArrayRef Values) { + for (auto *I : Values) + CommonMultiplicands.FindAndConstruct(I).second++; + } + + AccumTree(const AccumTree &Other) { + for (auto &I : Other.CommonMultiplicands) + CommonMultiplicands[I.first] = I.second; + for (auto &A : Other.Addends) { + auto UniqueCopy = llvm::make_unique(*A); + addAddend(UniqueCopy); + } + } + + /// Transfer ownership of unique_ptr C to this subtree. + void addAddend(std::unique_ptr &C) { + Addends.push_back(std::move(C)); + } + + /// Partition the accumulator tree into a multi-layer tree + /// where we try to hoist out invariant multiplicands. + void partition(Loop *L); + + /// CodeGen this expression tree before insertion point II. + Value *codeGen(Loop *L, BasicBlock::iterator II); + + /// Cost of generating this tree, defined as: + /// Cost of all individual subtrees + /// + Cost of adding the subtrees together + /// + Cost of multiplying the multiplicands + unsigned getCost(Loop *L) const { + // First all subtrees. + unsigned AddCost = 0; + for (auto &A : Addends) + AddCost += A->getCost(L); + + if (Addends.size()) + AddCost += Addends.size() - 1; + + unsigned MultCost = 0; + for (auto &KV : CommonMultiplicands) + if (KV.second && L->isLoopInvariant(KV.first)) + MultCost += 1; + else + MultCost += KV.second; + + // For leaf nodes, we don't need the extra multiply with + // the addend + if (MultCost && !Addends.size()) + MultCost--; + return AddCost + MultCost; + } + +private: + /// Helper function to recursively generate the AccumTree expression. + Value *codeGenTreeList(Loop *L, BasicBlock::iterator II, + ArrayRef > Values) { + // For an empty subtree, a common multiplicand must have been factored out + // and so we can safely return the addend '1.0'. + // i.e. (a * x0) + (a) => a * (x0 + 1.0) + if (Values.size() == 0) + return ConstantFP::get(II->getType(), 1.0); + + Value *Res = Values.front()->codeGen(L, II); + for (auto &Val : Values.drop_front()) { + Value *Tmp = Val->codeGen(L, II); + Instruction *Op = BinaryOperator::Create(Instruction::FAdd, Res, Tmp); + Op->copyFastMathFlags(&*II); + Op->insertBefore(&*II); + Res = Op; + } + + return Res; + } + + /// Helper function to generate a multiply expression from a list of values. + Value *codeGenMulList(BasicBlock::iterator II, ArrayRef Values) { + // For an empty multiplicand list, multiplying with '1.0' gives the same + // result. + if (Values.size() == 0) + return ConstantFP::get(II->getType(), 1.0); + + Value *Res = Values.front(); + for (auto *Val : Values.drop_front()) { + Instruction *Op = BinaryOperator::Create(Instruction::FMul, Res, Val); + Op->copyFastMathFlags(&*II); + Op->insertBefore(&*II); + Res = Op; + } + + return Res; + } + +public: +#define INDENT(n) (std::string((n), ' ')) + void dump(raw_ostream &OS, unsigned Indent = 0) { + OS << INDENT(Indent) << "(\n"; + for (auto &I : CommonMultiplicands) + for (unsigned J = 0; J< I.second; ++J) + OS << INDENT(Indent) << *I.first << " *\n"; + + for (auto &A : Addends) { + A->dump(OS, Indent + 4); + if (A != Addends.back()) + OS << " + "; + OS << "\n"; + } + + OS << INDENT(Indent) << ")"; + } +}; + +Value *AccumTree::codeGen(Loop *L, BasicBlock::iterator II) { + // Recursively generate the Addends + Value *Add = codeGenTreeList(L, II, Addends); + + // Expand the extracted values into an array and sort by invariants first. + SmallVector Expanded; + for (auto &IV : CommonMultiplicands) + for (unsigned NumV = 0; NumV < IV.second; ++NumV) + Expanded.push_back(IV.first); + + auto InvariantsOnRHS = [&L](Value *A, Value *B) { + if (L->isLoopInvariant(A) && !L->isLoopInvariant(B)) + return true; + else if (L->isLoopInvariant(B)) + return false; + else + return ::OrderedValues[A] < ::OrderedValues[B]; + }; + std::stable_sort(Expanded.begin(), Expanded.end(), InvariantsOnRHS); + Expanded.push_back(Add); + return codeGenMulList(II, Expanded); +} + +void AccumTree::partition(Loop *L) { + // Create a histogram with multiplicands as bins for all addends. + DenseMap Hist; + for (auto &A : Addends) { + for (auto &I : A->CommonMultiplicands) + Hist[I.first] += I.second != 0; + } + + // Find the most common multiplicand. + typedef std::map::value_type MapType; + const auto Max = + std::max_element(Hist.begin(), Hist.end(), [](MapType A, MapType B) { + return A.second < B.second || + (A.second == B.second && + ::OrderedValues[A.first] < ::OrderedValues[B.first]); + }); + + if (Max->second < 2) + return; + + auto MatchTree = llvm::make_unique(); + auto NoMatchTree = llvm::make_unique(); + + // Factor out a common multiplicand + Value *MaxVal = Max->first; + MatchTree->CommonMultiplicands[MaxVal] = 1; + + // (while maintaining program order of addends) + std::reverse(Addends.begin(), Addends.end()); + while (Addends.size()) { + auto ACpy = Addends.pop_back_val(); + if (ACpy->CommonMultiplicands[MaxVal] > 0) { + ACpy->CommonMultiplicands[MaxVal]--; + MatchTree->addAddend(ACpy); + } else + NoMatchTree->addAddend(ACpy); + } + + if (MatchTree->Addends.size()) { + MatchTree->partition(L); + addAddend(MatchTree); + } + + if (NoMatchTree->Addends.size()) { + NoMatchTree->partition(L); + addAddend(NoMatchTree); + } + + // Try to squash into a single subtree. + if (Addends.size() == 1) { + auto Addend = Addends.pop_back_val(); + for (auto &I : Addend->CommonMultiplicands) + CommonMultiplicands[I.first] += I.second; + for (auto &A : Addend->Addends) + addAddend(A); + } +} + +// Find expression in the loop that match BinopMatch. +static void findNodes(Loop *L, Value *V, SmallVectorImpl &Res, + MatchFunc BinopMatch, bool IncludeInnerNodes = false) { + Value *LHS, *RHS; + if (L->isLoopInvariant(V) || !BinopMatch(V, LHS, RHS)) + Res.push_back(V); + else { + if (IncludeInnerNodes) + Res.push_back(V); + findNodes(L, LHS, Res, BinopMatch, IncludeInnerNodes); + findNodes(L, RHS, Res, BinopMatch, IncludeInnerNodes); + } +} + +// Calculates the cost of an expression tree, with a selection given +// by 'MatchSequence', which matches the IR in the given order. +// (e.g. For Add-chains with multiply-subtrees, we'll look for adds first, +// and then multiplies) +// DeadCodeCost is the cost of the remaining tree that will remain if +// we choose to rewrite the entire expression tree. +static int calculateCost(Loop *L, Value *V, int &DeadCodeCost, + ArrayRef MatchSequence, bool IsRoot) { + Value *LHS, *RHS; + if (L->isLoopInvariant(V)) + return 0; + + if (!MatchSequence[0](V, LHS, RHS)) { + // We first look for Adds, and only then we look for Muls + if (MatchSequence.size() > 1 && MatchSequence[1](V, LHS, RHS)) + MatchSequence = MatchSequence.slice(1); + else + return 0; + } + + int DeadCodeCostLHS = 0; + int DeadCodeCostRHS = 0; + unsigned SubTreeCost = + calculateCost(L, LHS, DeadCodeCostLHS, MatchSequence, false) + + calculateCost(L, RHS, DeadCodeCostRHS, MatchSequence, false) + 1; + + bool SubtreeIsDead = !IsRoot && V->getNumUses() > 1; + DeadCodeCost = + SubtreeIsDead ? SubTreeCost : DeadCodeCostLHS + DeadCodeCostRHS; + + return SubTreeCost; +} + +bool MatchFAdd(Value *V, Value *&LHS, Value *&RHS) { + return match(V, m_FAdd(m_Value(LHS), m_Value(RHS))) && + cast(V)->isFast(); +} + +bool MatchFMul(Value *V, Value *&LHS, Value *&RHS) { + return match(V, m_FMul(m_Value(LHS), m_Value(RHS))) && + cast(V)->isFast(); +} + +static Value *breakAddChain(Instruction *I, Loop *L) { + LLVM_DEBUG(dbgs() << "LETF: Attempting to break chain for " << *I << "\n"); + + SmallVector Addends; + findNodes(L, I, Addends, MatchFAdd); + LLVM_DEBUG(dbgs() << "LETF: Found chain of " << Addends.size() << " Add(s).\n"); + + if (Addends.size() < 2) + return nullptr; + + AccumTree T; + OrderedValues.clear(); + for (auto *A : Addends) { + // Create the partitioning tree from here. + SmallVector MulOpnds; + findNodes(L, A, MulOpnds, MatchFMul); + LLVM_DEBUG(dbgs() << "LETF: " << *A << " is a chain of " << MulOpnds.size() + << " multiple(s).\n"); + + int Cnt = 0; + for (auto *V : MulOpnds) + OrderedValues[V] = Cnt++; + + auto AT = llvm::make_unique(MulOpnds); + T.addAddend(AT); + } + + int DeadCodeCost = 0; + int CostBefore = + calculateCost(L, I, DeadCodeCost, { MatchFAdd, MatchFMul }, true); + + T.partition(L); + int CostAfter = T.getCost(L); + LLVM_DEBUG(dbgs() << "LETF: Cost model reports Before=" << CostBefore << " After=" + << CostAfter << " DeadCodeCost=" << DeadCodeCost << "\n"); + + if ((CostAfter + DeadCodeCost < CostBefore) || IgnoreCostModel) { + LLVM_DEBUG(dbgs() << "LETF: Accepting factorised version.\n"); + return T.codeGen(L, I->getIterator()); + } + + return nullptr; +} + +namespace { +class LoopExprTreeFactoringPass : public LoopPass { +public: + static char ID; // Pass ID, replacement for typeid + LoopExprTreeFactoringPass() : LoopPass(ID) { + initializeLoopExprTreeFactoringPassPass(*PassRegistry::getPassRegistry()); + } + + bool processLoop(Loop *L) { + // Keep a set of nodes that we don't want to revisit + std::set ProcessedNodes; + + for (auto *BB : L->blocks()) { + SmallVector WorkList; + + // First build up a worklist + BasicBlock::reverse_iterator BI, BE; + for (BI = BB->rbegin(), BE = BB->rend(); BI != BE; BI++) { + Instruction *I = &*BI; + + if (I->getOpcode() != Instruction::FAdd) + continue; + + WorkList.push_back(I); + } + + for (auto *I : WorkList) { + if (ProcessedNodes.count(I)) + continue; + + Value *New = breakAddChain(I, L); + if (!New) + continue; + + if (auto *NewI = dyn_cast(New)) { + // If we made a change, discard both the old and new chain + // from the worklist. + SmallVector Addends; + findNodes(L, I, Addends, MatchFAdd, true); + ProcessedNodes.insert(Addends.begin(), Addends.end()); + + I->replaceAllUsesWith(NewI); + I->eraseFromParent(); + } + } + } + + OrderedValues.clear(); + return !ProcessedNodes.empty(); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + + return processLoop(L); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + } +}; +} + +char LoopExprTreeFactoringPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopExprTreeFactoringPass, DEBUG_TYPE, + "Loop Expression Tree Factoring", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_END(LoopExprTreeFactoringPass, DEBUG_TYPE, + "Loop Expression Tree Factoring", false, false) + +Pass *llvm::createLoopExprTreeFactoringPass() { + return new LoopExprTreeFactoringPass(); +} Index: lib/Transforms/Scalar/LoopInterchange.cpp =================================================================== --- lib/Transforms/Scalar/LoopInterchange.cpp +++ lib/Transforms/Scalar/LoopInterchange.cpp @@ -615,7 +615,7 @@ return llvm::none_of(Ins->users(), [=](User *U) -> bool { auto *UserIns = dyn_cast(U); RecurrenceDescriptor RD; - return !UserIns || !RecurrenceDescriptor::isReductionPHI(UserIns, L, RD); + return !UserIns || !RecurrenceDescriptor::isReductionPHI(UserIns, L, SE, RD); }); } @@ -712,7 +712,7 @@ InductionDescriptor ID; if (InductionDescriptor::isInductionPHI(&PHI, L, SE, ID)) Inductions.push_back(&PHI); - else if (RecurrenceDescriptor::isReductionPHI(&PHI, L, RD)) + else if (RecurrenceDescriptor::isReductionPHI(&PHI, L, SE, RD)) Reductions.push_back(&PHI); else { LLVM_DEBUG( Index: lib/Transforms/Scalar/LoopLoadElimination.cpp =================================================================== --- lib/Transforms/Scalar/LoopLoadElimination.cpp +++ lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -99,7 +99,9 @@ assert(LoadPtrType->getPointerAddressSpace() == StorePtr->getType()->getPointerAddressSpace() && - LoadType == StorePtr->getType()->getPointerElementType() && + LoadType->getScalarSizeInBits() == + StorePtr->getType()->getPointerElementType() + ->getScalarSizeInBits() && "Should be a known dependence"); // Currently we only support accesses with unit stride. FIXME: we should be @@ -209,7 +211,8 @@ continue; // Only progagate the value if they are of the same type. - if (Store->getPointerOperandType() != Load->getPointerOperandType()) + if (Store->getPointerOperandType()->getScalarSizeInBits() != + Load->getPointerOperandType()->getScalarSizeInBits()) continue; Candidates.emplace_front(Load, Store); @@ -435,7 +438,25 @@ PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded", &L->getHeader()->front()); PHI->addIncoming(Initial, PH); - PHI->addIncoming(Cand.Store->getOperand(0), L->getLoopLatch()); + + Value *StoreValue; + + Type *LoadType = Initial->getType(); + Type *StoreType = Cand.Store->getOperand(0)->getType(); + + assert(LoadType->getScalarSizeInBits() == + StoreType->getScalarSizeInBits() && + "The type sizes should match!"); + + if (LoadType != StoreType) { + // Need a bitcast to convert to the loaded type + StoreValue = + CastInst::Create(Instruction::BitCast, Cand.Store->getOperand(0), + LoadType, "store_forward_cast", Cand.Store); + } else + StoreValue = Cand.Store->getOperand(0); + + PHI->addIncoming(StoreValue, L->getLoopLatch()); Cand.Load->replaceAllUsesWith(PHI); } Index: lib/Transforms/Scalar/LoopRewriteGEPs.cpp =================================================================== --- /dev/null +++ lib/Transforms/Scalar/LoopRewriteGEPs.cpp @@ -0,0 +1,152 @@ +//===- LoopRewriteGEPs.cpp ------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass attempts to provide GVN with more opportunities to eliminate +// partially redundant loads. One of the barriers is when GVN cannot perform +// PHI translation, which occurs when address calculations use a common index +// (most likely the loop's induction variable) rather than a common base. +// +// We solve this by transforming: +// %p1 = getelementptr float, float* %base, i64 %indvars.iv +// %p2 = getelementptr float, float* %anotherbase, i64 %indvars.iv +// into: +// %p1 = getelementptr float, float* %base, i64 %indvars.iv +// %tmp = add i64 indvars.iv, ((%anotherbase-%base)/sizeof(float)) +// %p2 = getelementptr float, float* %base, i64 %tmp +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/LoopUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "rewrite-geps-in-loop" + +namespace { +class LoopRewriteGEPsPass : public LoopPass { +public: + static char ID; + ScalarEvolution *SE; + + LoopRewriteGEPsPass() : LoopPass(ID) { + initializeLoopRewriteGEPsPassPass(*PassRegistry::getPassRegistry()); + } + + bool processLoop(Loop *L) { + if (!L->empty()) + return false; + + bool Changed = false; + + for (BasicBlock *BB : L->blocks()) { + const DataLayout &DL = BB->getModule()->getDataLayout(); + SmallVector VisitedGEPs; + + for (Instruction &I : *BB) { + auto *GEP = dyn_cast(&I); + if (!GEP) + continue; + if (GEP->getNumOperands() != 2) + continue; + // FIXME: Vector GEPs are not SCEVable + if (GEP->getType()->isVectorTy()) + continue; + + Type* DataTy = GEP->getResultElementType(); + int64_t DataSize = DL.getTypeAllocSize(DataTy); + Value* OrigIdx = GEP->getOperand(1); + + GetElementPtrInst *RelatedGEP = nullptr; + Constant *IdxInc; + + // Look for a GEP that's identical in all but base pointer. + for (auto VisitedGEP : VisitedGEPs) { + if (VisitedGEP->getResultElementType() != DataTy) + continue; + if (VisitedGEP->getOperand(1) != OrigIdx) + continue; + + // Calculate the distance betwen the two base pointers... + auto *Dist = SE->getMinusSCEV(SE->getSCEV(GEP), + SE->getSCEV(VisitedGEP)); + + //...and if it's constant and of the correct scale. + if (isa(Dist)) { + APInt CDist = cast(Dist)->getAPInt(); + if (CDist.srem(DataSize)) + continue; + + // Can we calculate NewIdx without sext/trunc of the original? + auto *IdxTy = cast(OrigIdx->getType()); + APInt IdxIncVal1 = CDist.sdiv(DataSize); + APInt IdxIncVal2 = IdxIncVal1.sextOrTrunc(IdxTy->getBitWidth()); + if (!APInt::isSameValue(IdxIncVal1, IdxIncVal2)) + continue; + + RelatedGEP = VisitedGEP; + IdxInc = ConstantInt::get(IdxTy, IdxIncVal2); + break; + } + } + + if (!RelatedGEP) { + // Nothing we can do, but it's base pointer might be useful later. + VisitedGEPs.push_back(GEP); + continue; + } + + LLVM_DEBUG(dbgs() << "Rewrite GEP: Original GEP: " << *GEP << '\n'); + + // Make a new better GEP. + auto *NewIdx = BinaryOperator::CreateAdd(OrigIdx, IdxInc, "", GEP); + GEP->setOperand(0, RelatedGEP->getOperand(0)); + GEP->setOperand(1, NewIdx); + Changed = true; + + LLVM_DEBUG(dbgs() << "Rewrite GEP: Related GEP : " << *RelatedGEP << '\n'); + LLVM_DEBUG(dbgs() << "Rewrite GEP: New GEP Idx : " << *NewIdx << '\n'); + LLVM_DEBUG(dbgs() << "Rewrite GEP: New GEP : " << *GEP << '\n'); + } + } + + return Changed; + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + + SE = &getAnalysis().getSE(); + return processLoop(L); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired(); + } +}; +} + +char LoopRewriteGEPsPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopRewriteGEPsPass, DEBUG_TYPE, + "Rewrite GEPs in Loop", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(LoopRewriteGEPsPass, DEBUG_TYPE, + "Rewrite GEPs in Loop", false, false) + +Pass *llvm::createLoopRewriteGEPsPass() { + return new LoopRewriteGEPsPass(); +} Index: lib/Transforms/Scalar/LoopSpeculativeBoundsCheck.cpp =================================================================== --- /dev/null +++ lib/Transforms/Scalar/LoopSpeculativeBoundsCheck.cpp @@ -0,0 +1,636 @@ +//===- LoopSpeculativeBoundsCheck.cpp - Versioning for may-alias tripcount-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// TODO: Rewrite description; this was based on LoopVersioningLICM +// +// When alias analysis is uncertain about the aliasing between any two accesses, +// it will return MayAlias. This uncertainty from alias analysis restricts LICM +// from proceeding further. In cases where alias analysis is uncertain we might +// use loop versioning as an alternative. +// +// Loop Versioning will create a version of the loop with aggressive aliasing +// assumptions in addition to the original with conservative (default) aliasing +// assumptions. The version of the loop making aggressive aliasing assumptions +// will have all the memory accesses marked as no-alias. These two versions of +// loop will be preceded by a memory runtime check. This runtime check consists +// of bound checks for all unique memory accessed in loop, and it ensures the +// lack of memory aliasing. The result of the runtime check determines which of +// the loop versions is executed: If the runtime check detects any memory +// aliasing, then the original loop is executed. Otherwise, the version with +// aggressive aliasing assumptions is used. +// +// Following are the top level steps: +// +// a) Perform LoopSpeculativeBoundsCheck's feasibility check. +// b) If loop is a candidate for versioning then create a memory bound check, +// by considering all the memory accesses in loop body. +// c) Clone original loop and set all memory accesses as no-alias in new loop. +// d) Set original loop & versioned loop as a branch target of the runtime check +// result. +// +// It transforms loop as shown below: +// +// +----------------+ +// |Runtime Memcheck| +// +----------------+ +// | +// +----------+----------------+----------+ +// | | +// +---------+----------+ +-----------+----------+ +// |Orig Loop Preheader | |Cloned Loop Preheader | +// +--------------------+ +----------------------+ +// | | +// +--------------------+ +----------------------+ +// |Orig Loop Body | |Cloned Loop Body | +// +--------------------+ +----------------------+ +// | | +// +----------+--------------+-----------+ +// | +// +--------+--------+ +// |Exit Block (Join)| +// +-----------------+ +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredIteratorCache.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/Utils/ValueMapper.h" + +#define DEBUG_TYPE "loop-speculative-bounds-check" + +using namespace llvm; +using namespace llvm::PatternMatch; + +namespace { +struct LoopSpeculativeBoundsCheck : public LoopPass { + static char ID; + + bool runOnLoop(Loop *L, LPPassManager &LPM) override; + bool processLoop(Loop *L); + + using llvm::Pass::doFinalization; + + bool doFinalization() override { return false; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired(); + AU.addRequired(); + AU.addRequiredID(LCSSAID); + AU.addRequired(); + AU.addRequired(); + AU.addRequiredID(LoopSimplifyID); + AU.addRequired(); + AU.addRequired(); + AU.addPreserved(); + AU.addPreserved(); + } + + LoopSpeculativeBoundsCheck() + : LoopPass(ID), AA(nullptr), SE(nullptr), LI(nullptr), DT(nullptr), + TLI(nullptr), LAA(nullptr), LAI(nullptr), IsReadOnlyLoop(true) { + initializeLoopSpeculativeBoundsCheckPass(*PassRegistry::getPassRegistry()); + } + + AliasAnalysis *AA; // Current AliasAnalysis information + ScalarEvolution *SE; // Current ScalarEvolution + LoopInfo *LI; // Current LoopInfo + DominatorTree *DT; // Dominator Tree for the current Loop. + TargetLibraryInfo *TLI; // TargetLibraryInfo for constant folding. + LoopAccessLegacyAnalysis *LAA; // Current LoopAccessLegacyAnalysis + const LoopAccessInfo *LAI; // Current Loop's LoopAccessInfo + + // TODO: Convert all uses to parameters? + ValueToValueMapTy VMap; + SmallVector DefsUsedOutside; + + bool IsReadOnlyLoop; // Read only loop marker. + Value *GlobalBoundPtr; // Pointer to global loop bound variable + LoadInst *GlobalLoadInst; // Pointer to load from global + + bool isSuitableLoop(Loop *, AliasSetTracker &); + bool legalLoopStructure(Loop *); + bool legalLoopInstructions(Loop *); + bool checkMemoryAccesses(AliasSetTracker &); + bool instructionSafeForVersioning(Loop *, Instruction *); + void addPHINodes(Loop *, Loop *); + StringRef getPassName() const override { return "Loop SBC"; } +}; +} + +/// \brief Check loop structure and confirms it's good +/// for LoopSpeculativeBoundsCheck. +bool LoopSpeculativeBoundsCheck::legalLoopStructure(Loop *L) { + // If we can compute a tripcount, we don't need to do anything. + const SCEV *ExitCount = SE->getBackedgeTakenCount(L); + if (ExitCount != SE->getCouldNotCompute()) { + LLVM_DEBUG(dbgs() << " loop has known exit count, speculation unnecessary\n"); + return false; + } + + // Loop must have a preheader, if not return false. + if (!L->getLoopPreheader()) { + LLVM_DEBUG(dbgs() << " loop preheader is missing\n"); + return false; + } + // Loop should be innermost loop, if not return false. + if (L->getSubLoops().size()) { + LLVM_DEBUG(dbgs() << " loop is not innermost\n"); + return false; + } + // Loop should have a single backedge, if not return false. + // TODO: Revisit once we have multi-backedge support in the SLV? + if (L->getNumBackEdges() != 1) { + LLVM_DEBUG(dbgs() << " loop has multiple backedges\n"); + return false; + } + // Loop must have a single exiting block, if not return false. + // TODO: Relax restriction in future to help out the SLV. + if (!L->getExitingBlock()) { + LLVM_DEBUG(dbgs() << " loop has multiple exiting block\n"); + return false; + } + // We only handle bottom-tested loop, i.e. loop in which the condition is + // checked at the end of each iteration. With that we can assume that all + // instructions in the loop are executed the same number of times. + if (L->getExitingBlock() != L->getLoopLatch()) { + LLVM_DEBUG(dbgs() << " loop is not bottom tested\n"); + return false; + } + // Parallel loops must not have aliasing loop-invariant memory accesses. + // Hence we don't need to version anything in this case. + if (L->isAnnotatedParallel()) { + LLVM_DEBUG(dbgs() << " Parallel loop is not worth versioning\n"); + return false; + } + + // Loop should have a dedicated exit block, if not return false. + if (!L->hasDedicatedExits()) { + LLVM_DEBUG(dbgs() << " loop does not have dedicated exit blocks\n"); + return false; + } + + // Find an appropriate compare against a sign-extended load from a global + // pointer; this prevents vectorization, so being able to conditionally + // hoist the load out of the loop will allow scalar evolution to figure + // things out and allow vectorization. + // + // TODO: Allow more cases than just icmp(ivar,ext?(ld(global))) + auto *EBlock = L->getExitingBlock(); + auto *Term = EBlock->getTerminator(); + auto *ExitBlock = L->getExitBlock(); + ICmpInst::Predicate Pred; + BasicBlock *TBlock, *FBlock; + PHINode* IVar; + Value* Addr; + Value* Ld; + + if (match(Term, m_Br(m_ICmp(Pred, m_Add(m_PHI(IVar), m_One()), + m_AnyExtOrNone(m_Value(Ld))), + TBlock, FBlock)) && + match(Ld, m_Load(m_Value(Addr))) && + isa(Addr) && + (TBlock == ExitBlock || FBlock == ExitBlock) && + L->isLoopInvariant(Addr)) { + GlobalBoundPtr = Addr; + GlobalLoadInst = cast(Ld); + } else { + LLVM_DEBUG(dbgs() << " loop condition not a comparison with a global var\n"); + return false; + } + + if (!L->contains(GlobalLoadInst->getParent())) { + LLVM_DEBUG(dbgs() << " global load is not part of the loop\n"); + return false; + } + + // Some safety checks to make sure the load will always occur. The checks for + // no throwing instructions later is also part of this... + // Get the exit blocks for the current loop. + SmallVector ExitBlocks; + L->getExitBlocks(ExitBlocks); + + // Verify that the block dominates each of the exit blocks of the loop. + for (auto *EB : ExitBlocks) + if (!DT->dominates(GlobalLoadInst->getParent(), EB)) { + LLVM_DEBUG(dbgs() << " GlobalLoadInst doesn't dominate exit block\n"); + return false; + } + + // As a degenerate case, if the loop is statically infinite then we haven't + // proven anything since there are no exit blocks. + if (ExitBlocks.empty()) { + LLVM_DEBUG(dbgs() << " No exit blocks (infinite loop)\n"); + return false; + } + + // Passed all structural checks, so onto testing individual instructions + // and memory accesses... + return true; +} + +/// \brief Check memory accesses in loop and confirms it's good for +/// LoopSpeculativeBoundsCheck. +bool LoopSpeculativeBoundsCheck::checkMemoryAccesses(AliasSetTracker &CurAST) { + bool HasMayAlias = false; + bool HasMod = false; + // Memory check: + // Find the alias set for the load which the loop condition depends upon, + // then check the types are consistent, that we have potential writes in + // the set, and that they *may* alias -- if they *must*, we can't help. + // If there's no aliasing at all, LICM should have already dealt with this. + // If it hasn't, probably worth looking into why... + uint64_t Size = 0; + auto &DL = GlobalLoadInst->getModule()->getDataLayout(); + if (GlobalLoadInst->getType()->isSized()) + Size = DL.getTypeStoreSize(GlobalLoadInst->getType()); + else { + LLVM_DEBUG(dbgs() << " Unable to find size for GlobalLoadInst\n"); + return false; + } + + AAMDNodes AAInfo; + GlobalLoadInst->getAAMetadata(AAInfo); + + const AliasSet *AS = CurAST.getAliasSetForPointerIfExists(GlobalBoundPtr, + Size, AAInfo); + if (!AS) { + LLVM_DEBUG(dbgs() << " Unable to find AliasSet for Global Load\n"); + return false; + } + + // With MustAlias its not worth adding runtime bound check. + // If we only alias with ourselves (single pointer in the set), then + // it's safe to proceed. + if (AS->isMustAlias() && (AS->getRefCount() > 1)) { + LLVM_DEBUG(dbgs() << " GlobalLoadInst in a MustAlias set\n"); + return false; + } + + Value *SomePtr = AS->begin()->getValue(); + bool TypeCheck = true; + // Check for Mod & MayAlias + HasMayAlias |= AS->isMayAlias(); + HasMod |= AS->isMod(); + for (const auto &A : *AS) { + Value *Ptr = A.getValue(); + // Alias tracker should have pointers of same data type. + TypeCheck = (TypeCheck && (SomePtr->getType() == Ptr->getType())); + } + + // Ensure types should be of same type. + if (!TypeCheck) { + LLVM_DEBUG(dbgs() << " Alias tracker type safety failed!\n"); + return false; + } + // Ensure loop body shouldn't be read only. + if (!HasMod) { + LLVM_DEBUG(dbgs() << " No memory modified in loop body\n"); + return false; + } + // Make sure alias set has may alias case. + // If there no alias memory ambiguity, return false. + if (!HasMayAlias) { + LLVM_DEBUG(dbgs() << " No ambiguity in memory access.\n"); + return false; + } + return true; +} + +/// \brief Check loop instructions safe for Loop versioning. +/// It returns true if it's safe else returns false. +/// Consider following: +/// 1) Check all load store in loop body are non atomic & non volatile. +/// 2) Check function call safety, by ensuring its not accessing memory. +/// 3) Loop body shouldn't have any may throw instruction. +bool LoopSpeculativeBoundsCheck::instructionSafeForVersioning(Loop *L, + Instruction *I) { + assert(I != nullptr && "Null instruction found!"); + // Check function call safety + if (isa(I) && !AA->doesNotAccessMemory(CallSite(I))) { + LLVM_DEBUG(dbgs() << " Unsafe call site found.\n"); + return false; + } + // Avoid loops with possiblity of throw + if (I->mayThrow()) { + LLVM_DEBUG(dbgs() << " May throw instruction found in loop body\n"); + return false; + } + // If current instruction is load instructions + // make sure it's a simple load (non atomic & non volatile) + if (I->mayReadFromMemory()) { + LoadInst *Ld = dyn_cast(I); + if (!Ld || !Ld->isSimple()) { + LLVM_DEBUG(dbgs() << " Found a non-simple load.\n"); + return false; + } + } + // If current instruction is store instruction + // make sure it's a simple store (non atomic & non volatile) + else if (I->mayWriteToMemory()) { + StoreInst *St = dyn_cast(I); + if (!St || !St->isSimple()) { + LLVM_DEBUG(dbgs() << " Found a non-simple store.\n"); + return false; + } + IsReadOnlyLoop = false; + } + return true; +} + +/// \brief Check loop instructions and confirms it's good for +/// LoopSpeculativeBoundsCheck. +bool LoopSpeculativeBoundsCheck::legalLoopInstructions(Loop *L) { + // Resetting counters. + IsReadOnlyLoop = true; + // Iterate over loop blocks and instructions of each block and check + // instruction safety. + for (auto *Block : L->getBlocks()) + for (auto &Inst : *Block) { + // If instruction is unsafe just return false. + if (!instructionSafeForVersioning(L, &Inst)) + return false; + + // If there's an outside use of this instruction, record that so + // we can build a phi later. + for (auto *U : Inst.users()) { + Instruction *Use = cast(U); + if (!L->contains(Use->getParent())) { + DefsUsedOutside.push_back(&Inst); + break; + } + } + } + + // Read only loop should have already been handled, but we know there's + // not a possible alias between the loop condition boundary and any + // write in the loop; moving the load won't help SE check the loop + // tripcount for further optimizations. + if (IsReadOnlyLoop) { + LLVM_DEBUG(dbgs() << " Found a read-only loop!\n"); + return false; + } + + return true; +} + +/// \brief Checks legality for LoopSpeculativeBoundsCheck by considering: +/// a) loop structure legality b) loop instruction legality +/// c) loop memory access legality. +/// Return true if legal else returns false. +bool LoopSpeculativeBoundsCheck::isSuitableLoop(Loop *L, + AliasSetTracker &CurAST) { + LLVM_DEBUG(dbgs() << "Loop: " << *L); + // Check loop structure legality. + if (!legalLoopStructure(L)) { + LLVM_DEBUG(dbgs() + << " Loop structure not suitable for " + << "LoopSpeculativeBoundsCheck\n\n"); + return false; + } + // Check loop instruction legality. + if (!legalLoopInstructions(L)) { + LLVM_DEBUG(dbgs() + << " Loop instructions not suitable for " + << "LoopSpeculativeBoundsCheck\n\n"); + return false; + } + // Check loop memory access legality. + if (!checkMemoryAccesses(CurAST)) { + LLVM_DEBUG(dbgs() + << " Loop memory access not suitable for " + << "LoopSpeculativeBoundsCheck\n\n"); + return false; + } + // Loop versioning is feasible, return true. + LLVM_DEBUG(dbgs() << " Loop Speculative Bounds Check possible\n\n"); + return true; +} + +void LoopSpeculativeBoundsCheck::addPHINodes(Loop *SpeculativeLoop, + Loop *NonSpeculativeLoop) { + BasicBlock *PHIBlock = SpeculativeLoop->getExitBlock(); + assert(PHIBlock && "No single successor to loop exit block"); + + for (auto *Inst : DefsUsedOutside) { + auto *NonSpeculativeLoopInst = cast(VMap[Inst]); + PHINode *PN = nullptr; + bool Found = false; + + // First see if we have a single-operand PHI with the value defined by the + // original loop. + for (auto I = PHIBlock->begin(); (PN = dyn_cast(I)); ++I) { + if (PN->getIncomingValue(0) == Inst) { + Found = true; + + // Add the new incoming value from the non-versioned loop. + PN->addIncoming(NonSpeculativeLoopInst, + NonSpeculativeLoop->getExitingBlock()); + } + } + // If not create it. + if (!Found) { + PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver", + &PHIBlock->front()); + for (auto *User : Inst->users()) + if (!SpeculativeLoop->contains(cast(User)->getParent())) + User->replaceUsesOfWith(Inst, PN); + PN->addIncoming(Inst, SpeculativeLoop->getExitingBlock()); + + // Add the new incoming value from the non-versioned loop. + PN->addIncoming(NonSpeculativeLoopInst, + NonSpeculativeLoop->getExitingBlock()); + } + } +} + +bool LoopSpeculativeBoundsCheck::processLoop(Loop *L) { + AliasSetTracker AST(*AA); + bool Changed = false; + + for (auto *BB : L->getBlocks()) + AST.add(*BB); + + if (isSuitableLoop(L, AST)) { + Loop *SpeculativeLoop; // Loop with global load removed + Loop *NonSpeculativeLoop; // Original loop code + Value *RuntimeCheck = nullptr; // Runtime check to determine whether + // to enter speculated loop + Value *SCEVRuntimeCheck; + + BasicBlock *PH = L->getLoopPreheader(); + BasicBlock *CheckBB = PH; + + CheckBB->setName(L->getHeader()->getName() + ".speculative.bounds.check"); + + // Create empty preheader for the loop (and after cloning for the + // non-versioned loop). + BasicBlock *NewPH = + SplitBlock(CheckBB, CheckBB->getTerminator(), DT, LI); + NewPH->setName(L->getHeader()->getName() + ".ph"); + + SpeculativeLoop = L; + SmallVector NonSpeculativeLoopBlocks; + NonSpeculativeLoop = + cloneLoopWithPreheader(NewPH, CheckBB, SpeculativeLoop, VMap, + ".specbounds.orig", LI, DT, + NonSpeculativeLoopBlocks); + remapInstructionsInBlocks(NonSpeculativeLoopBlocks, VMap); + + addPHINodes(SpeculativeLoop, NonSpeculativeLoop); + + // Clone the load instruction into the new checking block + IRBuilder<> Builder(CheckBB->getTerminator()); + Builder.SetCurrentDebugLocation(GlobalLoadInst->getDebugLoc()); + + Instruction *ClonedLoad = GlobalLoadInst->clone(); + ClonedLoad->setName(GlobalLoadInst->getName() + ".speculative"); + Builder.Insert(ClonedLoad); + + // Copy across metadata. + // TODO: Should we restrict some kinds of metadata? + SmallVector, 4> Metadata; + GlobalLoadInst->getAllMetadataOtherThanDebugLoc(Metadata); + + for (auto M : Metadata) + ClonedLoad->setMetadata(M.first, M.second); + + // Drop old SE info, recalculate trip count based on the load in CheckBB. + GlobalLoadInst->replaceAllUsesWith(ClonedLoad); + SE->forgetLoop(SpeculativeLoop); + + // Figure out which checks are needed based on the speculated trip count. + LAI = &LAA->getInfo(SpeculativeLoop); + SmallVector AliasChecks = + LAI->getRuntimePointerChecking()->getChecks(); + + Instruction *FirstCheckInst; + Instruction *MemRuntimeCheck; + std::tie(FirstCheckInst, MemRuntimeCheck) = + LAI->addRuntimeChecks(CheckBB->getTerminator(), AliasChecks); + + // TODO: Using SCEV checks similar to LoopVersioning -- does this make + // sense for the bounds checking? We're just concerned with a load + // from a global used as a loop termination condition possibly aliasing + // with writes inside the loop. + const SCEVUnionPredicate &Pred = LAI->getPSE().getUnionPredicate(); + SCEVExpander Exp(*SE, CheckBB->getModule()->getDataLayout(), + "scev.check"); + SCEVRuntimeCheck = + Exp.expandCodeForPredicate(&Pred, CheckBB->getTerminator()); + auto *CI = dyn_cast(SCEVRuntimeCheck); + + // Discard the SCEV runtime check if it is always true. + if (CI && CI->isZero()) + SCEVRuntimeCheck = nullptr; + + // Figure out which checks which should plant + if (MemRuntimeCheck && SCEVRuntimeCheck) { + RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck, + SCEVRuntimeCheck, "ldist.safe"); + if (auto *I = dyn_cast(RuntimeCheck)) + I->insertBefore(CheckBB->getTerminator()); + } else + RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck; + + // We should get an alias check if the load was causing problems, so if + // we don't get one something has gone wrong somewhere. For now default + // to safe behaviour of always jumping to the original loop and let a + // later pass remove the unused blocks... + LLVMContext &Context = PH->getContext(); + Value *BrVal = RuntimeCheck ? RuntimeCheck : ConstantInt::getTrue(Context); + + Instruction *OrigTerm = CheckBB->getTerminator(); + BranchInst::Create(NonSpeculativeLoop->getLoopPreheader(), + SpeculativeLoop->getLoopPreheader(), + BrVal, OrigTerm); + OrigTerm->eraseFromParent(); + + // The loops merge in the original exit block. This is now dominated by the + // memchecking block. + DT->changeImmediateDominator(SpeculativeLoop->getExitBlock(), CheckBB); + + Changed = true; + } + + // Clear out data structures to avoid leaking memory... + // TODO: Move to a per-loop structure as a local var? + VMap.clear(); + DefsUsedOutside.clear(); + + return Changed; +} + +bool LoopSpeculativeBoundsCheck::runOnLoop(Loop *L, LPPassManager &LPM) { + if (skipLoop(L)) + return false; + + // Only run on inner loops. + if (!L->empty()) + return false; + + LI = &getAnalysis().getLoopInfo(); + AA = &getAnalysis().getAAResults(); + SE = &getAnalysis().getSE(); + DT = &getAnalysis().getDomTree(); + TLI = &getAnalysis().getTLI(); + LAA = &getAnalysis(); + + return processLoop(L); +} + + +char LoopSpeculativeBoundsCheck::ID = 0; +INITIALIZE_PASS_BEGIN(LoopSpeculativeBoundsCheck, + "loop-speculative-bounds-check", + "Loop Speculative Bounds Checking", false, false) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(LoopSpeculativeBoundsCheck, + "loop-speculative-bounds-check", + "Loop Speculative Bounds Checking", false, false) + +Pass *llvm::createLoopSpeculativeBoundsCheckPass() { + return new LoopSpeculativeBoundsCheck(); +} Index: lib/Transforms/Scalar/LoopVersioningLICM.cpp =================================================================== --- lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -317,8 +317,12 @@ if (AS.isForwardingAliasSet()) continue; // With MustAlias its not worth adding runtime bound check. - if (AS.isMustAlias()) + // If we only alias with ourselves (single pointer in the set), then + // it's safe to proceed. + if (AS.isMustAlias() && (AS.getRefCount() > 1)) { + LLVM_DEBUG(dbgs() << " Alias set with MustAlias attribute\n"); return false; + } Value *SomePtr = AS.begin()->getValue(); bool TypeCheck = true; // Check for Mod & MayAlias Index: lib/Transforms/Scalar/Reassociate.cpp =================================================================== --- lib/Transforms/Scalar/Reassociate.cpp +++ lib/Transforms/Scalar/Reassociate.cpp @@ -2007,6 +2007,59 @@ return NI; } +/// Returns whether the operator BO is an operand to a GEP instruction and +/// is foldable as a SCEV. Reassociation may lose the (non-)wrapping +/// flags causing the SCEV not to fold. During vectorisation the SCEV +/// representation is more important. Instcombine also does reassociation +/// and will likely pick up any remaining cases. +static bool isSCEVableGEPOperand(const BinaryOperator *BO) { + unsigned Opcode = BO->getOpcode(); + if (!(Opcode == Instruction::Add || Opcode == Instruction::Mul)) + return false; + + if (!BO->hasNoSignedWrap()) + return false; + + // Create two buffers and two pointers to these buffers. + // We'll take a double-buffering approach, since we cannot + // change the array while iterating over it. + SmallVector Users, Buffer; + SmallVectorImpl *UsersPtr = &Users; + SmallVectorImpl *BufferPtr = &Buffer; + + // Fill Users buffer with users of the operation + Users.append(BO->user_begin(), BO->user_end()); + + // Add cast operations to 'Buffer' for futher processing + while (!UsersPtr->empty()) { + // Process all users in buffer + for (const Value *U : (*UsersPtr)) { + if (auto *C = dyn_cast(U)) { + BufferPtr->append(C->user_begin(), C->user_end()); + continue; + } + + // Preserving NSW flags is more important than reassociation. + if (auto *BO = dyn_cast(U)) { + unsigned Opc = BO->getOpcode(); + if ((Opc == Instruction::Add) || (Opc == Instruction::Sub)) + if (BO->hasNoSignedWrap()) { + BufferPtr->append(U->user_begin(), U->user_end()); + continue; + } + } + + if (!isa(U)) + return false; + } + // Swap buffers + std::swap(UsersPtr, BufferPtr); + BufferPtr->clear(); + } + + return true; +} + /// Inspect and optimize the given instruction. Note that erasing /// instructions is not allowed. void ReassociatePass::OptimizeInst(Instruction *I) { @@ -2119,6 +2172,11 @@ return; } + // Do not reassociate if this expression will lose nsw flags when + // reassocating, destroying SCEV folding opportunities. + if (isSCEVableGEPOperand(BO)) + return; + // If this is an add tree that is used by a sub instruction, ignore it // until we process the subtract. if (BO->hasOneUse() && BO->getOpcode() == Instruction::Add && Index: lib/Transforms/Scalar/SROA.cpp =================================================================== --- lib/Transforms/Scalar/SROA.cpp +++ lib/Transforms/Scalar/SROA.cpp @@ -2702,8 +2702,8 @@ } /// Compute a vector splat for a given element value. - Value *getVectorSplat(Value *V, unsigned NumElements) { - V = IRB.CreateVectorSplat(NumElements, V, "vsplat"); + Value *getVectorSplat(VectorType *VT, Value *V) { + V = IRB.CreateVectorSplat(VT->getElementCount(), V, "vsplat"); LLVM_DEBUG(dbgs() << " splat: " << *V << "\n"); return V; } @@ -2772,8 +2772,10 @@ Value *Splat = getIntegerSplat(II.getValue(), DL.getTypeSizeInBits(ElementTy) / 8); Splat = convertValue(DL, IRB, Splat, ElementTy); - if (NumElements > 1) - Splat = getVectorSplat(Splat, NumElements); + if (NumElements > 1) { + VectorType *VT = VectorType::get(Splat->getType(), NumElements); + Splat = getVectorSplat(VT, Splat); + } Value *Old = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); @@ -2805,7 +2807,7 @@ V = getIntegerSplat(II.getValue(), DL.getTypeSizeInBits(ScalarTy) / 8); if (VectorType *AllocaVecTy = dyn_cast(AllocaTy)) - V = getVectorSplat(V, AllocaVecTy->getNumElements()); + V = getVectorSplat(AllocaVecTy, V); V = convertValue(DL, IRB, V, AllocaTy); } @@ -2986,6 +2988,7 @@ "copyload"); if (AATags) Load->setAAMetadata(AATags); + Load->copyMetadata(II, LLVMContext::MD_mem_parallel_loop_access); Src = Load; } @@ -3006,6 +3009,9 @@ IRB.CreateAlignedStore(Src, DstPtr, DstAlign, II.isVolatile())); if (AATags) Store->setAAMetadata(AATags); + if (IsWholeAlloca || !IsDest) + Store->copyMetadata(II, LLVMContext::MD_mem_parallel_loop_access); + LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); return !II.isVolatile(); } Index: lib/Transforms/Scalar/Scalar.cpp =================================================================== --- lib/Transforms/Scalar/Scalar.cpp +++ lib/Transforms/Scalar/Scalar.cpp @@ -67,6 +67,8 @@ initializeLoopInterchangePass(Registry); initializeLoopPredicationLegacyPassPass(Registry); initializeLoopRotateLegacyPassPass(Registry); + initializeLoopExprTreeFactoringPassPass(Registry); + initializeLoopSpeculativeBoundsCheckPass(Registry); initializeLoopStrengthReducePass(Registry); initializeLoopRerollPass(Registry); initializeLoopUnrollPass(Registry); @@ -99,9 +101,11 @@ initializePlaceSafepointsPass(Registry); initializeFloat2IntLegacyPassPass(Registry); initializeLoopDistributeLegacyPass(Registry); + initializeSeparateInvariantsFromGepOffsetPass(Registry); initializeLoopLoadEliminationPass(Registry); initializeLoopSimplifyCFGLegacyPassPass(Registry); initializeLoopVersioningPassPass(Registry); + initializeLoopRewriteGEPsPassPass(Registry); initializeEntryExitInstrumenterPass(Registry); initializePostInlineEntryExitInstrumenterPass(Registry); } @@ -174,6 +178,10 @@ unwrap(PM)->add(createLoopIdiomPass()); } +void LLVMAddLoopExprTreeFactoringPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopExprTreeFactoringPass()); +} + void LLVMAddLoopRotatePass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLoopRotatePass()); } @@ -274,3 +282,7 @@ void LLVMAddLowerExpectIntrinsicPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLowerExpectIntrinsicPass()); } + +void LLVMAddLoopRewriteGEPsPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopRewriteGEPsPass()); +} Index: lib/Transforms/Scalar/Scalarizer.cpp =================================================================== --- lib/Transforms/Scalar/Scalarizer.cpp +++ lib/Transforms/Scalar/Scalarizer.cpp @@ -543,7 +543,7 @@ bool Scalarizer::visitGetElementPtrInst(GetElementPtrInst &GEPI) { VectorType *VT = dyn_cast(GEPI.getType()); - if (!VT) + if (!VT || VT->getVectorIsScalable()) return false; IRBuilder<> Builder(&GEPI); @@ -554,7 +554,7 @@ // splat the pointer into a vector value, and scatter that vector. Value *Op0 = GEPI.getOperand(0); if (!Op0->getType()->isVectorTy()) - Op0 = Builder.CreateVectorSplat(NumElems, Op0); + Op0 = Builder.CreateVectorSplat({NumElems, false}, Op0); Scatterer Base = scatter(&GEPI, Op0); SmallVector Ops; @@ -565,7 +565,7 @@ // The indices might be scalars even if it's a vector GEP. In those cases, // splat the scalar into a vector value, and scatter that vector. if (!Op->getType()->isVectorTy()) - Op = Builder.CreateVectorSplat(NumElems, Op); + Op = Builder.CreateVectorSplat({NumElems, false}, Op); Ops[I] = scatter(&GEPI, Op); } @@ -665,6 +665,10 @@ if (!VT) return false; + SmallVector Mask; + if (!SVI.getShuffleMask(Mask)) + return false; + unsigned NumElems = VT->getNumElements(); Scatterer Op0 = scatter(&SVI, SVI.getOperand(0)); Scatterer Op1 = scatter(&SVI, SVI.getOperand(1)); @@ -672,7 +676,7 @@ Res.resize(NumElems); for (unsigned I = 0; I < NumElems; ++I) { - int Selector = SVI.getMaskValue(I); + int Selector = Mask[I]; if (Selector < 0) Res[I] = UndefValue::get(VT->getElementType()); else if (unsigned(Selector) < Op0.size()) Index: lib/Transforms/Scalar/SeparateInvariantsFromGepOffset.cpp =================================================================== --- /dev/null +++ lib/Transforms/Scalar/SeparateInvariantsFromGepOffset.cpp @@ -0,0 +1,279 @@ +//===- SeparateInvariantsFromGepOffset - Improve gep offset expressions----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "separate-gepinv" + +STATISTIC(NumVisited, "Number of loops visited."); +STATISTIC(NumOptimized, "Number of loops optimized."); + +namespace llvm { + void initializeSeparateInvariantsFromGepOffsetPass(PassRegistry &); +} + +namespace { +struct SeparateInvariantsFromGepOffset : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + SeparateInvariantsFromGepOffset() : FunctionPass(ID) { + initializeSeparateInvariantsFromGepOffsetPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + this->F = &F; + LI = &getAnalysis().getLoopInfo(); + TTI = &getAnalysis().getTTI(F); + SE = &getAnalysis().getSE(); + + bool Changed = false; + for (auto I = LI->begin(), IE = LI->end(); I != IE; ++I) + for (auto L = df_begin(*I), LE = df_end(*I); L != LE; ++L) + Changed |= runOnLoop(*L); + + return Changed; + } + + bool runOnLoop(Loop *L); + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequiredID(LoopSimplifyID); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + } + +private: + Function *F; + LoopInfo *LI; + ScalarEvolution *SE; + TargetTransformInfo *TTI; + + bool Optimize_Addressing(BasicBlock *); +}; +} + +char SeparateInvariantsFromGepOffset::ID = 0; +static const char *name = "Separate Invariants From GEP Offset"; +INITIALIZE_PASS_BEGIN(SeparateInvariantsFromGepOffset, DEBUG_TYPE, name, + false, false) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(SeparateInvariantsFromGepOffset, DEBUG_TYPE, name, + false, false) + +namespace llvm { +FunctionPass *createSeparateInvariantsFromGepOffsetPass() { + return new SeparateInvariantsFromGepOffset(); +} +} + +/// Hoist out loop invariants from GEP offset by adding them to Res, +/// and returning an offset expression that does not include the items +/// in Res. +static Value *OptimizeGEPExpr(Value *V, IRBuilder<> *Builder, + ScalarEvolution *SE, Loop *L, + SmallVectorImpl &Res, bool DoIt) { + // Values used in matching + Value *Op, *LHS, *RHS; + + // Constants and loop invariants are leafs. + if (L->isLoopInvariant(V)) { + Res.push_back(V); + return nullptr; + } + + // Match: (sext|zext)(add (nuw|nsw) %lhs, %rhs) + // And look for loop invariants in %lhs and %rhs + if (match(V, m_CombineOr(m_SExt(m_Value(Op)), m_ZExt(m_Value(Op))))) { + // Match the extend with 'add' and no wrapping flags + if (match(V, m_SExt(m_NSWAdd(m_Value(), m_Value()))) || + match(V, m_ZExt(m_NUWAdd(m_Value(), m_Value())))) { + // Push cast through add + Op = OptimizeGEPExpr(Op, Builder, SE, L, Res, DoIt); + + // Result is loop invariant, this node is also loop invariant + if (!Op) { + Res.pop_back(); + Res.push_back(V); + return nullptr; + } + + // Do not make any IR changes in analysis mode + if (!DoIt) + return V; + + // Recreate cast + Instruction::CastOps Opcode = cast(V)->getOpcode(); + Op = Builder->CreateCast(Opcode, Op, V->getType()); + + // Return updated offset + return Op; + } + } + + // Match: add %lhs, %rhs + // And return %lhs or %rhs if the other operand is loop invariant. + if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) { + LHS = OptimizeGEPExpr(LHS, Builder, SE, L, Res, DoIt); + RHS = OptimizeGEPExpr(RHS, Builder, SE, L, Res, DoIt); + + // We lose nsw flags when rewriting the expression + if (LHS && RHS && DoIt) + return Builder->CreateAdd(LHS, RHS); + + // Result is loop invariant, this node is also loop invariant + if (!LHS && !RHS) { + Res.pop_back(); + Res.pop_back(); + Res.push_back(V); + return nullptr; + } + + // One of the operands is not loop invariant + return LHS ? LHS : RHS; + } + + // Non-loopinvariant leaf node + return V; +} + +// Compare by complexity, Constants first +static struct ComplexityCompare { + bool operator() (Value *A, Value *B) { return isa(A); }; +} ComplexityCompareObj; + +static Instruction *OptimizeGEP(GetElementPtrInst *Gep, IRBuilder<> *Builder, + ScalarEvolution *SE, Loop *L) { + // See whether transform is beneficial + SmallVector Factors; + Value *NewOffset = OptimizeGEPExpr(Gep->getOperand(1), Builder, SE, L, + Factors, false); + + // If not, skip + if (Factors.size() == 0 || NewOffset == nullptr) + return nullptr; + + // Do the transform! + Factors.clear(); + NewOffset = OptimizeGEPExpr(Gep->getOperand(1), Builder, SE, L, Factors, + true); + + // Sort by complexity + std::stable_sort(Factors.begin(), Factors.end(), ComplexityCompareObj); + + // Add loop invariants to pointer + // Note: we'll lose the inbounds + Value *Base = Gep->getPointerOperand(); + for (Value *V : Factors) + Base = Builder->CreateGEP(Base, V); + + // Add the updated offset (without invariants) + if (NewOffset) + Base = Builder->CreateGEP(Base, NewOffset); + + return cast(Base); +} + +// Transform: +// %mul = mul i64 %0, %1 +// %off0 = add i64 %mul, %index +// %off1 = add i64 %off0, 1 +// %base = getelementptr inbounds float, float* %ptr, i64 %off1 +// Into: +// %mul = mul i64 %0, %1 +// %base0 = getelementptr inbounds %ptr, %mul +// %off0 = add i64 %index, 1 +// %base1 = getelementptr inbounds %base0, %off0 +// TODO: Perhaps only do this for types/targets (e.g. SVE) that need this? +bool SeparateInvariantsFromGepOffset::Optimize_Addressing(BasicBlock *BB) { + IRBuilder<> Builder(BB); + + // Look for 'getelementptr %liv, %offset' + // where %liv is loop invariant, %offset is only used in + // this gep, we optimize the expression. + bool Changed = false; + for (auto I = BB->begin(), E = BB->end(); I != E; ++I) { + auto *Gep = dyn_cast(I); + if (!Gep || Gep->use_empty()) + continue; + + Loop *L = LI->getLoopFor(BB); + if (!L) + continue; + + if (Gep->getNumIndices() != 1) + continue; + + if (!L->isLoopInvariant(Gep->getPointerOperand())) + continue; + + if (L->isLoopInvariant(Gep->getOperand(1))) + continue; + + if (Gep->getOperand(1)->getType()->isVectorTy()) + continue; + + if (!Gep->getOperand(1)->hasOneUse()) + continue; + + Builder.SetInsertPoint(BB, I); + Instruction *Rtrn = OptimizeGEP(Gep, &Builder, SE, L); + if (!Rtrn) + continue; + + LLVM_DEBUG(dbgs() << "Replacing " << Gep << "\nWith: " << Rtrn << "\n"); + I->replaceAllUsesWith(Rtrn); + Changed = true; + } + + return Changed; +} + +bool SeparateInvariantsFromGepOffset::runOnLoop(Loop *L) { + bool LoopOptimized = false; + + if (auto PreHeader = L->getLoopPreheader()) + LoopOptimized |= Optimize_Addressing(PreHeader); + + for (auto BB = L->block_begin(), BE = L->block_end(); BB != BE; ++BB) + LoopOptimized |= Optimize_Addressing(*BB); + + ++NumVisited; + if (LoopOptimized) + ++NumOptimized; + + return LoopOptimized; +} Index: lib/Transforms/Scalar/SimplifyCFGPass.cpp =================================================================== --- lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -49,6 +49,11 @@ "bonus-inst-threshold", cl::Hidden, cl::init(1), cl::desc("Control the number of bonus instructions (default = 1)")); +static cl::opt UserSwitchRemovalThreshold( + "switch-removal-threshold", cl::Hidden, cl::init(0), + cl::desc("Set the threshold for the number of switch cases where we" + "convert switch blocks to branches and compares")); + static cl::opt UserKeepLoops( "keep-loops", cl::Hidden, cl::init(true), cl::desc("Preserve canonical loop structure (default = true)")); @@ -235,7 +240,8 @@ CFGSimplifyPass(unsigned Threshold = 1, bool ForwardSwitchCond = false, bool ConvertSwitch = false, bool KeepLoops = true, bool SinkCommon = false, - std::function Ftor = nullptr) + std::function Ftor = nullptr, + unsigned SwitchRemovalThreshold = 0) : FunctionPass(ID), PredicateFtor(std::move(Ftor)) { initializeCFGSimplifyPassPass(*PassRegistry::getPassRegistry()); @@ -259,6 +265,11 @@ Options.SinkCommonInsts = UserSinkCommonInsts.getNumOccurrences() ? UserSinkCommonInsts : SinkCommon; + + Options.SwitchRemovalThreshold = + UserSwitchRemovalThreshold.getNumOccurrences() + ? UserSwitchRemovalThreshold + : SwitchRemovalThreshold; } bool runOnFunction(Function &F) override { @@ -290,7 +301,9 @@ llvm::createCFGSimplificationPass(unsigned Threshold, bool ForwardSwitchCond, bool ConvertSwitch, bool KeepLoops, bool SinkCommon, - std::function Ftor) { + std::function Ftor, + unsigned SwitchRemovalThreshold) { return new CFGSimplifyPass(Threshold, ForwardSwitchCond, ConvertSwitch, - KeepLoops, SinkCommon, std::move(Ftor)); + KeepLoops, SinkCommon, std::move(Ftor), + SwitchRemovalThreshold); } Index: lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- lib/Transforms/Utils/CodeExtractor.cpp +++ lib/Transforms/Utils/CodeExtractor.cpp @@ -925,8 +925,16 @@ auto *OutI = dyn_cast(outputs[i]); if (!OutI) continue; + // Find proper insertion point. - Instruction *InsertPt = OutI->getNextNode(); + Instruction *InsertPt; + // In case OutI is an invoke, we insert the store at the beginning in the + // 'normal destination' BB. Otherwise we insert the store right after OutI. + if (auto *InvokeI = dyn_cast(OutI)) + InsertPt = InvokeI->getNormalDest()->getFirstNonPHI(); + else + InsertPt = OutI->getNextNode(); + // Let's assume that there is no other guy interleave non-PHI in PHIs. if (isa(InsertPt)) InsertPt = InsertPt->getParent()->getFirstNonPHI(); Index: lib/Transforms/Utils/LoopRotationUtils.cpp =================================================================== --- lib/Transforms/Utils/LoopRotationUtils.cpp +++ lib/Transforms/Utils/LoopRotationUtils.cpp @@ -331,6 +331,19 @@ continue; } + // Duplicating lifetimes will create invalid lifetimes. Instead we can + // extend the lifetime of the memory location by hoisting the start to + // the loop entry and sinking ends to the loop exit. + if (auto *IntrInst = dyn_cast(Inst)) { + if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) { + Inst->moveBefore(LoopEntryBranch); + continue; + } else if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) { + Inst->moveBefore(Exit->getFirstNonPHI()); + continue; + } + } + // Otherwise, create a duplicate of the instruction. Instruction *C = Inst->clone(); Index: lib/Transforms/Utils/LoopUtils.cpp =================================================================== --- lib/Transforms/Utils/LoopUtils.cpp +++ lib/Transforms/Utils/LoopUtils.cpp @@ -41,6 +41,21 @@ #define DEBUG_TYPE "loop-utils" +static const char* RecurrenceNames[] = { + "RK_NoReduction", + "RK_IntegerAdd", + "RK_IntegerMult", + "RK_IntegerOr", + "RK_IntegerAnd", + "RK_IntegerXor", + "RK_IntegerMinMax", + "RK_FloatAdd", + "RK_FloatMult", + "RK_FloatMinMax", + "RK_ConstSelectICmp", + "RK_ConstSelectFCmp" +}; + bool RecurrenceDescriptor::areAllUsesIn(Instruction *I, SmallPtrSetImpl &Set) { for (User::op_iterator Use = I->op_begin(), E = I->op_end(); Use != E; ++Use) @@ -59,6 +74,8 @@ case RK_IntegerAnd: case RK_IntegerXor: case RK_IntegerMinMax: + case RK_ConstSelectICmp: + case RK_ConstSelectFCmp: return true; } return false; @@ -110,87 +127,145 @@ /// Compute the minimal bit width needed to represent a reduction whose exit /// instruction is given by Exit. -static std::pair computeRecurrenceType(Instruction *Exit, +static std::pair computeRecurrenceType( RecurrenceDescriptor::ExitInstrList &Exits, DemandedBits *DB, AssumptionCache *AC, DominatorTree *DT) { bool IsSigned = false; - const DataLayout &DL = Exit->getModule()->getDataLayout(); - uint64_t MaxBitWidth = DL.getTypeSizeInBits(Exit->getType()); + uint64_t MaxBW = 0; + for (Instruction *Exit : Exits) { + const DataLayout &DL = Exit->getModule()->getDataLayout(); + uint64_t MaxBitWidth = DL.getTypeSizeInBits(Exit->getType()); - if (DB) { - // Use the demanded bits analysis to determine the bits that are live out - // of the exit instruction, rounding up to the nearest power of two. If the - // use of demanded bits results in a smaller bit width, we know the value - // must be positive (i.e., IsSigned = false), because if this were not the - // case, the sign bit would have been demanded. - auto Mask = DB->getDemandedBits(Exit); - MaxBitWidth = Mask.getBitWidth() - Mask.countLeadingZeros(); - } - - if (MaxBitWidth == DL.getTypeSizeInBits(Exit->getType()) && AC && DT) { - // If demanded bits wasn't able to limit the bit width, we can try to use - // value tracking instead. This can be the case, for example, if the value - // may be negative. - auto NumSignBits = ComputeNumSignBits(Exit, DL, 0, AC, nullptr, DT); - auto NumTypeBits = DL.getTypeSizeInBits(Exit->getType()); - MaxBitWidth = NumTypeBits - NumSignBits; - KnownBits Bits = computeKnownBits(Exit, DL); - if (!Bits.isNonNegative()) { - // If the value is not known to be non-negative, we set IsSigned to true, - // meaning that we will use sext instructions instead of zext - // instructions to restore the original type. - IsSigned = true; - if (!Bits.isNegative()) - // If the value is not known to be negative, we don't known what the - // upper bit is, and therefore, we don't know what kind of extend we - // will need. In this case, just increase the bit width by one bit and - // use sext. - ++MaxBitWidth; + if (DB) { + // Use the demanded bits analysis to determine the bits that are live out + // of the exit instruction, rounding up to the nearest power of two. If the + // use of demanded bits results in a smaller bit width, we know the value + // must be positive (i.e., IsSigned = false), because if this were not the + // case, the sign bit would have been demanded. + auto Mask = DB->getDemandedBits(Exit); + MaxBitWidth = Mask.getBitWidth() - Mask.countLeadingZeros(); } - } - if (!isPowerOf2_64(MaxBitWidth)) - MaxBitWidth = NextPowerOf2(MaxBitWidth); - return std::make_pair(Type::getIntNTy(Exit->getContext(), MaxBitWidth), + if (MaxBitWidth == DL.getTypeSizeInBits(Exit->getType()) && AC && DT) { + // If demanded bits wasn't able to limit the bit width, we can try to use + // value tracking instead. This can be the case, for example, if the value + // may be negative. + auto NumSignBits = ComputeNumSignBits(Exit, DL, 0, AC, nullptr, DT); + auto NumTypeBits = DL.getTypeSizeInBits(Exit->getType()); + MaxBitWidth = NumTypeBits - NumSignBits; + KnownBits Bits = computeKnownBits(Exit, DL); + if (!Bits.isNonNegative()) { + // If the value is not known to be non-negative, we set IsSigned to true, + // meaning that we will use sext instructions instead of zext + // instructions to restore the original type. + IsSigned = true; + if (!Bits.isNegative()) + // If the value is not known to be negative, we don't known what the + // upper bit is, and therefore, we don't know what kind of extend we + // will need. In this case, just increase the bit width by one bit and + // use sext. + ++MaxBitWidth; + } + } + if (!isPowerOf2_64(MaxBitWidth)) + MaxBitWidth = NextPowerOf2(MaxBitWidth); + MaxBW = std::max(MaxBW, MaxBitWidth); + } + + return std::make_pair(Type::getIntNTy(Exits[0]->getContext(), MaxBW), IsSigned); } /// Collect cast instructions that can be ignored in the vectorizer's cost /// model, given a reduction exit value and the minimal type in which the /// reduction can be represented. -static void collectCastsToIgnore(Loop *TheLoop, Instruction *Exit, +static void collectCastsToIgnore(Loop *TheLoop, + RecurrenceDescriptor::ExitInstrList &Exits, Type *RecurrenceType, SmallPtrSetImpl &Casts) { - SmallVector Worklist; - SmallPtrSet Visited; - Worklist.push_back(Exit); + for (auto *Exit : Exits) { + SmallVector Worklist; + SmallPtrSet Visited; + Worklist.push_back(Exit); - while (!Worklist.empty()) { - Instruction *Val = Worklist.pop_back_val(); - Visited.insert(Val); - if (auto *Cast = dyn_cast(Val)) - if (Cast->getSrcTy() == RecurrenceType) { - // If the source type of a cast instruction is equal to the recurrence - // type, it will be eliminated, and should be ignored in the vectorizer - // cost model. - Casts.insert(Cast); - continue; - } + while (!Worklist.empty()) { + Instruction *Val = Worklist.pop_back_val(); + Visited.insert(Val); + if (auto *Cast = dyn_cast(Val)) + if (Cast->getSrcTy() == RecurrenceType) { + // If the source type of a cast instruction is equal to the recurrence + // type, it will be eliminated, and should be ignored in the vectorizer + // cost model. + Casts.insert(Cast); + continue; + } - // Add all operands to the work list if they are loop-varying values that - // we haven't yet visited. - for (Value *O : cast(Val)->operands()) - if (auto *I = dyn_cast(O)) - if (TheLoop->contains(I) && !Visited.count(I)) - Worklist.push_back(I); + // Add all operands to the work list if they are loop-varying values that + // we haven't yet visited. + for (Value *O : cast(Val)->operands()) + if (auto *I = dyn_cast(O)) + if (TheLoop->contains(I) && !Visited.count(I)) + Worklist.push_back(I); + } } } +// Check if a given Phi node can be recognized as an ordered reduction for +// vectorizing floating point operations without unsafe math. +static bool +checkOrderedReduction(RecurrenceDescriptor::RecurrenceKind Kind, + RecurrenceDescriptor::ExitInstrList &ExitInstructions, + PHINode *Phi) { + // Currently only 'fadd' is supported. + if (Kind != RecurrenceDescriptor::RK_FloatAdd) + return false; + + if (ExitInstructions.size() != 1) + return false; + + auto EI = ExitInstructions[0]; + bool IsOrdered = EI->getOpcode() == Instruction::FAdd && + !cast(EI)->isFast(); + + // If this comes from a PHI node, look through it + if (auto EIP = dyn_cast(EI)) { + if (EIP->getNumIncomingValues() != 2) + return false; + + auto ChainVal = EIP->getIncomingValue(0) == Phi + ? EIP->getIncomingValue(1) + : EIP->getIncomingValue(0); + + if (!isa(ChainVal)) + return false; + + EI = cast(ChainVal); + IsOrdered = EI->getOpcode() == Instruction::FAdd && + !cast(EI)->isFast(); + } + + // The only pattern accepted is the one in which the reduction PHI is used + // as one of the operands of the exit istruction. + auto LHS = EI->getOperand(0); + auto RHS = EI->getOperand(1); + IsOrdered = IsOrdered && ((LHS == Phi) || (RHS == Phi)); + + if (!IsOrdered) + return false; + + LLVM_DEBUG(dbgs() << "LU: Found an ordered reduction: Phi: " + << *Phi << ", ExitInst: " << *EI << "\n"); + + return true; +} + bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind, Loop *TheLoop, bool HasFunNoNaNAttr, + ScalarEvolution *SE, RecurrenceDescriptor &RedDes, + bool AllowMultipleExits, DemandedBits *DB, AssumptionCache *AC, DominatorTree *DT) { @@ -209,7 +284,8 @@ // We only allow for a single reduction value to be used outside the loop. // This includes users of the reduction, variables (which form a cycle // which ends in the phi node). - Instruction *ExitInstruction = nullptr; + ExitInstrList ExitInstructions; + StoreInst* IntermediateStore = nullptr; // Indicates that we found a reduction operation in our scan. bool FoundReduxOp = false; @@ -271,47 +347,93 @@ Instruction *Cur = Worklist.back(); Worklist.pop_back(); + // Store instructions are allowed iff it is the store of the reduction + // value to the same loop uniform memory location. + if (auto *SI = dyn_cast(Cur)) { + const SCEV *PtrScev = SE->getSCEV(SI->getPointerOperand()); + // Check it is the same address as previous stores + if (IntermediateStore) { + const SCEV *OtherScev = + SE->getSCEV(IntermediateStore->getPointerOperand()); + + if (OtherScev != PtrScev) { + LLVM_DEBUG(dbgs() << "Storing reduction value to different addresses " << + "inside the loop: " << *SI->getPointerOperand() << " and " << + *IntermediateStore->getPointerOperand() << '\n'); + return false; + } + } + + // Check the pointer is loop invariant + if (!SE->isLoopInvariant(PtrScev, TheLoop)) { + LLVM_DEBUG(dbgs() << "Storing reduction value to non-uniform address " << + "inside the loop: " << *SI->getPointerOperand() << '\n'); + return false; + } + + // IntermediateStore is always the last store in the loop. + IntermediateStore = SI; + continue; + } + // No Users. // If the instruction has no users then this is a broken chain and can't be // a reduction variable. - if (Cur->use_empty()) + if (Cur->use_empty()) { + LLVM_DEBUG(dbgs() << "LU: Instruction has no users: " << *Cur << "\n"); return false; + } bool IsAPhi = isa(Cur); // A header PHI use other than the original PHI. - if (Cur != Phi && IsAPhi && Cur->getParent() == Phi->getParent()) + if (Cur != Phi && IsAPhi && Cur->getParent() == Phi->getParent()) { + LLVM_DEBUG(dbgs() << "LU: loop header phi that isn't the original phi: " << + *Cur << "\n"); return false; + } // Reductions of instructions such as Div, and Sub is only possible if the // LHS is the reduction variable. if (!Cur->isCommutative() && !IsAPhi && !isa(Cur) && !isa(Cur) && !isa(Cur) && - !VisitedInsts.count(dyn_cast(Cur->getOperand(0)))) + !VisitedInsts.count(dyn_cast(Cur->getOperand(0)))) { + LLVM_DEBUG(dbgs() << "LU: LHS isn't the reduction var: " << *Cur << "\n"); return false; + } // Any reduction instruction must be of one of the allowed kinds. We ignore // the starting value (the Phi or an AND instruction if the Phi has been // type-promoted). if (Cur != Start) { ReduxDesc = isRecurrenceInstr(Cur, Kind, ReduxDesc, HasFunNoNaNAttr); - if (!ReduxDesc.isRecurrence()) + if (!ReduxDesc.isRecurrence()) { + LLVM_DEBUG(dbgs() << "LU: Not an allowed instruction: " << + *Cur << " for recurrence type " << RecurrenceNames[Kind] << "\n"); return false; + } } // A reduction operation must only have one use of the reduction value. if (!IsAPhi && Kind != RK_IntegerMinMax && Kind != RK_FloatMinMax && - hasMultipleUsesOf(Cur, VisitedInsts)) + Kind != RK_ConstSelectICmp && Kind != RK_ConstSelectFCmp && + hasMultipleUsesOf(Cur, VisitedInsts)) { + LLVM_DEBUG(dbgs() << "LU: Too many uses of reduction value: " << *Cur << "\n"); return false; + } // All inputs to a PHI node must be a reduction value. - if (IsAPhi && Cur != Phi && !areAllUsesIn(Cur, VisitedInsts)) + if (IsAPhi && Cur != Phi && !areAllUsesIn(Cur, VisitedInsts)) { + LLVM_DEBUG(dbgs() << "LU: All input must be a reduction value: " << + *Cur << "\n"); return false; + } - if (Kind == RK_IntegerMinMax && + if ((Kind == RK_IntegerMinMax || Kind == RK_ConstSelectICmp) && (isa(Cur) || isa(Cur))) ++NumCmpSelectPatternInst; - if (Kind == RK_FloatMinMax && (isa(Cur) || isa(Cur))) + if ((Kind == RK_FloatMinMax || Kind == RK_ConstSelectFCmp) && + (isa(Cur) || isa(Cur))) ++NumCmpSelectPatternInst; // Check whether we found a reduction operator. @@ -322,6 +444,25 @@ // nodes once we get to them. SmallVector NonPHIs; SmallVector PHIs; + // We really want exit analysis available to all of these instead of + // doing it in all the transform passes or utilities. + // + // TODO: Exit analysis can come first -- just check that the ptr + // op for loadcmp is SAR, not necessarily LI_Base+Induction + // That way we can assume the cast is safe. (or a dedicated + // pass makes it irrelevant) + // + // Simple assumption for now -- EE is not the latch. + SmallVector ExitingBlocks; + // TheLoop->dump(); + TheLoop->getExitingBlocks(ExitingBlocks); + // Grab expected exit instruction. + BasicBlock *EEBlock = nullptr; + for (auto *EB: ExitingBlocks) { + if (EB != TheLoop->getLoopLatch()) + EEBlock = EB; + } + for (User *U : Cur->users()) { Instruction *UI = cast(U); @@ -330,23 +471,69 @@ if (!TheLoop->contains(Parent)) { // If we already know this instruction is used externally, move on to // the next user. - if (ExitInstruction == Cur) + if (is_contained(ExitInstructions, Cur)) continue; - // Exit if you find multiple values used outside or if the header phi - // node is being used. In this case the user uses the value of the - // previous iteration, in which case we would loose "VF-1" iterations of - // the reduction operation if we vectorize. - if (ExitInstruction != nullptr || Cur == Phi) + // We want to allow: + // %sum = phi(0, ph, %sum.next, loopbody) + // : + // + // : + // %sum.next = .. + if (Cur == Phi) { + if (!AllowMultipleExits) { + LLVM_DEBUG(dbgs() << "LU: Use of header phi: " << *Cur << "\n"); + return false; + } + } + + // Check if the operand is from an early exit + if (AllowMultipleExits && EEBlock) { + // The user should either be an escapee merge block OR a PHI node + // that is used in a store to the intermediate store address. + if (isa(UI)) { + // If it has more than 1 operand, it should be an escapee merge block. + if (UI->getNumOperands() > 1) { + // Test that each exit instruction is a PHI operand + bool Fail = false; + for(auto *EI : ExitInstructions) { + // if not represented in PHI, its not OK + if (!is_contained(UI->operands(), EI)) { + Fail = true; + break; + } + } + if (!Fail) { + ExitInstructions.push_back(Cur); + continue; + } + } + // Only 1 operand, the PHI should be used in a store (check below) + UI = dyn_cast(*(UI->user_begin())); + } + + // This is an exit instruction iff the store is an intermediate store + if (isa(UI)) { + ExitInstructions.push_back(Cur); + NonPHIs.push_back(UI); + continue; + } + + LLVM_DEBUG(dbgs() << "LU: Early exit value is not an escapee value\n"); return false; + } // The instruction used by an outside user must be the last instruction // before we feed back to the reduction phi. Otherwise, we loose VF-1 // operations on the value. - if (!is_contained(Phi->operands(), Cur)) + if (!AllowMultipleExits && !is_contained(Phi->operands(), Cur)) { + LLVM_DEBUG(dbgs() << "LU: Outside user does not use last instruction in " + << "reduction chain: " << *Cur << "\n"); + LLVM_DEBUG(dbgs() << "LU: Phi checked is: " << *Phi << "\n"); return false; + } - ExitInstruction = Cur; + ExitInstructions.push_back(Cur); continue; } @@ -357,12 +544,18 @@ if (VisitedInsts.insert(UI).second) { if (isa(UI)) PHIs.push_back(UI); - else + else { + if (auto *SI = dyn_cast(UI)) { + if (SI->getValueOperand() == Cur) + NonPHIs.push_back(UI); + } else NonPHIs.push_back(UI); + } } else if (!isa(UI) && ((!isa(UI) && !isa(UI) && !isa(UI)) || - !isMinMaxSelectCmpPattern(UI, IgnoredVal).isRecurrence())) + (!isMinMaxSelectCmpPattern(UI, IgnoredVal).isRecurrence() && + !isConstSelectCmpPattern(UI, IgnoredVal).isRecurrence()))) return false; // Remember that we completed the cycle. @@ -379,9 +572,38 @@ NumCmpSelectPatternInst != 2) return false; - if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction) + if ((Kind == RK_ConstSelectICmp || Kind == RK_ConstSelectFCmp) && + NumCmpSelectPatternInst != 1) return false; + // If there is an intermediate store, it must store the last reduction value. + if (!ExitInstructions.empty() && IntermediateStore) { + if (!AllowMultipleExits && + IntermediateStore->getValueOperand() != ExitInstructions.back()) { + LLVM_DEBUG(dbgs() << "LU: Last store Instruction of reduction value " << + "does not store last calculated value of the reduction: " << + *IntermediateStore << '\n'); + return false; + } + } + + // If all uses are inside the loop (intermediate stores), then the + // reduction value after the loop will be the one used in the last store. + if (ExitInstructions.empty() && IntermediateStore) { + auto *ExitValue = + cast(IntermediateStore->getValueOperand()); + ExitInstructions.push_back(ExitValue); + } + + if (!FoundStartPHI || !FoundReduxOp || ExitInstructions.empty()) { + LLVM_DEBUG(dbgs() << "LU: Did not find one of: StartPHI: " << + FoundStartPHI << ", ReduxOp: " << FoundReduxOp << + ", ExitInstruction Count: " << ExitInstructions.size() << "\n"); + return false; + } + + // Special handling for ordered reductions + const bool IsOrdered = checkOrderedReduction(Kind, ExitInstructions, Phi); if (Start != Phi) { // If the starting value is not the same as the phi node, we speculatively // looked through an 'and' instruction when evaluating a potential @@ -409,7 +631,7 @@ // to begin with. Type *ComputedType; std::tie(ComputedType, IsSigned) = - computeRecurrenceType(ExitInstruction, DB, AC, DT); + computeRecurrenceType(ExitInstructions, DB, AC, DT); if (ComputedType != RecurrenceType) return false; @@ -422,7 +644,7 @@ // instructions that are a part of the reduction. The vectorizer cost // model could then apply the recurrence type to these instructions, // without needing a white list of instructions to ignore. - collectCastsToIgnore(TheLoop, ExitInstruction, RecurrenceType, CastInsts); + collectCastsToIgnore(TheLoop, ExitInstructions, RecurrenceType, CastInsts); } // We found a reduction var if we have reached the original phi node and we @@ -432,14 +654,75 @@ // is saved as part of the RecurrenceDescriptor. // Save the description of this reduction variable. - RecurrenceDescriptor RD( - RdxStart, ExitInstruction, Kind, ReduxDesc.getMinMaxKind(), - ReduxDesc.getUnsafeAlgebraInst(), RecurrenceType, IsSigned, CastInsts); + RecurrenceDescriptor RD(RdxStart, ExitInstructions, IntermediateStore, Kind, + ReduxDesc.getMinMaxKind(), + ReduxDesc.getUnsafeAlgebraInst(), RecurrenceType, + IsSigned, CastInsts, IsOrdered); RedDes = RD; return true; } +RecurrenceDescriptor::InstDesc +RecurrenceDescriptor::isConstSelectCmpPattern(Instruction *I, InstDesc &Prev) { + assert((isa(I) || isa(I) || isa(I)) && + "Expect a select instruction"); + Instruction *Cmp = nullptr; + SelectInst *Select = nullptr; + + // We must handle the select(cmp()) as a single instruction. Advance to the + // select. + if ((Cmp = dyn_cast(I)) || (Cmp = dyn_cast(I))) { + if (!Cmp->hasOneUse() || !(Select = dyn_cast(*I->user_begin()))) + return InstDesc(false, I); + return InstDesc(Select, Prev.getMinMaxKind()); + } + + // Only handle single use cases for now. + if (!(Select = dyn_cast(I))) + return InstDesc(false, I); + if (!(Cmp = dyn_cast(I->getOperand(0))) && + !(Cmp = dyn_cast(I->getOperand(0)))) + return InstDesc(false, I); + if (!Cmp->hasOneUse()) + return InstDesc(false, I); + + int64_t SelectVal; + int PhiIndex; + if (ConstantInt *Tmp = dyn_cast(Select->getOperand(2))) { + SelectVal = Tmp->getSExtValue(); + PhiIndex = 1; + } else if (ConstantInt *Tmp = dyn_cast(Select->getOperand(1))) { + SelectVal = Tmp->getSExtValue(); + PhiIndex = 2; + } else + return InstDesc(false, I); + + Value *RdxStart; + if (PHINode *Tmp = dyn_cast(Select->getOperand(PhiIndex))) { + if (ConstantInt *Tmp2 = dyn_cast(Tmp->getOperand(0))) + RdxStart = Tmp2; + else if (ConstantInt *Tmp2 = dyn_cast(Tmp->getOperand(1))) + RdxStart = Tmp2; + else + return InstDesc(false, I); + } else + return InstDesc(false, I); + + int64_t RdxStartVal; + if (ConstantInt *Tmp = dyn_cast(RdxStart)) + RdxStartVal = Tmp->getSExtValue(); + else + return InstDesc(false, I); + + // It doesn't actually matter if the reduction variable is of a different + // signedness to the integer comparison, or even if it is a floating + // point comparison. All that matters is that we generate a final reduction + // that ensures the result contains the initial reduction value if the select + // instruction always chooses the value from the previous iteration. + return InstDesc(Select, SelectVal < RdxStartVal ? MRK_SIntMin : MRK_SIntMax); +} + /// Returns true if the instruction is a Select(ICmp(X, Y), X, Y) instruction /// pattern corresponding to a min(X, Y) or max(X, Y). RecurrenceDescriptor::InstDesc @@ -494,9 +777,18 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurrenceKind Kind, InstDesc &Prev, bool HasFunNoNaNAttr) { - bool FP = I->getType()->isFloatingPointTy(); + auto Opc = I->getOpcode(); + + // Explicitly note lack of fastmath when debugging vectorization + if ((Opc == Instruction::FMul || + Opc == Instruction::FAdd || + Opc == Instruction::FSub) && !I->isFast()) + LLVM_DEBUG(dbgs() << + "LU: Warning! Fastmath not set on fp reduction instruction: " + << *I << "\n"); + Instruction *UAI = Prev.getUnsafeAlgebraInst(); - if (!UAI && FP && !I->isFast()) + if (!UAI && isa(I) && !I->isFast()) UAI = I; // Found an unsafe (unvectorizable) algebra instruction. switch (I->getOpcode()) { @@ -523,6 +815,10 @@ case Instruction::FCmp: case Instruction::ICmp: case Instruction::Select: + if (Kind == RK_ConstSelectICmp || + Kind == RK_ConstSelectFCmp) + return isConstSelectCmpPattern(I, Prev); + // else fallthrough ... if (Kind != RK_IntegerMinMax && (!HasFunNoNaNAttr || Kind != RK_FloatMinMax)) return InstDesc(false, I); @@ -543,63 +839,77 @@ return false; } + bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, + ScalarEvolution *SE, RecurrenceDescriptor &RedDes, + bool AllowMultipleExits, DemandedBits *DB, AssumptionCache *AC, DominatorTree *DT) { - BasicBlock *Header = TheLoop->getHeader(); Function &F = *Header->getParent(); bool HasFunNoNaNAttr = F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; - if (AddReductionVar(Phi, RK_IntegerAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { + if (AddReductionVar(Phi, RK_IntegerAdd, TheLoop, HasFunNoNaNAttr, SE, RedDes, + AllowMultipleExits, DB, AC, DT)) { LLVM_DEBUG(dbgs() << "Found an ADD reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerMult, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { + if (AddReductionVar(Phi, RK_IntegerMult, TheLoop, HasFunNoNaNAttr, SE, RedDes, + AllowMultipleExits, DB, AC, DT)) { LLVM_DEBUG(dbgs() << "Found a MUL reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerOr, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { + if (AddReductionVar(Phi, RK_IntegerOr, TheLoop, HasFunNoNaNAttr, SE, RedDes, + AllowMultipleExits, DB, AC, DT)) { LLVM_DEBUG(dbgs() << "Found an OR reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerAnd, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { + if (AddReductionVar(Phi, RK_IntegerAnd, TheLoop, HasFunNoNaNAttr, SE, RedDes, + AllowMultipleExits, DB, AC, DT)) { LLVM_DEBUG(dbgs() << "Found an AND reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerXor, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { + if (AddReductionVar(Phi, RK_IntegerXor, TheLoop, HasFunNoNaNAttr, SE, RedDes, + AllowMultipleExits, DB, AC, DT)) { LLVM_DEBUG(dbgs() << "Found a XOR reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerMinMax, TheLoop, HasFunNoNaNAttr, RedDes, - DB, AC, DT)) { + if (AddReductionVar(Phi, RK_IntegerMinMax, TheLoop, HasFunNoNaNAttr, + SE, RedDes, AllowMultipleExits, DB, AC, DT)) { LLVM_DEBUG(dbgs() << "Found a MINMAX reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_FloatMult, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { + if (AddReductionVar(Phi, RK_ConstSelectICmp, TheLoop, HasFunNoNaNAttr, + SE, RedDes, AllowMultipleExits, DB, AC, DT)) { + LLVM_DEBUG(dbgs() << "Found a integer conditional select reduction PHI." + << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_FloatMult, TheLoop, HasFunNoNaNAttr, SE, RedDes, + AllowMultipleExits, DB, AC, DT)) { LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_FloatAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { + if (AddReductionVar(Phi, RK_FloatAdd, TheLoop, HasFunNoNaNAttr, SE, RedDes, + AllowMultipleExits, DB, AC, DT)) { LLVM_DEBUG(dbgs() << "Found an FAdd reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_FloatMinMax, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { - LLVM_DEBUG(dbgs() << "Found an float MINMAX reduction PHI." << *Phi - << "\n"); + if (AddReductionVar(Phi, RK_FloatMinMax, TheLoop, HasFunNoNaNAttr, SE, RedDes, + AllowMultipleExits, DB, AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an float MINMAX reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_ConstSelectFCmp, TheLoop, HasFunNoNaNAttr, + SE, RedDes, AllowMultipleExits, DB, AC, DT)) { + LLVM_DEBUG(dbgs() << "Found a float/integer conditional select reduction" + << " PHI." << *Phi << "\n"); return true; } // Not a reduction of known type. + LLVM_DEBUG(dbgs() << "Not a known reduction type: " << *Phi << "\n"); return false; } @@ -659,6 +969,7 @@ /// the operation K. Constant *RecurrenceDescriptor::getRecurrenceIdentity(RecurrenceKind K, Type *Tp) { + assert(Tp && "Missing type"); switch (K) { case RK_IntegerXor: case RK_IntegerAdd: @@ -700,8 +1011,10 @@ case RK_FloatAdd: return Instruction::FAdd; case RK_IntegerMinMax: + case RK_ConstSelectICmp: return Instruction::ICmp; case RK_FloatMinMax: + case RK_ConstSelectFCmp: return Instruction::FCmp; default: llvm_unreachable("Unknown recurrence operation"); @@ -846,6 +1159,11 @@ InductionBinOp->getOpcode() == Instruction::FSub) && "Original bin op should be defined for FP induction"); + if (!Index->getType()->isFloatingPointTy()) { + // We need to do some conversion. + Index = B.CreateUIToFP(Index, StartValue->getType()); + } + Value *StepValue = cast(Step)->getValue(); // Floating point operations had to be 'fast' to enable the induction. @@ -1712,6 +2030,15 @@ Flags.IsSigned = (MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_SIntMin); return createSimpleTargetReduction(B, TTI, Instruction::ICmp, Src, Flags); } + case RD::RK_ConstSelectICmp: + case RD::RK_ConstSelectFCmp: { + RD::MinMaxRecurrenceKind MMKind = Desc.getMinMaxRecurrenceKind(); + int Opc = RecKind == RD::RK_ConstSelectICmp ? + Instruction::ICmp : Instruction::FCmp; + Flags.IsMaxOp = MMKind == RD::MRK_SIntMax; + Flags.IsSigned = true; + return createSimpleTargetReduction(B, TTI, Opc, Src, Flags); + } case RD::RK_FloatMinMax: { Flags.IsMaxOp = Desc.getMinMaxRecurrenceKind() == RD::MRK_FloatMax; return createSimpleTargetReduction(B, TTI, Instruction::FCmp, Src, Flags); @@ -1721,6 +2048,37 @@ } } +Value *llvm::createOrderedReduction(IRBuilder<> &Builder, + RecurrenceDescriptor &Desc, Value *Src, + Value *Start, Value *Predicate) { + assert(Desc.isOrdered() && "Recurrence must be an ordered kind"); + auto Kind = Desc.getRecurrenceKind(); + assert((Kind == RecurrenceDescriptor::RK_FloatAdd || + Kind == RecurrenceDescriptor::RK_FloatMult) && + "Unknown reduction kind"); + assert(Src->getType()->isVectorTy() && "Expected a vector type"); + assert(!Start->getType()->isVectorTy() && "Expected a scalar type"); + + std::function CreateRdx; + // Predication is done by masking out the inactive elements of the vector + // source with a safe identity value. + switch (Kind) { + default: + llvm_unreachable("Unknown reduction kind"); + case RecurrenceDescriptor::RK_FloatAdd: { + auto *MaskedVec = Builder.CreateSelect( + Predicate, Src, ConstantFP::get(Src->getType(), 0.0)); + return Builder.CreateFAddReduce(Start, MaskedVec); + } + case RecurrenceDescriptor::RK_FloatMult: { + auto *MaskedVec = Builder.CreateSelect( + Predicate, Src, ConstantFP::get(Src->getType(), 1.0)); + return Builder.CreateFMulReduce(Start, MaskedVec); + } + } +} + + void llvm::propagateIRFlags(Value *I, ArrayRef VL, Value *OpValue) { auto *VecOp = dyn_cast(I); if (!VecOp) @@ -1739,3 +2097,23 @@ VecOp->andIRFlags(V); } } + +bool +llvm::storeToSameAddress(ScalarEvolution *SE, StoreInst *A, StoreInst *B) { + // Compare store + if (A == B) + return true; + + // Otherwise Compare pointers + Value *APtr = A->getPointerOperand(); + Value *BPtr = B->getPointerOperand(); + if (A == B) + return true; + + // Otherwise compare address SCEVs + if (SE->getSCEV(APtr) == SE->getSCEV(BPtr)) + return true; + + return false; +} + Index: lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- lib/Transforms/Utils/SimplifyCFG.cpp +++ lib/Transforms/Utils/SimplifyCFG.cpp @@ -5575,9 +5575,110 @@ return true; } -bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { +static bool TurnSmallSwitchIntoICmps(SwitchInst *SI, IRBuilder<> &Builder) { + assert(SI->getNumCases() > 1 && "Degenerate switch?"); + + // Check to see if we have a genuine default, reachable block with executable + // instructions in them. + bool HasDefault = + !isa(SI->getDefaultDest()->getFirstNonPHIOrDbg()); + + BasicBlock *DefaultBlock = HasDefault ? SI->getDefaultDest() : nullptr; + SmallVector UniqueBlocks; BasicBlock *BB = SI->getParent(); + // We don't attempt to deal with ranges here. + for (auto Case : SI->cases()) { + BasicBlock *Dest = Case.getCaseSuccessor(); + for (auto Block : UniqueBlocks) { + // We don't support multiple cases with the same dest + if (Block == Dest) + return false; + } + UniqueBlocks.push_back(Dest); + } + + // Record the total weighting for this switch block. + uint64_t TotalWeight = 0; + SmallVector Weights; + if (HasBranchWeights(SI) && + Weights.size() == (SI->getNumCases() + 1)) { + GetBranchWeights(SI, Weights); + for (auto W : Weights) + TotalWeight += W; + } + + BasicBlock *OtherDest = nullptr; + uint64_t FalseWeight = TotalWeight; + for (auto CI : SI->cases()) { + BasicBlock *TrueDest = CI.getCaseSuccessor(); + Value *Cmp = + Builder.CreateICmpEQ(SI->getCondition(), CI.getCaseValue(), "switch"); + + // Walk through PHIs in TrueDest and see which ones came + // from the switch block, then remap them. + if (OtherDest != nullptr) { + for (PHINode &PN : TrueDest->phis()) { + for (auto PB : PN.blocks()) { + if (PB == BB) { + Value *V = PN.getIncomingValueForBlock(BB); + PN.removeIncomingValue(BB, false); + PN.addIncoming(V, OtherDest); + } + } + } + } + + BasicBlock *MoveAfter = OtherDest ? OtherDest : BB; + OtherDest = + BasicBlock::Create(BB->getContext(), + BB->getName() + ".switch", + BB->getParent(), BB); + OtherDest->moveAfter(MoveAfter); + + Instruction *I = Builder.CreateCondBr(Cmp, TrueDest, OtherDest); + // Update weight for the newly-created conditional branch. + if (TotalWeight) { + int index = CI.getSuccessorIndex(); + FalseWeight -= Weights[index]; + setBranchWeights(I, Weights[index], FalseWeight); + } + Builder.SetInsertPoint(OtherDest); + + } + + if (DefaultBlock) { + // The last block we created is empty, which is bad mmm'k! + Builder.CreateBr(DefaultBlock); + + // The block that we jump to may have had some PHIs that came + // from the block containing the switch statement. Now that we + // are removing the switch statement we need to fix up the PHIs. + + // Walk through PHIs in DefaultBlock and see which ones came + // from the switch block, then remap them. + for (PHINode &PN : DefaultBlock->phis()) { + for (auto PB : PN.blocks()) { + if (PB == BB) { + Value *V = PN.getIncomingValueForBlock(BB); + PN.removeIncomingValue(BB, false); + PN.addIncoming(V, OtherDest); + } + } + } + } else + Builder.CreateUnreachable(); + + // Drop the switch. + SI->eraseFromParent(); + + Builder.SetInsertPoint(BB); + + return true; +} + +bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { + BasicBlock *BB = SI->getParent(); if (isValueEqualityComparison(SI)) { // If we only have one predecessor, and if it is a branch on this value, // see if that predecessor totally determines the outcome of this switch. @@ -5597,8 +5698,14 @@ return simplifyCFG(BB, TTI, Options) | true; } + unsigned NumCases = SI->getNumCases(); + bool RemoveSwitches = Options.SwitchRemovalThreshold >= NumCases; + + if (RemoveSwitches && TurnSmallSwitchIntoICmps(SI, Builder)) + return simplifyCFG(BB, TTI, Options) | true; + // Try to transform the switch into an icmp and a branch. - if (TurnSwitchRangeIntoICmp(SI, Builder)) + if (!RemoveSwitches && TurnSwitchRangeIntoICmp(SI, Builder)) return simplifyCFG(BB, TTI, Options) | true; // Remove unreachable cases. @@ -5817,16 +5924,19 @@ if (SimplifyEqualityComparisonWithOnlyPredecessor(BI, OnlyPred, Builder)) return simplifyCFG(BB, TTI, Options) | true; - // This block must be empty, except for the setcond inst, if it exists. - // Ignore dbg intrinsics. - auto I = BB->instructionsWithoutDebug().begin(); - if (&*I == BI) { - if (FoldValueComparisonIntoPredecessors(BI, Builder)) - return simplifyCFG(BB, TTI, Options) | true; - } else if (&*I == cast(BI->getCondition())) { - ++I; - if (&*I == BI && FoldValueComparisonIntoPredecessors(BI, Builder)) - return simplifyCFG(BB, TTI, Options) | true; + bool RemoveSwitches = Options.SwitchRemovalThreshold > 0; + if (!RemoveSwitches) { + // This block must be empty, except for the setcond inst, if it exists. + // Ignore dbg intrinsics. + auto I = BB->instructionsWithoutDebug().begin(); + if (&*I == BI) { + if (FoldValueComparisonIntoPredecessors(BI, Builder)) + return simplifyCFG(BB, TTI, Options) | true; + } else if (&*I == cast(BI->getCondition())) { + ++I; + if (&*I == BI && FoldValueComparisonIntoPredecessors(BI, Builder)) + return simplifyCFG(BB, TTI, Options) | true; + } } } Index: lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- lib/Transforms/Utils/SimplifyLibCalls.cpp +++ lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" @@ -927,6 +928,20 @@ // Math Library Optimizations //===----------------------------------------------------------------------===// +// Replace a libcall \p CI with a call to intrinsic \p IID +static Value *replaceUnaryCall(CallInst *CI, IRBuilder<> &B, Intrinsic::ID IID) { + // Propagate fast-math flags from the existing call to the new call. + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + + Module *M = CI->getModule(); + Value *V = CI->getArgOperand(0); + Function *F = Intrinsic::getDeclaration(M, IID, CI->getType()); + CallInst *NewCall = B.CreateCall(F, V); + NewCall->takeName(CI); + return NewCall; +} + /// Return a variant of Val with float type. /// Currently this works in two cases: If Val is an FPExtension of a float /// value to something bigger, simply return the operand. @@ -949,104 +964,75 @@ return nullptr; } -/// Shrink double -> float for unary functions like 'floor'. -static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, - bool CheckRetType) { - Function *Callee = CI->getCalledFunction(); - // We know this libcall has a valid prototype, but we don't know which. +/// Shrink double -> float functions. +static Value *optimizeDoubleFP(CallInst *CI, IRBuilder<> &B, + bool isBinary, bool isPrecise = false) { if (!CI->getType()->isDoubleTy()) return nullptr; - if (CheckRetType) { - // Check if all the uses for function like 'sin' are converted to float. + // If not all the uses of the function are converted to float, then bail out. + // This matters if the precision of the result is more important than the + // precision of the arguments. + if (isPrecise) for (User *U : CI->users()) { FPTruncInst *Cast = dyn_cast(U); if (!Cast || !Cast->getType()->isFloatTy()) return nullptr; } - } - // If this is something like 'floor((double)floatval)', convert to floorf. - Value *V = valueHasFloatPrecision(CI->getArgOperand(0)); - if (V == nullptr) + // If this is something like 'g((double) float)', convert to 'gf(float)'. + Value *V[2]; + V[0] = valueHasFloatPrecision(CI->getArgOperand(0)); + V[1] = isBinary ? valueHasFloatPrecision(CI->getArgOperand(1)) : nullptr; + if (!V[0] || (isBinary && !V[1])) return nullptr; // If call isn't an intrinsic, check that it isn't within a function with the - // same name as the float version of this call. + // same name as the float version of this call, otherwise the result is an + // infinite loop. For example, from MinGW-w64: // - // e.g. inline float expf(float val) { return (float) exp((double) val); } - // - // A similar such definition exists in the MinGW-w64 math.h header file which - // when compiled with -O2 -ffast-math causes the generation of infinite loops - // where expf is called. - if (!Callee->isIntrinsic()) { - const Function *F = CI->getFunction(); - StringRef FName = F->getName(); - StringRef CalleeName = Callee->getName(); - if ((FName.size() == (CalleeName.size() + 1)) && - (FName.back() == 'f') && - FName.startswith(CalleeName)) + // float expf(float val) { return (float) exp((double) val); } + Function *CalleeFn = CI->getCalledFunction(); + StringRef CalleeNm = CalleeFn->getName(); + AttributeList CalleeAt = CalleeFn->getAttributes(); + if (CalleeFn && !CalleeFn->isIntrinsic()) { + const Function *Fn = CI->getFunction(); + StringRef FnName = Fn->getName(); + if (FnName.back() == 'f' && + FnName.size() == (CalleeNm.size() + 1) && + FnName.startswith(CalleeNm)) return nullptr; } - // Propagate fast-math flags from the existing call to the new call. + // Propagate the math semantics from the current function to the new function. IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(CI->getFastMathFlags()); - // floor((double)floatval) -> (double)floorf(floatval) - if (Callee->isIntrinsic()) { + // g((double) float) -> (double) gf(float) + Value *R; + if (CalleeFn->isIntrinsic()) { Module *M = CI->getModule(); - Intrinsic::ID IID = Callee->getIntrinsicID(); - Function *F = Intrinsic::getDeclaration(M, IID, B.getFloatTy()); - V = B.CreateCall(F, V); - } else { - // The call is a library call rather than an intrinsic. - V = emitUnaryFloatFnCall(V, Callee->getName(), B, Callee->getAttributes()); + Intrinsic::ID IID = CalleeFn->getIntrinsicID(); + Function *Fn = Intrinsic::getDeclaration(M, IID, B.getFloatTy()); + R = isBinary ? B.CreateCall(Fn, V) : B.CreateCall(Fn, V[0]); } + else + R = isBinary ? emitBinaryFloatFnCall(V[0], V[1], CalleeNm, B, CalleeAt) + : emitUnaryFloatFnCall(V[0], CalleeNm, B, CalleeAt); - return B.CreateFPExt(V, B.getDoubleTy()); + return B.CreateFPExt(R, B.getDoubleTy()); } -// Replace a libcall \p CI with a call to intrinsic \p IID -static Value *replaceUnaryCall(CallInst *CI, IRBuilder<> &B, Intrinsic::ID IID) { - // Propagate fast-math flags from the existing call to the new call. - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - Module *M = CI->getModule(); - Value *V = CI->getArgOperand(0); - Function *F = Intrinsic::getDeclaration(M, IID, CI->getType()); - CallInst *NewCall = B.CreateCall(F, V); - NewCall->takeName(CI); - return NewCall; +/// Shrink double -> float for unary functions. +static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, + bool isPrecise = false) { + return optimizeDoubleFP(CI, B, false, isPrecise); } -/// Shrink double -> float for binary functions like 'fmin/fmax'. -static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - // We know this libcall has a valid prototype, but we don't know which. - if (!CI->getType()->isDoubleTy()) - return nullptr; - - // If this is something like 'fmin((double)floatval1, (double)floatval2)', - // or fmin(1.0, (double)floatval), then we convert it to fminf. - Value *V1 = valueHasFloatPrecision(CI->getArgOperand(0)); - if (V1 == nullptr) - return nullptr; - Value *V2 = valueHasFloatPrecision(CI->getArgOperand(1)); - if (V2 == nullptr) - return nullptr; - - // Propagate fast-math flags from the existing call to the new call. - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - // fmin((double)floatval1, (double)floatval2) - // -> (double)fminf(floatval1, floatval2) - // TODO: Handle intrinsics in the same way as in optimizeUnaryDoubleFP(). - Value *V = emitBinaryFloatFnCall(V1, V2, Callee->getName(), B, - Callee->getAttributes()); - return B.CreateFPExt(V, B.getDoubleTy()); +/// Shrink double -> float for binary functions. +static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B, + bool isPrecise = false) { + return optimizeDoubleFP(CI, B, true, isPrecise); } // cabs(z) -> sqrt((creal(z)*creal(z)) + (cimag(z)*cimag(z))) @@ -1119,14 +1105,97 @@ return InnerChain[Exp]; } +/// Use exp{,2}(x * y) for pow(exp{,2}(x), y); +/// exp2(x) for pow(2.0, x); exp10(x) for pow(10.0, x). +Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { + Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); + AttributeList Attrs = Pow->getCalledFunction()->getAttributes(); + Module *Mod = Pow->getModule(); + Type *Ty = Pow->getType(); + + // Evaluate special cases related to a nested function as the base. + + // pow(exp(x), y) -> exp(x * y) + // pow(exp2(x), y) -> exp2(x * y) + // If exp{,2}() is used only once, it is better to fold two transcendental + // math functions into one. If used again, exp{,2}() would still have to be + // called with the original argument, then keep both original transcendental + // functions. However, this transformation is only safe with fully relaxed + // math semantics, since, besides rounding differences, it changes overflow + // and underflow behavior quite dramatically. For example: + // pow(exp(1000), 0.001) = pow(inf, 0.001) = inf + // Whereas: + // exp(1000 * 0.001) = exp(1) + // TODO: Loosen the requirement for fully relaxed math semantics. + // TODO: Handle exp10() when more targets have it available. + CallInst *BaseFn = dyn_cast(Base); + if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && Pow->isFast()) { + LibFunc LibFn; + + Function *CalleeFn = BaseFn->getCalledFunction(); + if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + (LibFn == LibFunc_exp || LibFn == LibFunc_exp2) && TLI->has(LibFn)) { + Value *ExpFn; + + // Create new exp{,2}() with the product as its argument. + Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul"); + ExpFn = emitUnaryFloatFnCall(FMul, CalleeFn->getName(), B, + BaseFn->getAttributes()); + + // Since the new exp{,2}() is different from the original one, dead code + // elimination cannot be trusted to remove it, since it may have side + // effects (e.g., errno). When the only consumer for the original + // exp{,2}() is pow(), then it has to be explicitly erased. + BaseFn->replaceAllUsesWith(ExpFn); + BaseFn->eraseFromParent(); + + return ExpFn; + } + } + + // Evaluate special cases related to a constant base. + + // pow(2.0, x) -> exp2(x) + if (match(Base, m_SpecificFP(2.0))) { + Value *Exp2 = Intrinsic::getDeclaration(Mod, Intrinsic::exp2, Ty); + return B.CreateCall(Exp2, Expo, "exp2"); + } + + // pow(10.0, x) -> exp10(x) + // TODO: There is no exp10() intrinsic yet, but some day there shall be one. + if (match(Base, m_SpecificFP(10.0)) && + hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) + return emitUnaryFloatFnCall(Expo, TLI->getName(LibFunc_exp10), B, Attrs); + + return nullptr; +} + +static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno, + Module *M, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + // If errno is never set, then use the intrinsic for sqrt(). + if (NoErrno) { + Function *SqrtFn = + Intrinsic::getDeclaration(M, Intrinsic::sqrt, V->getType()); + return B.CreateCall(SqrtFn, V, "sqrt"); + } + + // Otherwise, use the libcall for sqrt(). + if (hasUnaryFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf, + LibFunc_sqrtl)) + // TODO: We also should check that the target can in fact lower the sqrt() + // libcall. We currently have no way to ask this question, so we ask if + // the target has a sqrt() libcall, which is not exactly the same. + return emitUnaryFloatFnCall(V, TLI->getName(LibFunc_sqrt), B, Attrs); + + return nullptr; +} + /// Use square root in place of pow(x, +/-0.5). Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { - // TODO: There is some subset of 'fast' under which these transforms should - // be allowed. - if (!Pow->isFast()) - return nullptr; - Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); + AttributeList Attrs = Pow->getCalledFunction()->getAttributes(); + Module *Mod = Pow->getModule(); Type *Ty = Pow->getType(); const APFloat *ExpoF; @@ -1134,22 +1203,25 @@ (!ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5))) return nullptr; - // If errno is never set, then use the intrinsic for sqrt(). - if (Pow->hasFnAttr(Attribute::ReadNone)) { - Function *SqrtFn = Intrinsic::getDeclaration(Pow->getModule(), - Intrinsic::sqrt, Ty); - Sqrt = B.CreateCall(SqrtFn, Base); - } - // Otherwise, use the libcall for sqrt(). - else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) - // TODO: We also should check that the target can in fact lower the sqrt() - // libcall. We currently have no way to ask this question, so we ask if - // the target has a sqrt() libcall, which is not exactly the same. - Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), B, - Pow->getCalledFunction()->getAttributes()); - else + Sqrt = getSqrtCall(Base, Attrs, Pow->doesNotAccessMemory(), Mod, B, TLI); + if (!Sqrt) return nullptr; + // Handle signed zero base by expanding to fabs(sqrt(x)). + if (!Pow->hasNoSignedZeros()) { + Function *FAbsFn = Intrinsic::getDeclaration(Mod, Intrinsic::fabs, Ty); + Sqrt = B.CreateCall(FAbsFn, Sqrt, "abs"); + } + + // Handle non finite base by expanding to + // (x == -infinity ? +infinity : sqrt(x)). + if (!Pow->hasNoInfs()) { + Value *PosInf = ConstantFP::getInfinity(Ty), + *NegInf = ConstantFP::getInfinity(Ty, true); + Value *FCmp = B.CreateFCmpOEQ(Base, NegInf, "isinf"); + Sqrt = B.CreateSelect(FCmp, PosInf, Sqrt); + } + // If the exponent is negative, then get the reciprocal. if (ExpoF->isNegative()) Sqrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Sqrt, "reciprocal"); @@ -1160,134 +1232,109 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); Function *Callee = Pow->getCalledFunction(); - AttributeList Attrs = Callee->getAttributes(); StringRef Name = Callee->getName(); - Module *Module = Pow->getModule(); Type *Ty = Pow->getType(); Value *Shrunk = nullptr; bool Ignored; - if (UnsafeFPShrink && - Name == TLI->getName(LibFunc_pow) && hasFloatVersion(Name)) - Shrunk = optimizeUnaryDoubleFP(Pow, B, true); + // Bail out if simplifying libcalls to pow() is disabled. + if (!hasUnaryFloatFn(TLI, Ty, LibFunc_pow, LibFunc_powf, LibFunc_powl)) + return nullptr; // Propagate the math semantics from the call to any created instructions. IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(Pow->getFastMathFlags()); + // Shrink pow() to powf() if the arguments are single precision, + // unless the result is expected to be double precision. + if (UnsafeFPShrink && + Name == TLI->getName(LibFunc_pow) && hasFloatVersion(Name)) + Shrunk = optimizeBinaryDoubleFP(Pow, B, true); + // Evaluate special cases related to the base. // pow(1.0, x) -> 1.0 - if (match(Base, m_SpecificFP(1.0))) + if (match(Base, m_FPOne())) return Base; - // pow(2.0, x) -> exp2(x) - if (match(Base, m_SpecificFP(2.0))) { - Value *Exp2 = Intrinsic::getDeclaration(Module, Intrinsic::exp2, Ty); - return B.CreateCall(Exp2, Expo, "exp2"); - } - - // pow(10.0, x) -> exp10(x) - if (ConstantFP *BaseC = dyn_cast(Base)) - // There's no exp10 intrinsic yet, but, maybe, some day there shall be one. - if (BaseC->isExactlyValue(10.0) && - hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) - return emitUnaryFloatFnCall(Expo, TLI->getName(LibFunc_exp10), B, Attrs); - - // pow(exp(x), y) -> exp(x * y) - // pow(exp2(x), y) -> exp2(x * y) - // We enable these only with fast-math. Besides rounding differences, the - // transformation changes overflow and underflow behavior quite dramatically. - // Example: x = 1000, y = 0.001. - // pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x*y) = exp(1). - auto *BaseFn = dyn_cast(Base); - if (BaseFn && BaseFn->isFast() && Pow->isFast()) { - LibFunc LibFn; - Function *CalleeFn = BaseFn->getCalledFunction(); - if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && - (LibFn == LibFunc_exp || LibFn == LibFunc_exp2) && TLI->has(LibFn)) { - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(Pow->getFastMathFlags()); - - Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul"); - return emitUnaryFloatFnCall(FMul, CalleeFn->getName(), B, - CalleeFn->getAttributes()); - } - } + if (Value *Exp = replacePowWithExp(Pow, B)) + return Exp; // Evaluate special cases related to the exponent. + // pow(x, -1.0) -> 1.0 / x + if (match(Expo, m_SpecificFP(-1.0))) + return B.CreateFDiv(ConstantFP::get(Ty, 1.0), Base, "reciprocal"); + + // pow(x, 0.0) -> 1.0 + if (match(Expo, m_SpecificFP(0.0))) + return ConstantFP::get(Ty, 1.0); + + // pow(x, 1.0) -> x + if (match(Expo, m_FPOne())) + return Base; + + // pow(x, 2.0) -> x * x + if (match(Expo, m_SpecificFP(2.0))) + return B.CreateFMul(Base, Base, "square"); + if (Value *Sqrt = replacePowWithSqrt(Pow, B)) return Sqrt; - ConstantFP *ExpoC = dyn_cast(Expo); - if (!ExpoC) - return Shrunk; + // pow(x, n) -> x * x * x * ... + const APFloat *ExpoF; + if (Pow->isFast() && match(Expo, m_APFloat(ExpoF))) { + // We limit to a max of 7 multiplications, thus the maximum exponent is 32. + // If the exponent is an integer+0.5 we generate a call to sqrt and an + // additional fmul. + // TODO: This whole transformation should be backend specific (e.g. some + // backends might prefer libcalls or the limit for the exponent might + // be different) and it should also consider optimizing for size. + APFloat LimF(ExpoF->getSemantics(), 33.0), + ExpoA(abs(*ExpoF)); + if (ExpoA.compare(LimF) == APFloat::cmpLessThan) { + // This transformation applies to integer or integer+0.5 exponents only. + // For integer+0.5, we create a sqrt(Base) call. + Value *Sqrt = nullptr; + if (!ExpoA.isInteger()) { + APFloat Expo2 = ExpoA; + // To check if ExpoA is an integer + 0.5, we add it to itself. If there + // is no floating point exception and the result is an integer, then + // ExpoA == integer + 0.5 + if (Expo2.add(ExpoA, APFloat::rmNearestTiesToEven) != APFloat::opOK) + return nullptr; - // pow(x, -1.0) -> 1.0 / x - if (ExpoC->isExactlyValue(-1.0)) - return B.CreateFDiv(ConstantFP::get(Ty, 1.0), Base, "reciprocal"); + if (!Expo2.isInteger()) + return nullptr; - // pow(x, 0.0) -> 1.0 - if (ExpoC->getValueAPF().isZero()) - return ConstantFP::get(Ty, 1.0); + Sqrt = + getSqrtCall(Base, Pow->getCalledFunction()->getAttributes(), + Pow->doesNotAccessMemory(), Pow->getModule(), B, TLI); + } - // pow(x, 1.0) -> x - if (ExpoC->isExactlyValue(1.0)) - return Base; + // We will memoize intermediate products of the Addition Chain. + Value *InnerChain[33] = {nullptr}; + InnerChain[1] = Base; + InnerChain[2] = B.CreateFMul(Base, Base, "square"); - // pow(x, 2.0) -> x * x - if (ExpoC->isExactlyValue(2.0)) - return B.CreateFMul(Base, Base, "square"); + // We cannot readily convert a non-double type (like float) to a double. + // So we first convert it to something which could be converted to double. + ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored); + Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B); - // FIXME: Correct the transforms and pull this into replacePowWithSqrt(). - if (ExpoC->isExactlyValue(0.5) && - hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) { - // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))). - // This is faster than calling pow(), and still handles -0.0 and - // negative infinity correctly. - // TODO: In finite-only mode, this could be just fabs(sqrt(x)). - Value *PosInf = ConstantFP::getInfinity(Ty); - Value *NegInf = ConstantFP::getInfinity(Ty, true); + // Expand pow(x, y+0.5) to pow(x, y) * sqrt(x). + if (Sqrt) + FMul = B.CreateFMul(FMul, Sqrt); - // TODO: As above, we should lower to the sqrt() intrinsic if the pow() is - // an intrinsic, to match errno semantics. - Value *Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), - B, Attrs); - Function *FAbsFn = Intrinsic::getDeclaration(Module, Intrinsic::fabs, Ty); - Value *FAbs = B.CreateCall(FAbsFn, Sqrt, "abs"); - Value *FCmp = B.CreateFCmpOEQ(Base, NegInf, "isinf"); - Sqrt = B.CreateSelect(FCmp, PosInf, FAbs); - return Sqrt; + // If the exponent is negative, then get the reciprocal. + if (ExpoF->isNegative()) + FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal"); + + return FMul; + } } - // pow(x, n) -> x * x * x * .... - if (Pow->isFast()) { - APFloat ExpoA = abs(ExpoC->getValueAPF()); - // We limit to a max of 7 fmul(s). Thus the maximum exponent is 32. - // This transformation applies to integer exponents only. - if (!ExpoA.isInteger() || - ExpoA.compare - (APFloat(ExpoA.getSemantics(), 32.0)) == APFloat::cmpGreaterThan) - return nullptr; - - // We will memoize intermediate products of the Addition Chain. - Value *InnerChain[33] = {nullptr}; - InnerChain[1] = Base; - InnerChain[2] = B.CreateFMul(Base, Base, "square"); - - // We cannot readily convert a non-double type (like float) to a double. - // So we first convert it to something which could be converted to double. - ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored); - Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B); - - // If the exponent is negative, then get the reciprocal. - if (ExpoC->isNegative()) - FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal"); - return FMul; - } - - return nullptr; + return Shrunk; } Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { Index: lib/Transforms/Utils/VNCoercion.cpp =================================================================== --- lib/Transforms/Utils/VNCoercion.cpp +++ lib/Transforms/Utils/VNCoercion.cpp @@ -14,10 +14,15 @@ /// Return true if coerceAvailableValueToLoadType will succeed. bool canCoerceMustAliasedValueToLoad(Value *StoredVal, Type *LoadTy, const DataLayout &DL) { - // If the loaded or stored value is an first class array or struct, don't try - // to transform them. We need to be able to bitcast to integer. - if (LoadTy->isStructTy() || LoadTy->isArrayTy() || - StoredVal->getType()->isStructTy() || StoredVal->getType()->isArrayTy()) + Type *StoredValTy = StoredVal->getType(); + + // If the loaded or stored value is a first class array, struct or scalable + // vector, don't try to transform them. We need to be able to bitcast to + // integer. + if (LoadTy->isArrayTy() || StoredValTy->isArrayTy() || + LoadTy->isStructTy() || StoredValTy->isStructTy() || + (LoadTy->isVectorTy() && LoadTy->getVectorIsScalable()) || + (StoredValTy->isVectorTy() && StoredValTy->getVectorIsScalable())) return false; uint64_t StoreSize = DL.getTypeSizeInBits(StoredVal->getType()); @@ -31,7 +36,7 @@ return false; // Don't coerce non-integral pointers to integers or vice versa. - if (DL.isNonIntegralPointerType(StoredVal->getType()) != + if (DL.isNonIntegralPointerType(StoredValTy) != DL.isNonIntegralPointerType(LoadTy)) return false; Index: lib/Transforms/Vectorize/BOSCC.cpp =================================================================== --- /dev/null +++ lib/Transforms/Vectorize/BOSCC.cpp @@ -0,0 +1,311 @@ +//===- BOSCC.cpp ------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// \brief BOSCC Pass +// +// This pass reintroduces control flow on vectorized loops. It does so by +// detecting the following masked.[store|gather] into a block... +// +// ----- input ----- +// %BB: +// ... +// ... +// void call masked.[store|gather](..., with predicate %P, ...) +// NextInst ... +// ... +// br ... +// ------------------ +// +// ... and generating a new CFG as follows: +// +// ----- output ----- +// %BB: +// ... +// %test = test any true %P +// BR: if (%test) %guarded else %skip +// +// %guarded: +// void call masked.[store|gather](..., with predicate %P, ...) +// BR unconditional %skip +// +// %skip: +// NextInst ... +// ... +// ------------------ +// +// The name of the pass stands for 'branches-on-superword-condition-codes', a +// technique for reintroducing control flow in vectorized loops described in +// http://www.mcs.anl.gov/papers/P1411.pdf +// +// This implementation does not use all the analysis described in the paper, but +// the mechanics of the optimization are basically the same, as in 'check a +// predicate, branch around if all false'. +// +// Eventually this optimization will make use of profile data when available. +// +// TODO: +// 1. Check Width and Interleave hints set to 1 (see setAlreadyVectorized in +// loop vecotrizer). It might require to extract LoopVectorizeHints class in +// a separate module/headerfile? +// 2. Use Alias Analysis to populate the UnsafeToMerge set. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Vectorize.h" +#include +#include +#include +#include + +using namespace llvm; + +#define DEBUG_TYPE "boscc" + +STATISTIC(NumGuardedStores, "Number of masked stores that can be" + " guarded by a test on the predicate"); +STATISTIC(NumGuardedBlocks, "Number of guarded blocks generated"); + +namespace { +class BOSCC : public LoopPass { +public: + static char ID; // Pass identification. + + /// Constructor. + explicit BOSCC() : LoopPass(ID) { + initializeBOSCCPass(*PassRegistry::getPassRegistry()); + } + + /// LoopPass interface. + bool runOnLoop(Loop *L, LPPassManager &) override; + + /// LoopPass interface. + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + } + +private: + /// Loop info structure to update after the transformation. + LoopInfo *LI; + + /// Container for the places there to split. Each element is a triple holding: + /// 1. The instruction to be guarded by teh new block. + /// 2. The instruction that follows, that is used as the first instruction + /// of the "skip" block. + /// 3. The predicate that need to be checked. + using TripleTy = + std::tuple; + std::deque Splits; + + /// Holds a mapping between the predicate that is being tested and the guarded + /// basic block generated, to avoid unnecessary duplication. + std::map PredicateMapping; + + /// Set of instructions that are not safe to merge in the same predicated + /// block. + /// + /// The UnsafeToMerge set holds instructions that can be guarded by a + /// predicate check but that are executed around a memory access (R or W) + /// call/instruction as follows: + /// + /// ... + /// InstN + /// ... + /// FirstInstI + /// ... + /// ... + /// InstReadWriteMem + /// ... + /// ... + /// SecondInstI + /// ... + /// ... + /// + /// Because the code-generation of the guarded blocks happens from bottom to + /// top in the vector body sequence, we need to add both instructions (first + /// and second) to the unsafe to merge set, otherwise it would not find any + /// basic block In PredicateMapping to merge to when processing FirstInstI + /// after SecondInstI. + std::set UnsafeToMerge; + + /// Populate the structures needed by the pass. + void runOnBasicBlock(BasicBlock *); +}; +} + +char BOSCC::ID = 0; +INITIALIZE_PASS_BEGIN(BOSCC, "boscc", + "Inserts branches-on-superword-condition-codes", false, + false) +// must be executed after loop vectorization +INITIALIZE_PASS_DEPENDENCY(LoopVectorize) +INITIALIZE_PASS_END(BOSCC, "boscc", + "Inserts branches-on-superword-condition-codes", false, + false) + +namespace llvm { +Pass *createBOSCCPass() { return new BOSCC(); } +} + +bool BOSCC::runOnLoop(Loop *L, LPPassManager &LPM) { + LI = &getAnalysis().getLoopInfo(); + + // must be the innermost loop (as it is supposed to be vectorized + if (!L->getSubLoops().empty()) + return false; + + // Width and Interleave loop metadata set to 1 should be check here. + + for (auto &BB : L->blocks()) { + runOnBasicBlock(BB); + } + + // Process the places to split. + for (auto Pair : Splits) { + auto Inst = std::get<0>(Pair); + auto Pred = std::get<2>(Pair); + BasicBlock *Before = Inst->getParent(); + assert(Before && "Original basic block is missing."); + + // Search for a block that is already used for that predicate, and also + // check if the instruction is unsafe to merge. + if ((PredicateMapping.find(Pred) == std::end(PredicateMapping)) || + (UnsafeToMerge.find(&*Inst) != std::end(UnsafeToMerge))) { + // Create the new block structure + BasicBlock *Guarded = Before->splitBasicBlock(Inst, "guarded.block"); + ++NumGuardedBlocks; + assert(Guarded && "Guarded block creation failed."); + auto Next = std::get<1>(Pair); + BasicBlock *After = + Next->getParent()->splitBasicBlock(Next, "unconditional.block"); + assert(After && "Post store block creation failed."); + // Update loop infos. + L->addBasicBlockToLoop(Guarded, *LI); + L->addBasicBlockToLoop(After, *LI); + + // Create the conditional test and the branch instruction. + auto OldTerm = Before->getTerminator(); + IRBuilder<> Builder(OldTerm); + // FIXME: Make this generic for all targets once we have generic reductions. + Value *Done = getAnyTrueReduction(Builder, Pred); + Builder.CreateCondBr(Done, Guarded, After); + OldTerm->eraseFromParent(); + PredicateMapping[Pred] = Guarded; + } else { + // This predicate has already a guarded block, just move the store there + // as it is not in the UnsafeToMerge set. Since the 'Splits' container is + // populated from the front end, all the stores already in the guarded + // block are subsequent to teh once we are currently processing, + // i.e. program order of stores is guaranteed. + auto InsertPoint = PredicateMapping[Pred]->getFirstNonPHI(); + Inst->moveBefore(InsertPoint); + } + } + + LLVM_DEBUG(L->verifyLoop()); + + const bool CodeHasChanged = !Splits.empty(); + + // Clear up the data used in this particular loop. + Splits.clear(); + PredicateMapping.clear(); + UnsafeToMerge.clear(); + return CodeHasChanged; +} + +void BOSCC::runOnBasicBlock(BasicBlock *BB) { + for (auto II = BB->begin(), BE = BB->end(); II != BE; ++II) { + if (auto CI = dyn_cast(II)) { + Function *F = CI->getCalledFunction(); + if (!F) + continue; + + Intrinsic::ID IID = F->getIntrinsicID(); + if (!IID) + continue; + + if ((IID == Intrinsic::masked_store) || + (IID == Intrinsic::masked_scatter)) { + LLVM_DEBUG(dbgs() << "Found a masked store\n"); + + // Skip the main predicate of the loop because it would not make sense + // for it to contain no inactive elements. + auto Pred = II->getOperand(3); + if (isa(Pred)) { + LLVM_DEBUG(dbgs() << "Skipping stores using the loop main predicate.\n"); + continue; + } + + auto Next = std::next(II); + if (isa(Next)) { + // For safety we skip masked.store followed by a branch instruction, + // which is unlikely, but we want to skip handling multiple blocks + // branching. + LLVM_DEBUG(dbgs() << "Skipping a branch instruction.\n"); + continue; + } + + // Ignore the last instruction in a block. + if (Next == BE) + continue; + + LLVM_DEBUG(dbgs() << "Adding store:\n"; II->print(dbgs()); + dbgs() << "\n followed by:\n"; Next->print(dbgs()); + dbgs() << "\n"); + /// Add the intruction to the list of split points. + Splits.emplace_front(II, Next, Pred); + ++NumGuardedStores; + } + } + } + + // Extra checks to make sure program order is mantained. + for (auto TI = Splits.begin(), TE = Splits.end(); TI != TE; ++TI) { + auto TNext = std::next(TI); + if (TNext == TE) + break; + + auto FirstInstI = std::get<0>(*TNext); + auto SecondInstI = std::get<0>(*TI); + + // Skip instructions from different BBs. + if (FirstInstI->getParent() != SecondInstI->getParent()) + continue; + + for (auto II = std::next(FirstInstI); II != SecondInstI; ++II) { + // In reality we should check for writes/reads from the same memory + // location of the store. Right now we just use a safe condition, we will + // refine the choice by cheking aliasing pointers in the future. + if (II->mayReadOrWriteMemory()) { + UnsafeToMerge.insert(&*FirstInstI); + UnsafeToMerge.insert(&*SecondInstI); + LLVM_DEBUG(dbgs() << "Unsafe to merge pair:\n"; + dbgs() << *FirstInstI << "\n"; dbgs() << *SecondInstI << "\n"); + } + } + } + + // Useful for seeing what's going on in the full block, for example to check + // how many predicates are in the block and if any of them are equivalent. + LLVM_DEBUG(dbgs() << "\nUnique blocks being processed: \n"; + std::set Blocks; + for (auto Split : Splits) { + auto Store = std::get<0>(Split); + Blocks.insert(Store->getParent()); + } + for (auto BB : Blocks) + dbgs() << *BB;); +} Index: lib/Transforms/Vectorize/CMakeLists.txt =================================================================== --- lib/Transforms/Vectorize/CMakeLists.txt +++ lib/Transforms/Vectorize/CMakeLists.txt @@ -1,8 +1,13 @@ add_llvm_library(LLVMVectorize + LoopVectorizationAnalysis.cpp LoadStoreVectorizer.cpp LoopVectorizationLegality.cpp LoopVectorize.cpp + LVCommon.cpp + SVELoopVectorize.cpp + SearchLoopVectorize.cpp SLPVectorizer.cpp + BOSCC.cpp Vectorize.cpp VPlan.cpp VPlanHCFGBuilder.cpp @@ -11,6 +16,7 @@ ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms + ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/IPO DEPENDS intrinsics_gen Index: lib/Transforms/Vectorize/LVCommon.cpp =================================================================== --- /dev/null +++ lib/Transforms/Vectorize/LVCommon.cpp @@ -0,0 +1,107 @@ +#include "llvm/Transforms/LVCommon.h" + +using namespace llvm; +// BEGIN SVE compatability code + +namespace llvm { + +cl::opt +EnableIfConversion("enable-if-conversion", cl::init(true), cl::Hidden, + cl::desc("Enable if-conversion during vectorization.")); + +cl::opt VectorizeSCEVCheckThreshold( + "vectorize-scev-check-threshold", cl::init(16), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed.")); + +cl::opt PragmaVectorizeSCEVCheckThreshold( + "pragma-vectorize-scev-check-threshold", cl::init(128), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed with a " + "vectorize(enable) pragma")); + +cl::opt PragmaVectorizeMemoryCheckThreshold( + "pragma-vectorize-memory-check-threshold", cl::init(128), cl::Hidden, + cl::desc("The maximum allowed number of runtime memory checks with a " + "vectorize(enable) pragma.")); + +/// We don't vectorize loops with a known constant trip count below this number. +cl::opt TinyTripCountVectorThreshold( + "vectorizer-min-trip-count", cl::init(16), cl::Hidden, + cl::desc("Don't vectorize loops with a constant " + "trip count that is smaller than this " + "value.")); + +cl::opt MaximizeBandwidth( + "vectorizer-maximize-bandwidth", cl::init(false), cl::Hidden, + cl::desc("Maximize bandwidth when selecting vectorization factor which " + "will be determined by the smallest type in loop.")); + +cl::opt EnableInterleavedMemAccesses( + "enable-interleaved-mem-accesses", cl::init(false), cl::Hidden, + cl::desc("Enable vectorization on interleaved memory accesses in a loop")); + +/// Maximum factor for an interleaved memory access. +cl::opt MaxInterleaveGroupFactor( + "max-interleave-group-factor", cl::Hidden, + cl::desc("Maximum factor for an interleaved access group (default = 8)"), + cl::init(8)); + +cl::opt ForceTargetNumScalarRegs( + "force-target-num-scalar-regs", cl::init(0), cl::Hidden, + cl::desc("A flag that overrides the target's number of scalar registers.")); + +cl::opt ForceTargetNumVectorRegs( + "force-target-num-vector-regs", cl::init(0), cl::Hidden, + cl::desc("A flag that overrides the target's number of vector registers.")); + +cl::opt ForceTargetMaxScalarInterleaveFactor( + "force-target-max-scalar-interleave", cl::init(0), cl::Hidden, + cl::desc("A flag that overrides the target's max interleave factor for " + "scalar loops.")); + +cl::opt ForceTargetMaxVectorInterleaveFactor( + "force-target-max-vector-interleave", cl::init(0), cl::Hidden, + cl::desc("A flag that overrides the target's max interleave factor for " + "vectorized loops.")); + +cl::opt ForceTargetInstructionCost( + "force-target-instruction-cost", cl::init(0), cl::Hidden, + cl::desc("A flag that overrides the target's expected cost for " + "an instruction to a single constant value. Mostly " + "useful for getting consistent testing.")); + +cl::opt SmallLoopCost( + "small-loop-cost", cl::init(20), cl::Hidden, + cl::desc( + "The cost of a loop that is considered 'small' by the interleaver.")); + +cl::opt LoopVectorizeWithBlockFrequency( + "loop-vectorize-with-block-frequency", cl::init(false), cl::Hidden, + cl::desc("Enable the use of the block frequency analysis to access PGO " + "heuristics minimizing code growth in cold regions and being more " + "aggressive in hot regions.")); + +// Runtime interleave loops for load/store throughput. +cl::opt EnableLoadStoreRuntimeInterleave( + "enable-loadstore-runtime-interleave", cl::init(true), cl::Hidden, + cl::desc( + "Enable runtime interleaving until load/store ports are saturated")); + +/// The number of stores in a loop that are allowed to need predication. +cl::opt NumberOfStoresToPredicate( + "vectorize-num-stores-pred", cl::init(1), cl::Hidden, + cl::desc("Max number of stores to be predicated behind an if.")); + +cl::opt EnableIndVarRegisterHeur( + "enable-ind-var-reg-heur", cl::init(true), cl::Hidden, + cl::desc("Count the induction variable only once when interleaving")); + +cl::opt EnableCondStoresVectorization( + "enable-cond-stores-vec", cl::init(true), cl::Hidden, + cl::desc("Enable if predication of stores during vectorization.")); + +cl::opt MaxNestedScalarReductionIC( + "max-nested-scalar-reduction-interleave", cl::init(2), cl::Hidden, + cl::desc("The maximum interleave count to use when interleaving a scalar " + "reduction in a nested loop.")); +} +// END SVE compatability code Index: lib/Transforms/Vectorize/LoopVectorizationAnalysis.cpp =================================================================== --- /dev/null +++ lib/Transforms/Vectorize/LoopVectorizationAnalysis.cpp @@ -0,0 +1,3440 @@ +//===- LoopVectorizationAnalysis.cpp ----------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Analyzes a loop to determine suitability for vectorization +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/LoopVectorizationAnalysis.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +// TODO: Parameterize this for SLVLoopVectorizeHints instead of +// #defining it here. +#define SLV_NAME "search-loop-vectorize" + +#define LVA_NAME "loop-vec-analysis" +#define DEBUG_TYPE LVA_NAME +#ifndef NDEBUG +#define NODEBUG_EARLY_BAILOUT() \ + do { if (!::llvm::DebugFlag || !::llvm::isCurrentDebugType(DEBUG_TYPE)) \ + { return false; } } while (0) +#else +#define NODEBUG_EARLY_BAILOUT() { return false; } +#endif + +/// TODO: Rename flags? Also reorder.... +///////////////////////// +/// This enables versioning on the strides of symbolically striding memory +/// accesses in code like the following. +/// for (i = 0; i < N; ++i) +/// A[i * Stride1] += B[i * Stride2] ... +/// +/// Will be roughly translated to +/// if (Stride1 == 1 && Stride2 == 1) { +/// for (i = 0; i < N; i+=4) +/// A[i:i+3] += ... +/// } else +/// ... +static cl::opt EnableMemAccessVersioning( + "sl-enable-mem-access-versioning", cl::init(true), cl::Hidden, + cl::desc("Enable symblic stride memory access versioning")); + +static cl::opt EnableInterleavedMemAccesses( + "sl-enable-interleaved-mem-accesses", cl::init(false), cl::Hidden, + cl::desc("Enable vectorization on interleaved memory accesses in a loop")); + +/// The number of stores in a loop that are allowed to need predication. +static cl::opt NumberOfStoresToPredicate( + "sl-vectorize-num-stores-pred", cl::init(1), cl::Hidden, + cl::desc("Max number of stores to be predicated behind an if.")); + +/// Maximum factor for an interleaved memory access. +static cl::opt MaxInterleaveGroupFactor( + "sl-max-interleave-group-factor", cl::Hidden, + cl::desc("Maximum factor for an interleaved access group (default = 8)"), + cl::init(8)); + +static cl::opt EnableCondStoresVectorization( + "sl-enable-cond-stores-vec", cl::init(false), cl::Hidden, + cl::desc("Enable if predication of stores during vectorization.")); + +static cl::opt MaximizeBandwidth( + "sl-vectorizer-maximize-bandwidth", cl::init(false), cl::Hidden, + cl::desc("Maximize bandwidth when selecting vectorization factor which " + "will be determined by the smallest type in loop.")); + +/// We don't interleave loops with a known constant trip count below this +/// number. +static const unsigned TinyTripCountInterleaveThreshold = 128; + +static cl::opt ForceTargetNumScalarRegs( + "sl-force-target-num-scalar-regs", cl::init(0), cl::Hidden, + cl::desc("A flag that overrides the target's number of scalar registers.")); + +static cl::opt ForceTargetNumVectorRegs( + "sl-force-target-num-vector-regs", cl::init(0), cl::Hidden, + cl::desc("A flag that overrides the target's number of vector registers.")); + +static cl::opt EnableIndVarRegisterHeur( + "sl-enable-ind-var-reg-heur", cl::init(true), cl::Hidden, + cl::desc("Count the induction variable only once when interleaving")); + +static cl::opt ForceTargetMaxScalarInterleaveFactor( + "sl-force-target-max-scalar-interleave", cl::init(0), cl::Hidden, + cl::desc("A flag that overrides the target's max interleave factor for " + "scalar loops.")); + +static cl::opt ForceTargetMaxVectorInterleaveFactor( + "sl-force-target-max-vector-interleave", cl::init(0), cl::Hidden, + cl::desc("A flag that overrides the target's max interleave factor for " + "vectorized loops.")); + +static cl::opt ForceTargetInstructionCost( + "sl-force-target-instruction-cost", cl::init(0), cl::Hidden, + cl::desc("A flag that overrides the target's expected cost for " + "an instruction to a single constant value. Mostly " + "useful for getting consistent testing.")); + +static cl::opt SmallLoopCost( + "sl-small-loop-cost", cl::init(20), cl::Hidden, + cl::desc( + "The cost of a loop that is considered 'small' by the interleaver.")); + +// Runtime interleave loops for load/store throughput. +static cl::opt EnableLoadStoreRuntimeInterleave( + "sl-enable-loadstore-runtime-interleave", cl::init(true), cl::Hidden, + cl::desc( + "Enable runtime interleaving until load/store ports are saturated")); + +static cl::opt MaxNestedScalarReductionIC( + "sl-max-nested-scalar-reduction-interleave", cl::init(2), cl::Hidden, + cl::desc("The maximum interleave count to use when interleaving a scalar " + "reduction in a nested loop.")); + +static cl::opt EnableNonConsecutiveStrideIndVars( + "sl-enable-non-consecutive-stride-ind-vars", cl::init(false), cl::Hidden, + cl::desc("Enable recognition of induction variables that aren't consecutive between loop iterations")); + +static cl::opt EnableScalableVectorisation( + "sl-force-scalable-vectorization", cl::init(true), cl::Hidden, + cl::desc("Enable vectorization using scalable vectors")); + +static cl::opt EnableUncountedLoops( + "sl-enable-lv-uncounted-loops", cl::init(false), cl::Hidden, + cl::desc("Enable vectorization of loops without a defined trip count")); + +//////////////////////////////////////////////////////////////////////////////// +// Helper functions (TODO: Move some?) +//////////////////////////////////////////////////////////////////////////////// + +/// \brief Check whether it is safe to if-convert this phi node. +/// +/// Phi nodes with constant expressions that can trap are not safe to if +/// convert. +static bool canIfConvertPHINodes(BasicBlock *BB) { + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + PHINode *Phi = dyn_cast(I); + if (!Phi) + return true; + // TODO: Use to drive predication instead of bailing out? + // Need to find an example that triggers this. + for (unsigned p = 0, e = Phi->getNumIncomingValues(); p != e; ++p) + if (Constant *C = dyn_cast(Phi->getIncomingValue(p))) + if (C->canTrap()) + return false; + } + return true; +} + +// TODO: Feels ugly, see if there's a better way where invoked. +static Type *convertPointerToIntegerType(const DataLayout &DL, Type *Ty) { + if (Ty->isPointerTy()) + return DL.getIntPtrType(Ty); + + // It is possible that char's or short's overflow when we ask for the loop's + // trip count, work around this by changing the type size. + if (Ty->getScalarSizeInBits() < 32) + return Type::getInt32Ty(Ty->getContext()); + + return Ty; +} + +// TODO: Feels like something that should be provided by types instead. +static Type* getWiderType(const DataLayout &DL, Type *Ty0, Type *Ty1) { + Ty0 = convertPointerToIntegerType(DL, Ty0); + Ty1 = convertPointerToIntegerType(DL, Ty1); + if (Ty0->getScalarSizeInBits() > Ty1->getScalarSizeInBits()) + return Ty0; + return Ty1; +} + +// TODO: Rename, clarify. Rework entirely? RPhis not neccesarily special +// anymore... +/// \brief Check that the instruction has outside loop users and is not an +/// identified reduction variable. +static bool hasOutsideLoopUser(const Loop *TheLoop, Instruction *Inst, + SmallPtrSetImpl &Reductions) { + // Reduction instructions are allowed to have exit users. All other + // instructions must not have external users. + if (!Reductions.count(Inst)) + //Check that all of the users of the loop are inside the BB. + for (User *U : Inst->users()) { + Instruction *UI = cast(U); + // This user may be a reduction exit value. + if (!TheLoop->contains(UI)) { + LLVM_DEBUG(dbgs() << "LVA: Found an outside user " << *UI << " for : " + << *Inst << "\n"); + return true; + } + } + return false; +} + +/// A helper function that returns the alignment of load or store instruction. +static unsigned getMemInstAlignment(Value *I) { + assert((isa(I) || isa(I)) && + "Expected Load or Store instruction"); + if (auto *LI = dyn_cast(I)) + return LI->getAlignment(); + return cast(I)->getAlignment(); +} + +//////////////////////////////////////////////////////////////////////////////// +// SLVLoopVectorizeHints +//////////////////////////////////////////////////////////////////////////////// +SLVLoopVectorizeHints::SLVLoopVectorizeHints(const Loop *L, bool DisableInterleaving, + OptimizationRemarkEmitter &ORE) + : Width("vectorize.width", VectorizerParams::VectorizationFactor, HK_WIDTH), + Interleave("interleave.count", DisableInterleaving, HK_UNROLL), + Force("vectorize.enable", FK_Undefined, HK_FORCE), TheLoop(L), ORE(ORE) { + + // Populate values with existing loop metadata. + getHintsFromMetadata(); + + // force-vector-interleave overrides DisableInterleaving. + if (VectorizerParams::isInterleaveForced()) + Interleave.Value = VectorizerParams::VectorizationInterleave; + + LLVM_DEBUG(if (DisableInterleaving && Interleave.Value == 1) dbgs() + << "LVA: Interleaving disabled by the pass manager\n"); +} + +bool SLVLoopVectorizeHints::allowVectorization(Function *F, Loop *L, + bool AlwaysVectorize) const { + if (getForce() == SLVLoopVectorizeHints::FK_Disabled) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: #pragma vectorize disable.\n"); + emitRemarkWithHints(); + return false; + } + + if (!AlwaysVectorize && getForce() != SLVLoopVectorizeHints::FK_Enabled) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: No #pragma vectorize enable.\n"); + emitRemarkWithHints(); + return false; + } + + if (getWidth() == 1 && getInterleave() == 1) { + // TODO: As below. + // FIXME: Add a separate metadata to indicate when the loop has already + // been vectorized instead of setting width and count to 1. + LLVM_DEBUG(dbgs() << "LVA: Not vectorizing: Disabled/already vectorized.\n"); + + // FIXME: Add interleave.disable metadata. This will allow + // vectorize.disable to be used without disabling the pass and errors + // to differentiate between disabled vectorization and a width of 1. + ORE.emit(OptimizationRemarkAnalysis(vectorizeAnalysisPassName(), + "AllDisabled", L->getStartLoc(), + L->getHeader()) + << "loop not vectorized: vectorization and interleaving are " + "explicitly disabled, or vectorize width and interleave " + "count are both set to 1"); + + return false; + } + + return true; +} + +/// Dumps all the hint information. +void SLVLoopVectorizeHints::emitRemarkWithHints() const { + using namespace ore; + if (Force.Value == SLVLoopVectorizeHints::FK_Disabled) + ORE.emit(OptimizationRemarkMissed(SLV_NAME, "MissedExplicitlyDisabled", + TheLoop->getStartLoc(), + TheLoop->getHeader()) + << "loop not vectorized: vectorization is explicitly disabled"); + else { + OptimizationRemarkMissed R(SLV_NAME, "MissedDetails", + TheLoop->getStartLoc(), TheLoop->getHeader()); + R << "loop not vectorized"; + if (Force.Value == SLVLoopVectorizeHints::FK_Enabled) { + R << " (Force=" << NV("Force", true); + if (Width.Value != 0) + R << ", Vector Width=" << NV("VectorWidth", Width.Value); + if (Interleave.Value != 0) + R << ", Interleave Count=" << NV("InterleaveCount", Interleave.Value); + R << ")"; + } + ORE.emit(R); + } +} + +const char *SLVLoopVectorizeHints::vectorizeAnalysisPassName() const { + // If hints are provided that don't disable vectorization use the + // AlwaysPrint pass name to force the frontend to print the diagnostic. + // TODO: Parameterize based on calling vectorizer. + if (getWidth() == 1) + return SLV_NAME; + if (getForce() == SLVLoopVectorizeHints::FK_Disabled) + return SLV_NAME; + if (getForce() == SLVLoopVectorizeHints::FK_Undefined && getWidth() == 0) + return SLV_NAME; + return OptimizationRemarkAnalysis::AlwaysPrint; +} + +void SLVLoopVectorizeHints::getHintsFromMetadata() { + MDNode *LoopID = TheLoop->getLoopID(); + if (!LoopID) + return; + + // First operand should refer to the loop id itself. + assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); + assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); + + // TODO: Range based possible? + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + const MDString *S = nullptr; + SmallVector Args; + + // The expected hint is either a MDString or a MDNode with the first + // operand a MDString. + if (const MDNode *MD = dyn_cast(LoopID->getOperand(i))) { + if (!MD || MD->getNumOperands() == 0) + continue; + S = dyn_cast(MD->getOperand(0)); + for (unsigned i = 1, ie = MD->getNumOperands(); i < ie; ++i) + Args.push_back(MD->getOperand(i)); + } else { + S = dyn_cast(LoopID->getOperand(i)); + assert(Args.size() == 0 && "too many arguments for MDString"); + } + + if (!S) + continue; + + // Check if the hint starts with the loop metadata prefix. + StringRef Name = S->getString(); + if (Args.size() == 1) + setHint(Name, Args[0]); + } +} + +void SLVLoopVectorizeHints::setHint(StringRef Name, Metadata *Arg) { + if (!Name.startswith(Prefix())) + return; + Name = Name.substr(Prefix().size(), StringRef::npos); + + const ConstantInt *C = mdconst::dyn_extract(Arg); + if (!C) return; + unsigned Val = C->getZExtValue(); + + Hint *Hints[] = {&Width, &Interleave, &Force}; + for (auto H : Hints) { + if (Name == H->Name) { + if (H->validate(Val)) + H->Value = Val; + else + LLVM_DEBUG(dbgs() << "LVA: ignoring invalid hint '" << Name << "'\n"); + break; + } + } +} + +MDNode *SLVLoopVectorizeHints::createHintMetadata(StringRef Name, + unsigned V) const { + LLVMContext &Context = TheLoop->getHeader()->getContext(); + Metadata *MDs[] = {MDString::get(Context, Name), + ConstantAsMetadata::get( + ConstantInt::get(Type::getInt32Ty(Context), V))}; + return MDNode::get(Context, MDs); +} + +bool SLVLoopVectorizeHints::matchesHintMetadataName(MDNode *Node, + ArrayRef HintTypes) { + MDString* Name = dyn_cast(Node->getOperand(0)); + if (!Name) + return false; + + for (auto H : HintTypes) + if (Name->getString().endswith(H.Name)) + return true; + return false; +} + +void SLVLoopVectorizeHints::writeHintsToMetadata(ArrayRef HintTypes) { + if (HintTypes.size() == 0) + return; + + // Reserve the first element to LoopID (see below). + SmallVector MDs(1); + // If the loop already has metadata, then ignore the existing operands. + MDNode *LoopID = TheLoop->getLoopID(); + if (LoopID) { + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + MDNode *Node = cast(LoopID->getOperand(i)); + // If node in update list, ignore old value. + if (!matchesHintMetadataName(Node, HintTypes)) + MDs.push_back(Node); + } + } + + // Now, add the missing hints. + for (auto H : HintTypes) + MDs.push_back(createHintMetadata(Twine(Prefix(), H.Name).str(), H.Value)); + + // Replace current metadata node with new one. + LLVMContext &Context = TheLoop->getHeader()->getContext(); + MDNode *NewLoopID = MDNode::get(Context, MDs); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + + TheLoop->setLoopID(NewLoopID); +} + +//////////////////////////////////////////////////////////////////////////////// +// LoopVectorizationRequirements +//////////////////////////////////////////////////////////////////////////////// + +bool LoopVectorizationRequirements::doesNotMeet(Function *F, Loop *L, + const SLVLoopVectorizeHints &Hints) { +//const char *Name = Hints.vectorizeAnalysisPassName(); + bool Failed = false; + if (UnsafeAlgebraInst && !Hints.allowReordering()) { +// emitOptimizationRemarkAnalysisFPCommute( +// F->getContext(), Name, *F, UnsafeAlgebraInst->getDebugLoc(), +// VectorizationReport() << "cannot prove it is safe to reorder " +// "floating-point operations"); + Failed = true; + } + + // Test if runtime memcheck thresholds are exceeded. + bool PragmaThresholdReached = + NumRuntimePointerChecks > PragmaVectorizeMemoryCheckThreshold; + bool ThresholdReached = + NumRuntimePointerChecks > VectorizerParams::RuntimeMemoryCheckThreshold; + if ((ThresholdReached && !Hints.allowReordering()) || + PragmaThresholdReached) { +// emitOptimizationRemarkAnalysisAliasing( +// F->getContext(), Name, *F, L->getStartLoc(), +// VectorizationReport() +// << "cannot prove it is safe to reorder memory operations"); + + LLVM_DEBUG(dbgs() << "LVA: Too many memory checks needed.\n"); + Failed = true; + } + + return Failed; +} + +//////////////////////////////////////////////////////////////////////////////// +// SLVLoopVectorizationLegality +//////////////////////////////////////////////////////////////////////////////// +bool SLVLoopVectorizationLegality::isLegalMaskedStore(Type *DataType, Value *Ptr) { + return isConsecutivePtr(Ptr) && TTI->isLegalMaskedStore(DataType); +} + +bool SLVLoopVectorizationLegality::isLegalMaskedLoad(Type *DataType, Value *Ptr) { + return isConsecutivePtr(Ptr) && TTI->isLegalMaskedLoad(DataType); +} + +int SLVLoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { + assert(Ptr->getType()->isPointerTy() && "Unexpected non-ptr"); + // Make sure that the pointer does not point to structs. + if (Ptr->getType()->getPointerElementType()->isAggregateType()) + return 0; + + // If this value is a pointer induction variable we know it is consecutive. + PHINode *Phi = dyn_cast_or_null(Ptr); + if (Phi && Inductions.count(Phi)) { + InductionDescriptor II = Inductions[Phi]; + return II.getConsecutiveDirection(); + } + + // Look passed casts that have no affect on address generation. + auto *BC = dyn_cast_or_null(Ptr); + if (BC && BC->getSrcTy()->isPointerTy()) { + const DataLayout &DL = BC->getModule()->getDataLayout(); + Type *DstEltTy = cast(BC->getDestTy())->getElementType(); + Type *SrcEltTy = cast(BC->getSrcTy())->getElementType(); + + if (DL.getTypeAllocSize(DstEltTy) != DL.getTypeAllocSize(SrcEltTy)) + return 0; + + Ptr = BC->getOperand(0); + } + + GetElementPtrInst *Gep = getGEPInstruction(Ptr); + if (!Gep) + return 0; + + ScalarEvolution *SE = PSE.getSE(); + unsigned NumOperands = Gep->getNumOperands(); + Value *GpPtr = Gep->getPointerOperand(); + // If this GEP value is a consecutive pointer induction variable and all of + // the indices are constant then we know it is consecutive. We can + Phi = dyn_cast(GpPtr); + if (Phi && Inductions.count(Phi)) { + + // Make sure that the pointer does not point to structs. + PointerType *GepPtrType = cast(GpPtr->getType()); + if (GepPtrType->getElementType()->isAggregateType()) + return 0; + + // Make sure that all of the index operands are loop invariant. + for (unsigned i = 1; i < NumOperands; ++i) + if (!SE->isLoopInvariant(SE->getSCEV(Gep->getOperand(i)), TheLoop)) + return 0; + + InductionDescriptor II = Inductions[Phi]; + return II.getConsecutiveDirection(); + } + + unsigned InductionOperand = getGEPInductionOperand(Gep); + + // Check that all of the gep indices are uniform except for our induction + // operand. + for (unsigned i = 0; i != NumOperands; ++i) + if (i != InductionOperand && + !SE->isLoopInvariant(SE->getSCEV(Gep->getOperand(i)), TheLoop)) + return 0; + + // We can emit wide load/stores only if the last non-zero index is the + // induction variable. + const SCEV *Last = nullptr; + if (!Strides.count(Gep)) + Last = SE->getSCEV(Gep->getOperand(InductionOperand)); + else { + // Because of the multiplication by a stride we can have a s/zext cast. + // We are going to replace this stride by 1 so the cast is safe to ignore. + // + // %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + // %0 = trunc i64 %indvars.iv to i32 + // %mul = mul i32 %0, %Stride1 + // %idxprom = zext i32 %mul to i64 << Safe cast. + // %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom + // + Last = replaceSymbolicStrideSCEV(PSE, Strides, + Gep->getOperand(InductionOperand), Gep); + if (const SCEVCastExpr *C = dyn_cast(Last)) + Last = + (C->getSCEVType() == scSignExtend || C->getSCEVType() == scZeroExtend) + ? C->getOperand() + : Last; + } + if (const SCEVAddRecExpr *AR = dyn_cast(Last)) { + const SCEV *Step = AR->getStepRecurrence(*SE); + + // The memory is consecutive because the last index is consecutive + // and all other indices are loop invariant. + if (Step->isOne()) + return 1; + if (Step->isAllOnesValue()) + return -1; + + // Try and find a different constant stride + if (EnableNonConsecutiveStrideIndVars) { + if (const SCEVConstant *SCC = dyn_cast(Step)) { + const ConstantInt *CI = SCC->getValue(); + // TODO: Error checking vs. INT_MAX? + return (int)CI->getLimitedValue(INT_MAX); + } + } + } + + return 0; +} + +bool SLVLoopVectorizationLegality::isUniform(Value *V) { + return LAI->isUniform(V); +} + +bool SLVLoopVectorizationLegality::canVectorizeWithIfConvert() { + bool CanIfConvert = true; + + if (!EnableIfConversion) { +// ORE->emit(createMissedAnalysis("IfConversionDisabled") +// << "if-conversion is disabled"); + CanIfConvert = false; + NODEBUG_EARLY_BAILOUT(); + } + + assert(TheLoop->getNumBlocks() > 1 && "Single block loops are vectorizable"); + + // A list of pointers that we can safely read and write to. + SmallPtrSet SafePointes; + + // Collect safe addresses. + for (Loop::block_iterator BI = TheLoop->block_begin(), + BE = TheLoop->block_end(); BI != BE; ++BI) { + BasicBlock *BB = *BI; + + if (blockNeedsPredication(BB)) + continue; + + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + if (LoadInst *LI = dyn_cast(I)) + SafePointes.insert(LI->getPointerOperand()); + else if (StoreInst *SI = dyn_cast(I)) + SafePointes.insert(SI->getPointerOperand()); + } + } + + // Collect the blocks that need predication. + BasicBlock *Header = TheLoop->getHeader(); + for (Loop::block_iterator BI = TheLoop->block_begin(), + BE = TheLoop->block_end(); BI != BE; ++BI) { + BasicBlock *BB = *BI; + + // We don't support switch statements inside loops. + if (!isa(BB->getTerminator())) { +// emitAnalysis(VectorizationReport(BB->getTerminator()) +// << "loop contains a switch statement"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - loop contains a switch statement.\n"); + CanIfConvert = false; + NODEBUG_EARLY_BAILOUT(); + } + + // We must be able to predicate all blocks that need to be predicated. + if (blockNeedsPredication(BB)) { + if (!blockCanBePredicated(BB, SafePointes)) { +// emitAnalysis(VectorizationReport(BB->getTerminator()) +// << "control flow cannot be substituted for a select"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - cannot predicate all blocks for if-conversion.\n"); + CanIfConvert = false; + NODEBUG_EARLY_BAILOUT(); + } + } else if (BB != Header && !canIfConvertPHINodes(BB)) { +// emitAnalysis(VectorizationReport(BB->getTerminator()) +// << "control flow cannot be substituted for a select"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - phi nodes cannot be if converted.\n"); + CanIfConvert = false; + NODEBUG_EARLY_BAILOUT(); + } + } + + // We can if-convert this loop. + return CanIfConvert; +} + +// Find break conditions that can be safely moved to the top of the loop body. +// Conditions: +// - Cannot be a PHI from something else than the loop header +// (meaning: it must be a SCEVable induction variable) +// - Cannot be an intrinsic with possible side effects +// - Cannot be a load if there may be a write in between. +bool SLVLoopVectorizationLegality::findConditionSubExprsRecurse( + Value *V, ConditionExprs &SubExprs) { + // Prevent unnecessary work + if (std::find(SubExprs.begin(),SubExprs.end(),V) != SubExprs.end()) + return true; + + Instruction *I = dyn_cast(V); + if ((I && !TheLoop->contains(I)) || isa(V) || + isa(V)) { + SubExprs.push_back(V); // Must be loop invariant + return true; + } + + // Otherwise it must be an instruction + if (!I) + return false; + + // Only allow call instruction without side-effects. + if (auto *CI = dyn_cast(I)) + if (!CI->doesNotAccessMemory()) + return false; + + if (auto *LI = dyn_cast(I)) { + // TODO: Currently we only allow consecutive loads with positive + // stride. + // 1. It needs to be consecutive because we do not yet + // have first faulting gathers implemented. + // 2. It needs to be a positive stride because first faulting + // consecutive loads do not work with reversed loads: + // For example: + // int foo(void) { + // for(int hp=14; incs[hp] > 0; hp--); + // return hp; + // } + // + // The problem with this is that we use first faulting loads to + // load 'incs[hp]'. This works absolutely fine when hp is incrementing, + // since the next value to load is one value next to the last loaded + // (and verified) element in the vector, e.g. + // V0 V1 + // ________________ ________________ + // [ e0, e1, e2, e3 ] [ e4, e5, e6, e7 ] + // ^^ + // last element loaded in V0, + // + // V0 is used to compare against 'incs[hp]' to decide whether e4..e7 + // will be loaded for a second iteration. The distance between 'e3' + // and 'e4' is 1. + // + // In the decrementing case, we need to load e7..e4, which is done by + // loading e4..e7 and reversing the result. The next vector we load + // will be e0..e3. The distance between e0 and e4 is 4. + // If 'e0' would signal a fault (first faulting behaviour) + // then we cannot process e1..e3. + // + // Possible solution: Use first faulting gathers and do not reverse the + // order of the load addresses like we do above. + if (isConsecutivePtr(LI->getPointerOperand()) <= 0) + return false; + + if (!LI->isSimple()) + return false; + } + + ScalarEvolution *SE = PSE.getSE(); + // PHI nodes are allowed only if they are loop SCEVs + if (isa(I)) { + // No support for complex data flow in loop + if (I->getParent() != TheLoop->getHeader()) + return false; + + // We must be able to tell something about this PHI node + if (!SE->isSCEVable(I->getType())) + return false; + + // Should be either LIV or computable + const SCEV *SC = SE->getSCEV(I); + if (!SE->isLoopInvariant(SC, TheLoop) && + SC == SE->getCouldNotCompute()) + return false; + + // Safe + SubExprs.push_back(I); + return true; + } + + // Test all operands + bool OperandsSafe = true; + for (Value *Op: I->operands()) + OperandsSafe &= findConditionSubExprsRecurse(Op, SubExprs); + + if (!OperandsSafe) + return false; + + SubExprs.push_back(V); + return true; +} + +bool SLVLoopVectorizationLegality::findConditionSubExprs(Value *V, + ConditionExprs &SubExprs) { + // Prevent unnecessary work + if (std::find(SubExprs.begin(),SubExprs.end(),V) != SubExprs.end()) + return true; + + // Recurse to find subexpressions + if (!findConditionSubExprsRecurse(V, SubExprs)) + return false; + + // It is always safe IR if V is not an instruction + auto *VI = dyn_cast(V); + if (!VI) + return true; + + // If there are any Load Instructions in the SubExprs, make sure it does not + // read anything that may be written to in between. + SmallVector Loads; + for (Value *L : SubExprs) { + if (auto *LI = dyn_cast(L)) { + if (TheLoop->contains(LI)) + Loads.push_back(LI); + } + } + + // If there are no loads, all is safe + if (!Loads.size()) + return true; + + + SmallVector Stores; + SmallVector Descs; + auto *LatchBr = + dyn_cast(TheLoop->getLoopLatch()->getTerminator()); + + // Find all stores from the loop header upto V, except for the condition + // in the LatchBr, because that condition is not moved 'up'. + if (!LatchBr->isConditional() || V != LatchBr->getCondition()) { + PDT->getDescendants(VI->getParent(), Descs); + for (auto *BB : Descs) { + // FIXME: Check if this is not too restrictive. + if (!TheLoop->contains(BB)) + break; + BasicBlock::reverse_iterator I, E; + for (I = BB->rbegin(), E = BB->rend(); I != E; ++I) { + if (StoreInst *SI = dyn_cast(&*I)) + Stores.push_back(SI); + } + } + } + + // If any of the loads may alias with one of the stores, be conservative. + for (auto *LI : Loads) { + for (auto *SI : Stores) { + // If it is an exact alias, we can take the value of that store. + AliasResult AR = AA->alias(LI->getPointerOperand(), + SI->getPointerOperand()); + if (AR == MustAlias) { + // Recursion, updates SubExprs + if (!findConditionSubExprs(SI->getValueOperand(), SubExprs)) + return false; + } else if (AR != NoAlias) + return false; + } + } + + // All is safe + return true; +} + +bool SLVLoopVectorizationLegality::canVectorizeExits() { + bool CanVectorize = true; + + // TODO: Exit info struct separate from branch -> condition mapping? + // TODO: Interaction with if-conversion? + // TODO: When vectorizing instructions, check branch for if-convert + // vs. exit mapping; change predicate and adjust mapping + // in the latter case. + + SmallVector ExitingBlocks; + TheLoop->getExitingBlocks(ExitingBlocks); + bool FoundCountedExit = false; + + auto DFSCompare = [this](BasicBlock *A, BasicBlock*B) { + return DFS->getPostorder(A) >= DFS->getPostorder(B); + }; + std::sort(ExitingBlocks.begin(), ExitingBlocks.end(), DFSCompare); + + // If we're only allowing counted loops, then we currently have a stricter + // set of constraints before we can vectorize. + // Counted loops currently only allow for a single exit block. + if (!AllowUncounted && ExitingBlocks.size() != 1) { + CanVectorize = false; +// emitAnalysis(VectorizationReport() << +// "loop control flow is not understood by vectorizer"); + LLVM_DEBUG(dbgs() << "LVA: Not vectorizing - multiple exit blocks.\n"); + NODEBUG_EARLY_BAILOUT(); + } + + // Check each exiting block + for (auto *EB: ExitingBlocks) { + BranchInst *Br = dyn_cast(EB->getTerminator()); + if (!Br || !Br->isConditional()) { + // No idea what to do with something that isn't a condbr yet + CanVectorize = false; + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - exit is not a conditional branch\n"); + NODEBUG_EARLY_BAILOUT(); + continue; + } + + // TODO: Use getExitEdges here? + BasicBlock *ExitBlock = nullptr; + for (unsigned i = 0; i < Br->getNumSuccessors(); ++i) + if (!TheLoop->contains(Br->getSuccessor(i))) + ExitBlock = Br->getSuccessor(i); + + assert(ExitBlock && + "Unable to find exit block from BranchInst of exiting block\n"); + + bool AllLIV = true; + PHINode *LastPhi = nullptr; + for (BasicBlock::iterator I = ExitBlock->begin(); I != ExitBlock->end() && + isa(I); ++I) { + LastPhi = cast(I); + Value *EV = LastPhi->getIncomingValueForBlock(EB); + auto *EI = dyn_cast(EV); + if (EI && TheLoop->contains(EI)) { + AllLIV = false; + break; + } + } + + if (AllLIV && LastPhi) { + // This must be a mergenode. + // TODO: Check the following assumption is true, could have picked + // the wrong PHI node. + (void) EF->CreateEscapee(LastPhi); + } + + // TODO: Figure out a better way of doing this. See ggHMatrix3::type in + // eon for a case where the immediate exit block doesn't have a phi, + // but its (sole) successor has a phi on the exit block and uses a + // constant as the value. + // Probably also lies with getEscapeeValuesFromMergeNode, though + // it may be difficult to tell what the eventual merge node is, + // since many other loops contribute to it. + if (!LastPhi) { + CanVectorize = false; + LLVM_DEBUG(dbgs() << "LVA: Not vectorizing - " + "unable to resolve exit phi\n"); + NODEBUG_EARLY_BAILOUT(); + } + + ConditionExprs SubExprs; + if (!findConditionSubExprs(Br->getCondition(), SubExprs)) { + CanVectorize = false; + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - unable to safely move condition to top of " + "loop body.\n"); + NODEBUG_EARLY_BAILOUT(); + } + + // Check for a counted exit. + if (PSE.getSE()->getExitCount(TheLoop, EB) != PSE.getSE()->getCouldNotCompute()) { + // Only permit one at this time. + if (!FoundCountedExit && EB == TheLoop->getLoopLatch()) { + FoundCountedExit = true; + Exits.emplace_back(EK_Counted, EB, ExitBlock, SubExprs); + LLVM_DEBUG(dbgs() << "LVA: Adding counted exit: " << + *(EB->getTerminator()) << "\n"); + continue; + } + } + + // Bail out unless uncounted loops are explicitly allowed. + if (!AllowUncounted) { + CanVectorize = false; +// emitAnalysis(VectorizationReport() << +// "could not determine number of loop iterations"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - SCEV could not compute the loop " + "exit count.\n"); + NODEBUG_EARLY_BAILOUT(); + continue; + } + + if (Br->getParent() != TheLoop->getLoopLatch() && + !DT->dominates(Br, TheLoop->getLoopLatch())) { + CanVectorize = false; + LLVM_DEBUG(dbgs() << "LVA: Not vectorizing - Exiting block does not " + "dominate latch\n"); + NODEBUG_EARLY_BAILOUT(); + continue; + } + + if (CanVectorize) { + Exits.emplace_back(EK_LoadCompare, EB, ExitBlock, SubExprs); + LLVM_DEBUG(dbgs() << "LVA: Adding uncounted exit: " << + *(EB->getTerminator()) << "\n"); + IsUncounted = true; + } + } + + return CanVectorize; +} + +static TargetTransformInfo::ReductionFlags +getReductionFlagsFromDesc(RecurrenceDescriptor Rdx) { + using RD = RecurrenceDescriptor; + RD::RecurrenceKind RecKind = Rdx.getRecurrenceKind(); + TargetTransformInfo::ReductionFlags Flags; + Flags.IsOrdered = Rdx.isOrdered(); + if (RecKind == RD::RK_IntegerMinMax || RecKind == RD::RK_FloatMinMax) { + auto MMKind = Rdx.getMinMaxRecurrenceKind(); + Flags.IsSigned = MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_SIntMin; + Flags.IsMaxOp = MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_FloatMax; + } else if (RecKind == RD::RK_ConstSelectICmp || + RecKind == RD::RK_ConstSelectFCmp) { + auto MMKind = Rdx.getMinMaxRecurrenceKind(); + Flags.IsSigned = true; + Flags.IsMaxOp = MMKind == RD::MRK_SIntMax; + } + return Flags; +} + +bool SLVLoopVectorizationLegality::canVectorizeInstrs() { + BasicBlock *Header = TheLoop->getHeader(); + + bool CanVectorize = true; + ScalarEvolution *SE = PSE.getSE(); + + // Look for the attribute signaling the absence of NaNs. + Function &F = *Header->getParent(); + const DataLayout &DL = F.getParent()->getDataLayout(); + if (F.hasFnAttribute("no-nans-fp-math")) + HasFunNoNaNAttr = + F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; + + // For each block in the loop. + for (Loop::block_iterator bb = TheLoop->block_begin(), + be = TheLoop->block_end(); bb != be; ++bb) { + // Scan the instructions in the block and look for hazards. + for (BasicBlock::iterator it = (*bb)->begin(), e = (*bb)->end(); it != e; + ++it) { + + if (PHINode *Phi = dyn_cast(it)) { + Type *PhiTy = Phi->getType(); + // Check that this PHI type is allowed. + if (!PhiTy->isIntegerTy() && + !PhiTy->isFloatingPointTy() && + !PhiTy->isPointerTy()) { +// emitAnalysis(VectorizationReport(&*it) +// << "loop control flow is not understood by vectorizer"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - Found an non-int non-pointer PHI.\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + } + + // If this PHINode is not in the header block, then we know that we + // can convert it to select during if-conversion. No need to check if + // the PHIs in this block are induction or reduction variables. + if (*bb != Header) { + // Check that this instruction has no outside users or is an + // identified reduction value with an outside user. + // TODO: For now, we ignore this case with uncounted loops and just + // focus on phis created in the header block. + if (!hasOutsideLoopUser(TheLoop, &*it, AllowedExit)) + continue; +// emitAnalysis(VectorizationReport(&*it) << +// "value could not be identified as " +// "an induction or reduction variable"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - if-convertible induction phi used outside loop.\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + continue; + } + + // We only allow if-converted PHIs with exactly two incoming values. + if (Phi->getNumIncomingValues() != 2) { +// emitAnalysis(VectorizationReport(&*it) +// << "control flow not understood by vectorizer"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - Phi with more than two incoming values.\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + continue; + } + + InductionDescriptor ID; + if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID)) { + LLVM_DEBUG(dbgs() << "isInductionPHI(" << *Phi << ")\n"); + Inductions[Phi] = ID; + // Get the widest type. + if (!WidestIndTy) + WidestIndTy = convertPointerToIntegerType(DL, PhiTy); + else + WidestIndTy = getWiderType(DL, PhiTy, WidestIndTy); + + // Int inductions are special because we only allow one IV. + if (ID.getKind() == InductionDescriptor::IK_IntInduction && + ID.getConstIntStepValue() && + ID.getConstIntStepValue()->isOne() && + isa(ID.getStartValue()) && + cast(ID.getStartValue())->isNullValue()) { + // Use the phi node with the widest type as induction. Use the last + // one if there are multiple (no good reason for doing this other + // than it is expedient). We've checked that it begins at zero and + // steps by one, so this is a canonical induction variable. + if (!Induction || PhiTy == WidestIndTy) + Induction = Phi; + } + + LLVM_DEBUG(dbgs() << "LVA: Found an induction variable " << *Phi << "\n"); + + // Until we explicitly handle the case of an induction variable with + // an outside loop user we have to give up vectorizing this loop. + if (hasOutsideLoopUser(TheLoop, &*it, AllowedExit)) { + Escapee *Esc; + if (!EF->canVectorizeEscapeeValue(Phi, Esc)) { +// emitAnalysis(VectorizationReport(&*it) << +// "use of induction value outside of the " +// "loop is not handled by vectorizer"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - induction phi used outside loop.\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + } + } + + continue; + } + + if (RecurrenceDescriptor::isReductionPHI(Phi, TheLoop, SE, + Reductions[Phi], + AllowUncounted)) { + LLVM_DEBUG(dbgs() << "isReductionPHI(" << *Phi << ")\n"); + if (Reductions[Phi].hasUnsafeAlgebra()) + Requirements->addUnsafeAlgebraInst( + Reductions[Phi].getUnsafeAlgebraInst()); + + // TODO: There has to be a nicer way to do this? + RecurrenceDescriptor::ExitInstrList::iterator I, E; + RecurrenceDescriptor::ExitInstrList *EIList = + Reductions[Phi].getLoopExitInstrs(); + + // All reduction variables are also escapees! (Or at least, they + // should be, so add a check that the number of ExitInstructions + // equals the number of Exits). + // Reuse the same logic by adding the merge nodes of a reduction + // to the Escapee stuff. + + // Get exiting blocks + SmallVector Exits; + TheLoop->getExitingBlocks(Exits); + if (IsUncounted && + Reductions[Phi].getLoopExitInstrs()->size() != Exits.size()) { +// emitAnalysis(VectorizationReport(&*it) << +// "not all exits represented by reduction exit " +// "instructions, so it does not match the criteria " +// "for an escapee. Cannot vectorize."); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - reduction phi does not classify as " + "an escapee\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + + } else { + // If this is a valid reductio, we need to create an Escapee + // object for it as well. + if (!EF->CreateEscapee(Phi, Reductions[Phi])) { + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + continue; + } + + for (I = EIList->begin(), E = EIList->end(); I != E; ++I) + AllowedExit.insert(*I); + + LLVM_DEBUG(dbgs() << "LVA: Found a reduction variable " << *Phi << "\n"); + + if (!ScalarizedReduction) { + auto RD = Reductions[Phi]; + auto Kind = RD.getRecurrenceKind(); + auto Flags = getReductionFlagsFromDesc(RD); + Flags.NoNaN = hasNoNaNAttr(); + ScalarizedReduction = !TTI->canReduceInVector( + RecurrenceDescriptor::getRecurrenceBinOp(Kind), + RD.getRecurrenceType(), Flags); + } + continue; + } + } + +// emitAnalysis(VectorizationReport(&*it) << +// "value that could not be identified as " +// "reduction is used outside the loop"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - unidentified phi " << *Phi << + " found (not a known reduction).\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + continue; + }// end of PHI handling + + // We handle calls that: + // * Are debug info intrinsics. + // * Have a mapping to an IR intrinsic. + // * Have a vector version available. + CallInst *CI = dyn_cast(it); + if (CI && !getVectorIntrinsicIDForCall(CI, TLI) && !isa(CI) && + !(CI->getCalledFunction() && TLI && + TLI->isFunctionVectorizable(CI->getCalledFunction()->getName()))) { +// emitAnalysis(VectorizationReport(&*it) << +// "call instruction " << *CI << "cannot be vectorized"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - found a non-intrinsic, non-libfunc callsite " << + *CI << "\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + continue; + } + + // Intrinsics such as powi,cttz and ctlz are legal to vectorize if the + // second argument is the same (i.e. loop invariant) + if (CI && + hasVectorInstrinsicScalarOpd(getVectorIntrinsicIDForCall(CI, TLI), 1)) { + if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(1)), TheLoop)) { +// emitAnalysis(VectorizationReport(&*it) +// << "intrinsic instruction cannot be vectorized"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - found unvectorizable intrinsic " << *CI << "\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + continue; + } + } + + // Check that the instruction return type is vectorizable. + // Also, we can't vectorize extractelement instructions. + if ((!VectorType::isValidElementType(it->getType()) && + !it->getType()->isVoidTy()) || + it->getType()->isFP128Ty() || + isa(it)) { +// emitAnalysis(VectorizationReport(&*it) +// << "instruction return type cannot be vectorized"); + LLVM_DEBUG(dbgs() << "LVA: Found unvectorizable type.\n"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - found unvectorizable type " << + *(it->getType()) << "\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + continue; + } + + // Check that the stored type is vectorizable. + if (StoreInst *ST = dyn_cast(it)) { + Type *T = ST->getValueOperand()->getType(); + if (!VectorType::isValidElementType(T) || it->getType()->isFP128Ty()) { +// emitAnalysis(VectorizationReport(ST) << +// "store instruction cannot be vectorized"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - cannot vectorize store instruction " << + *ST << "\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + continue; + } + } + + // We do allow some escapees, especially for reductions + if (hasOutsideLoopUser(TheLoop, &*it, AllowedExit)) { + Escapee *Esc; + if (!EF->canVectorizeEscapeeValue(&*it, Esc)) { +// emitAnalysis(VectorizationReport(&*it) << +// "use of induction value outside of the " +// "loop is not handled by vectorizer"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - induction phi used outside loop.\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + } + } + } // next instr. + + } + + if (!Induction) { + LLVM_DEBUG(dbgs() << "LVA: Did not find one integer induction var.\n"); + if (Inductions.empty()) { +// emitAnalysis(VectorizationReport() +// << "loop induction variable could not be identified"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - unable to identify loop induction variable.\n"); + CanVectorize = false; + } + } + + // Now we know the widest induction type, check if our found induction + // is the same size. If it's not, unset it here and InnerLoopVectorizer + // will create another. + if (Induction && WidestIndTy != Induction->getType()) + Induction = nullptr; + + return CanVectorize; +} + +void EscapeeFactory::getEscapeeValuesFromMergeNode(PHINode *Merge, + SmallVectorImpl &Values) { + unsigned NumLoopExitValues = 0; + for (unsigned I=0; I < Merge->getNumIncomingValues(); ++I) { + BasicBlock *BB = Merge->getIncomingBlock(I); + if (DFS->getLoop()->contains(BB)) + NumLoopExitValues++; + } + + bool LookThroughPHIs = NumLoopExitValues > 1; + + // For each value incoming to the PHI + for (unsigned I=0; I < Merge->getNumIncomingValues(); ++I) { + Value *Value = Merge->getIncomingValue(I); + BasicBlock *BB = Merge->getIncomingBlock(I); + + // If it is an exit value from the loop, add it to list + if (DFS->getLoop()->contains(BB)) { + Values.push_back(std::tie(BB, Value)); + continue; + } + + // Otherwise, check if it comes directly from an otherwise empty PHI + auto *ValuePhi = dyn_cast(Value); + + // Only a direct PHI node is allowed + if (!ValuePhi || !LookThroughPHIs) + continue; + + // This must have only one incoming edge + // (could this be an assert?) + if (ValuePhi->getNumIncomingValues() != 1) + continue; + + auto *IncomingVal = ValuePhi->getIncomingValue(0); + BasicBlock *IncomingBB = ValuePhi->getIncomingBlock(0); + + if (!isa(IncomingVal)) { + Values.push_back(std::tie(IncomingBB, IncomingVal)); + continue; + } + + // If incoming value is not an exit value, skip + if (!DFS->getLoop()->contains(IncomingBB)) + continue; + + // We only allow PHI nodes here, otherwise code from + // this block needs to be executed in the vector tail. + if (!isa(IncomingBB->getFirstNonPHI())) + continue; + + Values.push_back(std::tie(IncomingBB, IncomingVal)); + } +} + +Escapee* EscapeeFactory::CreateEscapee(PHINode *Recurrence, + RecurrenceDescriptor &RD) { + Escapee *Res = nullptr; + Instruction *Last = RD.getLoopExitInstrs()->back(); + + // If the reduction has a PHI merge node, we use that logic. + if (canVectorizeEscapeeValue(Last, Res)) { + Res->setStore(RD.IntermediateStore); + Res->IsReduction = true; + return Res; + } + + // TODO: We should be able to function without an intermediate + // store. mesa from spec2K will trigger this, as will about + // 20 benchmarks. Place an assert here instead to find them easily. + // Otherwise it must have an intermediate store + if (!RD.IntermediateStore) + return nullptr; + + SmallVector Values; + Values.push_back(Recurrence); // add original PHI value + Values.append(RD.getLoopExitInstrs()->begin(), RD.getLoopExitInstrs()->end()); + + // Sort exiting blocks in DFS pre-order + auto DFSCompare = [this](Instruction *A, Instruction*B) { + return DFS->getPostorder(A->getParent()) >= + DFS->getPostorder(B->getParent()); + }; + std::sort(Values.begin() + 1, Values.end(), DFSCompare); + + // Create a new Escapee and store it in cache + Res = new Escapee(RD.IntermediateStore, Values.begin(), Values.end(), true); + Escapees.insert(std::make_pair(RD.IntermediateStore, Res)); + + return Res; +} + +Escapee* EscapeeFactory::CreateEscapee(PHINode *Merge) { + // First collect the loop exit blocks to the Escapee Merge Block + // for this Escapee Merge Node. + SmallVector BBs; + getEscapeeValuesFromMergeNode(Merge, BBs); + + // Sort exiting blocks in DFS pre-order + auto DFSCompare = [this](MergeValTuple A, MergeValTuple B) { + return DFS->getPostorder(std::get<0>(A)) >= DFS->getPostorder(std::get<0>(B)); + }; + std::sort(BBs.begin(), BBs.end(), DFSCompare); + + // For each exit, add value to the Escapee Merge Node + SmallVector Values; + for (auto &BB : BBs) { + Values.push_back(std::get<1>(BB)); + } + + // Create a new Escapee and store it in cache + Escapee *Res = new Escapee(Merge, Values.begin(), Values.end()); + Escapees.insert(std::make_pair(Merge, Res)); + + return Res; +} + +// Determines whether a given Instruction is an Escapee Value. All +// Escapee Values can be vectorized. +// +// Definitions: +// ------------ +// An Escapee Merge Block is: +// - A basicblock that post-dominates a loop and all of its loop exits. +// - No basic block exists that post-dominates all loop exits +// and is dominated by the Escapee Merge Block. +// +// An Escapee Merge Node is: +// - A PHI node that selects Escapee Values from every exit into a single +// value. (note: it can also select values from other blocks) +// - The Escapee Merge Node resides in the Escapee Merge Block +// +// An Escapee Value is: +// - A value defined in the body of a loop +// - which is either a recognized induction variable, reduction variable +// or a non-PHI node, +// - which is used only once outside the loop in an Escapee Merge Node. +bool EscapeeFactory::canVectorizeEscapeeValue(Instruction *Val, Escapee *&Res) { + // Find (and only allow 1) single external user + Instruction *ExternalUse = nullptr; + for (User *U: Val->users()) { + auto *UI = dyn_cast(U); + if (DFS->getLoop()->contains(UI)) + continue; + + if (ExternalUse && ExternalUse != UI) { + LLVM_DEBUG(dbgs() << "LVA: Found multiple external users for candidate " + "Escapee Variable " << *Val << "\n"); + return false; + } + + ExternalUse = UI; + } + + // Require 1 external user + if (!ExternalUse) { + LLVM_DEBUG(dbgs () << "LVA: Did not find any external users for candidate " + "Escapee Variable " << *Val << '\n'); + return false; + } + + // TODO: VTail needs to perform proper transform in order to + // support pointer indvar types... + if (ExternalUse->getType()->isPointerTy()) { + LLVM_DEBUG(dbgs() << + "LVA: External use is a pointer type, vtail needs fixing\n"); + return false; + } + + // Must be a PHI + auto *PN = dyn_cast(ExternalUse); + if (!PN) { + LLVM_DEBUG(dbgs () << "LVA: External user is not a PHI node for candidate " + "Escapee Value " << *Val << '\n'); + return false; + } + + // If this merge node was already analyzed as an escapee, + // we can safely return that. + if (Escapees.count(PN)) { + Res = Escapees[PN]; + return true; + } + + // Get exiting blocks + SmallVector Exits; + DFS->getLoop()->getExitingBlocks(Exits); + + // Look through an otherwise empty PHI node iff there are + // more than 1 loop exits. + if (Exits.size() > 1 && + PN->hasOneUse() && + PN->getNumOperands() == 1 && + isa(*(PN->user_begin()))) + PN = cast(*(PN->user_begin())); + + // Check all loop exit nodes are represented + SmallVector MergeValues; + getEscapeeValuesFromMergeNode(PN, MergeValues); + + if (MergeValues.size() != Exits.size()) { + LLVM_DEBUG(dbgs () << "LVA: Not all exits are represented in PHI node for " + "candidate Escapee Value " << *Val << '\n'); + return false; + } + + // Must post-dominate (may be implicit from the above) + BasicBlock *ExternalBB = ExternalUse->getParent(); + for (BasicBlock *Exit : Exits) { + if (!PDT->dominates(ExternalBB, Exit) && Exit != ExternalBB) { + LLVM_DEBUG(dbgs () << "LVA: PHI node does not post dominate all loop exits " + "for candidate Escapee Value " << *Val << '\n'); + return false; + } + } + + // Get nearest common block from post-dominator graph. + BasicBlock *Nearest = ExternalBB; + if (Exits.size() > 1) { + Nearest = Exits[0]; + for (BasicBlock *Exit : Exits) + Nearest = PDT->findNearestCommonDominator(Nearest, Exit); + } + + // If this is not the nearest... + if (Nearest != ExternalBB) { + LLVM_DEBUG(dbgs () << "LVA: PHI node not nearest common dominator for " + "candidate Escapee Value " << *Val << '\n'); + return false; + } + + // Create the Escapee + Res = CreateEscapee(PN); + return true; +} + +bool SLVLoopVectorizationLegality::canVectorizeMemory() { + LAI = &(*GetLAA)(*TheLoop); + InterleaveInfo.setLAI(LAI); + const OptimizationRemarkAnalysis *LAR = LAI->getReport(); + if (LAR) { + OptimizationRemarkAnalysis VR(Hints->vectorizeAnalysisPassName(), + "loop not vectorized: ", *LAR); + ORE->emit(VR); + } + + if (!LAI->canVectorizeMemory()) + return false; + + if (LAI->hasStoreToLoopInvariantAddress()) { +// emitAnalysis( +// VectorizationReport() +// << "write to a loop invariant address could not be vectorized"); + LLVM_DEBUG(dbgs() << "LVA: We don't allow storing to uniform addresses\n"); + return false; + } + + Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); + PSE.addPredicate(LAI->getPSE().getUnionPredicate()); + + return true; +} + +bool SLVLoopVectorizationLegality::canVectorize() { + bool CanVectorize = true; + + // We must have a loop in canonical form. Loops with indirectbr in them cannot + // be canonicalized. + if (!TheLoop->getLoopPreheader()) { + CanVectorize = false; +// emitAnalysis( +// VectorizationReport() << +// "loop control flow is not understood by vectorizer"); + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - unable to find preheader for original loop.\n"); + NODEBUG_EARLY_BAILOUT(); + } + + // We can only vectorize innermost loops. + if (!TheLoop->empty()) { + CanVectorize = false; +// emitAnalysis(VectorizationReport() << "loop is not the innermost loop"); + LLVM_DEBUG(dbgs() << "LVA: Not vectorizing - not the innermost loop.\n"); + NODEBUG_EARLY_BAILOUT(); + } + + // We must have a single backedge. + if (TheLoop->getNumBackEdges() != 1) { + CanVectorize = false; +// emitAnalysis( +// VectorizationReport() << +// "loop control flow is not understood by vectorizer"); + LLVM_DEBUG(dbgs() << "LVA: Not vectorizing - multiple backedges.\n"); + NODEBUG_EARLY_BAILOUT(); + } + + // We need to have a loop header. + LLVM_DEBUG(dbgs() << "LVA: Found a loop in " << + TheLoop->getHeader()->getParent()->getName() << ": " << + TheLoop->getHeader()->getName() << '\n'); + + // Check if we can if-convert non-single-bb loops. + unsigned NumBlocks = TheLoop->getNumBlocks(); + if (NumBlocks != 1 && !canVectorizeWithIfConvert()) { + CanVectorize = false; + LLVM_DEBUG(dbgs() << "LVA: Not vectorizing - can't if-convert the loop.\n"); + NODEBUG_EARLY_BAILOUT(); + } + + // Check if we can vectorize the instructions and CFG in this loop. + if (!canVectorizeInstrs()) { + CanVectorize = false; + LLVM_DEBUG(dbgs() << + "LVA: Not vectorizing - can't vectorize the instructions or CFG.\n"); + NODEBUG_EARLY_BAILOUT(); + } + + // Go over each instruction and look at memory deps. + if (!canVectorizeMemory()) { + CanVectorize = false; + LLVM_DEBUG(dbgs() << + "LVA: Can't vectorize due to memory conflicts.\n"); + NODEBUG_EARLY_BAILOUT(); + } + + // Check to see if all exits can be combined with predication + if (!canVectorizeExits()) { + CanVectorize = false; + LLVM_DEBUG(dbgs() << "LVA: Not vectorizing - unsuitable exits.\n"); + NODEBUG_EARLY_BAILOUT(); + } + + if (CanVectorize) { + // Collect all of the variables that remain uniform after vectorization. + collectLoopUniforms(); + LLVM_DEBUG(dbgs() << "LVA: We can vectorize this loop" + << (LAI->getRuntimePointerChecking()->Need + ? " (with a runtime bound check)" + : "") + << "!\n"); + } + + bool UseInterleaved = TTI->enableInterleavedAccessVectorization(); + + // If an override option has been passed in for interleaved accesses, use it. + if (EnableInterleavedMemAccesses.getNumOccurrences() > 0) + UseInterleaved = EnableInterleavedMemAccesses; + + // Analyze interleaved memory accesses. + if (UseInterleaved) + InterleaveInfo.analyzeInterleaving(Strides); + + unsigned SCEVThreshold = VectorizeSCEVCheckThreshold; + if (Hints->getForce() == SLVLoopVectorizeHints::FK_Enabled) + SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; + + if (PSE.getUnionPredicate().getComplexity() > SCEVThreshold) { +// emitAnalysis(VectorizationReport() +// << "Too many SCEV assumptions need to be made and checked " +// << "at runtime"); + LLVM_DEBUG(dbgs() << "LVA: Too many SCEV checks needed.\n"); + NODEBUG_EARLY_BAILOUT(); + } + + // Okay! We can vectorize. At this point we don't have any other mem analysis + // which may limit our maximum vectorization factor, so just return true with + // no restrictions. + return CanVectorize; +} + +void SLVLoopVectorizationLegality::collectStridedAccess(Value *MemAccess) { + Value *Ptr = nullptr; + if (LoadInst *LI = dyn_cast(MemAccess)) + Ptr = LI->getPointerOperand(); + else if (StoreInst *SI = dyn_cast(MemAccess)) + Ptr = SI->getPointerOperand(); + else + return; + + Value *Stride = getStrideFromPointer(Ptr, PSE.getSE(), TheLoop); + if (!Stride) + return; + + LLVM_DEBUG(dbgs() << "LVA: Found a strided access that we can version"); + LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *Stride << "\n"); + Strides[Ptr] = Stride; + StrideSet.insert(Stride); +} + +void SLVLoopVectorizationLegality::collectLoopUniforms() { + // We now know that the loop is vectorizable! + // Collect variables that will remain uniform after vectorization. + std::vector Worklist; + BasicBlock *Latch = TheLoop->getLoopLatch(); + + // Start with the conditional branch and walk up the block. + Worklist.push_back(Latch->getTerminator()->getOperand(0)); + + // Also add all consecutive pointer values; these values will be uniform + // after vectorization (and subsequent cleanup) and, until revectorization is + // supported, all dependencies must also be uniform. + for (Loop::block_iterator B = TheLoop->block_begin(), + BE = TheLoop->block_end(); B != BE; ++B) + for (BasicBlock::iterator I = (*B)->begin(), IE = (*B)->end(); + I != IE; ++I) + if (I->getType()->isPointerTy() && isConsecutivePtr(&*I)) + Worklist.insert(Worklist.end(), I->op_begin(), I->op_end()); + + while (!Worklist.empty()) { + Instruction *I = dyn_cast(Worklist.back()); + Worklist.pop_back(); + + // Look at instructions inside this loop. + // Stop when reaching PHI nodes. + // TODO: we need to follow values all over the loop, not only in this block. + if (!I || !TheLoop->contains(I) || isa(I)) + continue; + + // This is a known uniform. + Uniforms.insert(I); + + // Insert all operands. + Worklist.insert(Worklist.end(), I->op_begin(), I->op_end()); + } +} + +bool SLVLoopVectorizationLegality::isInductionVariable(const Value *V) { + Value *In0 = const_cast(V); + + if (EnableScalableVectorisation) { + // TODO: Need to handle other arithmetic/logical instructions + Instruction *Inst = dyn_cast(In0); + if (Inst && Inst->getOpcode() == Instruction::Shl) { + Value *ShiftVal = Inst->getOperand(1); + if (!dyn_cast(ShiftVal)) + return false; + In0 = Inst->getOperand(0); + } + } + + PHINode *PN = dyn_cast_or_null(In0); + if (!PN) + return false; + + return Inductions.count(PN); +} + +bool SLVLoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) { + return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); +} + +bool SLVLoopVectorizationLegality::blockCanBePredicated(BasicBlock *BB, + SmallPtrSetImpl &SafePtrs) { + // TODO: Modernize... + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + // Check that we don't have a constant expression that can trap as operand. + for (Instruction::op_iterator OI = it->op_begin(), OE = it->op_end(); + OI != OE; ++OI) { + if (Constant *C = dyn_cast(*OI)) + if (C->canTrap()) + return false; + } + // We might be able to hoist the load. + if (it->mayReadFromMemory()) { + LoadInst *LI = dyn_cast(it); + if (!LI) + return false; + if (!SafePtrs.count(LI->getPointerOperand())) { + if (isLegalMaskedLoad(LI->getType(), LI->getPointerOperand())) { + MaskedOp.insert(LI); + continue; + } + return false; + } + } + + // We don't predicate stores at the moment. + if (it->mayWriteToMemory()) { + StoreInst *SI = dyn_cast(it); + // We only support predication of stores in basic blocks with one + // predecessor. + if (!SI) + return false; + + bool isSafePtr = (SafePtrs.count(SI->getPointerOperand()) != 0); + bool isSinglePredecessor = SI->getParent()->getSinglePredecessor(); + + if (++NumPredStores > NumberOfStoresToPredicate || !isSafePtr || + !isSinglePredecessor) { + // Build a masked store if it is legal for the target, otherwise scalarize + // the block. + bool isLegalMaskedOp = + isLegalMaskedStore(SI->getValueOperand()->getType(), + SI->getPointerOperand()); + if (isLegalMaskedOp) { + --NumPredStores; + MaskedOp.insert(SI); + continue; + } + return false; + } + } + if (it->mayThrow()) + return false; + + // The instructions below can trap. + // TODO: Take predicated div/rem into account. + switch (it->getOpcode()) { + default: continue; + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + return false; + } + } + + return true; +} + +//////////////////////////////////////////////////////////////////////////////// +// InterleavedAccessInfo +//////////////////////////////////////////////////////////////////////////////// +void InterleavedAccessInfo::collectConstStrideAccesses( + MapVector &AccessStrideInfo, + const ValueToValueMap &Strides) { + + auto &DL = TheLoop->getHeader()->getModule()->getDataLayout(); + + // Since it's desired that the load/store instructions be maintained in + // "program order" for the interleaved access analysis, we have to visit the + // blocks in the loop in reverse postorder (i.e., in a topological order). + // Such an ordering will ensure that any load/store that may be executed + // before a second load/store will precede the second load/store in + // AccessStrideInfo. + LoopBlocksDFS DFS(TheLoop); + DFS.perform(LI); + for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) + for (auto &I : *BB) { + auto *LI = dyn_cast(&I); + auto *SI = dyn_cast(&I); + if (!LI && !SI) + continue; + + Value *Ptr = getPointerOperand(&I); + // We don't check wrapping here because we don't know yet if Ptr will be + // part of a full group or a group with gaps. Checking wrapping for all + // pointers (even those that end up in groups with no gaps) will be overly + // conservative. For full groups, wrapping should be ok since if we would + // wrap around the address space we would do a memory access at nullptr + // even without the transformation. The wrapping checks are therefore + // deferred until after we've formed the interleaved groups. + int64_t Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, + /*Assume=*/true, /*ShouldCheckWrap=*/false); + + const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); + PointerType *PtrTy = dyn_cast(Ptr->getType()); + uint64_t Size = DL.getTypeAllocSize(PtrTy->getElementType()); + + // An alignment of 0 means target ABI alignment. + unsigned Align = getMemInstAlignment(&I); + if (!Align) + Align = DL.getABITypeAlignment(PtrTy->getElementType()); + + AccessStrideInfo[&I] = StrideDescriptor(Stride, Scev, Size, Align); + } +} + +// Analyze interleaved accesses and collect them into interleaved load and +// store groups. +// +// When generating code for an interleaved load group, we effectively hoist all +// loads in the group to the location of the first load in program order. When +// generating code for an interleaved store group, we sink all stores to the +// location of the last store. This code motion can change the order of load +// and store instructions and may break dependences. +// +// The code generation strategy mentioned above ensures that we won't violate +// any write-after-read (WAR) dependences. +// +// E.g., for the WAR dependence: a = A[i]; // (1) +// A[i] = b; // (2) +// +// The store group of (2) is always inserted at or below (2), and the load +// group of (1) is always inserted at or above (1). Thus, the instructions will +// never be reordered. All other dependences are checked to ensure the +// correctness of the instruction reordering. +// +// The algorithm visits all memory accesses in the loop in bottom-up program +// order. Program order is established by traversing the blocks in the loop in +// reverse postorder when collecting the accesses. +// +// We visit the memory accesses in bottom-up order because it can simplify the +// construction of store groups in the presence of write-after-write (WAW) +// dependences. +// +// E.g., for the WAW dependence: A[i] = a; // (1) +// A[i] = b; // (2) +// A[i + 1] = c; // (3) +// +// We will first create a store group with (3) and (2). (1) can't be added to +// this group because it and (2) are dependent. However, (1) can be grouped +// with other accesses that may precede it in program order. Note that a +// bottom-up order does not imply that WAW dependences should not be checked. +void InterleavedAccessInfo::analyzeInterleaving( + const ValueToValueMap &Strides) { + LLVM_DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n"); + + // Holds all accesses with a constant stride. + MapVector AccessStrideInfo; + collectConstStrideAccesses(AccessStrideInfo, Strides); + + if (AccessStrideInfo.empty()) + return; + + // Collect the dependences in the loop. + collectDependences(); + + // Holds all interleaved store groups temporarily. + SmallSetVector StoreGroups; + // Holds all interleaved load groups temporarily. + SmallSetVector LoadGroups; + + // Search in bottom-up program order for pairs of accesses (A and B) that can + // form interleaved load or store groups. In the algorithm below, access A + // precedes access B in program order. We initialize a group for B in the + // outer loop of the algorithm, and then in the inner loop, we attempt to + // insert each A into B's group if: + // + // 1. A and B have the same stride, + // 2. A and B have the same memory object size, and + // 3. A belongs in B's group according to its distance from B. + // + // Special care is taken to ensure group formation will not break any + // dependences. + for (auto BI = AccessStrideInfo.rbegin(), E = AccessStrideInfo.rend(); + BI != E; ++BI) { + Instruction *B = BI->first; + StrideDescriptor DesB = BI->second; + + // Initialize a group for B if it has an allowable stride. Even if we don't + // create a group for B, we continue with the bottom-up algorithm to ensure + // we don't break any of B's dependences. + InterleaveGroup *Group = nullptr; + if (isStrided(DesB.Stride)) { + Group = getInterleaveGroup(B); + if (!Group) { + LLVM_DEBUG(dbgs() << "LV: Creating an interleave group with:" << *B << '\n'); + Group = createInterleaveGroup(B, DesB.Stride, DesB.Align); + } + if (B->mayWriteToMemory()) + StoreGroups.insert(Group); + else + LoadGroups.insert(Group); + } + + for (auto AI = std::next(BI); AI != E; ++AI) { + Instruction *A = AI->first; + StrideDescriptor DesA = AI->second; + + // Our code motion strategy implies that we can't have dependences + // between accesses in an interleaved group and other accesses located + // between the first and last member of the group. Note that this also + // means that a group can't have more than one member at a given offset. + // The accesses in a group can have dependences with other accesses, but + // we must ensure we don't extend the boundaries of the group such that + // we encompass those dependent accesses. + // + // For example, assume we have the sequence of accesses shown below in a + // stride-2 loop: + // + // (1, 2) is a group | A[i] = a; // (1) + // | A[i-1] = b; // (2) | + // A[i-3] = c; // (3) + // A[i] = d; // (4) | (2, 4) is not a group + // + // Because accesses (2) and (3) are dependent, we can group (2) with (1) + // but not with (4). If we did, the dependent access (3) would be within + // the boundaries of the (2, 4) group. + if (!canReorderMemAccessesForInterleavedGroups(&*AI, &*BI)) { + + // If a dependence exists and A is already in a group, we know that A + // must be a store since A precedes B and WAR dependences are allowed. + // Thus, A would be sunk below B. We release A's group to prevent this + // illegal code motion. A will then be free to form another group with + // instructions that precede it. + if (isInterleaved(A)) { + InterleaveGroup *StoreGroup = getInterleaveGroup(A); + StoreGroups.remove(StoreGroup); + releaseGroup(StoreGroup); + } + + // If a dependence exists and A is not already in a group (or it was + // and we just released it), B might be hoisted above A (if B is a + // load) or another store might be sunk below A (if B is a store). In + // either case, we can't add additional instructions to B's group. B + // will only form a group with instructions that it precedes. + break; + } + + // At this point, we've checked for illegal code motion. If either A or B + // isn't strided, there's nothing left to do. + if (!isStrided(DesA.Stride) || !isStrided(DesB.Stride)) + continue; + + // Ignore A if it's already in a group or isn't the same kind of memory + // operation as B. + if (isInterleaved(A) || A->mayReadFromMemory() != B->mayReadFromMemory()) + continue; + + // Check rules 1 and 2. Ignore A if its stride or size is different from + // that of B. + if (DesA.Stride != DesB.Stride || DesA.Size != DesB.Size) + continue; + + // Calculate the distance from A to B. + const SCEVConstant *DistToB = dyn_cast( + PSE.getSE()->getMinusSCEV(DesA.Scev, DesB.Scev)); + if (!DistToB) + continue; + int64_t DistanceToB = DistToB->getAPInt().getSExtValue(); + + // Check rule 3. Ignore A if its distance to B is not a multiple of the + // size. + if (DistanceToB % static_cast(DesB.Size)) + continue; + + // Ignore A if either A or B is in a predicated block. Although we + // currently prevent group formation for predicated accesses, we may be + // able to relax this limitation in the future once we handle more + // complicated blocks. + if (isPredicated(A->getParent()) || isPredicated(B->getParent())) + continue; + + // The index of A is the index of B plus A's distance to B in multiples + // of the size. + int IndexA = + Group->getIndex(B) + DistanceToB / static_cast(DesB.Size); + + // Try to insert A into B's group. + if (Group->insertMember(A, IndexA, DesA.Align)) { + LLVM_DEBUG(dbgs() << "LV: Inserted:" << *A << '\n' + << " into the interleave group with" << *B << '\n'); + InterleaveGroupMap[A] = Group; + + // Set the first load in program order as the insert position. + if (A->mayReadFromMemory()) + Group->setInsertPos(A); + } + } // Iteration over A accesses. + } // Iteration over B accesses. + + // Remove interleaved store groups with gaps. + for (InterleaveGroup *Group : StoreGroups) + if (Group->getNumMembers() != Group->getFactor()) + releaseGroup(Group); + + // Remove interleaved groups with gaps (currently only loads) whose memory + // accesses may wrap around. We have to revisit the getPtrStride analysis, + // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does + // not check wrapping (see documentation there). + // FORNOW we use Assume=false; + // TODO: Change to Assume=true but making sure we don't exceed the threshold + // of runtime SCEV assumptions checks (thereby potentially failing to + // vectorize altogether). + // Additional optional optimizations: + // TODO: If we are peeling the loop and we know that the first pointer doesn't + // wrap then we can deduce that all pointers in the group don't wrap. + // This means that we can forcefully peel the loop in order to only have to + // check the first pointer for no-wrap. When we'll change to use Assume=true + // we'll only need at most one runtime check per interleaved group. + // + for (InterleaveGroup *Group : LoadGroups) { + + // Case 1: A full group. Can Skip the checks; For full groups, if the wide + // load would wrap around the address space we would do a memory access at + // nullptr even without the transformation. + if (Group->getNumMembers() == Group->getFactor()) + continue; + + // Case 2: If first and last members of the group don't wrap this implies + // that all the pointers in the group don't wrap. + // So we check only group member 0 (which is always guaranteed to exist), + // and group member Factor - 1; If the latter doesn't exist we rely on + // peeling (if it is a non-reveresed accsess -- see Case 3). + Value *FirstMemberPtr = getPointerOperand(Group->getMember(0)); + if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false, + /*ShouldCheckWrap=*/true)) { + LLVM_DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to " + "first group member potentially pointer-wrapping.\n"); + releaseGroup(Group); + continue; + } + Instruction *LastMember = Group->getMember(Group->getFactor() - 1); + if (LastMember) { + Value *LastMemberPtr = getPointerOperand(LastMember); + if (!getPtrStride(PSE, LastMemberPtr, TheLoop, Strides, /*Assume=*/false, + /*ShouldCheckWrap=*/true)) { + LLVM_DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to " + "last group member potentially pointer-wrapping.\n"); + releaseGroup(Group); + } + } else { + // Case 3: A non-reversed interleaved load group with gaps: We need + // to execute at least one scalar epilogue iteration. This will ensure + // we don't speculatively access memory out-of-bounds. We only need + // to look for a member at index factor - 1, since every group must have + // a member at index zero. + if (Group->isReverse()) { + releaseGroup(Group); + continue; + } + LLVM_DEBUG(dbgs() << "LV: Interleaved group requires epilogue iteration.\n"); + RequiresScalarEpilogue = true; + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// SLVLoopVectorizationCostModel +//////////////////////////////////////////////////////////////////////////////// +VectorizationFactor +SLVLoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { + // Width 1 means no vectorize + VectorizationFactor Factor = { 1U, 0U, !EnableScalableVectorisation }; + if (OptForSize && Legal->getRuntimePointerChecking()->Need) { + /* TODO: Fix after Feb2017 merge + emitAnalysis(VectorizationReport() << + "runtime pointer checks needed. Enable vectorization of this " + "loop with '#pragma clang loop vectorize(enable)' when " + "compiling with -Os/-Oz"); + */ + LLVM_DEBUG(dbgs() << + "LVA: Aborting. Runtime ptr check is required with -Os/-Oz.\n"); + Factor.isFixed = true; + return Factor; + } + + if (!EnableCondStoresVectorization && Legal->getNumPredStores()) { + /* TODO: Fix after Feb2017 merge + + emitAnalysis(VectorizationReport() << + "store that is conditionally executed prevents vectorization"); + */ + LLVM_DEBUG(dbgs() << "LVA: No vectorization. There are conditional stores.\n"); + Factor.isFixed = true; + return Factor; + } + + // Find the trip count. + unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); + LLVM_DEBUG(dbgs() << "LVA: Found trip count: " << TC << '\n'); + + MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI); + unsigned SmallestType, WidestType; + + std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes(); + unsigned WidestRegister = TTI.getRegisterBitWidth(true); + unsigned MaxSafeDepDist = -1U; + if (Legal->getMaxSafeDepDistBytes() != -1U) + MaxSafeDepDist = Legal->getMaxSafeDepDistBytes() * 8; + + // For the case when the register size is unknown we cannot vectorise loops + // with data dependencies in a scalable manner. However, when the + // architecture provides an upper bound, we can query that before reverting + // to fixed width vectors. + if (MaxSafeDepDist < TTI.getRegisterBitWidthUpperBound(true)) { + Factor.isFixed = true; + Factor.Width= 1; + return Factor; + + // TODO: Have this disabled for now, as we only want to allow SL vectorization + // for Scalable vectors for now. (otherwise it requires implementing + // Neon support) + // + // // LAA may have assumed we can do strided during analysis + // if (Legal->getRuntimePointerChecking()->Strided && + // TTI.canVectorizeNonUnitStrides(true)) { + // LLVM_DEBUG(dbgs() << + // "LVA: Not vectorizing, can't do strided accesses on target.\n"); + // emitAnalysis(VectorizationReport() << + // "Target doesn't support vectorizing strided accesses."); + // Factor.Width = 1; + // return Factor; + // } + } + + WidestRegister = ((WidestRegister < MaxSafeDepDist) ? + WidestRegister : MaxSafeDepDist); + unsigned MaxVectorSize = WidestRegister / WidestType; + + LLVM_DEBUG(dbgs() << "LVA: The Smallest and Widest types: " << SmallestType << " / " + << WidestType << " bits.\n"); + LLVM_DEBUG(dbgs() << "LVA: The Widest register is: " + << WidestRegister << " bits.\n"); + + if (MaxVectorSize == 0) { + LLVM_DEBUG(dbgs() << "LVA: The target has no vector registers.\n"); + MaxVectorSize = 1; + } + + assert(MaxVectorSize <= 64 && "Did not expect to pack so many elements" + " into one vector!"); + + unsigned VF = MaxVectorSize; + if (MaximizeBandwidth && !OptForSize) { + // Collect all viable vectorization factors. + SmallVector VFs; + unsigned NewMaxVectorSize = WidestRegister / SmallestType; + for (unsigned VS = MaxVectorSize; VS <= NewMaxVectorSize; VS *= 2) + VFs.push_back(VS); + + // For each VF calculate its register usage. + auto RUs = calculateRegisterUsage(VFs); + + // Select the largest VF which doesn't require more registers than existing + // ones. + unsigned TargetNumRegisters = TTI.getNumberOfRegisters(true); + for (int i = RUs.size() - 1; i >= 0; --i) { + if (RUs[i].MaxLocalUsers <= TargetNumRegisters) { + VF = VFs[i]; + break; + } + } + } + + // If we optimize the program for size, avoid creating the tail loop. + if (OptForSize) { + // If we are unable to calculate the trip count then don't try to vectorize. + if (TC < 2) { + /* TODO: Fix after Feb2017 merge + + emitAnalysis + (VectorizationReport() << + "unable to calculate the loop count due to complex control flow"); + */ + LLVM_DEBUG(dbgs() << "LVA: Aborting. A tail loop is required with -Os/-Oz.\n"); + if (Factor.Width < 2) + Factor.isFixed = true; + return Factor; + } + + // Find the maximum SIMD width that can fit within the trip count. + VF = TC % MaxVectorSize; + + if (VF == 0) + VF = MaxVectorSize; + else { + // If the trip count that we found modulo the vectorization factor is not + // zero then we require a tail. + /* TODO: Fix after Feb2017 merge + + emitAnalysis(VectorizationReport() << + "cannot optimize for size and vectorize at the " + "same time. Enable vectorization of this loop " + "with '#pragma clang loop vectorize(enable)' " + "when compiling with -Os/-Oz"); + */ + LLVM_DEBUG(dbgs() << "LVA: Aborting. A tail loop is required with -Os/-Oz.\n"); + Factor.isFixed = true; + return Factor; + } + } + + int UserVF = Hints->getWidth(); + if (UserVF != 0) { + assert(isPowerOf2_32(UserVF) && "VF needs to be a power of two"); + LLVM_DEBUG(dbgs() << "LVA: Using user VF " << UserVF << ".\n"); + + Factor.Width = UserVF; + if (Factor.Width < 2) + Factor.isFixed = true; + return Factor; + } + + float Cost = expectedCost({/*Width=*/1, 0, /*isFixed=*/true}); +#ifndef NDEBUG + const float ScalarCost = Cost; +#endif /* NDEBUG */ + Factor.Width = 1; + LLVM_DEBUG(dbgs() << "LVA: Scalar loop costs: " << (int)ScalarCost << ".\n"); + + bool ForceVectorization = Hints->getForce() == SLVLoopVectorizeHints::FK_Enabled; + // Ignore scalar width, because the user explicitly wants vectorization. + if (ForceVectorization && VF > 1) { + Factor.Width = 2; + Cost = expectedCost(Factor) / (float)Factor.Width; + } + + VectorizationFactor PotentialFactor = Factor; + for (unsigned i=2; i <= VF; i*=2) { + // Notice that the vector loop needs to be executed less times, so + // we need to divide the cost of the vector loops by the width of + // the vector elements. + PotentialFactor.Width = i; + float VectorCost = expectedCost(PotentialFactor) / (float)i; + LLVM_DEBUG(dbgs() << "LVA: Vector loop of width " << i << " costs: " << + (int)VectorCost << ".\n"); + if (VectorCost < Cost) { + Cost = VectorCost; + Factor = PotentialFactor; + } + } + + LLVM_DEBUG(if (ForceVectorization && Factor.Width > 1 && Cost >= ScalarCost) dbgs() + << "LVA: Vectorization seems to be not beneficial, " + << "but was forced by a user.\n"); + Factor.Cost = Factor.Width * Cost; + if (Factor.Width < 2) + Factor.isFixed = true; + LLVM_DEBUG(dbgs() << "LVA: Selecting VF: " << (Factor.isFixed ? "" : "n x ") << + Factor.Width << ".\n"); + return Factor; +} + +// TODO: Move to LVA +std::pair +SLVLoopVectorizationCostModel::getSmallestAndWidestTypes() { + unsigned MinWidth = -1U; + unsigned MaxWidth = 8; + const DataLayout &DL = TheFunction->getParent()->getDataLayout(); + + // For each block. + for (Loop::block_iterator bb = TheLoop->block_begin(), + be = TheLoop->block_end(); bb != be; ++bb) { + BasicBlock *BB = *bb; + + // For each instruction in the loop. + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + Type *T = it->getType(); + + // Skip ignored values. + if (ValuesToIgnore.count(&*it)) + continue; + + // Only examine Loads, Stores and PHINodes. + if (!isa(it) && !isa(it) && !isa(it)) + continue; + + // Examine PHI nodes that are reduction variables. Update the type to + // account for the recurrence type. + if (PHINode *PN = dyn_cast(it)) { + if (!Legal->getReductionVars()->count(PN)) + continue; + RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[PN]; + T = RdxDesc.getRecurrenceType(); + } + + // Examine the stored values. + if (StoreInst *ST = dyn_cast(it)) + T = ST->getValueOperand()->getType(); + + // Ignore loaded pointer types and stored pointer types that are not + // consecutive. However, we do want to take consecutive stores/loads of + // pointer vectors into account. + if (T->isPointerTy() && !isConsecutiveLoadOrStore(&*it)) + continue; + + MinWidth = std::min(MinWidth, + (unsigned)DL.getTypeSizeInBits(T->getScalarType())); + MaxWidth = std::max(MaxWidth, + (unsigned)DL.getTypeSizeInBits(T->getScalarType())); + } + } + + return {MinWidth, MaxWidth}; +} + +unsigned +SLVLoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, + VectorizationFactor VF, + unsigned LoopCost) { + + // -- The interleave heuristics -- + // We interleave the loop in order to expose ILP and reduce the loop overhead. + // There are many micro-architectural considerations that we can't predict + // at this level. For example, frontend pressure (on decode or fetch) due to + // code size, or the number and capabilities of the execution ports. + // + // We use the following heuristics to select the interleave count: + // 1. If the code has reductions, then we interleave to break the cross + // iteration dependency. + // 2. If the loop is really small, then we interleave to reduce the loop + // overhead. + // 3. We don't interleave if we think that we will spill registers to memory + // due to the increased register pressure. + + // TODO: Not sure of the best approach for combining uncounted loops and + // unrolling. Disable for now. + // if (EnableUncountedLoops) + return 1; + + // TODO: revisit this decision but for now it is not worth considering + if (!VF.isFixed) + return 1; + + // When we optimize for size, we don't interleave. + if (OptForSize) + return 1; + + // We used the distance for the interleave count. + if (Legal->getMaxSafeDepDistBytes() != -1U) + return 1; + + // Do not interleave loops with a relatively small trip count. + unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); + if (TC > 1 && TC < TinyTripCountInterleaveThreshold) + return 1; + + unsigned TargetNumRegisters = TTI.getNumberOfRegisters(VF.Width > 1); + LLVM_DEBUG(dbgs() << "LVA: The target has " << TargetNumRegisters << + " registers\n"); + + if (VF.Width == 1) { + if (ForceTargetNumScalarRegs.getNumOccurrences() > 0) + TargetNumRegisters = ForceTargetNumScalarRegs; + } else { + if (ForceTargetNumVectorRegs.getNumOccurrences() > 0) + TargetNumRegisters = ForceTargetNumVectorRegs; + } + + RegisterUsage R = calculateRegisterUsage({VF.Width})[0]; + // We divide by these constants so assume that we have at least one + // instruction that uses at least one register. + R.MaxLocalUsers = std::max(R.MaxLocalUsers, 1U); + R.NumInstructions = std::max(R.NumInstructions, 1U); + + // We calculate the interleave count using the following formula. + // Subtract the number of loop invariants from the number of available + // registers. These registers are used by all of the interleaved instances. + // Next, divide the remaining registers by the number of registers that is + // required by the loop, in order to estimate how many parallel instances + // fit without causing spills. All of this is rounded down if necessary to be + // a power of two. We want power of two interleave count to simplify any + // addressing operations or alignment considerations. + unsigned IC = PowerOf2Floor((TargetNumRegisters - R.LoopInvariantRegs) / + R.MaxLocalUsers); + + // Don't count the induction variable as interleaved. + if (EnableIndVarRegisterHeur) + IC = PowerOf2Floor((TargetNumRegisters - R.LoopInvariantRegs - 1) / + std::max(1U, (R.MaxLocalUsers - 1))); + + // Clamp the interleave ranges to reasonable counts. + unsigned MaxInterleaveCount = TTI.getMaxInterleaveFactor(VF.Width); + + // Check if the user has overridden the max. + if (VF.Width == 1) { + if (ForceTargetMaxScalarInterleaveFactor.getNumOccurrences() > 0) + MaxInterleaveCount = ForceTargetMaxScalarInterleaveFactor; + } else { + if (ForceTargetMaxVectorInterleaveFactor.getNumOccurrences() > 0) + MaxInterleaveCount = ForceTargetMaxVectorInterleaveFactor; + } + + // If we did not calculate the cost for VF (because the user selected the VF) + // then we calculate the cost of VF here. + if (LoopCost == 0) + LoopCost = expectedCost(VF); + + // Clamp the calculated IC to be between the 1 and the max interleave count + // that the target allows. + if (IC > MaxInterleaveCount) + IC = MaxInterleaveCount; + else if (IC < 1) + IC = 1; + + // Interleave if we vectorized this loop and there is a reduction that could + // benefit from interleaving. + if (VF.Width > 1 && Legal->getReductionVars()->size()) { + LLVM_DEBUG(dbgs() << "LVA: Interleaving because of reductions.\n"); + return IC; + } + + // Note that if we've already vectorized the loop we will have done the + // runtime check and so interleaving won't require further checks. + bool InterleavingRequiresRuntimePointerCheck = + (VF.Width == 1 && Legal->getRuntimePointerChecking()->Need); + + // We want to interleave small loops in order to reduce the loop overhead and + // potentially expose ILP opportunities. + LLVM_DEBUG(dbgs() << "LVA: Loop cost is " << LoopCost << '\n'); + if (!InterleavingRequiresRuntimePointerCheck && LoopCost < SmallLoopCost) { + // We assume that the cost overhead is 1 and we use the cost model + // to estimate the cost of the loop and interleave until the cost of the + // loop overhead is about 5% of the cost of the loop. + unsigned SmallIC = + std::min(IC, (unsigned)PowerOf2Floor(SmallLoopCost / LoopCost)); + + // Interleave until store/load ports (estimated by max interleave count) are + // saturated. + unsigned NumStores = Legal->getNumStores(); + unsigned NumLoads = Legal->getNumLoads(); + unsigned StoresIC = IC / (NumStores ? NumStores : 1); + unsigned LoadsIC = IC / (NumLoads ? NumLoads : 1); + + // If we have a scalar reduction (vector reductions are already dealt with + // by this point), we can increase the critical path length if the loop + // we're interleaving is inside another loop. Limit, by default to 2, so the + // critical path only gets increased by one reduction operation. + if (Legal->getReductionVars()->size() && + TheLoop->getLoopDepth() > 1) { + unsigned F = static_cast(MaxNestedScalarReductionIC); + SmallIC = std::min(SmallIC, F); + StoresIC = std::min(StoresIC, F); + LoadsIC = std::min(LoadsIC, F); + } + + if (EnableLoadStoreRuntimeInterleave && + std::max(StoresIC, LoadsIC) > SmallIC) { + LLVM_DEBUG(dbgs() << "LVA: Interleaving to saturate store or load ports.\n"); + return std::max(StoresIC, LoadsIC); + } + + LLVM_DEBUG(dbgs() << "LVA: Interleaving to reduce branch cost.\n"); + return SmallIC; + } + + // Interleave if this is a large loop (small loops are already dealt with by + // this + // point) that could benefit from interleaving. + bool HasReductions = (Legal->getReductionVars()->size() > 0); + if (TTI.enableAggressiveInterleaving(HasReductions)) { + LLVM_DEBUG(dbgs() << "LVA: Interleaving to expose ILP.\n"); + return IC; + } + + LLVM_DEBUG(dbgs() << "LVA: Not Interleaving.\n"); + return 1; +} + +SmallVector +SLVLoopVectorizationCostModel::calculateRegisterUsage( + const SmallVector &VFs) { + // This function calculates the register usage by measuring the highest number + // of values that are alive at a single location. Obviously, this is a very + // rough estimation. We scan the loop in a topological order in order and + // assign a number to each instruction. We use RPO to ensure that defs are + // met before their users. We assume that each instruction that has in-loop + // users starts an interval. We record every time that an in-loop value is + // used, so we have a list of the first and last occurrences of each + // instruction. Next, we transpose this data structure into a multi map that + // holds the list of intervals that *end* at a specific location. This multi + // map allows us to perform a linear search. We scan the instructions linearly + // and record each time that a new interval starts, by placing it in a set. + // If we find this value in the multi-map then we remove it from the set. + // The max register usage is the maximum size of the set. + // We also search for instructions that are defined outside the loop, but are + // used inside the loop. We need this number separately from the max-interval + // usage number because when we unroll, loop-invariant values do not take + // more register. + LoopBlocksDFS DFS(TheLoop); + DFS.perform(LI); + + RegisterUsage RU; + RU.NumInstructions = 0; + + // Each 'key' in the map opens a new interval. The values + // of the map are the index of the 'last seen' usage of the + // instruction that is the key. + typedef DenseMap IntervalMap; + // Maps instruction to its index. + DenseMap IdxToInstr; + // Marks the end of each interval. + IntervalMap EndPoint; + // Saves the list of instruction indices that are used in the loop. + SmallSet Ends; + // Saves the list of values that are used in the loop but are + // defined outside the loop, such as arguments and constants. + SmallPtrSet LoopInvariants; + + unsigned Index = 0; + for (LoopBlocksDFS::RPOIterator bb = DFS.beginRPO(), + be = DFS.endRPO(); bb != be; ++bb) { + RU.NumInstructions += (*bb)->size(); + for (Instruction &I : **bb) { + IdxToInstr[Index++] = &I; + + // Save the end location of each USE. + for (unsigned i = 0; i < I.getNumOperands(); ++i) { + Value *U = I.getOperand(i); + Instruction *Instr = dyn_cast(U); + + // Ignore non-instruction values such as arguments, constants, etc. + if (!Instr) continue; + + // If this instruction is outside the loop then record it and continue. + if (!TheLoop->contains(Instr)) { + LoopInvariants.insert(Instr); + continue; + } + + // Overwrite previous end points. + EndPoint[Instr] = Index; + Ends.insert(Instr); + } + } + } + + // Saves the list of intervals that end with the index in 'key'. + typedef SmallVector InstrList; + DenseMap TransposeEnds; + + // Transpose the EndPoints to a list of values that end at each index. + for (IntervalMap::iterator it = EndPoint.begin(), e = EndPoint.end(); + it != e; ++it) + TransposeEnds[it->second].push_back(it->first); + + SmallSet OpenIntervals; + + // Get the size of the widest register. + unsigned MaxSafeDepDist = -1U; + if (Legal->getMaxSafeDepDistBytes() != -1U) + MaxSafeDepDist = Legal->getMaxSafeDepDistBytes() * 8; + unsigned WidestRegister = + std::min(TTI.getRegisterBitWidth(true), MaxSafeDepDist); + const DataLayout &DL = TheFunction->getParent()->getDataLayout(); + + SmallVector RUs(VFs.size()); + SmallVector MaxUsages(VFs.size(), 0); + + LLVM_DEBUG(dbgs() << "LV(REG): Calculating max register usage:\n"); + + // A lambda that gets the register usage for the given type and VF. + auto GetRegUsage = [&DL, WidestRegister](Type *Ty, unsigned VF) { + unsigned TypeSize = DL.getTypeSizeInBits(Ty->getScalarType()); + return std::max(1, VF * TypeSize / WidestRegister); + }; + + for (unsigned int i = 0; i < Index; ++i) { + Instruction *I = IdxToInstr[i]; + // Ignore instructions that are never used within the loop. + if (!Ends.count(I)) continue; + + // Remove all of the instructions that end at this location. + InstrList &List = TransposeEnds[i]; + for (unsigned int j = 0, e = List.size(); j < e; ++j) + OpenIntervals.erase(List[j]); + + // Skip ignored values. + if (ValuesToIgnore.count(I)) + continue; + + // For each VF find the maximum usage of registers. + for (unsigned j = 0, e = VFs.size(); j < e; ++j) { + if (VFs[j] == 1) { + MaxUsages[j] = std::max(MaxUsages[j], OpenIntervals.size()); + continue; + } + + // Count the number of live intervals. + unsigned RegUsage = 0; + for (auto Inst : OpenIntervals) { + // Skip ignored values for VF > 1. + if (VecValuesToIgnore.count(Inst)) + continue; + RegUsage += GetRegUsage(Inst->getType(), VFs[j]); + } + MaxUsages[j] = std::max(MaxUsages[j], RegUsage); + } + + LLVM_DEBUG(dbgs() << "LV(REG): At #" << i << " Interval # " + << OpenIntervals.size() << '\n'); + + // Add the current instruction to the list of open intervals. + OpenIntervals.insert(I); + } + + for (unsigned i = 0, e = VFs.size(); i < e; ++i) { + unsigned Invariant = 0; + if (VFs[i] == 1) + Invariant = LoopInvariants.size(); + else { + for (auto Inst : LoopInvariants) + Invariant += GetRegUsage(Inst->getType(), VFs[i]); + } + + LLVM_DEBUG(dbgs() << "LV(REG): VF = " << VFs[i] << '\n'); + LLVM_DEBUG(dbgs() << "LV(REG): Found max usage: " << MaxUsages[i] << '\n'); + LLVM_DEBUG(dbgs() << "LV(REG): Found invariant usage: " << Invariant << '\n'); + LLVM_DEBUG(dbgs() << "LV(REG): LoopSize: " << RU.NumInstructions << '\n'); + + RU.LoopInvariantRegs = Invariant; + RU.MaxLocalUsers = MaxUsages[i]; + RUs[i] = RU; + } + + return RUs; +} + +unsigned SLVLoopVectorizationCostModel::expectedCost(VectorizationFactor VF) { + unsigned Cost = 0; + + // For each block. + for (Loop::block_iterator bb = TheLoop->block_begin(), + be = TheLoop->block_end(); bb != be; ++bb) { + unsigned BlockCost = 0; + BasicBlock *BB = *bb; + + // For each instruction in the old loop. + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + // Skip dbg intrinsics. + if (isa(it)) + continue; + + // Skip ignored values. + if (ValuesToIgnore.count(&*it)) + continue; + + unsigned C = getInstructionCost(&*it, VF); + + // Check if we should override the cost. + if (ForceTargetInstructionCost.getNumOccurrences() > 0) + C = ForceTargetInstructionCost; + + BlockCost += C; + LLVM_DEBUG(dbgs() << "LVA: Found an estimated cost of " << C << " for VF " << + (VF.isFixed ? "" : "n x ") << VF.Width << " For instruction: " + << *it << '\n'); + } + + // We assume that if-converted blocks have a 50% chance of being executed. + // When the code is scalar then some of the blocks are avoided due to CF. + // When the code is vectorized we execute all code paths. + if (VF.Width == 1 && Legal->blockNeedsPredication(*bb)) + BlockCost /= 2; + + Cost += BlockCost; + } + + return Cost; +} + +/// Estimate the overhead of scalarizing a value. Insert and Extract are set if +/// the result needs to be inserted and/or extracted from vectors. +unsigned SLVLoopVectorizationCostModel::getScalarizationOverhead(Type *Ty, + bool Insert, bool Extract, + const TargetTransformInfo &TTI) { + if (Ty->isVoidTy()) + return 0; + + assert(Ty->isVectorTy() && "Can only scalarize vectors"); + unsigned Cost = 0; + + for (int i = 0, e = Ty->getVectorNumElements(); i < e; ++i) { + if (Insert) + Cost += TTI.getVectorInstrCost(Instruction::InsertElement, Ty, i); + if (Extract) + Cost += TTI.getVectorInstrCost(Instruction::ExtractElement, Ty, i); + } + + return Cost; +} + +// Estimate cost of a call instruction CI if it were vectorized with factor VF. +// Return the cost of the instruction, including scalarization overhead if it's +// needed. The flag NeedToScalarize shows if the call needs to be scalarized - +// i.e. either vector version isn't available, or is too expensive. +unsigned SLVLoopVectorizationCostModel::getVectorCallCost(CallInst *CI, + unsigned VF, + const TargetTransformInfo &TTI, + const TargetLibraryInfo *TLI, + bool &NeedToScalarize) { + VectorType::ElementCount EC(VF, false); + Function *F = CI->getCalledFunction(); + StringRef FnName = CI->getCalledFunction()->getName(); + Type *ScalarRetTy = CI->getType(); + SmallVector Tys, ScalarTys; + for (auto &ArgOp : CI->arg_operands()) + ScalarTys.push_back(ArgOp->getType()); + + // Estimate cost of scalarized vector call. The source operands are assumed + // to be vectors, so we need to extract individual elements from there, + // execute VF scalar calls, and then gather the result into the vector return + // value. + unsigned ScalarCallCost = TTI.getCallInstrCost(F, ScalarRetTy, ScalarTys); + if (VF == 1) + return ScalarCallCost; + + // Compute corresponding vector type for return value and arguments. + // TODO: Doesn't take WA into account. + Type *RetTy = ToVectorTy(ScalarRetTy, VF); + for (unsigned i = 0, ie = ScalarTys.size(); i != ie; ++i) + Tys.push_back(ToVectorTy(ScalarTys[i], VF)); + + // Compute costs of unpacking argument values for the scalar calls and + // packing the return values to a vector. + unsigned ScalarizationCost = + getScalarizationOverhead(RetTy, true, false, TTI); + for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) + ScalarizationCost += getScalarizationOverhead(Tys[i], false, true, TTI); + + unsigned Cost = ScalarCallCost * VF + ScalarizationCost; + + // If we can't emit a vector call for this function, then the currently found + // cost is the cost we need to return. + NeedToScalarize = true; + FunctionType *FTy = FunctionType::get(RetTy, Tys, false); + if (!TLI || + !TLI->isFunctionVectorizable(FnName, EC, false /* NoMask */, FTy) || + CI->isNoBuiltin()) + return Cost; + + // If the corresponding vector cost is cheaper, return its cost. + unsigned VectorCallCost = TTI.getCallInstrCost(nullptr, RetTy, Tys); + if (VectorCallCost < Cost) { + NeedToScalarize = false; + return VectorCallCost; + } + return Cost; +} + + +// Estimate cost of an intrinsic call instruction CI if it were vectorized with +// factor VF. Return the cost of the instruction, including scalarization +// overhead if it's needed. +unsigned SLVLoopVectorizationCostModel::getVectorIntrinsicCost(CallInst *CI, + unsigned VF, + const TargetTransformInfo &TTI, + const TargetLibraryInfo *TLI) { + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); + assert(ID && "Expected intrinsic call!"); + + Type *RetTy = ToVectorTy(CI->getType(), VF); + SmallVector Tys; + for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) + Tys.push_back(ToVectorTy(CI->getArgOperand(i)->getType(), VF)); + + FastMathFlags FMF; + if (auto *FPMO = dyn_cast(CI)) + FMF = FPMO->getFastMathFlags(); + + // TODO: Make the cost model scalable aware so that this decision can + // be pass to getIntrinsicInstrCost when it has the correct types. + + // Both the code generator and loop vectoriser are unable to scalarise width + // agnostic calls. Make calls that do no map directly to one or more + // instructions prohibitively expensive so we never try to scalarise them. + /// TODO: Always WA? + if ((EnableScalableVectorisation) && (VF > 1)) + if ((ID == Intrinsic::cos) || + (ID == Intrinsic::exp) || + (ID == Intrinsic::log) || + (ID == Intrinsic::pow) || + (ID == Intrinsic::sin)) + return 99999; + + return TTI.getIntrinsicInstrCost(ID, RetTy, Tys, FMF); +} + + +/// \brief Check whether the address computation for a non-consecutive memory +/// access looks like an unlikely candidate for being merged into the indexing +/// mode. +/// +/// We look for a GEP which has one index that is an induction variable and all +/// other indices are loop invariant. If the stride of this access is also +/// within a small bound we decide that this address computation can likely be +/// merged into the addressing mode. +/// In all other cases, we identify the address computation as complex. +/* TODO: Reevaluate after Feb 2017 merge. +static bool isLikelyComplexAddressComputation(Value *Ptr, + SLVLoopVectorizationLegality *Legal, + ScalarEvolution *SE, + const Loop *TheLoop) { + GetElementPtrInst *Gep = dyn_cast(Ptr); + if (!Gep) + return true; + + // We are looking for a gep with all loop invariant indices except for one + // which should be an induction variable. + unsigned NumOperands = Gep->getNumOperands(); + for (unsigned i = 1; i < NumOperands; ++i) { + Value *Opd = Gep->getOperand(i); + if (!SE->isLoopInvariant(SE->getSCEV(Opd), TheLoop) && + !Legal->isInductionVariable(Opd)) + return true; + } + + // Now we know we have a GEP ptr, %inv, %ind, %inv. Make sure that the step + // can likely be merged into the address computation. + unsigned MaxMergeDistance = 64; + + const SCEVAddRecExpr *AddRec = dyn_cast(SE->getSCEV(Ptr)); + if (!AddRec) + return true; + + // Check the step is constant. + const SCEV *Step = AddRec->getStepRecurrence(*SE); + // Calculate the pointer stride and check if it is consecutive. + const SCEVConstant *C = dyn_cast(Step); + if (!C) + return true; + + const APInt &APStepVal = C->getAPInt(); + + // Huge step value - give up. + if (APStepVal.getBitWidth() > 64) + return true; + + int64_t StepVal = APStepVal.getSExtValue(); + + return StepVal > MaxMergeDistance; +} + */ + +static bool isStrideMul(Instruction *I, SLVLoopVectorizationLegality *Legal) { + return Legal->hasStride(I->getOperand(0)) || + Legal->hasStride(I->getOperand(1)); +} + +// Given a Chain +// A -> B -> Z, +// where: +// A = s/zext +// B = add +// C = trunc +// Check this is one of +// s/zext(i32) -> add -> trunc(valtype) +static bool isPartOfPromotedAdd(Instruction *I, Type **OrigType) { + Instruction *TruncOp = I; + + // If I is one of step A, find step C + if ((I->getOpcode() == Instruction::ZExt || + I->getOpcode() == Instruction::SExt)) { + // Confirm that s/zext is *only* used for the add + for(int K=0; K<2; ++K) { + if (!TruncOp->hasOneUse()) + return false; + TruncOp = dyn_cast(TruncOp->user_back()); + } + } + // If I is one of step B, find step C + else if ((I->getOpcode() == Instruction::Add)) { + if (!I->hasOneUse()) + return false; + TruncOp = I->user_back(); + } + + // Check if I is one of step C + if (TruncOp->getOpcode() != Instruction::Trunc) + return false; + + if (Instruction *Opnd = dyn_cast(TruncOp->getOperand(0))) { + if (TruncOp->getOpcode() != Instruction::Trunc || + Opnd->getOpcode() != Instruction::Add || !Opnd->hasNUses(1)) + return false; + + // Check each operand to the 'add' + unsigned cnt = 0; + for (Value *V : Opnd->operands()) { + if (const Instruction *AddOpnd = dyn_cast(V)) { + if (AddOpnd->getOpcode() != Instruction::ZExt && + AddOpnd->getOpcode() != Instruction::SExt) + break; + + if (!AddOpnd->getType()->isIntegerTy(32)) + break; + + if ( AddOpnd->getOperand(0)->getType() != TruncOp->getType() || + !AddOpnd->hasNUses(1)) + break; + } + cnt++; + } + + if (cnt == Opnd->getNumOperands()) { + if (OrigType) + *OrigType = TruncOp->getType(); + return true; + } + } + + return false; +} + +static MemAccessInfo calculateMemAccessInfo(Instruction *I, + Type *VectorTy, + SLVLoopVectorizationLegality *Legal, + ScalarEvolution *SE) { + const DataLayout &DL = I->getModule()->getDataLayout(); + + // Get pointer operand + Value *Ptr = nullptr; + if (auto *LI = dyn_cast(I)) + Ptr = LI->getPointerOperand(); + if (auto *SI = dyn_cast(I)) + Ptr = SI->getPointerOperand(); + + assert (Ptr && "Could not get pointer operand from instruction"); + + // Check for uniform access (scalar load + splat) + if (Legal->isUniform(Ptr)) + return MemAccessInfo::getUniformInfo(); + + // Get whether it is a predicated memory operation + bool IsMasked = Legal->isMaskRequired(I); + + // Try to find the stride of the pointer expression + if (auto *SAR = dyn_cast(SE->getSCEV(Ptr))) { + const SCEV *StepRecurrence = SAR->getStepRecurrence(*SE); + if (auto *StrideV = dyn_cast(StepRecurrence)) { + // Get the element size + unsigned VectorElementSize = + DL.getTypeStoreSize(VectorTy) / VectorTy->getVectorNumElements(); + + // Normalize Stride from bytes to number of elements + int Stride = + StrideV->getValue()->getSExtValue() / ((int64_t)VectorElementSize); + return MemAccessInfo::getStridedInfo(Stride, Stride < 0, IsMasked); + } else { + // Unknown stride is a subset of gather/scatter + return MemAccessInfo::getNonStridedInfo(StepRecurrence->getType(), + IsMasked); + } + } + + // If this is a scatter operation try to find the type of the offset, + // if applicable, e.g. A[i] = B[C[i]] + // ^^^^ get type of C[i] + Type *IdxTy = nullptr; + bool IsSigned = true; + if (auto *Gep = dyn_cast(Ptr)) { + for (unsigned Op=0; Op < Gep->getNumOperands(); ++Op) { + Value *Opnd = Gep->getOperand(Op); + if (Legal->isUniform(Opnd)) { + continue; + } + + // If there are multiple non-loop invariant indices + // in this GEP, fall back to the worst case below. + if (IdxTy != nullptr) { + IdxTy = nullptr; + break; + } + + // If type is promoted, see if we can use smaller type + IdxTy = Opnd->getType(); + if (auto *Ext = dyn_cast(Opnd)) { + if (Ext->isIntegerCast()) + IdxTy = Ext->getSrcTy(); + if (isa(Ext)) + IsSigned = false; + } + } + } + + // Worst case scenario, assume pointer size + if (!IdxTy) + IdxTy = DL.getIntPtrType(Ptr->getType()); + + return MemAccessInfo::getNonStridedInfo(IdxTy, IsMasked, IsSigned); +} + +unsigned +SLVLoopVectorizationCostModel::getInstructionCost(Instruction *I, + VectorizationFactor VF) { + // If we know that this instruction will remain uniform, check the cost of + // the scalar version. + if (Legal->isUniformAfterVectorization(I)) + VF.Width = 1; + + Type *RetTy = I->getType(); + if (VF.Width > 1 && MinBWs.count(I)) + RetTy = IntegerType::get(RetTy->getContext(), MinBWs[I]); + Type *VectorTy = ToVectorTy(RetTy, VF); + auto *SE = PSE.getSE(); + + // TODO: We need to estimate the cost of intrinsic calls. + switch (I->getOpcode()) { + case Instruction::GetElementPtr: + // We mark this instruction as zero-cost because the cost of GEPs in + // vectorized code depends on whether the corresponding memory instruction + // is scalarized or not. Therefore, we handle GEPs with the memory + // instruction cost. + return 0; + case Instruction::Br: { + return TTI.getCFInstrCost(I->getOpcode()); + } + case Instruction::PHI: + //TODO: IF-converted IFs become selects. + return 0; + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + // Since we will replace the stride by 1 the multiplication should go away. + if (I->getOpcode() == Instruction::Mul && isStrideMul(I, Legal)) + return 0; + // Certain instructions can be cheaper to vectorize if they have a constant + // second vector operand. One example of this are shifts on x86. + TargetTransformInfo::OperandValueKind Op1VK = + TargetTransformInfo::OK_AnyValue; + TargetTransformInfo::OperandValueKind Op2VK = + TargetTransformInfo::OK_AnyValue; + TargetTransformInfo::OperandValueProperties Op1VP = + TargetTransformInfo::OP_None; + TargetTransformInfo::OperandValueProperties Op2VP = + TargetTransformInfo::OP_None; + Value *Op2 = I->getOperand(1); + + // Check for a splat of a constant or for a non uniform vector of constants. + if (isa(Op2)) { + ConstantInt *CInt = cast(Op2); + if (CInt && CInt->getValue().isPowerOf2()) + Op2VP = TargetTransformInfo::OP_PowerOf2; + Op2VK = TargetTransformInfo::OK_UniformConstantValue; + } else if (isa(Op2) || isa(Op2)) { + Op2VK = TargetTransformInfo::OK_NonUniformConstantValue; + Constant *SplatValue = cast(Op2)->getSplatValue(); + if (SplatValue) { + ConstantInt *CInt = dyn_cast(SplatValue); + if (CInt && CInt->getValue().isPowerOf2()) + Op2VP = TargetTransformInfo::OP_PowerOf2; + Op2VK = TargetTransformInfo::OK_UniformConstantValue; + } + } + + // Note: When we find a s/zext_to_i32->add->trunc_to_origtype + // chain, we ask the target if it has an add for the original + // type. This is not allowed in C, so the target should ensure + // that the instruction does the sign/zero conversion in 'int'. + Type *OrigType = nullptr; + if (isPartOfPromotedAdd(I, &OrigType)) + VectorTy = VectorType::get(OrigType, VF.Width, !VF.isFixed); + + return TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, Op1VK, Op2VK, + Op1VP, Op2VP); + } + case Instruction::Select: { + SelectInst *SI = cast(I); + const SCEV *CondSCEV = SE->getSCEV(SI->getCondition()); + bool ScalarCond = (SE->isLoopInvariant(CondSCEV, TheLoop)); + Type *CondTy = SI->getCondition()->getType(); + if (!ScalarCond) + CondTy = VectorType::get(CondTy, VF.Width, !VF.isFixed); + + return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, CondTy); + } + case Instruction::ICmp: + case Instruction::FCmp: { + Type *ValTy = I->getOperand(0)->getType(); + Instruction *Op0AsInstruction = dyn_cast(I->getOperand(0)); + auto It = MinBWs.find(Op0AsInstruction); + if (VF.Width > 1 && It != MinBWs.end()) + ValTy = IntegerType::get(ValTy->getContext(), It->second); + VectorTy = ToVectorTy(ValTy, VF); + return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy); + } + case Instruction::Store: + case Instruction::Load: { + StoreInst *SI = dyn_cast(I); + LoadInst *LI = dyn_cast(I); + Type *ValTy = (SI ? SI->getValueOperand()->getType() : + LI->getType()); + VectorTy = ToVectorTy(ValTy, VF); + + unsigned Alignment = SI ? SI->getAlignment() : LI->getAlignment(); + unsigned AS = SI ? SI->getPointerAddressSpace() : + LI->getPointerAddressSpace(); + Value *Ptr = SI ? SI->getPointerOperand() : LI->getPointerOperand(); + // We add the cost of address computation here instead of with the gep + // instruction because only here we know whether the operation is + // scalarized. + if (VF.Width == 1) + return TTI.getAddressComputationCost(VectorTy) + + TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); + + // For an interleaved access, calculate the total cost of the whole + // interleave group. + if (Legal->isAccessInterleaved(I)) { + auto Group = Legal->getInterleavedAccessGroup(I); + assert(Group && "Fail to get an interleaved access group."); + + // Only calculate the cost once at the insert position. + if (Group->getInsertPos() != I) + return 0; + + unsigned InterleaveFactor = Group->getFactor(); + Type *WideVecTy = + VectorType::get(VectorTy->getVectorElementType(), + VectorTy->getVectorNumElements() * InterleaveFactor, + !VF.isFixed); + + // Holds the indices of existing members in an interleaved load group. + // An interleaved store group doesn't need this as it dones't allow gaps. + SmallVector Indices; + if (LI) { + for (unsigned i = 0; i < InterleaveFactor; i++) + if (Group->getMember(i)) + Indices.push_back(i); + } + + // Calculate the cost of the whole interleaved group. + unsigned Cost = TTI.getInterleavedMemoryOpCost( + I->getOpcode(), WideVecTy, Group->getFactor(), Indices, + Group->getAlignment(), AS); + + if (Group->isReverse()) + Cost += + Group->getNumMembers() * + TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0); + + // FIXME: The interleaved load group with a huge gap could be even more + // expensive than scalar operations. Then we could ignore such group and + // use scalar operations instead. + return Cost; + } + + // Scalarized loads/stores. + const DataLayout &DL = I->getModule()->getDataLayout(); + unsigned ScalarAllocatedSize = DL.getTypeAllocSize(ValTy); + unsigned VectorElementSize = DL.getTypeStoreSize(VectorTy) / VF.Width; + + // Get information about vector memory access + MemAccessInfo MAI = calculateMemAccessInfo(I, VectorTy, Legal, SE); + + // If there are no vector memory operations to support the stride, + // get the cost for scalarizing the operation. + if (!TTI.hasVectorMemoryOp(I->getOpcode(), VectorTy, MAI) || + ScalarAllocatedSize != VectorElementSize) { + // Get cost of scalarizing +// bool IsComplexComputation = +// isLikelyComplexAddressComputation(Ptr, Legal, SE, TheLoop); + unsigned Cost = 0; + // The cost of extracting from the value vector and pointer vector. + Type *PtrTy = ToVectorTy(Ptr->getType(), VF); + for (unsigned i = 0; i < VF.Width; ++i) { + // The cost of extracting the pointer operand. + Cost += TTI.getVectorInstrCost(Instruction::ExtractElement, PtrTy, i); + // In case of STORE, the cost of ExtractElement from the vector. + // In case of LOAD, the cost of InsertElement into the returned + // vector. + Cost += TTI.getVectorInstrCost(SI ? Instruction::ExtractElement : + Instruction::InsertElement, + VectorTy, i); + } + + // The cost of the scalar loads/stores. +// Cost += VF.Width * +// TTI.getAddressComputationCost(PtrTy, IsComplexComputation); + Cost += VF.Width * + TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), + Alignment, AS); + return Cost; + } + + // Wide load/stores. + unsigned Cost = TTI.getAddressComputationCost(VectorTy); + Cost += TTI.getVectorMemoryOpCost(I->getOpcode(), VectorTy, Ptr, + Alignment, AS, MAI, I); + + if (MAI.isStrided() && MAI.isReversed()) + Cost += TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, + VectorTy, 0); + else if (MAI.isUniform()) + Cost += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, + VectorTy, 0); + return Cost; + } + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::SIToFP: + case Instruction::UIToFP: + case Instruction::Trunc: + case Instruction::FPTrunc: + case Instruction::BitCast: { + // We optimize the truncation of induction variable. + // The cost of these is the same as the scalar operation. + if (I->getOpcode() == Instruction::Trunc && + Legal->isInductionVariable(I->getOperand(0))) + return TTI.getCastInstrCost(I->getOpcode(), I->getType(), + I->getOperand(0)->getType()); +// TODO: determine if still useful, deleting isPartOfPromotedAdd if not +// // Don't count these +// if (isPartOfPromotedAdd(I, nullptr)) +// return 0; +// +// Type *SrcVecTy = ToVectorTy(I->getOperand(0)->getType(), VF); + + Type *SrcScalarTy = I->getOperand(0)->getType(); + Type *SrcVecTy = ToVectorTy(SrcScalarTy, VF); + if (VF.Width > 1 && MinBWs.count(I)) { + // This cast is going to be shrunk. This may remove the cast or it might + // turn it into slightly different cast. For example, if MinBW == 16, + // "zext i8 %1 to i32" becomes "zext i8 %1 to i16". + // + // Calculate the modified src and dest types. + Type *MinVecTy = VectorTy; + if (I->getOpcode() == Instruction::Trunc) { + SrcVecTy = smallestIntegerVectorType(SrcVecTy, MinVecTy); + VectorTy = largestIntegerVectorType(ToVectorTy(I->getType(), VF), + MinVecTy); + } else if (I->getOpcode() == Instruction::ZExt || + I->getOpcode() == Instruction::SExt) { + SrcVecTy = largestIntegerVectorType(SrcVecTy, MinVecTy); + VectorTy = smallestIntegerVectorType(ToVectorTy(I->getType(), VF), + MinVecTy); + } + } + + return TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy); + } + case Instruction::Call: { + bool NeedToScalarize; + CallInst *CI = cast(I); + unsigned CallCost = getVectorCallCost(CI, VF.Width, TTI, TLI, + NeedToScalarize); + if (getVectorIntrinsicIDForCall(CI, TLI)) + return std::min(CallCost, getVectorIntrinsicCost(CI, VF.Width, TTI, TLI)); + return CallCost; + } + default: { + // We are scalarizing the instruction. Return the cost of the scalar + // instruction, plus the cost of insert and extract into vector + // elements, times the vector width. + unsigned Cost = 0; + + if (!RetTy->isVoidTy() && VF.Width != 1) { + unsigned InsCost = TTI.getVectorInstrCost(Instruction::InsertElement, + VectorTy); + unsigned ExtCost = TTI.getVectorInstrCost(Instruction::ExtractElement, + VectorTy); + + // The cost of inserting the results plus extracting each one of the + // operands. + Cost += VF.Width * (InsCost + ExtCost * I->getNumOperands()); + } + + // The cost of executing VF copies of the scalar instruction. This opcode + // is unknown. Assume that it is the same as 'mul'. + Cost += VF.Width * TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy); + return Cost; + } + }// end of switch. +} + + +bool SLVLoopVectorizationCostModel::isConsecutiveLoadOrStore(Instruction *Inst) { + // Check for a store. + if (StoreInst *ST = dyn_cast(Inst)) + return Legal->isConsecutivePtr(ST->getPointerOperand()) != 0; + + // Check for a load. + if (LoadInst *LI = dyn_cast(Inst)) + return Legal->isConsecutivePtr(LI->getPointerOperand()) != 0; + + return false; +} + +void SLVLoopVectorizationCostModel::collectValuesToIgnore() { + // Ignore ephemeral values. + CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore); + + // Ignore type-promoting instructions we identified during reduction + // detection. + for (auto &Reduction : *Legal->getReductionVars()) { + RecurrenceDescriptor &RedDes = Reduction.second; + SmallPtrSetImpl &Casts = RedDes.getCastInsts(); + VecValuesToIgnore.insert(Casts.begin(), Casts.end()); + } + + // Ignore induction phis that are only used in either GetElementPtr or ICmp + // instruction to exit loop. Induction variables usually have large types and + // can have big impact when estimating register usage. + // This is for when VF > 1. + for (auto &Induction : *Legal->getInductionVars()) { + auto *PN = Induction.first; + auto *UpdateV = PN->getIncomingValueForBlock(TheLoop->getLoopLatch()); + + // Check that the PHI is only used by the induction increment (UpdateV) or + // by GEPs. Then check that UpdateV is only used by a compare instruction or + // the loop header PHI. + // FIXME: Need precise def-use analysis to determine if this instruction + // variable will be vectorized. + if (std::all_of(PN->user_begin(), PN->user_end(), + [&](const User *U) -> bool { + return U == UpdateV || isa(U); + }) && + std::all_of(UpdateV->user_begin(), UpdateV->user_end(), + [&](const User *U) -> bool { + return U == PN || isa(U); + })) { + VecValuesToIgnore.insert(PN); + VecValuesToIgnore.insert(UpdateV); + } + } + + // Ignore instructions that will not be vectorized. + // This is for when VF > 1. + for (auto bb = TheLoop->block_begin(), be = TheLoop->block_end(); bb != be; + ++bb) { + for (auto &Inst : **bb) { + switch (Inst.getOpcode()) { + case Instruction::GetElementPtr: { + // Ignore GEP if its last operand is an induction variable so that it is + // a consecutive load/store and won't be vectorized as scatter/gather + // pattern. + + GetElementPtrInst *Gep = cast(&Inst); + unsigned NumOperands = Gep->getNumOperands(); + unsigned InductionOperand = getGEPInductionOperand(Gep); + bool GepToIgnore = true; + + // Check that all of the gep indices are uniform except for the + // induction operand. + for (unsigned i = 0; i != NumOperands; ++i) { + if (i != InductionOperand && + !PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), + TheLoop)) { + GepToIgnore = false; + break; + } + } + + if (GepToIgnore) + VecValuesToIgnore.insert(&Inst); + break; + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// LoopVectorizationAnalysis +//////////////////////////////////////////////////////////////////////////////// + +bool LoopVectorizationAnalysis::runOnFunction(Function &F) { + + // Legality is per loop..... + // Legal = SLVLoopVectorizationLegality() + return false; +} + +void LoopVectorizationAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + + AU.setPreservesAll(); +} + +char LoopVectorizationAnalysis::ID = 0; +static const char lva_name[] = "Loop Vectorization Analysis"; + +INITIALIZE_PASS_BEGIN(LoopVectorizationAnalysis, LVA_NAME, lva_name, false, true) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(LoopVectorizationAnalysis, LVA_NAME, lva_name, false, true) + +// TODO: Not needed? Remove from Scalar.h? +Pass *createLVAPass() { return new LoopVectorizationAnalysis(); } Index: lib/Transforms/Vectorize/LoopVectorizationLegality.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -15,6 +15,7 @@ // is a need (but D45420 needs to happen first). // #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" +#include "llvm/Transforms/LVCommon.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/IntrinsicInst.h" @@ -23,23 +24,6 @@ #define LV_NAME "loop-vectorize" #define DEBUG_TYPE LV_NAME -static cl::opt - EnableIfConversion("enable-if-conversion", cl::init(true), cl::Hidden, - cl::desc("Enable if-conversion during vectorization.")); - -static cl::opt PragmaVectorizeMemoryCheckThreshold( - "pragma-vectorize-memory-check-threshold", cl::init(128), cl::Hidden, - cl::desc("The maximum allowed number of runtime memory checks with a " - "vectorize(enable) pragma.")); - -static cl::opt VectorizeSCEVCheckThreshold( - "vectorize-scev-check-threshold", cl::init(16), cl::Hidden, - cl::desc("The maximum number of SCEV checks allowed.")); - -static cl::opt PragmaVectorizeSCEVCheckThreshold( - "pragma-vectorize-scev-check-threshold", cl::init(128), cl::Hidden, - cl::desc("The maximum number of SCEV checks allowed with a " - "vectorize(enable) pragma")); /// Maximum vectorization interleave count. static const unsigned MaxInterleaveFactor = 16; @@ -616,8 +600,8 @@ } RecurrenceDescriptor RedDes; - if (RecurrenceDescriptor::isReductionPHI(Phi, TheLoop, RedDes, DB, AC, - DT)) { + if (RecurrenceDescriptor::isReductionPHI(Phi, TheLoop, PSE.getSE(), + RedDes, false, DB, AC, DT)) { if (RedDes.hasUnsafeAlgebra()) Requirements->addUnsafeAlgebraInst(RedDes.getUnsafeAlgebraInst()); AllowedExit.insert(RedDes.getLoopExitInstr()); @@ -755,10 +739,39 @@ return false; if (LAI->hasStoreToLoopInvariantAddress()) { - ORE->emit(createMissedAnalysis("CantVectorizeStoreToLoopInvariantAddress") - << "write to a loop invariant address could not be vectorized"); - LLVM_DEBUG(dbgs() << "LV: We don't allow storing to uniform addresses\n"); - return false; + ScalarEvolution *SE = PSE.getSE(); + SmallVector UnhandledStores; + + // For each invariant address, check its last stored value is the result + // of one of our reductions and is unconditional. + for (StoreInst *SI : LAI->getInvariantStores()) { + bool FoundMatchingRecurrence = false; + for (auto &II : Reductions) { + RecurrenceDescriptor DS = II.second; + StoreInst *DSI = DS.IntermediateStore; + if (DSI && (DSI == SI) && !blockNeedsPredication(DSI->getParent())) { + FoundMatchingRecurrence = true; + break; + } + } + + if (FoundMatchingRecurrence) + // Earlier stores to this address are effectively deadcode. + llvm::remove_if(UnhandledStores, [SE, SI](StoreInst *I) { + return storeToSameAddress(SE, SI, I); + }); + else + UnhandledStores.push_back(SI); + } + + bool IsOK = UnhandledStores.empty(); + // TODO: we should also validate against InvariantMemSets. + if (!IsOK) { + ORE->emit(createMissedAnalysis("CantVectorizeStoreToLoopInvariantAddress") + << "write to a loop invariant address could not be vectorized"); + LLVM_DEBUG(dbgs() << "LV: We don't allow storing to uniform addresses\n"); + return false; + } } Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); Index: lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- lib/Transforms/Vectorize/LoopVectorize.cpp +++ lib/Transforms/Vectorize/LoopVectorize.cpp @@ -132,6 +132,8 @@ #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/LVCommon.h" +#include "llvm/Transforms/Vectorize.h" #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" #include #include @@ -154,92 +156,10 @@ STATISTIC(LoopsVectorized, "Number of loops vectorized"); STATISTIC(LoopsAnalyzed, "Number of loops analyzed for vectorization"); -/// Loops with a known constant trip count below this number are vectorized only -/// if no scalar iteration overheads are incurred. -static cl::opt TinyTripCountVectorThreshold( - "vectorizer-min-trip-count", cl::init(16), cl::Hidden, - cl::desc("Loops with a constant trip count that is smaller than this " - "value are vectorized only if no scalar iteration overheads " - "are incurred.")); - -static cl::opt MaximizeBandwidth( - "vectorizer-maximize-bandwidth", cl::init(false), cl::Hidden, - cl::desc("Maximize bandwidth when selecting vectorization factor which " - "will be determined by the smallest type in loop.")); - -static cl::opt EnableInterleavedMemAccesses( - "enable-interleaved-mem-accesses", cl::init(false), cl::Hidden, - cl::desc("Enable vectorization on interleaved memory accesses in a loop")); - -/// Maximum factor for an interleaved memory access. -static cl::opt MaxInterleaveGroupFactor( - "max-interleave-group-factor", cl::Hidden, - cl::desc("Maximum factor for an interleaved access group (default = 8)"), - cl::init(8)); - /// We don't interleave loops with a known constant trip count below this /// number. static const unsigned TinyTripCountInterleaveThreshold = 128; -static cl::opt ForceTargetNumScalarRegs( - "force-target-num-scalar-regs", cl::init(0), cl::Hidden, - cl::desc("A flag that overrides the target's number of scalar registers.")); - -static cl::opt ForceTargetNumVectorRegs( - "force-target-num-vector-regs", cl::init(0), cl::Hidden, - cl::desc("A flag that overrides the target's number of vector registers.")); - -static cl::opt ForceTargetMaxScalarInterleaveFactor( - "force-target-max-scalar-interleave", cl::init(0), cl::Hidden, - cl::desc("A flag that overrides the target's max interleave factor for " - "scalar loops.")); - -static cl::opt ForceTargetMaxVectorInterleaveFactor( - "force-target-max-vector-interleave", cl::init(0), cl::Hidden, - cl::desc("A flag that overrides the target's max interleave factor for " - "vectorized loops.")); - -static cl::opt ForceTargetInstructionCost( - "force-target-instruction-cost", cl::init(0), cl::Hidden, - cl::desc("A flag that overrides the target's expected cost for " - "an instruction to a single constant value. Mostly " - "useful for getting consistent testing.")); - -static cl::opt SmallLoopCost( - "small-loop-cost", cl::init(20), cl::Hidden, - cl::desc( - "The cost of a loop that is considered 'small' by the interleaver.")); - -static cl::opt LoopVectorizeWithBlockFrequency( - "loop-vectorize-with-block-frequency", cl::init(true), cl::Hidden, - cl::desc("Enable the use of the block frequency analysis to access PGO " - "heuristics minimizing code growth in cold regions and being more " - "aggressive in hot regions.")); - -// Runtime interleave loops for load/store throughput. -static cl::opt EnableLoadStoreRuntimeInterleave( - "enable-loadstore-runtime-interleave", cl::init(true), cl::Hidden, - cl::desc( - "Enable runtime interleaving until load/store ports are saturated")); - -/// The number of stores in a loop that are allowed to need predication. -static cl::opt NumberOfStoresToPredicate( - "vectorize-num-stores-pred", cl::init(1), cl::Hidden, - cl::desc("Max number of stores to be predicated behind an if.")); - -static cl::opt EnableIndVarRegisterHeur( - "enable-ind-var-reg-heur", cl::init(true), cl::Hidden, - cl::desc("Count the induction variable only once when interleaving")); - -static cl::opt EnableCondStoresVectorization( - "enable-cond-stores-vec", cl::init(true), cl::Hidden, - cl::desc("Enable if predication of stores during vectorization.")); - -static cl::opt MaxNestedScalarReductionIC( - "max-nested-scalar-reduction-interleave", cl::init(2), cl::Hidden, - cl::desc("The maximum interleave count to use when interleaving a scalar " - "reduction in a nested loop.")); - static cl::opt EnableVPlanNativePath( "enable-vplan-native-path", cl::init(false), cl::Hidden, cl::desc("Enable VPlan-native vectorization path with " @@ -1610,6 +1530,14 @@ /// Values to ignore in the cost model when VF > 1. SmallPtrSet VecValuesToIgnore; + + /// Estimate cost of a call instruction CI if it were vectorized + /// with factor VF. Return the cost of the instruction, including + /// scalarization overhead if it's needed. The flag NeedToScalarize + /// shows if the call needs to be scalarized - i.e. either vector + /// version isn't available, or is too expensive. + unsigned getVectorCallCost(CallInst *CI, unsigned VF, + bool &NeedToScalarize) const; }; } // end namespace llvm @@ -1765,7 +1693,7 @@ Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); // Broadcast the scalar into all locations in the vector. - Value *Shuf = Builder.CreateVectorSplat(VF, V, "broadcast"); + Value *Shuf = Builder.CreateVectorSplat({VF, false}, V, "broadcast"); return Shuf; } @@ -1786,7 +1714,7 @@ Step = Builder.CreateTrunc(Step, TruncType); Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType); } - Value *SplatStart = Builder.CreateVectorSplat(VF, Start); + Value *SplatStart = Builder.CreateVectorSplat({VF, false}, Start); Value *SteppedStart = getStepVector(SplatStart, 0, Step, II.getInductionOpcode()); @@ -1813,8 +1741,8 @@ // IRBuilder. IRBuilder can constant-fold the multiply, but it doesn't // handle a constant vector splat. Value *SplatVF = isa(Mul) - ? ConstantVector::getSplat(VF, cast(Mul)) - : Builder.CreateVectorSplat(VF, Mul); + ? ConstantVector::getSplat({VF, false}, cast(Mul)) + : Builder.CreateVectorSplat({VF, false}, Mul); Builder.restoreIP(CurrIP); // We may need to add the step a number of times, depending on the unroll @@ -1991,7 +1919,7 @@ Instruction::BinaryOps BinOp) { // Create and check the types. assert(Val->getType()->isVectorTy() && "Must be a vector"); - int VLen = Val->getType()->getVectorNumElements(); + unsigned VLen = Val->getType()->getVectorNumElements(); Type *STy = Val->getType()->getScalarType(); assert((STy->isIntegerTy() || STy->isFloatingPointTy()) && @@ -2002,13 +1930,13 @@ if (STy->isIntegerTy()) { // Create a vector of consecutive numbers from zero to VF. - for (int i = 0; i < VLen; ++i) + for (unsigned i = 0; i < VLen; ++i) Indices.push_back(ConstantInt::get(STy, StartIdx + i)); // Add the consecutive indices to the vector value. Constant *Cv = ConstantVector::get(Indices); assert(Cv->getType() == Val->getType() && "Invalid consecutive vec"); - Step = Builder.CreateVectorSplat(VLen, Step); + Step = Builder.CreateVectorSplat({VLen, false}, Step); assert(Step->getType() == Val->getType() && "Invalid step vec"); // FIXME: The newly created binary instructions should contain nsw/nuw flags, // which can be found from the original scalar operations. @@ -2020,13 +1948,13 @@ assert((BinOp == Instruction::FAdd || BinOp == Instruction::FSub) && "Binary Opcode should be specified for FP induction"); // Create a vector of consecutive numbers from zero to VF. - for (int i = 0; i < VLen; ++i) + for (unsigned i = 0; i < VLen; ++i) Indices.push_back(ConstantFP::get(STy, (double)(StartIdx + i))); // Add the consecutive indices to the vector value. Constant *Cv = ConstantVector::get(Indices); - Step = Builder.CreateVectorSplat(VLen, Step); + Step = Builder.CreateVectorSplat({VLen, false}, Step); // Floating point operations had to be 'fast' to enable the induction. FastMathFlags Flags; @@ -2525,6 +2453,16 @@ bool IfPredicateInstr) { assert(!Instr->getType()->isAggregateType() && "Can't handle vectors"); + // Don't create a memory instruction for an intermediate store of a + // reduction variable, because this will be one to a uniform address. + if (StoreInst *SI = dyn_cast(Instr)) { + for (auto &Reduction : *Legal->getReductionVars()) { + RecurrenceDescriptor DS = Reduction.second; + if (DS.IntermediateStore && + storeToSameAddress(PSE.getSE(), SI, DS.IntermediateStore)) + return; + } + } setDebugLocFromInst(Builder, Instr); // Does this instruction return a value ? @@ -3109,27 +3047,24 @@ !TTI.supportsEfficientVectorElementLoadStore())) Cost += TTI.getScalarizationOverhead(RetTy, true, false); + VectorType::ElementCount EC(VF, false); if (CallInst *CI = dyn_cast(I)) { SmallVector Operands(CI->arg_operands()); - Cost += TTI.getOperandsScalarizationOverhead(Operands, VF); + Cost += TTI.getOperandsScalarizationOverhead(Operands, EC); } else if (!isa(I) || !TTI.supportsEfficientVectorElementLoadStore()) { SmallVector Operands(I->operand_values()); - Cost += TTI.getOperandsScalarizationOverhead(Operands, VF); + Cost += TTI.getOperandsScalarizationOverhead(Operands, EC); } return Cost; } -// Estimate cost of a call instruction CI if it were vectorized with factor VF. -// Return the cost of the instruction, including scalarization overhead if it's -// needed. The flag NeedToScalarize shows if the call needs to be scalarized - -// i.e. either vector version isn't available, or is too expensive. -static unsigned getVectorCallCost(CallInst *CI, unsigned VF, - const TargetTransformInfo &TTI, - const TargetLibraryInfo *TLI, - bool &NeedToScalarize) { +unsigned +LoopVectorizationCostModel::getVectorCallCost(CallInst *CI, unsigned VF, + bool &NeedToScalarize) const { + VectorType::ElementCount EC(VF, false); Function *F = CI->getCalledFunction(); StringRef FnName = CI->getCalledFunction()->getName(); Type *ScalarRetTy = CI->getType(); @@ -3147,9 +3082,13 @@ // Compute corresponding vector type for return value and arguments. Type *RetTy = ToVectorTy(ScalarRetTy, VF); - for (Type *ScalarTy : ScalarTys) - Tys.push_back(ToVectorTy(ScalarTy, VF)); - + for (auto &Op : CI->arg_operands()) { + Type *ScalarTy = Op->getType(); + if (ScalarTy->isPointerTy() && Legal->isConsecutivePtr(Op)) + Tys.push_back(ScalarTy); + else + Tys.push_back(ToVectorTy(ScalarTy, VF)); + } // Compute costs of unpacking argument values for the scalar calls and // packing the return values to a vector. unsigned ScalarizationCost = getScalarizationOverhead(CI, VF, TTI); @@ -3159,7 +3098,10 @@ // If we can't emit a vector call for this function, then the currently found // cost is the cost we need to return. NeedToScalarize = true; - if (!TLI || !TLI->isFunctionVectorizable(FnName, VF) || CI->isNoBuiltin()) + FunctionType *FTy = FunctionType::get(RetTy, Tys, false); + if (!TLI || + !TLI->isFunctionVectorizable(FnName, EC, false /* NoMask */, FTy) || + CI->isNoBuiltin()) return Cost; // If the corresponding vector cost is cheaper, return its cost. @@ -3184,8 +3126,9 @@ if (auto *FPMO = dyn_cast(CI)) FMF = FPMO->getFastMathFlags(); + VectorType::ElementCount EC(VF, false); SmallVector Operands(CI->arg_operands()); - return TTI.getIntrinsicInstrCost(ID, CI->getType(), Operands, FMF, VF); + return TTI.getIntrinsicInstrCost(ID, CI->getType(), Operands, FMF, EC); } static Type *smallestIntegerVectorType(Type *T1, Type *T2) { @@ -3573,13 +3516,18 @@ Value *Identity; Value *VectorStart; if (RK == RecurrenceDescriptor::RK_IntegerMinMax || - RK == RecurrenceDescriptor::RK_FloatMinMax) { - // MinMax reduction have the start value as their identify. + RK == RecurrenceDescriptor::RK_FloatMinMax || + RK == RecurrenceDescriptor::RK_ConstSelectICmp || + RK == RecurrenceDescriptor::RK_ConstSelectFCmp) { + // MinMax and IntCond reductions have the start value as their identify. if (VF == 1) { VectorStart = Identity = ReductionStartValue; } else { + const char *Ident = (RK == RecurrenceDescriptor::RK_ConstSelectICmp || + RK == RecurrenceDescriptor::RK_ConstSelectFCmp) ? + "intcond.ident" : "minmax.ident"; VectorStart = Identity = - Builder.CreateVectorSplat(VF, ReductionStartValue, "minmax.ident"); + Builder.CreateVectorSplat({VF, false}, ReductionStartValue, Ident); } } else { // Handle other reduction kinds: @@ -3591,7 +3539,7 @@ // incoming scalar reduction. VectorStart = ReductionStartValue; } else { - Identity = ConstantVector::getSplat(VF, Iden); + Identity = ConstantVector::getSplat({VF, false}, Iden); // This vector is the Identity vector where the first element is the // incoming scalar reduction. @@ -3691,6 +3639,17 @@ BCBlockPhi->addIncoming(ReductionStartValue, LoopBypassBlocks[I]); BCBlockPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock); + // If there were stores of the reduction value to a uniform memory address + // inside the loop, create the final store here. + if (StoreInst *SI = RdxDesc.IntermediateStore) { + StoreInst *NewSI = Builder.CreateStore(ReducedPartRdx, + SI->getPointerOperand()); + propagateMetadata(NewSI, SI); + + // If the reduction value is used in other places, + // then let the code below create PHI's for that. + } + // Now, we need to fix the users of the reduction variable // inside and outside of the scalar remainder loop. // We know that the loop is in LCSSA form. We need to update the @@ -3902,7 +3861,7 @@ // the lane-zero scalar value. auto *Clone = Builder.Insert(GEP->clone()); for (unsigned Part = 0; Part < UF; ++Part) { - Value *EntryPart = Builder.CreateVectorSplat(VF, Clone); + Value *EntryPart = Builder.CreateVectorSplat({VF, false}, Clone); VectorLoopValueMap.setVectorValue(&I, Part, EntryPart); addMetadata(EntryPart, GEP); } @@ -4077,17 +4036,13 @@ StringRef FnName = CI->getCalledFunction()->getName(); Function *F = CI->getCalledFunction(); Type *RetTy = ToVectorTy(CI->getType(), VF); - SmallVector Tys; - for (Value *ArgOperand : CI->arg_operands()) - Tys.push_back(ToVectorTy(ArgOperand->getType(), VF)); - Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); // The flag shows whether we use Intrinsic or a usual Call for vectorized // version of the instruction. // Is it beneficial to perform intrinsic call compared to lib call? bool NeedToScalarize; - unsigned CallCost = getVectorCallCost(CI, VF, *TTI, TLI, NeedToScalarize); + unsigned CallCost = Cost->getVectorCallCost(CI, VF, NeedToScalarize); bool UseVectorIntrinsic = ID && getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost; assert((UseVectorIntrinsic || !NeedToScalarize) && @@ -4112,13 +4067,33 @@ TysForDecl[0] = VectorType::get(CI->getType()->getScalarType(), VF); VectorF = Intrinsic::getDeclaration(M, ID, TysForDecl); } else { + + SmallVector Tys; + for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) { + Value *Arg = CI->getArgOperand(i); + // Check if the argument `x` is a pointer marked by an + // OpenMP clause `linear(x:1)`. + if (Arg->getType()->isPointerTy() && + (Legal->isConsecutivePtr(Arg) == 1) && + isa(Args[i]->getType())) { + LLVM_DEBUG(dbgs() << "LV: vectorizing " << *Arg + << " as a linear pointer with step 1"); + Args[i] = + Builder.CreateExtractElement(Args[i], Builder.getInt32(0)); + Tys.push_back(Arg->getType()); + } else + Tys.push_back(ToVectorTy(Arg->getType(), VF)); + } + // Use vector version of the library call. - StringRef VFnName = TLI->getVectorizedFunction(FnName, VF); + VectorType::ElementCount EC(VF, false); + FunctionType *FTy = FunctionType::get(RetTy, Tys, false); + const std::string VFnName = + TLI->getVectorizedFunction(FnName, EC, false /* NoMask */, FTy); assert(!VFnName.empty() && "Vector function name is empty."); VectorF = M->getFunction(VFnName); if (!VectorF) { // Generate a declaration - FunctionType *FTy = FunctionType::get(RetTy, Tys, false); VectorF = Function::Create(FTy, Function::ExternalLinkage, VFnName, M); VectorF->copyAttributesFrom(F); @@ -4129,6 +4104,7 @@ SmallVector OpBundles; CI->getOperandBundlesAsDefs(OpBundles); CallInst *V = Builder.CreateCall(VectorF, Args, OpBundles); + TLI->setCallingConv(V); if (isa(V)) V->copyFastMathFlags(CI); @@ -5094,8 +5070,10 @@ if (ValuesToIgnore.count(&I)) continue; - // Only examine Loads, Stores and PHINodes. - if (!isa(I) && !isa(I) && !isa(I)) + // Examine Loads, Stores, PHINodes + // Also examine instructions which convert to a float/double + if (!isa(I) && !isa(I) && !isa(I) && + !isa(I) && !isa(I) && !isa(I)) continue; // Examine PHI nodes that are reduction variables. Update the type to @@ -5702,15 +5680,28 @@ const SCEV *PtrSCEV = getAddressAccessSCEV(Ptr, Legal, PSE, TheLoop); // Get the cost of the scalar memory instruction and address computation. - unsigned Cost = VF * TTI.getAddressComputationCost(PtrTy, SE, PtrSCEV); - Cost += VF * + // In practice if the address is invariant, then only the last store to the + // address really matters. Some typical IR after scalarization looks like: + // %6 = extractelement <2 x double> %5, i32 0 + // store double %6, double* %r, align 8 + // %7 = extractelement <2 x double> %5, i32 1 + // store double %7, double* %r, align 8 + // Any preceeding stores will be eliminated as redundant stores. In addition, + // dealing with the likely cost this way means we have more chance of + // vectorising the loop and sinking the store. + unsigned NumComps = + isa(I) && TheLoop->isLoopInvariant(Ptr) ? 1 : VF; + + unsigned Cost = NumComps * TTI.getAddressComputationCost(PtrTy, SE, PtrSCEV); + Cost += NumComps * TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), Alignment, AS, I); // Get the overhead of the extractelement and insertelement instructions // we might create due to scalarization. - Cost += getScalarizationOverhead(I, VF, TTI); + // If the address is invariant then this is really just a dup across lanes. + Cost += getScalarizationOverhead(I, NumComps, TTI); // If we have a predicated store, it may not be executed for each vector // lane. Scale the cost by the probability of executing the predicated @@ -6210,7 +6201,7 @@ case Instruction::Call: { bool NeedToScalarize; CallInst *CI = cast(I); - unsigned CallCost = getVectorCallCost(CI, VF, TTI, TLI, NeedToScalarize); + unsigned CallCost = getVectorCallCost(CI, VF, NeedToScalarize); if (getVectorIntrinsicIDForCall(CI, TLI)) return std::min(CallCost, getVectorIntrinsicCost(CI, VF, TTI, TLI)); return CallCost; @@ -6765,7 +6756,7 @@ // version of the instruction. // Is it beneficial to perform intrinsic call compared to lib call? bool NeedToScalarize; - unsigned CallCost = getVectorCallCost(CI, VF, *TTI, TLI, NeedToScalarize); + unsigned CallCost = CM.getVectorCallCost(CI, VF, NeedToScalarize); bool UseVectorIntrinsic = ID && getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost; return UseVectorIntrinsic || !NeedToScalarize; Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -612,6 +612,16 @@ /// \returns the cost of the vectorizable entry. int getEntryCost(TreeEntry *E); + using OrdersType = SmallVector; + + void checkPossibleLoadSequence(ArrayRef VL, + InstructionsState &S, + OrdersType &CurrentLoadOrder, + bool &LoadsArePacked, + bool &LoadsAreSimple, + bool &ValuesAreConsecutiveLoads, + bool &ValuesAreJumbledLoads); + /// This is the recursive part of buildTree. void buildTree_rec(ArrayRef Roots, unsigned Depth, int); @@ -1178,7 +1188,6 @@ /// List of users to ignore during scheduling and that don't need extracting. ArrayRef UserIgnoreList; - using OrdersType = SmallVector; /// A DenseMapInfo implementation for holding DenseMaps and DenseSets of /// sorted SmallVectors of unsigned. struct OrdersTypeDenseMapInfo { @@ -1392,6 +1401,68 @@ } } +void BoUpSLP::checkPossibleLoadSequence(ArrayRef VL, + InstructionsState &S, + OrdersType &CurrentLoadOrder, + bool &LoadsArePacked, + bool &LoadsAreSimple, + bool &ValuesAreConsecutiveLoads, + bool &ValuesAreJumbledLoads) { + auto *VL0 = cast(S.OpValue); + LoadsArePacked = false; + LoadsAreSimple = false; + ValuesAreConsecutiveLoads = false; + ValuesAreJumbledLoads = false; + if (S.getOpcode() != Instruction::Load ) + return; + + Type *ScalarTy = VL0->getType(); + if (DL->getTypeSizeInBits(ScalarTy) != + DL->getTypeAllocSizeInBits(ScalarTy)) + return; + + LoadsArePacked = true; + LoadsAreSimple = true; + SmallVector PointerOps(VL.size()); + auto POIter = PointerOps.begin(); + for (Value *V : VL) { + auto *L = cast(V); + if (!L->isSimple()) { + LoadsAreSimple = false; + return; + } + *POIter = L->getPointerOperand(); + ++POIter; + } + + // Check the order of pointer operands. + if (LoadsAreSimple && + llvm::sortPtrAccesses(PointerOps, *DL, *SE, CurrentLoadOrder)) { + Value *Ptr0; + Value *PtrN; + if (CurrentLoadOrder.empty()) { + Ptr0 = PointerOps.front(); + PtrN = PointerOps.back(); + } else { + Ptr0 = PointerOps[CurrentLoadOrder.front()]; + PtrN = PointerOps[CurrentLoadOrder.back()]; + } + const SCEV *Scev0 = SE->getSCEV(Ptr0); + const SCEV *ScevN = SE->getSCEV(PtrN); + const auto *Diff = + dyn_cast(SE->getMinusSCEV(ScevN, Scev0)); + uint64_t Size = DL->getTypeAllocSize(ScalarTy); + // Check that the sorted loads are consecutive. + if (Diff && + (Diff->getAPInt().getZExtValue() == (VL.size() - 1) * Size)) { + if (CurrentLoadOrder.empty()) + ValuesAreConsecutiveLoads = true; + else + ValuesAreJumbledLoads = true; + } + } +} + void BoUpSLP::buildTree_rec(ArrayRef VL, unsigned Depth, int UserTreeIdx) { assert((allConstant(VL) || allSameType(VL)) && "Invalid types!"); @@ -1466,13 +1537,26 @@ } } - // If any of the scalars is marked as a value that needs to stay scalar, then - // we need to gather the scalars. - for (unsigned i = 0, e = VL.size(); i != e; ++i) { - if (MustGather.count(VL[i])) { - LLVM_DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); - newTreeEntry(VL, false, UserTreeIdx); - return; + // If any of the scalars is marked as a value that needs to stay scalar and + // the loads are not consecutive (and thus not partially loaded as a vector), + // then we need to gather the scalars. If the loads are consecutive, we can + // probably reuse part of the loaded vector by using a cheap shufflevector. + bool LoadsArePacked = false; + bool LoadsAreSimple = false; + bool ValuesAreConsecutiveLoads = false; + bool ValuesAreJumbledLoads = false; + OrdersType CurrentLoadOrder; + checkPossibleLoadSequence(VL, S, CurrentLoadOrder, LoadsArePacked, + LoadsAreSimple, ValuesAreConsecutiveLoads, + ValuesAreJumbledLoads); + + if (!ValuesAreConsecutiveLoads) { + for (unsigned i = 0, e = VL.size(); i != e; ++i) { + if (MustGather.count(VL[i])) { + LLVM_DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); + newTreeEntry(VL, false, UserTreeIdx); + return; + } } } @@ -1596,16 +1680,20 @@ return; } case Instruction::Load: { + // We have to recalculate the PointerOps again as the list may have + // changed from above. + CurrentLoadOrder.clear(); + checkPossibleLoadSequence(VL, S, CurrentLoadOrder, + LoadsArePacked, LoadsAreSimple, + ValuesAreConsecutiveLoads, ValuesAreJumbledLoads); + // Check that a vectorized load would load the same memory as a scalar // load. For example, we don't want to vectorize loads that are smaller // than 8-bit. Even though we have a packed struct {} LLVM // treats loading/storing it as an i8 struct. If we vectorize loads/stores // from such a struct, we read/write packed bits disagreeing with the // unvectorized version. - Type *ScalarTy = VL0->getType(); - - if (DL->getTypeSizeInBits(ScalarTy) != - DL->getTypeAllocSizeInBits(ScalarTy)) { + if (!LoadsArePacked) { BS.cancelScheduling(VL, VL0); newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n"); @@ -1614,55 +1702,28 @@ // Make sure all loads in the bundle are simple - we can't vectorize // atomic or volatile loads. - SmallVector PointerOps(VL.size()); - auto POIter = PointerOps.begin(); - for (Value *V : VL) { - auto *L = cast(V); - if (!L->isSimple()) { - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n"); - return; - } - *POIter = L->getPointerOperand(); - ++POIter; + if (!LoadsAreSimple) { + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n"); + return; } - OrdersType CurrentOrder; - // Check the order of pointer operands. - if (llvm::sortPtrAccesses(PointerOps, *DL, *SE, CurrentOrder)) { - Value *Ptr0; - Value *PtrN; - if (CurrentOrder.empty()) { - Ptr0 = PointerOps.front(); - PtrN = PointerOps.back(); - } else { - Ptr0 = PointerOps[CurrentOrder.front()]; - PtrN = PointerOps[CurrentOrder.back()]; - } - const SCEV *Scev0 = SE->getSCEV(Ptr0); - const SCEV *ScevN = SE->getSCEV(PtrN); - const auto *Diff = - dyn_cast(SE->getMinusSCEV(ScevN, Scev0)); - uint64_t Size = DL->getTypeAllocSize(ScalarTy); - // Check that the sorted loads are consecutive. - if (Diff && Diff->getAPInt().getZExtValue() == (VL.size() - 1) * Size) { - if (CurrentOrder.empty()) { - // Original loads are consecutive and does not require reordering. - ++NumOpsWantToKeepOriginalOrder; - newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, - ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: added a vector of loads.\n"); - } else { - // Need to reorder. - auto I = NumOpsWantToKeepOrder.try_emplace(CurrentOrder).first; - ++I->getSecond(); - newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, - ReuseShuffleIndicies, I->getFirst()); - LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled loads.\n"); - } - return; - } + if (ValuesAreConsecutiveLoads) { + // Original loads are consecutive and does not require reordering. + ++NumOpsWantToKeepOriginalOrder; + newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, + ReuseShuffleIndicies, None); + LLVM_DEBUG(dbgs() << "SLP: added a vector of loads.\n"); + return; + } else if (ValuesAreJumbledLoads) { + // Need to reorder. + auto I = NumOpsWantToKeepOrder.try_emplace(CurrentLoadOrder).first; + ++I->getSecond(); + newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, + ReuseShuffleIndicies, I->getFirst()); + LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled loads.\n"); + return; } LLVM_DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n"); @@ -1999,6 +2060,9 @@ if (!LI || !LI->isSimple() || !LI->hasNUses(VL.size())) return false; } else { + // Check the vector length is known. + if (Vec->getType()->getVectorIsScalable()) + return false; NElts = Vec->getType()->getVectorNumElements(); } @@ -2357,7 +2421,7 @@ SmallVector Args(CI->arg_operands()); int VecCallCost = TTI->getIntrinsicInstrCost(ID, CI->getType(), Args, FMF, - VecTy->getNumElements()); + VecTy->getElementCount()); LLVM_DEBUG(dbgs() << "SLP: Call cost " << VecCallCost - ScalarCallCost << " (" << VecCallCost << "-" << ScalarCallCost << ")" @@ -2632,7 +2696,16 @@ // Iterate in reverse order to consider insert elements with the high cost. for (unsigned I = VL.size(); I > 0; --I) { unsigned Idx = I - 1; - if (!UniqueElements.insert(VL[Idx]).second) + // If the element has already been loaded as part of another vector load + // then we model the cost as a shuffle rather than an element insert. In + // practice this is what happens as later passes fold away many of the + // inserts and extracts. If we don't do this then we model the cost as too + // high to consider vectorising and miss out on generating better code. + bool LoadedAsVector = false; + if (TreeEntry *E = getTreeEntry(VL[Idx])) + LoadedAsVector = !E->NeedToGather && isa(VL[Idx]) && + E->ReorderIndices.empty(); + if (!UniqueElements.insert(VL[Idx]).second || LoadedAsVector) ShuffledElements.insert(Idx); } return getGatherCost(VecTy, ShuffledElements); Index: lib/Transforms/Vectorize/SVELoopVectorize.cpp =================================================================== --- /dev/null +++ lib/Transforms/Vectorize/SVELoopVectorize.cpp @@ -0,0 +1,8729 @@ +//===- SVELoopVectorize.cpp - A Loop Vectorizer ------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This is the LLVM loop vectorizer. This pass modifies 'vectorizable' loops +// and generates target-independent LLVM-IR. +// The vectorizer uses the TargetTransformInfo analysis to estimate the costs +// of instructions in order to estimate the profitability of vectorization. +// +// The loop vectorizer combines consecutive loop iterations into a single +// 'wide' iteration. After this transformation the index is incremented +// by the SIMD vector width, and not by one. +// +// This pass has three parts: +// 1. The main loop pass that drives the different parts. +// 2. LoopVectorizationLegality - A unit that checks for the legality +// of the vectorization. +// 3. InnerLoopVectorizer - A unit that performs the actual +// widening of instructions. +// 4. LoopVectorizationCostModel - A unit that checks for the profitability +// of vectorization. It decides on the optimal vector width, which +// can be one, if vectorization is not profitable. +// +//===----------------------------------------------------------------------===// +// +// The reduction-variable vectorization is based on the paper: +// D. Nuzman and R. Henderson. Multi-platform Auto-vectorization. +// +// Variable uniformity checks are inspired by: +// Karrenberg, R. and Hack, S. Whole Function Vectorization. +// +// The interleaved access vectorization is based on the paper: +// Dorit Nuzman, Ira Rosen and Ayal Zaks. Auto-Vectorization of Interleaved +// Data for SIMD +// +// Other ideas/concepts are from: +// A. Zaks and D. Nuzman. Autovectorization in GCC-two years later. +// +// S. Maleki, Y. Gao, M. Garzaran, T. Wong and D. Padua. An Evaluation of +// Vectorizing Compilers. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Vectorize/LoopVectorize.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Pass.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/LVCommon.h" +#include "llvm/Transforms/Vectorize.h" +#include "llvm/Support/BlockFrequency.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include +#include +#include +#include + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define LV_NAME "sve-loop-vectorize" +#define DEBUG_TYPE LV_NAME +#ifndef NDEBUG +#define NODEBUG_EARLY_BAILOUT() \ + do { if (!::llvm::DebugFlag || !::llvm::isCurrentDebugType(DEBUG_TYPE)) \ + { return false; } } while (0) +#else +#define NODEBUG_EARLY_BAILOUT() { return false; } +#endif + +STATISTIC(LoopsVectorized, "Number of loops vectorized"); +STATISTIC(LoopsVectorizedWA, "Number of loops vectorized with WA"); +STATISTIC(LoopsAnalyzed, "Number of loops analyzed for vectorization"); + +/// We don't interleave loops with a known constant trip count below this +/// number. +static const unsigned TinyTripCountInterleaveThreshold = 128; + +/// Maximum vectorization interleave count. +const unsigned MaxInterleaveFactor = 16; + + +static cl::opt EnableScalableVectorisation( + "force-scalable-vectorization", cl::init(true), cl::Hidden, + cl::ZeroOrMore, + cl::desc("Enable vectorization using scalable vectors")); + +static cl::opt EnableVectorPredication( + "force-vector-predication", cl::init(true), cl::Hidden, cl::ZeroOrMore, + cl::desc("Enable predicated vector operations.")); + +static cl::opt EnableNonConsecutiveStrideIndVars( + "enable-non-consecutive-stride-ind-vars", cl::init(false), cl::Hidden, + cl::desc("Enable recognition of induction variables that aren't " + "consecutive between loop iterations")); + +static cl::opt VectorizerMemSetThreshold( + "vectorize-memset-threshold", cl::init(8), + cl::Hidden, cl::desc("Maximum (write size in bytes / aligment)" + " ratio for the memset.")); + +static cl::opt VectorizeMemset( + "vectorize-memset", cl::init(true), cl::Hidden, + cl::desc("Enable vectorization of loops with memset calls in the loop " + "body")); + +/// Create an analysis remark that explains why vectorization failed +/// +/// \p PassName is the name of the pass (e.g. can be AlwaysPrint). \p +/// RemarkName is the identifier for the remark. If \p I is passed it is an +/// instruction that prevents vectorization. Otherwise \p TheLoop is used for +/// the location of the remark. \return the remark object that can be +/// streamed to. +static OptimizationRemarkAnalysis +createMissedAnalysis(const char *PassName, StringRef RemarkName, Loop *TheLoop, + Instruction *I = nullptr) { + Value *CodeRegion = TheLoop->getHeader(); + DebugLoc StartLoc = TheLoop->getLocRange().getStart(); + + if (I) { + CodeRegion = I->getParent(); + // If there is no debug location attached to the instruction, revert back to + // using the loop's. + if (I->getDebugLoc()) + StartLoc = I->getDebugLoc(); + } + + auto LocRange = DiagnosticLocation(StartLoc); + OptimizationRemarkAnalysis R(PassName, RemarkName, LocRange, CodeRegion); + R << "loop not vectorized: "; + return R; +} + +namespace { + +// Forward declarations. +class LoopVectorizeHints; +class LoopVectorizationLegality; +class LoopVectorizationCostModel; +class LoopVectorizationRequirements; + +/// Information about vectorization costs +struct VectorizationFactor { + unsigned Width; // Vector width with best cost + unsigned Cost; // Cost of the loop with that width + bool isFixed; // Is the width an absolute value or a scale. +}; + +/// Returns true if the given loop body has a cycle, excluding the loop +/// itself. +static bool hasCyclesInLoopBody(const Loop &L) { + if (!L.empty()) + return true; + + for (const auto &SCC : + make_range(scc_iterator::begin(L), + scc_iterator::end(L))) { + if (SCC.size() > 1) { + LLVM_DEBUG(dbgs() << "LVL: Detected a cycle in the loop body:\n"); + LLVM_DEBUG(L.dump()); + return true; + } + } + return false; +} + +/// A helper function for converting Scalar types to vector types. +/// If the incoming type is void, we return void. If the VF is 1, we return +/// the scalar type. +static Type *ToVectorTy(Type *Scalar, unsigned VF, bool IsScalable) { + if (Scalar->isVoidTy() || VF == 1) + return Scalar; + return VectorType::get(Scalar, VF, IsScalable); +} + +static Type* ToVectorTy(Type *Scalar, VectorizationFactor VF) { + if (Scalar->isVoidTy() || VF.Width == 1) + return Scalar; + return VectorType::get(Scalar, VF.Width, !VF.isFixed); +} + +/// A helper function that returns GEP instruction and knows to skip a +/// 'bitcast'. The 'bitcast' may be skipped if the source and the destination +/// pointee types of the 'bitcast' have the same size. +/// For example: +/// bitcast double** %var to i64* - can be skipped +/// bitcast double** %var to i8* - can not +static GetElementPtrInst *getGEPInstruction(Value *Ptr) { + + if (isa(Ptr)) + return cast(Ptr); + + if (isa(Ptr) && + isa(cast(Ptr)->getOperand(0))) { + Type *BitcastTy = Ptr->getType(); + Type *GEPTy = cast(Ptr)->getSrcTy(); + if (!isa(BitcastTy) || !isa(GEPTy)) + return nullptr; + Type *Pointee1Ty = cast(BitcastTy)->getPointerElementType(); + Type *Pointee2Ty = cast(GEPTy)->getPointerElementType(); + const DataLayout &DL = cast(Ptr)->getModule()->getDataLayout(); + if (DL.getTypeSizeInBits(Pointee1Ty) == DL.getTypeSizeInBits(Pointee2Ty)) + return cast(cast(Ptr)->getOperand(0)); + } + return nullptr; +} + +// FIXME: The following helper functions have multiple implementations +// in the project. They can be effectively organized in a common Load/Store +// utilities unit. + + +/// A helper function that returns the alignment of load or store instruction. +static unsigned getMemInstAlignment(Value *I) { + assert((isa(I) || isa(I)) && + "Expected Load or Store instruction"); + if (auto *LI = dyn_cast(I)) + return LI->getAlignment(); + return cast(I)->getAlignment(); +} + +/// A helper function that returns the address space of the pointer operand of +/// load or store instruction. +static unsigned getMemInstAddressSpace(Value *I) { + assert((isa(I) || isa(I)) && + "Expected Load or Store instruction"); + if (auto *LI = dyn_cast(I)) + return LI->getPointerAddressSpace(); + return cast(I)->getPointerAddressSpace(); +} + +/// A helper function that adds a 'fast' flag to floating-point operations. +static Value *addFastMathFlag(Value *V) { + if (isa(V)) { + FastMathFlags Flags; + Flags.setFast(true); + cast(V)->setFastMathFlags(Flags); + } + return V; +} + +/// InnerLoopVectorizer vectorizes loops which contain only one basic +/// block to a specified vectorization factor (VF). +/// This class performs the widening of scalars into vectors, or multiple +/// scalars. This class also implements the following features: +/// * It inserts an epilogue loop for handling loops that don't have iteration +/// counts that are known to be a multiple of the vectorization factor. +/// * It handles the code generation for reduction variables. +/// * Scalarization (implementation using scalars) of un-vectorizable +/// instructions. +/// InnerLoopVectorizer does not perform any vectorization-legality +/// checks, and relies on the caller to check for the different legality +/// aspects. The InnerLoopVectorizer relies on the +/// LoopVectorizationLegality class to provide information about the induction +/// and reduction variables that were found to a given vectorization factor. +class InnerLoopVectorizer { +public: + InnerLoopVectorizer(Loop *OrigLoop, PredicatedScalarEvolution &PSE, + LoopInfo *LI, DominatorTree *DT, + const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI, AssumptionCache *AC, + OptimizationRemarkEmitter *ORE, + unsigned VecWidth, unsigned UnrollFactor, + bool VecWidthIsFixed) + : OrigLoop(OrigLoop), PSE(PSE), LI(LI), DT(DT), TLI(TLI), TTI(TTI), + AC(AC), ORE(ORE), VF(VecWidth), Scalable(!VecWidthIsFixed), + UsePredication(EnableVectorPredication && isScalable()), + UF(UnrollFactor), Builder(PSE.getSE()->getContext()), + Induction(nullptr), OldInduction(nullptr), WidenMap(UnrollFactor), + VecBodyPostDom(nullptr), TripCount(nullptr), VectorTripCount(nullptr), + Legal(nullptr), AddedSafetyChecks(false), LatchBranch(nullptr), + IdxEnd(nullptr), IdxEndV(nullptr) {} + + // Perform the actual loop widening (vectorization). + // MinimumBitWidths maps scalar integer values to the smallest bitwidth they + // can be validly truncated to. The cost model has assumed this truncation + // will happen when vectorizing. + void vectorize(LoopVectorizationLegality *L, + MapVector MinimumBitWidths) { + MinBWs = MinimumBitWidths; + Legal = L; + // Create a new empty loop. Unlink the old loop and connect the new one. + if (UsePredication) + createEmptyLoopWithPredication(); + else + createEmptyLoop(); + // Widen each instruction in the old loop to a new one in the new loop. + // Use the Legality module to find the induction and reduction variables. + vectorizeLoop(); + } + + // Return true if any runtime check is added. + bool areSafetyChecksAdded() { return AddedSafetyChecks; } + + virtual ~InnerLoopVectorizer() {} + + bool isScalable() const { + return (VF > 1) && Scalable; + } + +protected: + /// A small list of PHINodes. + typedef SmallVector PhiVector; + /// When we unroll loops we have multiple vector values for each scalar. + /// This data structure holds the unrolled and vectorized values that + /// originated from one scalar instruction. + typedef SmallVector VectorParts; + + // When we if-convert we need to create edge masks. We have to cache values + // so that we don't end up with exponential recursion/IR. + typedef DenseMap, VectorParts> + EdgeMaskCache; + + /// \brief Add checks for strides that were assumed to be 1. + /// + /// Returns the last check instruction and the first check instruction in the + /// pair as (first, last). + std::pair addStrideCheck(Instruction *Loc); + + /// Create an empty loop, based on the loop ranges of the old loop. + void createEmptyLoop(); + /// Create an empty loop, using per-element predication to control termination + void createEmptyLoopWithPredication(); + + /// Set up the values of the IVs correctly when exiting the vector loop. + void fixupIVUsers(PHINode *OrigPhi, const InductionDescriptor &II, + Value *CountRoundDown, Value *EndValue, + BasicBlock *MiddleBlock); + + /// Create a new induction variable inside L. + PHINode *createInductionVariable(Loop *L, Value *Start, Value *End, + Value *Step, Instruction *DL); + /// Copy and widen the instructions from the old loop. + virtual void vectorizeLoop(); + + /// Fix a first-order recurrence. This is the second phase of vectorizing + /// this phi node. + void fixFirstOrderRecurrence(PHINode *Phi); + + /// \brief The Loop exit block may have single value PHI nodes where the + /// incoming value is 'Undef'. While vectorizing we only handled real values + /// that were defined inside the loop. Here we fix the 'undef case'. + /// See PR14725. + void fixLCSSAPHIs(); + + /// Shrinks vector element sizes based on information in "MinBWs". + void truncateToMinimalBitwidths(); + + /// A helper function that computes the predicate of the block BB, assuming + /// that the header block of the loop is set to True. It returns the *entry* + /// mask for the block BB. + VectorParts createBlockInMask(BasicBlock *BB); + /// A helper function that computes the predicate of the edge between SRC + /// and DST. + VectorParts createEdgeMask(BasicBlock *Src, BasicBlock *Dst); + + /// A helper function to vectorize a single BB within the innermost loop. + void vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV); + + /// Vectorize a single PHINode in a block. This method handles the induction + /// variable canonicalization. It supports both VF = 1 for unrolled loops and + /// arbitrary length vectors. + void widenPHIInstruction(Instruction *PN, VectorParts &Entry, unsigned UF, + unsigned VF, PhiVector *PV); + + // Patch up the condition for a branch instruction after the block has been + // vectorized; only used with predication for now. + void patchLatchBranch(BranchInst *Br); + + /// Insert the new loop to the loop hierarchy and pass manager + /// and update the analysis passes. + void updateAnalysis(); + + /// This instruction is un-vectorizable. Implement it as a sequence + /// of scalars. If \p IfPredicateStore is true we need to 'hide' each + /// scalarized instruction behind an if block predicated on the control + /// dependence of the instruction. + virtual void scalarizeInstruction(Instruction *Instr, + bool IfPredicateStore = false); + + /// Vectorize Load and Store instructions, + virtual void vectorizeMemoryInstruction(Instruction *Instr); + virtual void vectorizeArithmeticGEP(Instruction *Instr); + virtual void vectorizeGEPInstruction(Instruction *Instr); + virtual void vectorizeMemsetInstruction(MemSetInst *MSI); + + /// Create a broadcast instruction. This method generates a broadcast + /// instruction (shuffle) for loop invariant values and for the induction + /// value. If this is the induction variable then we extend it to N, N+1, ... + /// this is needed because each iteration in the loop corresponds to a SIMD + /// element. + virtual Value *getBroadcastInstrs(Value *V); + + /// This function adds (StartIdx, StartIdx + Step, StartIdx + 2*Step, ...) + /// to each vector element of Val. The sequence starts at StartIndex. + /// \p Opcode is relevant for FP induction variable. + virtual Value *getStepVector(Value *Val, int StartIdx, Value *Step, + Instruction::BinaryOps Opcode = + Instruction::BinaryOpsEnd); + virtual Value *getStepVector(Value *Val, Value* Start, Value *Step, + Instruction::BinaryOps Opcode = + Instruction::BinaryOpsEnd); + + virtual Constant *getRuntimeVF(Type *Ty); + + /// This function adds (StartIdx, StartIdx + Step, StartIdx + 2*Step, ...) + /// to each vector element of Val. The sequence starts at StartIndex. + /// Step is a SCEV. In order to get StepValue it takes the existing value + /// from SCEV or creates a new using SCEVExpander. + virtual Value *getStepVector(Value *Val, Value *Start, const SCEV *Step, + Instruction::BinaryOps Opcode = + Instruction::BinaryOpsEnd); + + /// Create a vector induction variable based on an existing scalar one. + /// Currently only works for integer primary induction variables with + /// a constant step. + /// If TruncType is provided, instead of widening the original IV, we + /// widen a version of the IV truncated to TruncType. + void widenInductionVariable(const InductionDescriptor &II, VectorParts &Entry, + IntegerType *TruncType = nullptr); + + /// When we go over instructions in the basic block we rely on previous + /// values within the current basic block or on loop invariant values. + /// When we widen (vectorize) values we place them in the map. If the values + /// are not within the map, they have to be loop invariant, so we simply + /// broadcast them into a vector. + VectorParts &getVectorValue(Value *V); + + /// Try to vectorize the interleaved access group that \p Instr belongs to. + void vectorizeInterleaveGroup(Instruction *Instr); + + /// Generate a shuffle sequence that will reverse the vector Vec. + virtual Value *reverseVector(Value *Vec); + + /// Returns (and creates if needed) the original loop trip count. + Value *getOrCreateTripCount(Loop *NewLoop); + + /// Returns (and creates if needed) the trip count of the widened loop. + Value *getOrCreateVectorTripCount(Loop *NewLoop); + + /// Returns the induction increment per iteration of the widened loop. + Constant *getInductionStep(); + + /// Emit a bypass check to see if the trip count would overflow, or we + /// wouldn't have enough iterations to execute one vector loop. + void emitMinimumIterationCountCheck(Loop *L, Value *Min, BasicBlock *Bypass); + /// Emit a bypass check to see if the vector loop's induction variable will + /// overflow. + void emitIVOverflowCheck(Loop *L, BasicBlock *Bypass); + /// Emit a bypass check to see if the vector trip count is nonzero. + void emitVectorLoopEnteredCheck(Loop *L, BasicBlock *Bypass); + /// Emit a bypass check to see if all of the SCEV assumptions we've + /// had to make are correct. + void emitSCEVChecks(Loop *L, BasicBlock *Bypass); + /// Emit bypass checks to check any memory assumptions we may have made. + void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass); + + /// Add additional metadata to \p To that was not present on \p Orig. + /// + /// Currently this is used to add the noalias annotations based on the + /// inserted memchecks. Use this for instructions that are *cloned* into the + /// vector loop. + void addNewMetadata(Instruction *To, const Instruction *Orig); + + /// Add metadata from one instruction to another. + /// + /// This includes both the original MDs from \p From and additional ones (\see + /// addNewMetadata). Use this for *newly created* instructions in the vector + /// loop. + void addMetadata(Instruction *To, const Instruction *From); + + /// \brief Similar to the previous function but it adds the metadata to a + /// vector of instructions. + void addMetadata(SmallVectorImpl &To, const Instruction *From); + + /// This is a helper class that holds the vectorizer state. It maps scalar + /// instructions to vector instructions. When the code is 'unrolled' then + /// then a single scalar value is mapped to multiple vector parts. The parts + /// are stored in the VectorPart type. + struct ValueMap { + /// C'tor. UnrollFactor controls the number of vectors ('parts') that + /// are mapped. + ValueMap(unsigned UnrollFactor) : UF(UnrollFactor) {} + + /// \return True if 'Key' is saved in the Value Map. + bool has(Value *Key) const { return MapStorage.count(Key); } + + /// Initializes a new entry in the map. Sets all of the vector parts to the + /// save value in 'Val'. + /// \return A reference to a vector with splat values. + VectorParts &splat(Value *Key, Value *Val) { + VectorParts &Entry = MapStorage[Key]; + Entry.assign(UF, Val); + return Entry; + } + + ///\return A reference to the value that is stored at 'Key'. + VectorParts &get(Value *Key) { + VectorParts &Entry = MapStorage[Key]; + if (Entry.empty()) + Entry.resize(UF); + assert(Entry.size() == UF); + return Entry; + } + + private: + /// The unroll factor. Each entry in the map stores this number of vector + /// elements. + unsigned UF; + + /// Map storage. We use std::map and not DenseMap because insertions to a + /// dense map invalidates its iterators. + std::map MapStorage; + }; + + ///\brief Perform CSE of induction variable instructions. + void CSE(SmallVector &BBs, SmallSet &Preds); + + /// The original loop. + Loop *OrigLoop; + /// A wrapper around ScalarEvolution used to add runtime SCEV checks. Applies + /// dynamic knowledge to simplify SCEV expressions and converts them to a + /// more usable form. + PredicatedScalarEvolution &PSE; + /// Loop Info. + LoopInfo *LI; + /// Dominator Tree. + DominatorTree *DT; + /// Target Library Info. + const TargetLibraryInfo *TLI; + /// Target Transform Info. + const TargetTransformInfo *TTI; + /// Assumption Cache. + AssumptionCache *AC; + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; + /// Alias Analysis. + AliasAnalysis *AA; + + /// \brief LoopVersioning. It's only set up (non-null) if memchecks were + /// used. + /// + /// This is currently only used to add no-alias metadata based on the + /// memchecks. The actually versioning is performed manually. + std::unique_ptr LVer; + + /// The vectorization SIMD factor to use. Each vector will have this many + /// vector elements. + unsigned VF; + bool Scalable; + bool UsePredication; + +protected: + /// Test if instruction I is the exit instruction of some recurrence. If yes, + /// it sets RD with the associated RecurrenceDescriptor instance. + bool testHorizontalReductionExitInst(Instruction *I, RecurrenceDescriptor &RD); + + /// The vectorization unroll factor to use. Each scalar is vectorized to this + /// many different vector instructions. + unsigned UF; + + /// The builder that we use + IRBuilder<> Builder; + + // --- Vectorization state --- + + /// The vector-loop preheader. + BasicBlock *LoopVectorPreHeader; + /// The scalar-loop preheader. + BasicBlock *LoopScalarPreHeader; + /// Middle Block between the vector and the scalar. + BasicBlock *LoopMiddleBlock; + /// The ExitBlock of the scalar loop. + BasicBlock *LoopExitBlock; + /// The vector loop body. + SmallVector LoopVectorBody; + /// The scalar loop body. + BasicBlock *LoopScalarBody; + /// A list of all bypass blocks. The first block is the entry of the loop. + SmallVector LoopBypassBlocks; + /// The new Induction variable which was added to the new block. + PHINode *Induction; + /// The induction variable of the old basic block. + PHINode *OldInduction; + /// Holds the entry predicates for the current iteration of the vector body. + PhiVector Predicate; + /// Holds the extended (to the widest induction type) start index. + Value *ExtendedIdx; + /// Maps scalars to widened vectors. + ValueMap WidenMap; + /// Store instructions that should be predicated, as a pair + /// + SmallVector, 4> PredicatedStores; + EdgeMaskCache MaskCache; + + // Loop vector body current post-dominator block. + BasicBlock *VecBodyPostDom; + typedef std::pair DomEdge; + SmallVector VecBodyDomEdges; + + // Conditional blocks due to if-conversion. + SmallSet PredicatedBlocks; + /// Trip count of the original loop. + Value *TripCount; + /// Trip count of the widened loop (TripCount - TripCount % (VF*UF)) + Value *VectorTripCount; + + /// Map of scalar integer values to the smallest bitwidth they can be legally + /// represented as. The vector equivalents of these values should be truncated + /// to this type. + MapVector MinBWs; + LoopVectorizationLegality *Legal; + + // Record whether runtime checks are added. + bool AddedSafetyChecks; + + /// Stores new branch for vectorized latch block so it + /// can be patched up after vectorization + BranchInst *LatchBranch; + + /// TODO -- rename this? + /// TODO -- move both to exit info descriptor? + Value *IdxEnd; + Value *IdxEndV; + + // Holds the end values for each induction variable. We save the end values + // so we can later fix-up the external users of the induction variables. + DenseMap IVEndValues; +}; + +class InnerLoopUnroller : public InnerLoopVectorizer { +public: + InnerLoopUnroller(Loop *OrigLoop, PredicatedScalarEvolution &PSE, + LoopInfo *LI, DominatorTree *DT, + const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI, AssumptionCache *AC, + OptimizationRemarkEmitter *ORE, + unsigned UnrollFactor) + : InnerLoopVectorizer(OrigLoop, PSE, LI, DT, TLI, TTI, AC, ORE, 1, + UnrollFactor, true) {} + +private: + void scalarizeInstruction(Instruction *Instr, + bool IfPredicateStore = false) override; + void vectorizeMemoryInstruction(Instruction *Instr) override; + Value *getBroadcastInstrs(Value *V) override; + Value *getStepVector(Value *Val, int StartIdx, Value *Step, + Instruction::BinaryOps Opcode = + Instruction::BinaryOpsEnd) override; + Value *getStepVector(Value *Val, Value *Start, Value *Step, + Instruction::BinaryOps Opcode = + Instruction::BinaryOpsEnd) override; + Value *getStepVector(Value *Val, Value *Start, const SCEV *StepSCEV, + Instruction::BinaryOps Opcode = + Instruction::BinaryOpsEnd) override; + Value *reverseVector(Value *Vec) override; +}; + +/// \brief Look for a meaningful debug location on the instruction or it's +/// operands. +static Instruction *getDebugLocFromInstOrOperands(Instruction *I) { + if (!I) + return I; + + DebugLoc Empty; + if (I->getDebugLoc() != Empty) + return I; + + for (User::op_iterator OI = I->op_begin(), OE = I->op_end(); OI != OE; ++OI) { + if (Instruction *OpInst = dyn_cast(*OI)) + if (OpInst->getDebugLoc() != Empty) + return OpInst; + } + + return I; +} + +/// \brief Set the debug location in the builder using the debug location in the +/// instruction. +static void setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr) { + if (const Instruction *Inst = dyn_cast_or_null(Ptr)) + B.SetCurrentDebugLocation(Inst->getDebugLoc()); + else + B.SetCurrentDebugLocation(DebugLoc()); +} + +#ifndef NDEBUG +/// \return string containing a file name and a line # for the given loop. +static std::string getDebugLocString(const Loop *L) { + std::string Result; + if (L) { + raw_string_ostream OS(Result); + if (const DebugLoc LoopDbgLoc = L->getStartLoc()) + LoopDbgLoc.print(OS); + else + // Just print the module name. + OS << L->getHeader()->getParent()->getParent()->getModuleIdentifier(); + OS.flush(); + } + return Result; +} +#endif + +/// \brief Propagate known metadata from one instruction to another. +static void propagateMetadata(Instruction *To, const Instruction *From) { + SmallVector, 4> Metadata; + From->getAllMetadataOtherThanDebugLoc(Metadata); + + for (auto M : Metadata) { + unsigned Kind = M.first; + + // These are safe to transfer (this is safe for TBAA, even when we + // if-convert, because should that metadata have had a control dependency + // on the condition, and thus actually aliased with some other + // non-speculated memory access when the condition was false, this would be + // caught by the runtime overlap checks). + if (Kind != LLVMContext::MD_tbaa && Kind != LLVMContext::MD_alias_scope && + Kind != LLVMContext::MD_noalias && Kind != LLVMContext::MD_fpmath && + Kind != LLVMContext::MD_nontemporal) + continue; + + To->setMetadata(Kind, M.second); + } +} + +void InnerLoopVectorizer::addNewMetadata(Instruction *To, + const Instruction *Orig) { + // If the loop was versioned with memchecks, add the corresponding no-alias + // metadata. + if (LVer && (isa(Orig) || isa(Orig))) + LVer->annotateInstWithNoAlias(To, Orig); +} + +void InnerLoopVectorizer::addMetadata(Instruction *To, + const Instruction *From) { + propagateMetadata(To, From); + addNewMetadata(To, From); +} + +void InnerLoopVectorizer::addMetadata(SmallVectorImpl &To, + const Instruction *From) { + for (Value *V : To) + if (Instruction *I = dyn_cast(V)) + addMetadata(I, From); +} + +/// \brief The group of interleaved loads/stores sharing the same stride and +/// close to each other. +/// +/// Each member in this group has an index starting from 0, and the largest +/// index should be less than interleaved factor, which is equal to the absolute +/// value of the access's stride. +/// +/// E.g. An interleaved load group of factor 4: +/// for (unsigned i = 0; i < 1024; i+=4) { +/// a = A[i]; // Member of index 0 +/// b = A[i+1]; // Member of index 1 +/// d = A[i+3]; // Member of index 3 +/// ... +/// } +/// +/// An interleaved store group of factor 4: +/// for (unsigned i = 0; i < 1024; i+=4) { +/// ... +/// A[i] = a; // Member of index 0 +/// A[i+1] = b; // Member of index 1 +/// A[i+2] = c; // Member of index 2 +/// A[i+3] = d; // Member of index 3 +/// } +/// +/// Note: the interleaved load group could have gaps (missing members), but +/// the interleaved store group doesn't allow gaps. +class InterleaveGroup { +public: + InterleaveGroup(Instruction *Instr, int Stride, unsigned Align) + : Align(Align), SmallestKey(0), LargestKey(0), InsertPos(Instr) { + assert(Align && "The alignment should be non-zero"); + + Factor = std::abs(Stride); + assert(Factor > 1 && "Invalid interleave factor"); + + Reverse = Stride < 0; + Members[0] = Instr; + } + + bool isReverse() const { return Reverse; } + unsigned getFactor() const { return Factor; } + unsigned getAlignment() const { return Align; } + unsigned getNumMembers() const { return Members.size(); } + + /// \brief Try to insert a new member \p Instr with index \p Index and + /// alignment \p NewAlign. The index is related to the leader and it could be + /// negative if it is the new leader. + /// + /// \returns false if the instruction doesn't belong to the group. + bool insertMember(Instruction *Instr, int Index, unsigned NewAlign) { + assert(NewAlign && "The new member's alignment should be non-zero"); + + int Key = Index + SmallestKey; + + // Skip if there is already a member with the same index. + if (Members.count(Key)) + return false; + + if (Key > LargestKey) { + // The largest index is always less than the interleave factor. + if (Index >= static_cast(Factor)) + return false; + + LargestKey = Key; + } else if (Key < SmallestKey) { + // The largest index is always less than the interleave factor. + if (LargestKey - Key >= static_cast(Factor)) + return false; + + SmallestKey = Key; + } + + // It's always safe to select the minimum alignment. + Align = std::min(Align, NewAlign); + Members[Key] = Instr; + return true; + } + + /// \brief Get the member with the given index \p Index + /// + /// \returns nullptr if contains no such member. + Instruction *getMember(unsigned Index) const { + int Key = SmallestKey + Index; + if (!Members.count(Key)) + return nullptr; + + return Members.find(Key)->second; + } + + /// \brief Get the index for the given member. Unlike the key in the member + /// map, the index starts from 0. + unsigned getIndex(Instruction *Instr) const { + for (auto I : Members) + if (I.second == Instr) + return I.first - SmallestKey; + + llvm_unreachable("InterleaveGroup contains no such member"); + } + + Instruction *getInsertPos() const { return InsertPos; } + void setInsertPos(Instruction *Inst) { InsertPos = Inst; } + +private: + unsigned Factor; // Interleave Factor. + bool Reverse; + unsigned Align; + DenseMap Members; + int SmallestKey; + int LargestKey; + + // To avoid breaking dependences, vectorized instructions of an interleave + // group should be inserted at either the first load or the last store in + // program order. + // + // E.g. %even = load i32 // Insert Position + // %add = add i32 %even // Use of %even + // %odd = load i32 + // + // store i32 %even + // %odd = add i32 // Def of %odd + // store i32 %odd // Insert Position + Instruction *InsertPos; +}; + +/// \brief Drive the analysis of interleaved memory accesses in the loop. +/// +/// Use this class to analyze interleaved accesses only when we can vectorize +/// a loop. Otherwise it's meaningless to do analysis as the vectorization +/// on interleaved accesses is unsafe. +/// +/// The analysis collects interleave groups and records the relationships +/// between the member and the group in a map. +class InterleavedAccessInfo { +public: + InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, + DominatorTree *DT, LoopInfo *LI) + : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(nullptr), + RequiresScalarEpilogue(false) {} + + ~InterleavedAccessInfo() { + SmallSet DelSet; + // Avoid releasing a pointer twice. + for (auto &I : InterleaveGroupMap) + DelSet.insert(I.second); + for (auto *Ptr : DelSet) + delete Ptr; + } + + /// \brief Analyze the interleaved accesses and collect them in interleave + /// groups. Substitute symbolic strides using \p Strides. + void analyzeInterleaving(const ValueToValueMap &Strides); + + /// \brief Check if \p Instr belongs to any interleave group. + bool isInterleaved(Instruction *Instr) const { + return InterleaveGroupMap.count(Instr); + } + + /// \brief Return the maximum interleave factor of all interleaved groups. + unsigned getMaxInterleaveFactor() const { + unsigned MaxFactor = 1; + for (auto &Entry : InterleaveGroupMap) + MaxFactor = std::max(MaxFactor, Entry.second->getFactor()); + return MaxFactor; + } + + /// \brief Get the interleave group that \p Instr belongs to. + /// + /// \returns nullptr if doesn't have such group. + InterleaveGroup *getInterleaveGroup(Instruction *Instr) const { + if (InterleaveGroupMap.count(Instr)) + return InterleaveGroupMap.find(Instr)->second; + return nullptr; + } + + /// \brief Returns true if an interleaved group that may access memory + /// out-of-bounds requires a scalar epilogue iteration for correctness. + bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } + + /// \brief Initialize the LoopAccessInfo used for dependence checking. + void setLAI(const LoopAccessInfo *Info) { LAI = Info; } + +private: + /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. + /// Simplifies SCEV expressions in the context of existing SCEV assumptions. + /// The interleaved access analysis can also add new predicates (for example + /// by versioning strides of pointers). + PredicatedScalarEvolution &PSE; + Loop *TheLoop; + DominatorTree *DT; + LoopInfo *LI; + const LoopAccessInfo *LAI; + + /// True if the loop may contain non-reversed interleaved groups with + /// out-of-bounds accesses. We ensure we don't speculatively access memory + /// out-of-bounds by executing at least one scalar epilogue iteration. + bool RequiresScalarEpilogue; + + /// Holds the relationships between the members and the interleave group. + DenseMap InterleaveGroupMap; + + /// Holds dependences among the memory accesses in the loop. It maps a source + /// access to a set of dependent sink accesses. + DenseMap> Dependences; + + /// \brief The descriptor for a strided memory access. + struct StrideDescriptor { + StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, + unsigned Align) + : Stride(Stride), Scev(Scev), Size(Size), Align(Align) {} + + StrideDescriptor() = default; + + // The access's stride. It is negative for a reverse access. + int64_t Stride = 0; + const SCEV *Scev = nullptr; // The scalar expression of this access + uint64_t Size = 0; // The size of the memory object. + unsigned Align = 0; // The alignment of this access. + }; + + /// \brief A type for holding instructions and their stride descriptors. + typedef std::pair StrideEntry; + + /// \brief Create a new interleave group with the given instruction \p Instr, + /// stride \p Stride and alignment \p Align. + /// + /// \returns the newly created interleave group. + InterleaveGroup *createInterleaveGroup(Instruction *Instr, int Stride, + unsigned Align) { + assert(!InterleaveGroupMap.count(Instr) && + "Already in an interleaved access group"); + InterleaveGroupMap[Instr] = new InterleaveGroup(Instr, Stride, Align); + return InterleaveGroupMap[Instr]; + } + + /// \brief Release the group and remove all the relationships. + void releaseGroup(InterleaveGroup *Group) { + for (unsigned i = 0; i < Group->getFactor(); i++) + if (Instruction *Member = Group->getMember(i)) + InterleaveGroupMap.erase(Member); + + delete Group; + } + + /// \brief Collect all the accesses with a constant stride in program order. + void collectConstStrideAccesses( + MapVector &AccessStrideInfo, + const ValueToValueMap &Strides); + + /// \brief Returns true if \p Stride is allowed in an interleaved group. + static bool isStrided(int Stride) { + unsigned Factor = std::abs(Stride); + return Factor >= 2 && Factor <= MaxInterleaveGroupFactor; + } + + /// \brief Returns true if \p BB is a predicated block. + bool isPredicated(BasicBlock *BB) const { + return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); + } + + /// \brief Returns true if LoopAccessInfo can be used for dependence queries. + bool areDependencesValid() const { + return LAI && LAI->getDepChecker().getDependences(); + } + + /// \brief Returns true if memory accesses \p A and \p B can be reordered, if + /// necessary, when constructing interleaved groups. + /// + /// \p A must precede \p B in program order. We return false if reordering is + /// not necessary or is prevented because \p A and \p B may be dependent. + bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, + StrideEntry *B) const { + + // Code motion for interleaved accesses can potentially hoist strided loads + // and sink strided stores. The code below checks the legality of the + // following two conditions: + // + // 1. Potentially moving a strided load (B) before any store (A) that + // precedes B, or + // + // 2. Potentially moving a strided store (A) after any load or store (B) + // that A precedes. + // + // It's legal to reorder A and B if we know there isn't a dependence from A + // to B. Note that this determination is conservative since some + // dependences could potentially be reordered safely. + // A is potentially the source of a dependence. + auto *Src = A->first; + auto SrcDes = A->second; + + // B is potentially the sink of a dependence. + auto *Sink = B->first; + auto SinkDes = B->second; + + // Code motion for interleaved accesses can't violate WAR dependences. + // Thus, reordering is legal if the source isn't a write. + if (!Src->mayWriteToMemory()) + return true; + + // At least one of the accesses must be strided. + if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride)) + return true; + + // If dependence information is not available from LoopAccessInfo, + // conservatively assume the instructions can't be reordered. + if (!areDependencesValid()) + return false; + + // If we know there is a dependence from source to sink, assume the + // instructions can't be reordered. Otherwise, reordering is legal. + return !Dependences.count(Src) || !Dependences.lookup(Src).count(Sink); + } + + /// \brief Collect the dependences from LoopAccessInfo. + /// + /// We process the dependences once during the interleaved access analysis to + /// enable constant-time dependence queries. + void collectDependences() { + if (!areDependencesValid()) + return; + auto *Deps = LAI->getDepChecker().getDependences(); + for (auto Dep : *Deps) + Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI)); + } +}; + +/// Utility class for getting and setting loop vectorizer hints in the form +/// of loop metadata. +/// This class keeps a number of loop annotations locally (as member variables) +/// and can, upon request, write them back as metadata on the loop. It will +/// initially scan the loop for existing metadata, and will update the local +/// values based on information in the loop. +/// We cannot write all values to metadata, as the mere presence of some info, +/// for example 'force', means a decision has been made. So, we need to be +/// careful NOT to add them if the user hasn't specifically asked so. +class LoopVectorizeHints { + enum HintKind { + HK_WIDTH, + HK_UNROLL, + HK_FORCE, + HK_STYLE + }; + + /// Hint - associates name and validation with the hint value. + struct Hint { + const char *Name; + unsigned Value; // This may have to change for non-numeric values. + HintKind Kind; + + Hint(const char *Name, unsigned Value, HintKind Kind) + : Name(Name), Value(Value), Kind(Kind) {} + + bool validate(unsigned Val) { + switch (Kind) { + case HK_WIDTH: + return isPowerOf2_32(Val) && Val <= VectorizerParams::MaxVectorWidth; + case HK_UNROLL: + return isPowerOf2_32(Val) && Val <= MaxInterleaveFactor; + case HK_FORCE: + return Val <= 1; + case HK_STYLE: + return Val <= 2; + } + return false; + } + }; + + /// Vectorization width. + Hint Width; + /// Vectorization interleave factor. + Hint Interleave; + /// Vectorization forced. + Hint Force; + /// Vectorization style (fixed/scaled vector width). + Hint Style; + + /// Return the loop metadata prefix. + static StringRef Prefix() { return "llvm.loop."; } + + /// True if there is any unsafe math in the loop. + bool PotentiallyUnsafe; + +public: + enum ForceKind { + FK_Undefined = -1, ///< Not selected. + FK_Disabled = 0, ///< Forcing disabled. + FK_Enabled = 1 ///< Forcing enabled. + }; + + enum StyleKind { + SK_Unspecified = 0, + SK_Fixed = 1, ///< Forcing fixed width vectorization. + SK_Scaled = 2 ///< Forcing scalable vectorization. + }; + + LoopVectorizeHints(const Loop *L, bool DisableInterleaving, + OptimizationRemarkEmitter &ORE) + : Width("vectorize.width", VectorizerParams::VectorizationFactor, + HK_WIDTH), + Interleave("interleave.count", DisableInterleaving, HK_UNROLL), + Force("vectorize.enable", FK_Undefined, HK_FORCE), + Style("vectorize.style", SK_Unspecified, HK_STYLE), + PotentiallyUnsafe(false), TheLoop(L), ORE(ORE) { + // Populate values with existing loop metadata. + getHintsFromMetadata(); + + // force-vector-interleave overrides DisableInterleaving. + if (VectorizerParams::isInterleaveForced()) + Interleave.Value = VectorizerParams::VectorizationInterleave; + + LLVM_DEBUG(if (DisableInterleaving && Interleave.Value == 1) dbgs() + << "LV: Interleaving disabled by the pass manager\n"); + } + + /// Mark the loop L as already vectorized by setting the width to 1. + void setAlreadyVectorized() { + Width.Value = Interleave.Value = 1; + Hint Hints[] = {Width, Interleave}; + writeHintsToMetadata(Hints); + } + + bool allowVectorization(Function *F, Loop *L, bool AlwaysVectorize) const { + if (getForce() == LoopVectorizeHints::FK_Disabled) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: #pragma vectorize disable.\n"); + emitRemarkWithHints(); + return false; + } + + if (!AlwaysVectorize && getForce() != LoopVectorizeHints::FK_Enabled) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: No #pragma vectorize enable.\n"); + emitRemarkWithHints(); + return false; + } + + if (getWidth() == 1 && getInterleave() == 1) { + // FIXME: Add a separate metadata to indicate when the loop has already + // been vectorized instead of setting width and count to 1. + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: Disabled/already vectorized.\n"); + // FIXME: Add interleave.disable metadata. This will allow + // vectorize.disable to be used without disabling the pass and errors + // to differentiate between disabled vectorization and a width of 1. + ORE.emit([&]() { + return OptimizationRemarkAnalysis(vectorizeAnalysisPassName(), + "AllDisabled", L->getStartLoc(), + L->getHeader()) + << "loop not vectorized: vectorization and interleaving are " + "explicitly disabled, or the loop has already been " + "vectorized"; + }); + return false; + } + + return true; + } + + /// Dumps all the hint information. + void emitRemarkWithHints() const { + using namespace ore; + if (Force.Value == LoopVectorizeHints::FK_Disabled) + ORE.emit(OptimizationRemarkMissed(LV_NAME, "MissedExplicitlyDisabled", + TheLoop->getStartLoc(), + TheLoop->getHeader()) + << "loop not vectorized: vectorization is explicitly disabled"); + else { + OptimizationRemarkMissed R(LV_NAME, "MissedDetails", + TheLoop->getStartLoc(), TheLoop->getHeader()); + R << "loop not vectorized"; + if (Force.Value == LoopVectorizeHints::FK_Enabled) { + R << " (Force=" << NV("Force", true); + if (Style.Value == LoopVectorizeHints::SK_Fixed) + R << ", Style=fixed"; + else if (Style.Value == LoopVectorizeHints::SK_Scaled) + R << ", Style=scaled"; + if (Width.Value != 0) + R << ", Vector Width=" << NV("VectorWidth", Width.Value); + if (Interleave.Value != 0) + R << ", Interleave Count=" << NV("InterleaveCount", Interleave.Value); + R << ")"; + } + ORE.emit(R); + } + } + + unsigned getWidth() const { return Width.Value; } + unsigned getInterleave() const { return Interleave.Value; } + enum ForceKind getForce() const { return (ForceKind)Force.Value; } + unsigned getStyle() const { return Style.Value; } + /// \brief If hints are provided that force vectorization, use the AlwaysPrint + /// pass name to force the frontend to print the diagnostic. + const char *vectorizeAnalysisPassName() const { + if (getWidth() == 1) + return LV_NAME; + if (getForce() == LoopVectorizeHints::FK_Disabled) + return LV_NAME; + if (getForce() == LoopVectorizeHints::FK_Undefined && getWidth() == 0) + return LV_NAME; + return OptimizationRemarkAnalysis::AlwaysPrint; + } + + bool allowReordering() const { + // When enabling loop hints are provided we allow the vectorizer to change + // the order of operations that is given by the scalar loop. This is not + // enabled by default because can be unsafe or inefficient. For example, + // reordering floating-point operations will change the way round-off + // error accumulates in the loop. + return getForce() == LoopVectorizeHints::FK_Enabled || getWidth() > 1; + } + + bool isPotentiallyUnsafe() const { + // Avoid FP vectorization if the target is unsure about proper support. + // This may be related to the SIMD unit in the target not handling + // IEEE 754 FP ops properly, or bad single-to-double promotions. + // Otherwise, a sequence of vectorized loops, even without reduction, + // could lead to different end results on the destination vectors. + return getForce() != LoopVectorizeHints::FK_Enabled && PotentiallyUnsafe; + } + + void setPotentiallyUnsafe() { PotentiallyUnsafe = true; } + +private: + /// Find hints specified in the loop metadata and update local values. + void getHintsFromMetadata() { + MDNode *LoopID = TheLoop->getLoopID(); + if (!LoopID) + return; + + // First operand should refer to the loop id itself. + assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); + assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); + + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + const MDString *S = nullptr; + SmallVector Args; + + // The expected hint is either a MDString or a MDNode with the first + // operand a MDString. + if (const MDNode *MD = dyn_cast(LoopID->getOperand(i))) { + if (!MD || MD->getNumOperands() == 0) + continue; + S = dyn_cast(MD->getOperand(0)); + for (unsigned i = 1, ie = MD->getNumOperands(); i < ie; ++i) + Args.push_back(MD->getOperand(i)); + } else { + S = dyn_cast(LoopID->getOperand(i)); + assert(Args.size() == 0 && "too many arguments for MDString"); + } + + if (!S) + continue; + + // Check if the hint starts with the loop metadata prefix. + StringRef Name = S->getString(); + if (Args.size() == 1) + setHint(Name, Args[0]); + } + } + + /// Checks string hint with one operand and set value if valid. + void setHint(StringRef Name, Metadata *Arg) { + if (!Name.startswith(Prefix())) + return; + Name = Name.substr(Prefix().size(), StringRef::npos); + + const ConstantInt *C = mdconst::dyn_extract(Arg); + if (!C) + return; + unsigned Val = C->getZExtValue(); + + Hint *Hints[] = {&Width, &Style, &Interleave, &Force}; + for (auto H : Hints) { + if (Name == H->Name) { + if (H->validate(Val)) + H->Value = Val; + else + LLVM_DEBUG(dbgs() << "LV: ignoring invalid hint '" << Name << "'\n"); + break; + } + } + } + + /// Create a new hint from name / value pair. + MDNode *createHintMetadata(StringRef Name, unsigned V) const { + LLVMContext &Context = TheLoop->getHeader()->getContext(); + Metadata *MDs[] = {MDString::get(Context, Name), + ConstantAsMetadata::get( + ConstantInt::get(Type::getInt32Ty(Context), V))}; + return MDNode::get(Context, MDs); + } + + /// Matches metadata with hint name. + bool matchesHintMetadataName(MDNode *Node, ArrayRef HintTypes) { + MDString *Name = dyn_cast(Node->getOperand(0)); + if (!Name) + return false; + + for (auto H : HintTypes) + if (Name->getString().endswith(H.Name)) + return true; + return false; + } + + /// Sets current hints into loop metadata, keeping other values intact. + void writeHintsToMetadata(ArrayRef HintTypes) { + if (HintTypes.size() == 0) + return; + + // Reserve the first element to LoopID (see below). + SmallVector MDs(1); + // If the loop already has metadata, then ignore the existing operands. + MDNode *LoopID = TheLoop->getLoopID(); + if (LoopID) { + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + MDNode *Node = cast(LoopID->getOperand(i)); + // If node in update list, ignore old value. + if (!matchesHintMetadataName(Node, HintTypes)) + MDs.push_back(Node); + } + } + + // Now, add the missing hints. + for (auto H : HintTypes) + MDs.push_back(createHintMetadata(Twine(Prefix(), H.Name).str(), H.Value)); + + // Replace current metadata node with new one. + LLVMContext &Context = TheLoop->getHeader()->getContext(); + MDNode *NewLoopID = MDNode::get(Context, MDs); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + + TheLoop->setLoopID(NewLoopID); + } + + /// The loop these hints belong to. + const Loop *TheLoop; + + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter &ORE; +}; + +static void emitMissedWarning(Function *F, Loop *L, + const LoopVectorizeHints &LH, + OptimizationRemarkEmitter *ORE) { + LH.emitRemarkWithHints(); + + if (LH.getForce() == LoopVectorizeHints::FK_Enabled) { + if (LH.getWidth() != 1) + ORE->emit(DiagnosticInfoOptimizationFailure( + DEBUG_TYPE, "FailedRequestedVectorization", + {L->getLocRange().getStart()}, + L->getHeader()) + << "loop not vectorized: " + << "failed explicitly specified loop vectorization"); + else if (LH.getInterleave() != 1) + ORE->emit(DiagnosticInfoOptimizationFailure( + DEBUG_TYPE, "FailedRequestedInterleaving", + {L->getLocRange().getStart()}, + L->getHeader()) + << "loop not interleaved: " + << "failed explicitly specified loop interleaving"); + } +} + +/// LoopVectorizationLegality checks if it is legal to vectorize a loop, and +/// to what vectorization factor. +/// This class does not look at the profitability of vectorization, only the +/// legality. This class has two main kinds of checks: +/// * Memory checks - The code in canVectorizeMemory checks if vectorization +/// will change the order of memory accesses in a way that will change the +/// correctness of the program. +/// * Scalars checks - The code in canVectorizeInstrs and canVectorizeMemory +/// checks for a number of different conditions, such as the availability of a +/// single induction variable, that all types are supported and vectorize-able, +/// etc. This code reflects the capabilities of InnerLoopVectorizer. +/// This class is also used by InnerLoopVectorizer for identifying +/// induction variable and the different reduction variables. +class LoopVectorizationLegality { +public: + LoopVectorizationLegality( + Loop *L, PredicatedScalarEvolution &PSE, DominatorTree *DT, + TargetLibraryInfo *TLI, AliasAnalysis *AA, Function *F, + const TargetTransformInfo *TTI, + std::function *GetLAA, LoopInfo *LI, + OptimizationRemarkEmitter *ORE, LoopVectorizationRequirements *R, + LoopVectorizeHints *H) + : NumPredStores(0), TheLoop(L), PSE(PSE), TLI(TLI), TTI(TTI), DT(DT), + GetLAA(GetLAA), LAI(nullptr), ORE(ORE), InterleaveInfo(PSE, L, DT, LI), + PrimaryInduction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), + Requirements(R), Hints(H) {} + + /// Returns true if the function has an attribute saying that + /// we can assume the absence of NaNs. + bool hasNoNaNAttr(void) const { return HasFunNoNaNAttr; } + + /// ReductionList contains the reduction descriptors for all + /// of the reductions that were found in the loop. + typedef DenseMap ReductionList; + + /// InductionList saves induction variables and maps them to the + /// induction descriptor. + typedef MapVector InductionList; + + /// RecurrenceSet contains the phi nodes that are recurrences other than + /// inductions and reductions. + typedef SmallPtrSet RecurrenceSet; + + /// Returns true if it is legal to vectorize this loop. + /// This does not mean that it is profitable to vectorize this + /// loop, only that it is legal to do so. + bool canVectorize(); + + /// Returns the primary induction variable. + PHINode *getPrimaryInduction() { return PrimaryInduction; } + + /// Returns the reduction variables found in the loop. + ReductionList *getReductionVars() { return &Reductions; } + + /// Checks if all reduction can be performed using ordered intrinsics in + /// the loop body (e.g. using '@llvm.aarch64.sve.adda.' intrinsic). + bool allReductionsAreStrict() const { + for (auto RedP : Reductions) + if (!RedP.second.isOrdered()) + return false; + + return true; + } + + /// Returns the induction variables found in the loop. + InductionList *getInductionVars() { return &Inductions; } + + /// Return the first-order recurrences found in the loop. + RecurrenceSet *getFirstOrderRecurrences() { return &FirstOrderRecurrences; } + + /// Return the set of instructions to sink to handle first-order recurrences. + DenseMap &getSinkAfter() { return SinkAfter; } + + /// Returns the widest induction type. + Type *getWidestInductionType() { return WidestIndTy; } + + /// Returns True if V is an induction variable in this loop. + bool isInductionVariable(const Value *V); + + /// Returns True if PN is a reduction variable in this loop. + bool isReductionVariable(PHINode *PN) { return Reductions.count(PN); } + + /// Returns True if Phi is a first-order recurrence in this loop. + bool isFirstOrderRecurrence(const PHINode *Phi); + + /// Return true if the block BB needs to be predicated in order for the loop + /// to be vectorized. + bool blockNeedsPredication(BasicBlock *BB); + + /// Check if this pointer is consecutive when vectorizing. This happens + /// when the last index of the GEP is the induction variable, or that the + /// pointer itself is an induction variable. + /// This check allows us to vectorize A[idx] into a wide load/store. + /// Returns: + /// 0 - Stride is unknown or non-consecutive. + /// 1 - Address is consecutive. + /// -1 - Address is consecutive, and decreasing. + int isConsecutivePtr(Value *Ptr); + + /// Returns true if the value V is uniform within the loop. + bool isUniform(Value *V); + + /// Returns true if this instruction will remain scalar after vectorization. + bool isUniformAfterVectorization(Instruction *I) { return Uniforms.count(I); } + + /// Returns the information that we collected about runtime memory check. + const RuntimePointerChecking *getRuntimePointerChecking() const { + return LAI->getRuntimePointerChecking(); + } + + const LoopAccessInfo *getLAI() const { return LAI; } + + /// \brief Check if \p Instr belongs to any interleaved access group. + bool isAccessInterleaved(Instruction *Instr) { + return InterleaveInfo.isInterleaved(Instr); + } + + /// \brief Return the maximum interleave factor of all interleaved groups. + unsigned getMaxInterleaveFactor() const { + return InterleaveInfo.getMaxInterleaveFactor(); + } + + /// \brief Get the interleaved access group that \p Instr belongs to. + const InterleaveGroup *getInterleavedAccessGroup(Instruction *Instr) { + return InterleaveInfo.getInterleaveGroup(Instr); + } + + /// \brief Returns true if an interleaved group requires a scalar iteration + /// to handle accesses with gaps. + bool requiresScalarEpilogue() const { + return InterleaveInfo.requiresScalarEpilogue(); + } + + unsigned getMaxSafeDepDistBytes() { return LAI->getMaxSafeDepDistBytes(); } + + bool hasStride(Value *V) { return StrideSet.count(V); } + bool mustCheckStrides() { return !StrideSet.empty(); } + SmallPtrSet::iterator strides_begin() { + return StrideSet.begin(); + } + SmallPtrSet::iterator strides_end() { return StrideSet.end(); } + + /// Returns true if the target machine supports masked store operation + /// for the given \p DataType and kind of access to \p Ptr. + bool isLegalMaskedStore(Type *DataType, Value *Ptr) { + return isConsecutivePtr(Ptr) && TTI->isLegalMaskedStore(DataType); + } + /// Returns true if the target machine supports masked load operation + /// for the given \p DataType and kind of access to \p Ptr. + bool isLegalMaskedLoad(Type *DataType, Value *Ptr) { + return isConsecutivePtr(Ptr) && TTI->isLegalMaskedLoad(DataType); + } + /// Returns true if the target machine supports masked scatter operation + /// for the given \p DataType. + bool isLegalMaskedScatter(Type *DataType) { + return TTI->isLegalMaskedScatter(DataType); + } + /// Returns true if the target machine supports masked gather operation + /// for the given \p DataType. + bool isLegalMaskedGather(Type *DataType) { + return TTI->isLegalMaskedGather(DataType); + } + + /// Returns true if the target machine can represent \p V as a masked gather + /// or scatter operation. + bool isLegalGatherOrScatter(Value *V) { + auto *LI = dyn_cast(V); + auto *SI = dyn_cast(V); + if (!LI && !SI) + return false; + auto *Ptr = getPointerOperand(V); + auto *Ty = cast(Ptr->getType())->getElementType(); + return (LI && isLegalMaskedGather(Ty)) || (SI && isLegalMaskedScatter(Ty)); + } + + bool hasMemSet() + { + for (BasicBlock *BB : TheLoop->getBlocks()) { + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + if (it->getOpcode() == Instruction::Call) + if (isa(it)) + return true; + } + } + return false; + } + + /// Returns true if vector representation of the instruction \p I + /// requires mask. + bool isMaskRequired(const Instruction *I) { return (MaskedOp.count(I) != 0); } + /// Returns true if the loop requires masked operations for vectorisation to + /// be legal. + bool hasMaskedOperations() { return MaskedOp.begin() != MaskedOp.end(); } + unsigned getNumStores() const { return LAI->getNumStores(); } + unsigned getNumLoads() const { return LAI->getNumLoads(); } + unsigned getNumPredStores() const { return NumPredStores; } + +private: + /// Elaborate on the summary report from LoopAccessAnalysis + /// with more remarks based on the failure reasons. + void elaborateMemoryReport(); + + /// Check if a single basic block loop is vectorizable. + /// At this point we know that this is a loop with a constant trip count + /// and we only need to check individual instructions. + bool canVectorizeInstrs(); + + /// When we vectorize loops we may change the order in which + /// we read and write from memory. This method checks if it is + /// legal to vectorize the code, considering only memory constrains. + /// Returns true if the loop is vectorizable + bool canVectorizeMemory(); + + /// Return true if we can vectorize this loop using the IF-conversion + /// transformation. + bool canVectorizeWithIfConvert(); + + /// Collect the variables that need to stay uniform after vectorization. + void collectLoopUniforms(); + + /// Return true if all of the instructions in the block can be speculatively + /// executed. \p SafePtrs is a list of addresses that are known to be legal + /// and we know that we can read from them without segfault. + bool blockCanBePredicated(BasicBlock *BB, SmallPtrSetImpl &SafePtrs); + + /// \brief Collect memory access with loop invariant strides. + /// + /// Looks for accesses like "a[i * StrideA]" where "StrideA" is loop + /// invariant. + void collectStridedAccess(Value *LoadOrStoreInst); + + /// Updates the vectorization state by adding \p Phi to the inductions list. + /// This can set \p Phi as the main induction of the loop if \p Phi is a + /// better choice for the main induction than the existing one. + void addInductionPhi(PHINode *Phi, const InductionDescriptor &ID, + SmallPtrSetImpl &AllowedExit); + + /// Create an analysis remark that explains why vectorization failed + /// + /// \p RemarkName is the identifier for the remark. If \p I is passed it is + /// an instruction that prevents vectorization. Otherwise the loop is used + /// for the location of the remark. \return the remark object that can be + /// streamed to. + OptimizationRemarkAnalysis + createMissedAnalysis(StringRef RemarkName, Instruction *I = nullptr) const { + return ::createMissedAnalysis(Hints->vectorizeAnalysisPassName(), + RemarkName, TheLoop, I); + } + + unsigned NumPredStores; + + /// The loop that we evaluate. + Loop *TheLoop; + /// A wrapper around ScalarEvolution used to add runtime SCEV checks. + /// Applies dynamic knowledge to simplify SCEV expressions in the context + /// of existing SCEV assumptions. The analysis will also add a minimal set + /// of new predicates if this is required to enable vectorization and + /// unrolling. + PredicatedScalarEvolution &PSE; + /// Target Library Info. + TargetLibraryInfo *TLI; + /// Target Transform Info + const TargetTransformInfo *TTI; + /// Dominator Tree. + DominatorTree *DT; + // LoopAccess analysis. + std::function *GetLAA; + // And the loop-accesses info corresponding to this loop. This pointer is + // null until canVectorizeMemory sets it up. + const LoopAccessInfo *LAI; + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; + + /// The interleave access information contains groups of interleaved accesses + /// with the same stride and close to each other. + InterleavedAccessInfo InterleaveInfo; + + // --- vectorization state --- // + + /// Holds the primary induction variable. This is the counter of the + /// loop. + PHINode *PrimaryInduction; + /// Holds the reduction variables. + ReductionList Reductions; + /// Holds all of the induction variables that we found in the loop. + /// Notice that inductions don't need to start at zero and that induction + /// variables can be pointers. + InductionList Inductions; + /// Holds the phi nodes that are first-order recurrences. + RecurrenceSet FirstOrderRecurrences; + /// Holds instructions that need to sink past other instructions to handle + /// first-order recurrences. + DenseMap SinkAfter; + /// Holds the widest induction type encountered. + Type *WidestIndTy; + + /// Allowed outside users. This holds the reduction + /// vars which can be accessed from outside the loop. + SmallPtrSet AllowedExit; + /// This set holds the variables which are known to be uniform after + /// vectorization. + SmallPtrSet Uniforms; + + /// Can we assume the absence of NaNs. + bool HasFunNoNaNAttr; + + /// Vectorization requirements that will go through late-evaluation. + LoopVectorizationRequirements *Requirements; + + /// Used to emit an analysis of any legality issues. + LoopVectorizeHints *Hints; + + ValueToValueMap Strides; + SmallPtrSet StrideSet; + + /// While vectorizing these instructions we have to generate a + /// call to the appropriate masked intrinsic + SmallPtrSet MaskedOp; +}; + +/// LoopVectorizationCostModel - estimates the expected speedups due to +/// vectorization. +/// In many cases vectorization is not profitable. This can happen because of +/// a number of reasons. In this class we mainly attempt to predict the +/// expected speedup/slowdowns due to the supported instruction set. We use the +/// TargetTransformInfo to query the different backends for the cost of +/// different operations. +class LoopVectorizationCostModel { +public: + LoopVectorizationCostModel(Loop *L, PredicatedScalarEvolution &PSE, + LoopInfo *LI, LoopVectorizationLegality *Legal, + const TargetTransformInfo &TTI, + const TargetLibraryInfo *TLI, DemandedBits *DB, + AssumptionCache *AC, + OptimizationRemarkEmitter *ORE, const Function *F, + const LoopVectorizeHints *Hints) + : TheLoop(L), PSE(PSE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB), + AC(AC), ORE(ORE), TheFunction(F), Hints(Hints) {} + + /// \return The most profitable vectorization factor and the cost of that VF. + /// This method checks every power of two up to VF. If UserVF is not ZERO + /// then this vectorization factor will be selected if vectorization is + /// possible. + VectorizationFactor selectVectorizationFactor(bool OptForSize); + + /// \return The size (in bits) of the smallest and widest types in the code + /// that needs to be vectorized. We ignore values that remain scalar such as + /// 64 bit loop indices. + std::pair getSmallestAndWidestTypes(); + + /// \return The desired interleave count. + /// If interleave count has been specified by metadata it will be returned. + /// Otherwise, the interleave count is computed and returned. VF and LoopCost + /// are the selected vectorization factor and the cost of the selected VF. + unsigned selectInterleaveCount(bool OptForSize, VectorizationFactor VF, + unsigned LoopCost); + + /// \return The most profitable unroll factor. + /// This method finds the best unroll-factor based on register pressure and + /// other parameters. VF and LoopCost are the selected vectorization factor + /// and the cost of the selected VF. + unsigned computeInterleaveCount(bool OptForSize, VectorizationFactor VF, + unsigned LoopCost); + + /// \brief A struct that represents some properties of the register usage + /// of a loop. + struct RegisterUsage { + /// Holds the number of loop invariant values that are used in the loop. + unsigned LoopInvariantRegs; + /// Holds the maximum number of concurrent live intervals in the loop. + unsigned MaxLocalUsers; + /// Holds the number of instructions in the loop. + unsigned NumInstructions; + }; + + /// \return Returns information about the register usages of the loop for the + /// given vectorization factors. + SmallVector calculateRegisterUsage(ArrayRef VFs); + + /// Collect values we want to ignore in the cost model. + void collectValuesToIgnore(); + +private: + /// The vectorization cost is a combination of the cost itself and a boolean + /// indicating whether any of the contributing operations will actually + /// operate on vector values after type legalization in the backend. If this + /// latter value is false, then all operations will be scalarized + /// (i.e. no vectorization has actually taken place). + typedef std::pair VectorizationCostTy; + + /// Returns the expected execution cost. The unit of the cost does + /// not matter because we use the 'cost' units to compare different + /// vector widths. The cost that is returned is *not* normalized by + /// the factor width. + VectorizationCostTy expectedCost(VectorizationFactor VF); + + /// Returns the execution time cost of an instruction for a given vector + /// width. Vector width of one means scalar. + VectorizationCostTy getInstructionCost(Instruction *I, + VectorizationFactor VF); + + /// The cost-computation logic from getInstructionCost which provides + /// the vector type as an output parameter. + unsigned getInstructionCost(Instruction *I, VectorizationFactor VF, + Type *&VectorTy); + + /// Returns whether the instruction is a load or store and will be a emitted + /// as a vector operation. + bool isConsecutiveLoadOrStore(Instruction *I); + + /// Create an analysis remark that explains why vectorization failed + /// + /// \p RemarkName is the identifier for the remark. \return the remark object + /// that can be streamed to. + OptimizationRemarkAnalysis createMissedAnalysis(StringRef RemarkName) { + return ::createMissedAnalysis(Hints->vectorizeAnalysisPassName(), + RemarkName, TheLoop); + } + +public: + /// Map of scalar integer values to the smallest bitwidth they can be legally + /// represented as. The vector equivalents of these values should be truncated + /// to this type. + MapVector MinBWs; + + /// The loop that we evaluate. + Loop *TheLoop; + /// Predicated scalar evolution analysis. + PredicatedScalarEvolution &PSE; + /// Loop Info analysis. + LoopInfo *LI; + /// Vectorization legality. + LoopVectorizationLegality *Legal; + /// Vector target information. + const TargetTransformInfo &TTI; + /// Target Library Info. + const TargetLibraryInfo *TLI; + /// Demanded bits analysis. + DemandedBits *DB; + /// Assumption cache. + AssumptionCache *AC; + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; + + const Function *TheFunction; + /// Loop Vectorize Hint. + const LoopVectorizeHints *Hints; + /// Values to ignore in the cost model. + SmallPtrSet ValuesToIgnore; + /// Values to ignore in the cost model when VF > 1. + SmallPtrSet VecValuesToIgnore; +}; + +/// \brief This holds vectorization requirements that must be verified late in +/// the process. The requirements are set by legalize and costmodel. Once +/// vectorization has been determined to be possible and profitable the +/// requirements can be verified by looking for metadata or compiler options. +/// For example, some loops require FP commutativity which is only allowed if +/// vectorization is explicitly specified or if the fast-math compiler option +/// has been provided. +/// Late evaluation of these requirements allows helpful diagnostics to be +/// composed that tells the user what need to be done to vectorize the loop. For +/// example, by specifying #pragma clang loop vectorize or -ffast-math. Late +/// evaluation should be used only when diagnostics can generated that can be +/// followed by a non-expert user. +class LoopVectorizationRequirements { +public: + LoopVectorizationRequirements(OptimizationRemarkEmitter &ORE) + : NumRuntimePointerChecks(0), UnsafeAlgebraInst(nullptr), ORE(ORE) {} + + void addUnsafeAlgebraInst(Instruction *I) { + // First unsafe algebra instruction. + if (!UnsafeAlgebraInst) + UnsafeAlgebraInst = I; + } + + void addRuntimePointerChecks(unsigned Num) { NumRuntimePointerChecks = Num; } + + bool doesNotMeet(Function *F, Loop *L, const LoopVectorizeHints &Hints, + const LoopVectorizationLegality &LVL) { + const char *PassName = Hints.vectorizeAnalysisPassName(); + bool Failed = false; + if (UnsafeAlgebraInst && !Hints.allowReordering()) { + if (LVL.allReductionsAreStrict()) { + LLVM_DEBUG(dbgs() << "LV: Vectorization possible with ordered reduction\n"); + } else { + ORE.emit( + OptimizationRemarkAnalysisFPCommute(PassName, "CantReorderFPOps", + UnsafeAlgebraInst->getDebugLoc(), + UnsafeAlgebraInst->getParent()) + << "loop not vectorized: cannot prove it is safe to reorder " + "floating-point operations"); + Failed = true; + } + } + + // Test if runtime memcheck thresholds are exceeded. + bool PragmaThresholdReached = + NumRuntimePointerChecks > PragmaVectorizeMemoryCheckThreshold; + bool ThresholdReached = + NumRuntimePointerChecks > VectorizerParams::RuntimeMemoryCheckThreshold; + if ((ThresholdReached && !Hints.allowReordering()) || + PragmaThresholdReached) { + ORE.emit(OptimizationRemarkAnalysisAliasing(PassName, "CantReorderMemOps", + L->getStartLoc(), + L->getHeader()) + << "loop not vectorized: cannot prove it is safe to reorder " + "memory operations"); + LLVM_DEBUG(dbgs() << "LV: Too many memory checks needed.\n"); + Failed = true; + } + + return Failed; + } + +private: + unsigned NumRuntimePointerChecks; + Instruction *UnsafeAlgebraInst; + + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter &ORE; +}; + +static void addAcyclicInnerLoop(Loop &L, SmallVectorImpl &V) { + if (L.empty()) { + if (!hasCyclesInLoopBody(L)) + V.push_back(&L); + return; + } + for (Loop *InnerL : L) + addAcyclicInnerLoop(*InnerL, V); +} + +/// The SVELoopVectorize Pass. +struct SVELoopVectorize : public FunctionPass { + /// Pass identification, replacement for typeid + static char ID; + + explicit SVELoopVectorize(bool NoUnrolling = false, bool AlwaysVectorize = true) + : FunctionPass(ID), DisableUnrolling(NoUnrolling), + AlwaysVectorize(AlwaysVectorize) { + initializeSVELoopVectorizePass(*PassRegistry::getPassRegistry()); + } + + ScalarEvolution *SE; + LoopInfo *LI; + TargetTransformInfo *TTI; + DominatorTree *DT; + BlockFrequencyInfo *BFI; + TargetLibraryInfo *TLI; + DemandedBits *DB; + AliasAnalysis *AA; + AssumptionCache *AC; + LoopAccessLegacyAnalysis *LAA; + bool DisableUnrolling; + bool AlwaysVectorize; + OptimizationRemarkEmitter *ORE; + + BlockFrequency ColdEntryFreq; + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + SE = &getAnalysis().getSE(); + LI = &getAnalysis().getLoopInfo(); + TTI = &getAnalysis().getTTI(F); + DT = &getAnalysis().getDomTree(); + BFI = &getAnalysis().getBFI(); + auto *TLIP = getAnalysisIfAvailable(); + TLI = TLIP ? &TLIP->getTLI() : nullptr; + AA = &getAnalysis().getAAResults(); + AC = &getAnalysis().getAssumptionCache(F); + LAA = &getAnalysis(); + DB = &getAnalysis().getDemandedBits(); + ORE = &getAnalysis().getORE(); + + // Compute some weights outside of the loop over the loops. Compute this + // using a BranchProbability to re-use its scaling math. + const BranchProbability ColdProb(1, 5); // 20% + ColdEntryFreq = BlockFrequency(BFI->getEntryFreq()) * ColdProb; + + // Don't attempt if + // 1. the target claims to have no vector registers, and + // 2. interleaving won't help ILP. + // + // The second condition is necessary because, even if the target has no + // vector registers, loop vectorization may still enable scalar + // interleaving. + if (!TTI->getNumberOfRegisters(true) && TTI->getMaxInterleaveFactor(1) < 2) + return false; + + bool Changed = false; + + // The vectorizer requires loops to be in simplified form. + // Since simplification may add new inner loops, it has to run before the + // legality and profitability checks. This means running the loop vectorizer + // will simplify all loops, regardless of whether anything end up being + // vectorized. + for (auto &L : *LI) + Changed |= simplifyLoop(L, DT, LI, SE, AC, false /* PreserveLCSSA */); + + // Build up a worklist of inner-loops to vectorize. This is necessary as + // the act of vectorizing or partially unrolling a loop creates new loops + // and can invalidate iterators across the loops. + SmallVector Worklist; + + for (Loop *L : *LI) + addAcyclicInnerLoop(*L, Worklist); + + LoopsAnalyzed += Worklist.size(); + + // Now walk the identified inner loops. + while (!Worklist.empty()) { + Loop *L = Worklist.pop_back_val(); + + // For the inner loops we actually process, form LCSSA to simplify the + // transform. + Changed |= formLCSSARecursively(*L, *DT, LI, SE); + + Changed |= processLoop(L); + } + + // Process each loop nest in the function. + return Changed; + } + + static void AddRuntimeUnrollDisableMetaData(Loop *L) { + SmallVector MDs; + // Reserve first location for self reference to the LoopID metadata node. + MDs.push_back(nullptr); + bool IsUnrollMetadata = false; + MDNode *LoopID = L->getLoopID(); + if (LoopID) { + // First find existing loop unrolling disable metadata. + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + MDNode *MD = dyn_cast(LoopID->getOperand(i)); + if (MD) { + const MDString *S = dyn_cast(MD->getOperand(0)); + IsUnrollMetadata = + S && S->getString().startswith("llvm.loop.unroll.disable"); + } + MDs.push_back(LoopID->getOperand(i)); + } + } + + if (!IsUnrollMetadata) { + // Add runtime unroll disable metadata. + LLVMContext &Context = L->getHeader()->getContext(); + SmallVector DisableOperands; + DisableOperands.push_back( + MDString::get(Context, "llvm.loop.unroll.runtime.disable")); + MDNode *DisableNode = MDNode::get(Context, DisableOperands); + MDs.push_back(DisableNode); + MDNode *NewLoopID = MDNode::get(Context, MDs); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + L->setLoopID(NewLoopID); + } + } + + bool processLoop(Loop *L) { + assert(L->empty() && "Only process inner loops."); + +#ifndef NDEBUG + const std::string DebugLocStr = getDebugLocString(L); +#endif /* NDEBUG */ + + LLVM_DEBUG(dbgs() << "\nLV: Checking a loop in \"" + << L->getHeader()->getParent()->getName() << "\" from " + << DebugLocStr << "\n"); + + LoopVectorizeHints Hints(L, DisableUnrolling, *ORE); + + LLVM_DEBUG(dbgs() << "LV: Loop hints:" + << " force=" + << (Hints.getForce() == LoopVectorizeHints::FK_Disabled + ? "disabled" + : (Hints.getForce() == LoopVectorizeHints::FK_Enabled + ? "enabled" + : "?")) + << " width=" << Hints.getWidth() + << " style=" + << (Hints.getStyle() == LoopVectorizeHints::SK_Fixed + ? "fixed" + : (Hints.getStyle() == LoopVectorizeHints::SK_Scaled + ? "scaled" + : "default")) + << " unroll=" << Hints.getInterleave() << "\n"); + + // Function containing loop + Function *F = L->getHeader()->getParent(); + + // Looking at the diagnostic output is the only way to determine if a loop + // was vectorized (other than looking at the IR or machine code), so it + // is important to generate an optimization remark for each loop. Most of + // these messages are generated by emitOptimizationRemarkAnalysis. Remarks + // generated by emitOptimizationRemark and emitOptimizationRemarkMissed are + // less verbose reporting vectorized loops and unvectorized loops that may + // benefit from vectorization, respectively. + + if (!Hints.allowVectorization(F, L, AlwaysVectorize)) { + LLVM_DEBUG(dbgs() << "LV: Loop hints prevent vectorization.\n"); + return false; + } + + // Check the loop for a trip count threshold: + // do not vectorize loops with a tiny trip count. + const unsigned TC = SE->getSmallConstantTripCount(L); + if (TC > 0u && TC < TinyTripCountVectorThreshold) { + LLVM_DEBUG(dbgs() << "LV: Found a loop with a very small trip count. " + << "This loop is not worth vectorizing."); + if (Hints.getForce() == LoopVectorizeHints::FK_Enabled) + LLVM_DEBUG(dbgs() << " But vectorizing was explicitly forced.\n"); + else { + LLVM_DEBUG(dbgs() << "\n"); + ORE->emit(createMissedAnalysis(Hints.vectorizeAnalysisPassName(), + "NotBeneficial", L) + << "vectorization is not beneficial " + "and is not explicitly forced"); + ORE->emit(createMissedAnalysis(Hints.vectorizeAnalysisPassName(), + "NotBeneficial", L) + << "to locally force vectorization, prefix loop with " + "\"#pragma clang loop vectorize (enable)\""); + ORE->emit(createMissedAnalysis(Hints.vectorizeAnalysisPassName(), + "NotBeneficial", L) + << "to globally force vectorization, compile with " + "\"-mllvm -vectorizer-min-trip-count " + << std::to_string(TC) << "\""); + return false; + } + } + + PredicatedScalarEvolution PSE(*SE, *L); + + std::function GetLAA = + [&](Loop &L) -> const LoopAccessInfo & { return LAA->getInfo(&L); }; + + // Check if it is legal to vectorize the loop. + LoopVectorizationRequirements Requirements(*ORE); + LoopVectorizationLegality LVL(L, PSE, DT, TLI, AA, F, TTI, &GetLAA, LI, ORE, + &Requirements, &Hints); + if (!LVL.canVectorize()) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); + emitMissedWarning(F, L, Hints, ORE); + return false; + } + + // Use the cost model. + LoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F, + &Hints); + CM.collectValuesToIgnore(); + + // Check the function attributes to find out if this function should be + // optimized for size. + bool OptForSize = + Hints.getForce() != LoopVectorizeHints::FK_Enabled && F->optForSize(); + + // Compute the weighted frequency of this loop being executed and see if it + // is less than 20% of the function entry baseline frequency. Note that we + // always have a canonical loop here because we think we *can* vectorize. + // FIXME: This is hidden behind a flag due to pervasive problems with + // exactly what block frequency models. + if (LoopVectorizeWithBlockFrequency) { + BlockFrequency LoopEntryFreq = BFI->getBlockFreq(L->getLoopPreheader()); + if (Hints.getForce() != LoopVectorizeHints::FK_Enabled && + LoopEntryFreq < ColdEntryFreq) + OptForSize = true; + } + + // Check the function attributes to see if implicit floats are allowed. + // FIXME: This check doesn't seem possibly correct -- what if the loop is + // an integer loop and the vector instructions selected are purely integer + // vector instructions? + if (F->hasFnAttribute(Attribute::NoImplicitFloat)) { + LLVM_DEBUG(dbgs() << "LV: Can't vectorize when the NoImplicitFloat" + "attribute is used.\n"); + ORE->emit(createMissedAnalysis(Hints.vectorizeAnalysisPassName(), + "NoImplicitFloat", L) + << "loop not vectorized due to NoImplicitFloat attribute"); + emitMissedWarning(F, L, Hints, ORE); + return false; + } + + // Check if the target supports potentially unsafe FP vectorization. + // FIXME: Add a check for the type of safety issue (denormal, signaling) + // for the target we're vectorizing for, to make sure none of the + // additional fp-math flags can help. + if (Hints.isPotentiallyUnsafe() && + TTI->isFPVectorizationPotentiallyUnsafe()) { + LLVM_DEBUG(dbgs() << "LV: Potentially unsafe FP op prevents vectorization.\n"); + ORE->emit( + createMissedAnalysis(Hints.vectorizeAnalysisPassName(), "UnsafeFP", L) + << "loop not vectorized due to unsafe FP support."); + emitMissedWarning(F, L, Hints, ORE); + return false; + } + + // Select the optimal vectorization factor. + const VectorizationFactor VF = CM.selectVectorizationFactor(OptForSize); + + // Select the interleave count. + unsigned IC = CM.selectInterleaveCount(OptForSize, VF, VF.Cost); + + // Get user interleave count. + unsigned UserIC = Hints.getInterleave(); + + // Identify the diagnostic messages that should be produced. + std::pair VecDiagMsg, IntDiagMsg; + bool VectorizeLoop = true, InterleaveLoop = true; + if (Requirements.doesNotMeet(F, L, Hints, LVL)) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: loop did not meet vectorization " + "requirements.\n"); + emitMissedWarning(F, L, Hints, ORE); + return false; + } + + if (VF.Width == 1) { + LLVM_DEBUG(dbgs() << "LV: Vectorization is possible but not beneficial.\n"); + VecDiagMsg = std::make_pair( + "VectorizationNotBeneficial", + "the cost-model indicates that vectorization is not beneficial"); + VectorizeLoop = false; + } + + if (IC == 1 && UserIC <= 1) { + // Tell the user interleaving is not beneficial. + LLVM_DEBUG(dbgs() << "LV: Interleaving is not beneficial.\n"); + IntDiagMsg = std::make_pair( + "InterleavingNotBeneficial", + "the cost-model indicates that interleaving is not beneficial"); + InterleaveLoop = false; + if (UserIC == 1) { + IntDiagMsg.first = "InterleavingNotBeneficialAndDisabled"; + IntDiagMsg.second += + " and is explicitly disabled or interleave count is set to 1"; + } + } else if (IC > 1 && UserIC == 1) { + // Tell the user interleaving is beneficial, but it explicitly disabled. + LLVM_DEBUG(dbgs() + << "LV: Interleaving is beneficial but is explicitly disabled."); + IntDiagMsg = std::make_pair( + "InterleavingBeneficialButDisabled", + "the cost-model indicates that interleaving is beneficial " + "but is explicitly disabled or interleave count is set to 1"); + InterleaveLoop = false; + } + + if (!VectorizeLoop && InterleaveLoop && LVL.hasMaskedOperations()) { + LLVM_DEBUG(dbgs() + << "LV: Interleaving is beneficial but loop contain masked access"); + IntDiagMsg = std::make_pair( + "InterleavingBeneficialButContainsMaskedAccess", + "interleaving not possible because of masked accesses"); + InterleaveLoop = false; + } + + //Temporary fix for assertion failure in sve memset vectorisation. The fix + //avoids interleaving when memset is present. + if (!VectorizeLoop && InterleaveLoop && LVL.hasMemSet()) { + LLVM_DEBUG(dbgs() + << "LV: Interleaving is beneficial but loop contains memset"); + IntDiagMsg = std::make_pair( + "InterleavingBeneficialButContainsMemset", + "interleaving not possible because of the presence of memset"); + InterleaveLoop = false; + } + + // Override IC if user provided an interleave count. + IC = UserIC > 0 ? UserIC : IC; + + // Emit diagnostic messages, if any. + const char *VAPassName = Hints.vectorizeAnalysisPassName(); + if (!VectorizeLoop && !InterleaveLoop) { + // Do not vectorize or interleaving the loop. + ORE->emit(OptimizationRemarkMissed(VAPassName, VecDiagMsg.first, + {L->getLocRange().getStart()}, + L->getHeader()) + << VecDiagMsg.second); + ORE->emit(OptimizationRemarkMissed(LV_NAME, IntDiagMsg.first, + {L->getLocRange().getStart()}, + L->getHeader()) + << IntDiagMsg.second); + return false; + } else if (!VectorizeLoop && InterleaveLoop) { + LLVM_DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); + ORE->emit(OptimizationRemarkAnalysis(VAPassName, VecDiagMsg.first, + {L->getLocRange().getStart()}, + L->getHeader()) + << VecDiagMsg.second); + } else if (VectorizeLoop && !InterleaveLoop) { + LLVM_DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width << + ") in " << DebugLocStr << '\n'); + ORE->emit(OptimizationRemarkAnalysis(LV_NAME, IntDiagMsg.first, + {L->getLocRange().getStart()}, + L->getHeader()) + << IntDiagMsg.second); + } else if (VectorizeLoop && InterleaveLoop) { + LLVM_DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width << ") in " + << DebugLocStr << '\n'); + LLVM_DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); + } + + if (!VectorizeLoop) { + assert(IC > 1 && "interleave count should not be 1 or 0"); + // If we decided that it is not legal to vectorize the loop, then + // interleave it. + InnerLoopUnroller Unroller(L, PSE, LI, DT, TLI, TTI, AC, ORE, IC); + Unroller.vectorize(&LVL, CM.MinBWs); + + ORE->emit([&]() { + return OptimizationRemark(LV_NAME, "Interleaved", L->getStartLoc(), + L->getHeader()) + << "interleaved loop (interleaved count: " + << ore::NV("InterleaveCount", IC) << ")"; + }); + } else { + // If we decided that it is *legal* to vectorize the loop then do it. + InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, VF.Width, IC, + VF.isFixed); + LB.vectorize(&LVL, CM.MinBWs); + ++LoopsVectorized; + if (LB.isScalable()) + ++LoopsVectorizedWA; + + // Add metadata to disable runtime unrolling a scalar loop when there are + // no runtime checks about strides and memory. A scalar loop that is + // rarely used is not worth unrolling. + if (!LB.areSafetyChecksAdded()) + AddRuntimeUnrollDisableMetaData(L); + + using namespace ore; + // Report the vectorization decision. + OptimizationRemark R(LV_NAME, "Vectorized", + {L->getLocRange().getStart()}, + L->getHeader()); + R << "vectorized loop (vectorization width: " + << NV("VectorizationFactor", VF.Width) + << ", interleaved count: " << NV("InterleaveCount", IC) << ")" + << setExtraArgs() + << "(runtime checks: " + << NV("RTNeeded", + std::string(LVL.getRuntimePointerChecking()->Need ? "" : "no")) + << ", FixedWidthVectorization: " + << NV("FixedWidthVectorization", std::string("scaled")) + << ")"; + ORE->emit(R); + } + + // Mark the loop as already vectorized to avoid vectorizing again. + Hints.setAlreadyVectorized(); + + LLVM_DEBUG(verifyFunction(*L->getHeader()->getParent())); + return true; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addPreserved(); + AU.addPreserved(); + AU.addPreserved(); + AU.addPreserved(); + } +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Implementation of LoopVectorizationLegality, InnerLoopVectorizer and +// LoopVectorizationCostModel. +//===----------------------------------------------------------------------===// + +Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) { + // We need to place the broadcast of invariant variables outside the loop. + Instruction *Instr = dyn_cast(V); + bool NewInstr = + (Instr && std::find(LoopVectorBody.begin(), LoopVectorBody.end(), + Instr->getParent()) != LoopVectorBody.end()); + bool Invariant = OrigLoop->isLoopInvariant(V) && !NewInstr; + + // Place the code for broadcasting invariant variables in the new preheader. + IRBuilder<>::InsertPointGuard Guard(Builder); + if (Invariant) + Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); + + // Broadcast the scalar into all locations in the vector. + Value *Shuf = Builder.CreateVectorSplat({VF, Scalable}, V, "broadcast"); + + return Shuf; +} + +Value *InnerLoopVectorizer::getStepVector(Value *Val, Value *Start, + const SCEV *StepSCEV, + Instruction::BinaryOps BinOp) { + const DataLayout &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); + SCEVExpander Exp(*PSE.getSE(), DL, "induction"); + Value *StepValue = Exp.expandCodeFor(StepSCEV, StepSCEV->getType(), + &*Builder.GetInsertPoint()); + return getStepVector(Val, Start, StepValue, BinOp); +} + +void InnerLoopVectorizer::widenInductionVariable(const InductionDescriptor &II, + VectorParts &Entry, + IntegerType *TruncType) { + Value *Start = II.getStartValue(); + ConstantInt *Step = II.getConstIntStepValue(); + assert(Step && "Can not widen an IV with a non-constant step"); + + // Construct the initial value of the vector IV in the vector loop preheader + auto CurrIP = Builder.saveIP(); + Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); + if (TruncType) { + Step = ConstantInt::getSigned(TruncType, Step->getSExtValue()); + Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType); + } + Value *SplatStart = Builder.CreateVectorSplat({VF, Scalable}, Start); + Value *SteppedStart = getStepVector(SplatStart, 0, Step); + Builder.restoreIP(CurrIP); + + Value *RuntimeVF = getRuntimeVF(Start->getType()); + Value *SplatVF = Builder.CreateVectorSplat({VF, Scalable}, RuntimeVF); + PHINode *VecInd = + PHINode::Create(SteppedStart->getType(), 2, "vec.ind", + &*LoopVectorBody[0]->getFirstInsertionPt()); + Value *LastInduction = VecInd; + for (unsigned Part = 0; Part < UF; ++Part) { + Entry[Part] = LastInduction; + LastInduction = Builder.CreateAdd(LastInduction, SplatVF, "step.add"); + } + + auto Latch = LI->getLoopFor(LoopVectorBody[0])->getLoopLatch(); + VecInd->addIncoming(SteppedStart, LoopVectorPreHeader); + VecInd->addIncoming(LastInduction, Latch); +} + +Value *InnerLoopVectorizer::getStepVector(Value *Val, int Start, Value *Step, + Instruction::BinaryOps BinOp) { + Type *Ty = Val->getType()->getScalarType(); + return getStepVector(Val, ConstantInt::get(Ty, Start), Step, BinOp); +} + +Value *InnerLoopVectorizer::getStepVector(Value *Val, Value *Start, Value *Step, + Instruction::BinaryOps BinOp) { + assert(Val->getType()->isVectorTy() && "Must be a vector"); + assert(Step->getType() == Val->getType()->getScalarType() && + "Step has wrong type"); + + VectorType *Ty = cast(Val->getType()); + Value *One = ConstantInt::get(Start->getType(), 1); + + // Create a vector of consecutive numbers from Start to Start+VF + Value *Cv = Builder.CreateSeriesVector(Ty->getElementCount(), Start, One); + + Step = Builder.CreateVectorSplat(Ty->getElementCount(), Step); + if (Val->getType()->getScalarType()->isIntegerTy()) { + // Add the consecutive indices to the vector value. + assert(Cv->getType() == Val->getType() && "Invalid consecutive vec"); + // FIXME: The newly created binary instructions should contain nsw/nuw + // flags, which can be found from the original scalar operations. + Step = Builder.CreateMul(Cv, Step); + return Builder.CreateAdd(Val, Step, "induction"); + } else { + // Floating point induction. + assert(Val->getType()->getScalarType()->isFloatingPointTy() && + "Elem must be an fp type"); + assert((BinOp == Instruction::FAdd || BinOp == Instruction::FSub) && + "Binary Opcode should be specified for FP induction"); + // Cv is an integer vector, need to convert to fp. + + // Floating point operations had to be 'fast' to enable the induction. + FastMathFlags Flags; + Flags.setFast(true); + + Cv = Builder.CreateUIToFP(Cv, Ty); + Step = Builder.CreateFMul(Cv, Step); + + if (isa(Step)) + // Have to check, Step may be a constant + cast(Step)->setFastMathFlags(Flags); + + Value *BOp = Builder.CreateBinOp(BinOp, Val, Step, "induction"); + if (isa(BOp)) + cast(BOp)->setFastMathFlags(Flags); + return BOp; + } +} + +Constant *InnerLoopVectorizer::getRuntimeVF(Type *Ty) { + Constant *EC = ConstantInt::get(Ty, VF); + if (Scalable) + EC = ConstantExpr::getMul(VScale::get(Ty), EC); + + return EC; +} + +int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { + assert(Ptr->getType()->isPointerTy() && "Unexpected non-ptr"); + auto *SE = PSE.getSE(); + // Make sure that the pointer does not point to structs. + if (Ptr->getType()->getPointerElementType()->isAggregateType()) + return 0; + + // If this value is a pointer induction variable, we know it is consecutive. + PHINode *Phi = dyn_cast_or_null(Ptr); + if (Phi && Inductions.count(Phi)) { + InductionDescriptor II = Inductions[Phi]; + return II.getConsecutiveDirection(); + } + + GetElementPtrInst *Gep = getGEPInstruction(Ptr); + if (!Gep) + return 0; + + unsigned NumOperands = Gep->getNumOperands(); + Value *GpPtr = Gep->getPointerOperand(); + // If this GEP value is a consecutive pointer induction variable and all of + // the indices are constant, then we know it is consecutive. + Phi = dyn_cast(GpPtr); + if (Phi && Inductions.count(Phi)) { + + // Make sure that the pointer does not point to structs. + PointerType *GepPtrType = cast(GpPtr->getType()); + if (GepPtrType->getElementType()->isAggregateType()) + return 0; + + // Make sure that all of the index operands are loop invariant. + for (unsigned i = 1; i < NumOperands; ++i) + if (!SE->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), TheLoop)) + return 0; + + InductionDescriptor II = Inductions[Phi]; + return II.getConsecutiveDirection(); + } + + unsigned InductionOperand = getGEPInductionOperand(Gep); + + // Check that all of the gep indices are uniform except for our induction + // operand. + for (unsigned i = 0; i != NumOperands; ++i) + if (i != InductionOperand && + !SE->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), TheLoop)) + return 0; + + // We can emit wide load/stores only if the last non-zero index is the + // induction variable. + const SCEV *Last = nullptr; + if (!Strides.count(Gep)) + Last = PSE.getSCEV(Gep->getOperand(InductionOperand)); + else { + // Because of the multiplication by a stride we can have a s/zext cast. + // We are going to replace this stride by 1 so the cast is safe to ignore. + // + // %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + // %0 = trunc i64 %indvars.iv to i32 + // %mul = mul i32 %0, %Stride1 + // %idxprom = zext i32 %mul to i64 << Safe cast. + // %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom + // + Last = replaceSymbolicStrideSCEV(PSE, Strides, + Gep->getOperand(InductionOperand), Gep); + if (const SCEVCastExpr *C = dyn_cast(Last)) + Last = + (C->getSCEVType() == scSignExtend || C->getSCEVType() == scZeroExtend) + ? C->getOperand() + : Last; + } + if (const SCEVAddRecExpr *AR = dyn_cast(Last)) { + const SCEV *Step = AR->getStepRecurrence(*SE); + + // The memory is consecutive because the last index is consecutive + // and all other indices are loop invariant. + if (Step->isOne()) + return 1; + if (Step->isAllOnesValue()) + return -1; + + // Try and find a different constant stride + if (EnableNonConsecutiveStrideIndVars) { + if (const SCEVConstant *SCC = dyn_cast(Step)) { + const ConstantInt *CI = SCC->getValue(); + // TODO: Error checking vs. INT_MAX? + return (int)CI->getLimitedValue(INT_MAX); + } + } + } + + return 0; +} + +bool LoopVectorizationLegality::isUniform(Value *V) { + return LAI->isUniform(V); +} + +InnerLoopVectorizer::VectorParts & +InnerLoopVectorizer::getVectorValue(Value *V) { + assert(V != Induction && "The new induction variable should not be used."); + assert(!V->getType()->isVectorTy() && "Can't widen a vector"); + + // If we have a stride that is replaced by one, do it here. + if (Legal->hasStride(V)) + V = ConstantInt::get(V->getType(), 1); + + // If we have this scalar in the map, return it. + if (WidenMap.has(V)) + return WidenMap.get(V); + + // If this scalar is unknown, assume that it is a constant or that it is + // loop invariant. Broadcast V and save the value for future uses. + Value *B = getBroadcastInstrs(V); + return WidenMap.splat(V, B); +} + +Value *InnerLoopVectorizer::reverseVector(Value *Vec) { + assert(Vec->getType()->isVectorTy() && "Invalid type"); + VectorType *Ty = cast(Vec->getType()); + + // i32 reverse_mask[n] = { n-1, n-2...1, 0 } + Value *RuntimeVF = getRuntimeVF(Builder.getInt32Ty()); + Value *Start = Builder.CreateSub(RuntimeVF, Builder.getInt32(1)); + Value *Step = ConstantInt::get(Start->getType(), -1, true); + Value *Mask = Builder.CreateSeriesVector({VF,Scalable}, Start, Step); + + return Builder.CreateShuffleVector(Vec, UndefValue::get(Ty), Mask, "reverse"); +} + +// Get a mask to interleave \p NumVec vectors into a wide vector. +// I.e. <0, VF, VF*2, ..., VF*(NumVec-1), 1, VF+1, VF*2+1, ...> +// E.g. For 2 interleaved vectors, if VF is 4, the mask is: +// <0, 4, 1, 5, 2, 6, 3, 7> +static Constant *getInterleavedMask(IRBuilder<> &Builder, unsigned VF, + unsigned NumVec) { + SmallVector Mask; + for (unsigned i = 0; i < VF; i++) + for (unsigned j = 0; j < NumVec; j++) + Mask.push_back(Builder.getInt32(j * VF + i)); + + return ConstantVector::get(Mask); +} + +// Get the strided mask starting from index \p Start. +// I.e. +static Constant *getStridedMask(IRBuilder<> &Builder, unsigned Start, + unsigned Stride, unsigned VF) { + SmallVector Mask; + for (unsigned i = 0; i < VF; i++) + Mask.push_back(Builder.getInt32(Start + i * Stride)); + + return ConstantVector::get(Mask); +} + +// Get a mask of two parts: The first part consists of sequential integers +// starting from 0, The second part consists of UNDEFs. +// I.e. <0, 1, 2, ..., NumInt - 1, undef, ..., undef> +static Constant *getSequentialMask(IRBuilder<> &Builder, unsigned NumInt, + unsigned NumUndef) { + SmallVector Mask; + for (unsigned i = 0; i < NumInt; i++) + Mask.push_back(Builder.getInt32(i)); + + Constant *Undef = UndefValue::get(Builder.getInt32Ty()); + for (unsigned i = 0; i < NumUndef; i++) + Mask.push_back(Undef); + + return ConstantVector::get(Mask); +} + +// Concatenate two vectors with the same element type. The 2nd vector should +// not have more elements than the 1st vector. If the 2nd vector has less +// elements, extend it with UNDEFs. +static Value *ConcatenateTwoVectors(IRBuilder<> &Builder, Value *V1, + Value *V2) { + VectorType *VecTy1 = dyn_cast(V1->getType()); + VectorType *VecTy2 = dyn_cast(V2->getType()); + assert(VecTy1 && VecTy2 && + VecTy1->getScalarType() == VecTy2->getScalarType() && + "Expect two vectors with the same element type"); + + unsigned NumElts1 = VecTy1->getNumElements(); + unsigned NumElts2 = VecTy2->getNumElements(); + assert(NumElts1 >= NumElts2 && "Unexpect the first vector has less elements"); + + if (NumElts1 > NumElts2) { + // Extend with UNDEFs. + Constant *ExtMask = + getSequentialMask(Builder, NumElts2, NumElts1 - NumElts2); + V2 = Builder.CreateShuffleVector(V2, UndefValue::get(VecTy2), ExtMask); + } + + Constant *Mask = getSequentialMask(Builder, NumElts1 + NumElts2, 0); + return Builder.CreateShuffleVector(V1, V2, Mask); +} + +// Concatenate vectors in the given list. All vectors have the same type. +static Value *ConcatenateVectors(IRBuilder<> &Builder, + ArrayRef InputList) { + unsigned NumVec = InputList.size(); + assert(NumVec > 1 && "Should be at least two vectors"); + + SmallVector ResList; + ResList.append(InputList.begin(), InputList.end()); + do { + SmallVector TmpList; + for (unsigned i = 0; i < NumVec - 1; i += 2) { + Value *V0 = ResList[i], *V1 = ResList[i + 1]; + assert((V0->getType() == V1->getType() || i == NumVec - 2) && + "Only the last vector may have a different type"); + + TmpList.push_back(ConcatenateTwoVectors(Builder, V0, V1)); + } + + // Push the last vector if the total number of vectors is odd. + if (NumVec % 2 != 0) + TmpList.push_back(ResList[NumVec - 1]); + + ResList = TmpList; + NumVec = ResList.size(); + } while (NumVec > 1); + + return ResList[0]; +} + +// Try to vectorize the interleave group that \p Instr belongs to. +// +// E.g. Translate following interleaved load group (factor = 3): +// for (i = 0; i < N; i+=3) { +// R = Pic[i]; // Member of index 0 +// G = Pic[i+1]; // Member of index 1 +// B = Pic[i+2]; // Member of index 2 +// ... // do something to R, G, B +// } +// To: +// %wide.vec = load <12 x i32> ; Read 4 tuples of R,G,B +// %R.vec = shuffle %wide.vec, undef, <0, 3, 6, 9> ; R elements +// %G.vec = shuffle %wide.vec, undef, <1, 4, 7, 10> ; G elements +// %B.vec = shuffle %wide.vec, undef, <2, 5, 8, 11> ; B elements +// +// Or translate following interleaved store group (factor = 3): +// for (i = 0; i < N; i+=3) { +// ... do something to R, G, B +// Pic[i] = R; // Member of index 0 +// Pic[i+1] = G; // Member of index 1 +// Pic[i+2] = B; // Member of index 2 +// } +// To: +// %R_G.vec = shuffle %R.vec, %G.vec, <0, 1, 2, ..., 7> +// %B_U.vec = shuffle %B.vec, undef, <0, 1, 2, 3, u, u, u, u> +// %interleaved.vec = shuffle %R_G.vec, %B_U.vec, +// <0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11> ; Interleave R,G,B elements +// store <12 x i32> %interleaved.vec ; Write 4 tuples of R,G,B +void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { + const InterleaveGroup *Group = Legal->getInterleavedAccessGroup(Instr); + assert(Group && "Fail to get an interleaved access group."); + + // Skip if current instruction is not the insert position. + if (Instr != Group->getInsertPos()) + return; + + LoadInst *LI = dyn_cast(Instr); + StoreInst *SI = dyn_cast(Instr); + Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); + + // Prepare for the vector type of the interleaved load/store. + Type *ScalarTy = LI ? LI->getType() : SI->getValueOperand()->getType(); + unsigned InterleaveFactor = Group->getFactor(); + Type *VecTy = VectorType::get(ScalarTy, InterleaveFactor * VF, Scalable); + Type *PtrTy = VecTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); + + // Prepare for the new pointers. + setDebugLocFromInst(Builder, Ptr); + VectorParts &PtrParts = getVectorValue(Ptr); + SmallVector NewPtrs; + unsigned Index = Group->getIndex(Instr); + for (unsigned Part = 0; Part < UF; Part++) { + // Extract the pointer for current instruction from the pointer vector. A + // reverse access uses the pointer in the last lane. + Value *NewPtr = Builder.CreateExtractElement( + PtrParts[Part], + Group->isReverse() ? Builder.getInt32(VF - 1) : Builder.getInt32(0)); + + // Notice current instruction could be any index. Need to adjust the address + // to the member of index 0. + // + // E.g. a = A[i+1]; // Member of index 1 (Current instruction) + // b = A[i]; // Member of index 0 + // Current pointer is pointed to A[i+1], adjust it to A[i]. + // + // E.g. A[i+1] = a; // Member of index 1 + // A[i] = b; // Member of index 0 + // A[i+2] = c; // Member of index 2 (Current instruction) + // Current pointer is pointed to A[i+2], adjust it to A[i]. + NewPtr = Builder.CreateGEP(NewPtr, Builder.getInt32(-Index)); + + // Cast to the vector pointer type. + NewPtrs.push_back(Builder.CreateBitCast(NewPtr, PtrTy)); + } + + setDebugLocFromInst(Builder, Instr); + Value *UndefVec = UndefValue::get(VecTy); + + // Vectorize the interleaved load group. + if (LI) { + for (unsigned Part = 0; Part < UF; Part++) { + Instruction *NewLoadInstr = Builder.CreateAlignedLoad( + NewPtrs[Part], Group->getAlignment(), "wide.vec"); + + for (unsigned i = 0; i < InterleaveFactor; i++) { + Instruction *Member = Group->getMember(i); + + // Skip the gaps in the group. + if (!Member) + continue; + + Constant *StrideMask = getStridedMask(Builder, i, InterleaveFactor, VF); + Value *StridedVec = Builder.CreateShuffleVector( + NewLoadInstr, UndefVec, StrideMask, "strided.vec"); + + // If this member has different type, cast the result type. + if (Member->getType() != ScalarTy) { + VectorType *OtherVTy = VectorType::get(Member->getType(), VF, + Scalable); + StridedVec = Builder.CreateBitOrPointerCast(StridedVec, OtherVTy); + } + + VectorParts &Entry = WidenMap.get(Member); + Entry[Part] = + Group->isReverse() ? reverseVector(StridedVec) : StridedVec; + } + + addMetadata(NewLoadInstr, Instr); + } + return; + } + + // The sub vector type for current instruction. + VectorType *SubVT = VectorType::get(ScalarTy, VF, Scalable); + + // Vectorize the interleaved store group. + for (unsigned Part = 0; Part < UF; Part++) { + // Collect the stored vector from each member. + SmallVector StoredVecs; + for (unsigned i = 0; i < InterleaveFactor; i++) { + // Interleaved store group doesn't allow a gap, so each index has a member + Instruction *Member = Group->getMember(i); + assert(Member && "Fail to get a member from an interleaved store group"); + + Value *StoredVec = + getVectorValue(dyn_cast(Member)->getValueOperand())[Part]; + if (Group->isReverse()) + StoredVec = reverseVector(StoredVec); + + // If this member has different type, cast it to an unified type. + if (StoredVec->getType() != SubVT) + StoredVec = Builder.CreateBitOrPointerCast(StoredVec, SubVT); + + StoredVecs.push_back(StoredVec); + } + + // Concatenate all vectors into a wide vector. + Value *WideVec = ConcatenateVectors(Builder, StoredVecs); + + // Interleave the elements in the wide vector. + Constant *IMask = getInterleavedMask(Builder, VF, InterleaveFactor); + Value *IVec = Builder.CreateShuffleVector(WideVec, UndefVec, IMask, + "interleaved.vec"); + + Instruction *NewStoreInstr = + Builder.CreateAlignedStore(IVec, NewPtrs[Part], Group->getAlignment()); + addMetadata(NewStoreInstr, Instr); + } +} + +void InnerLoopVectorizer::vectorizeMemsetInstruction(MemSetInst *MSI) { + const auto Length = MSI->getLength(); + const auto IsVolatile = MSI->isVolatile(); + // Clamp Alignment to yield an acceptable vector element type. + const auto Alignment = std::min((uint64_t)MSI->getDestAlignment(), + (uint64_t)8); + const auto Val = MSI->getValue(); + const auto Dest = MSI->getRawDest(); + auto CL = dyn_cast(Length); + assert(CL && "Not a constant value."); + assert((CL->getZExtValue() % Alignment == 0) + && "Not a valid number of writes."); + assert(((CL->getZExtValue() / Alignment) <= VectorizerMemSetThreshold) + && "Not a valid number of elements."); + assert(!IsVolatile && "Cannot transform a volatile memset."); + assert(VectorizeMemset && "Should not vectorize memset."); + assert(isScalable() && "Require WA."); + + VectorParts &Ptrs = getVectorValue(Dest); + VectorParts &Vals = getVectorValue(Val); + for (unsigned Part = 0; Part < UF; ++Part) { + auto *Ctx = &MSI->getParent()->getParent()->getContext(); + assert(Vals[Part]->getType()->getScalarType()->getScalarSizeInBits() == 8 + && "Invalid pointer"); + Type *WideScalarTy = IntegerType::get(*Ctx, 8 * Alignment); + VectorType::ElementCount EC(VF * Alignment, Scalable); + Value *VecVal = Builder.CreateVectorSplat(EC, Val); + auto WideVecTy = VectorType::get(WideScalarTy, VF, Scalable); + VecVal = Builder.CreateBitCast(VecVal, WideVecTy); + + // Generate the actual memset replacement. + auto AddrSpace = Ptrs[Part]->getType()->getPointerAddressSpace(); + auto WideVecPtrTy = VectorType::get(WideScalarTy->getPointerTo(AddrSpace), + VF, Scalable); + Value *P = Predicate[Part]; + for (unsigned i = 0; i < CL->getZExtValue(); i+=Alignment) { + auto Ptr = Builder.CreateGEP(Ptrs[Part], Builder.getInt32(i)); + Ptr = Builder.CreateBitCast(Ptr, WideVecPtrTy); + auto *NewMemset = Builder.CreateMaskedScatter(VecVal, Ptr, Alignment, P); + propagateMetadata(NewMemset, MSI); + } + } +} + +void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { + // Attempt to issue a wide load. + LoadInst *LI = dyn_cast(Instr); + StoreInst *SI = dyn_cast(Instr); + + assert((LI || SI) && "Invalid Load/Store instruction"); + + // Don't create a memory instruction for an intermediate store of a + // reduction variable, because this will be one to a uniform address. + if (SI) { + for (auto &Reduction : *Legal->getReductionVars()) { + RecurrenceDescriptor DS = Reduction.second; + if (DS.IntermediateStore && + storeToSameAddress(PSE.getSE(), SI, DS.IntermediateStore)) + return; + } + } + + // Try to vectorize the interleave group if this access is interleaved. + if (Legal->isAccessInterleaved(Instr)) + return vectorizeInterleaveGroup(Instr); + + Type *ScalarDataTy = LI ? LI->getType() : SI->getValueOperand()->getType(); + Type *DataTy = VectorType::get(ScalarDataTy, VF, Scalable); + Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); + unsigned Alignment = LI ? LI->getAlignment() : SI->getAlignment(); + // An alignment of 0 means target abi alignment. We need to use the scalar's + // target abi alignment in such a case. + const DataLayout &DL = Instr->getModule()->getDataLayout(); + if (!Alignment) + Alignment = DL.getABITypeAlignment(ScalarDataTy); + unsigned AddressSpace = Ptr->getType()->getPointerAddressSpace(); + unsigned ScalarAllocatedSize = DL.getTypeAllocSize(ScalarDataTy); + unsigned VectorElementSize = DL.getTypeStoreSize(DataTy) / VF; + + if (SI && Legal->blockNeedsPredication(SI->getParent()) && + !Legal->isMaskRequired(SI) && !UsePredication) + return scalarizeInstruction(Instr, true); + + if (ScalarAllocatedSize != VectorElementSize) + return scalarizeInstruction(Instr); + + Constant *Zero = Builder.getInt32(0); + VectorParts &Entry = WidenMap.get(Instr); + + // If the pointer is loop invariant scalarize the load. + if (LI && Legal->isUniform(Ptr)) { + // ... unless we're vectorizing for a scalable architecture. + if (isScalable()) { + // The pointer may be uniform from SCEV perspective, + // but may not be hoisted out for other reasons. + auto *PtrI = dyn_cast(Ptr); + if (PtrI && OrigLoop->contains(PtrI)) { + Ptr = Builder.CreateExtractElement(getVectorValue(Ptr)[0], + Builder.getInt32(0)); + } + + // Generate a scalar load... + Instruction *NewLI = Builder.CreateLoad(Ptr); + propagateMetadata(NewLI, LI); + + // ... and splat it. + for (unsigned Part = 0; Part < UF; ++Part) { + Entry[Part] = + Builder.CreateVectorSplat({VF, Scalable}, NewLI, "uniform_load"); + } + } else + scalarizeInstruction(Instr); + + return; + } + + // If the pointer is non-consecutive and gather/scatter is not supported + // scalarize the instruction. + int Stride = Legal->isConsecutivePtr(Ptr); + bool Reverse = Stride < 0; + bool HasConsecutiveStride = (std::abs(Stride) == 1); + bool CreateGatherScatter = + !HasConsecutiveStride && + ((LI && Legal->isLegalMaskedGather(ScalarDataTy)) || + (SI && Legal->isLegalMaskedScatter(ScalarDataTy))); + + if (!HasConsecutiveStride && !CreateGatherScatter) + return scalarizeInstruction(Instr); + + VectorParts VectorGep; + GetElementPtrInst *Gep = getGEPInstruction(Ptr); + if (HasConsecutiveStride) { + if (Gep && Legal->isInductionVariable(Gep->getPointerOperand())) { + setDebugLocFromInst(Builder, Gep); + Value *PtrOperand = Gep->getPointerOperand(); + Value *FirstBasePtr = getVectorValue(PtrOperand)[0]; + FirstBasePtr = Builder.CreateExtractElement(FirstBasePtr, Zero); + + // Create the new GEP with the new induction variable. + GetElementPtrInst *Gep2 = cast(Gep->clone()); + Gep2->setOperand(0, FirstBasePtr); + Gep2->setName("gep.indvar.base"); + Ptr = Builder.Insert(Gep2); + } else if (Gep) { + setDebugLocFromInst(Builder, Gep); + assert(PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getPointerOperand()), + OrigLoop) && + "Base ptr must be invariant"); + // The last index does not have to be the induction. It can be + // consecutive and be a function of the index. For example A[I+1]; + unsigned NumOperands = Gep->getNumOperands(); + unsigned InductionOperand = getGEPInductionOperand(Gep); + // Create the new GEP with the new induction variable. + GetElementPtrInst *Gep2 = cast(Gep->clone()); + + for (unsigned i = 0; i < NumOperands; ++i) { + Value *GepOperand = Gep->getOperand(i); + Instruction *GepOperandInst = dyn_cast(GepOperand); + + // Update last index or loop invariant instruction anchored in loop. + if (i == InductionOperand || + (GepOperandInst && OrigLoop->contains(GepOperandInst))) { + assert((i == InductionOperand || + PSE.getSE()->isLoopInvariant(PSE.getSCEV(GepOperandInst), + OrigLoop)) && + "Must be last index or loop invariant"); + + VectorParts &GEPParts = getVectorValue(GepOperand); + Value *Index = GEPParts[0]; + Index = Builder.CreateExtractElement(Index, Zero); + Gep2->setOperand(i, Index); + Gep2->setName("gep.indvar.idx"); + } + } + Ptr = Builder.Insert(Gep2); + } else { // No GEP + // Use the induction element ptr. + assert(isa(Ptr) && "Invalid induction ptr"); + setDebugLocFromInst(Builder, Ptr); + VectorParts &PtrVal = getVectorValue(Ptr); + Ptr = Builder.CreateExtractElement(PtrVal[0], Zero); + } + } else { + // At this point we should vector version of GEP for Gather or Scatter + assert(CreateGatherScatter && "The instruction should be scalarized"); + // For scalable vectorization, vectorizeGEPInstruction has already + // handled this. Only useful for fixed-length. + // TODO: Unify this version with the scalable code once we can discuss with + // the community. + if (Gep && !isScalable()) { + SmallVector OpsV; + // Vectorizing GEP, across UF parts, we want to keep each loop-invariant + // base or index of GEP scalar + for (Value *Op : Gep->operands()) { + if (PSE.getSE()->isLoopInvariant(PSE.getSCEV(Op), OrigLoop)) + OpsV.push_back(VectorParts(UF, Op)); + else + OpsV.push_back(getVectorValue(Op)); + } + + for (unsigned Part = 0; Part < UF; ++Part) { + SmallVector Ops; + Value *GEPBasePtr = OpsV[0][Part]; + for (unsigned i = 1; i < Gep->getNumOperands(); i++) + Ops.push_back(OpsV[i][Part]); + Value *NewGep = + Builder.CreateGEP(nullptr, GEPBasePtr, Ops, "VectorGep"); + assert(NewGep->getType()->isVectorTy() && "Expected vector GEP"); + NewGep = + Builder.CreateBitCast(NewGep, VectorType::get(Ptr->getType(), + {VF, Scalable})); + VectorGep.push_back(NewGep); + } + } else + VectorGep = getVectorValue(Ptr); + } + + Type *DataPtrTy = DataTy->getPointerTo(AddressSpace); + VectorParts Mask = createBlockInMask(Instr->getParent()); + + VectorParts PredStoreMask; + if (SI && Legal->blockNeedsPredication(SI->getParent()) && + !Legal->isMaskRequired(SI)) { + assert(UsePredication && "Cannot predicate store without predication."); + assert(SI->getParent()->getSinglePredecessor() && + "Only support single predecessor blocks."); + PredStoreMask = createEdgeMask(SI->getParent()->getSinglePredecessor(), + SI->getParent()); + } + + // Handle Stores: + if (SI) { + assert(!Legal->isUniform(SI->getPointerOperand()) && + "We do not allow storing to uniform addresses"); + setDebugLocFromInst(Builder, SI); + // We don't want to update the value in the map as it might be used in + // another expression. So don't use a reference type for "StoredVal". + VectorParts StoredVal = getVectorValue(SI->getValueOperand()); + + for (unsigned Part = 0; Part < UF; ++Part) { + Instruction *NewSI = nullptr; + if (CreateGatherScatter) { + Value *P = Predicate[Part]; + + if (Legal->isMaskRequired(SI)) + P = Builder.CreateAnd(P, Mask[Part]); + + NewSI = Builder.CreateMaskedScatter(StoredVal[Part], VectorGep[Part], + Alignment, P); + } else { + // Calculate the pointer for the specific unroll-part. + Value *VecPtr; + + Value *MaskPart = Mask[Part]; + Value *Data = StoredVal[Part]; + + if (UsePredication) + MaskPart = Builder.CreateAnd(MaskPart, Predicate[Part]); + + if (Reverse) { + // If we store to reverse consecutive memory locations, then we need + // to reverse the order of elements in the stored value. + Data = reverseVector(Data); + // If the address is consecutive but reversed, then the + // wide store needs to start at the last vector element. + VecPtr = Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(1)); + VecPtr = Builder.CreateBitCast(VecPtr, DataPtrTy); + VecPtr = Builder.CreateGEP(nullptr, VecPtr, Builder.getInt32(-Part-1)); + MaskPart = reverseVector(MaskPart); + } else { + VecPtr = Builder.CreateBitCast(Ptr, DataPtrTy); + VecPtr = Builder.CreateGEP(nullptr, VecPtr, Builder.getInt32(Part)); + } + + if (Legal->isMaskRequired(SI)) + NewSI = Builder.CreateMaskedStore(Data, VecPtr, Alignment, MaskPart); + else if (UsePredication) { + Value* P = Predicate[Part]; + + if (Legal->blockNeedsPredication(SI->getParent())) + P = Builder.CreateAnd(P, PredStoreMask[Part]); + + if (Reverse) + P = reverseVector(P); + + NewSI = Builder.CreateMaskedStore(Data, VecPtr, Alignment, P); + } else + NewSI = Builder.CreateAlignedStore(Data, VecPtr, Alignment); + } + addMetadata(NewSI, SI); + } + return; + } + + // Handle loads. + assert(LI && "Must have a load instruction"); + setDebugLocFromInst(Builder, LI); + for (unsigned Part = 0; Part < UF; ++Part) { + Instruction *NewLI; + if (CreateGatherScatter) { + Value *P = Predicate[Part]; + if (Legal->isMaskRequired(LI)) + P = Builder.CreateAnd(P, Mask[Part]); + + NewLI = Builder.CreateMaskedGather(VectorGep[Part], Alignment, + P, 0, "wide.masked.gather"); + Entry[Part] = NewLI; + } else { + // Calculate the pointer for the specific unroll-part. + Value *VecPtr; + + Value *MaskPart = Mask[Part]; + + if (UsePredication) + MaskPart = Builder.CreateAnd(MaskPart, Predicate[Part]); + + if (Reverse) { + // If the address is consecutive but reversed, then the + // wide load needs to start at the last vector element. + VecPtr = Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(1)); + VecPtr = Builder.CreateBitCast(VecPtr, DataPtrTy); + VecPtr = Builder.CreateGEP(nullptr, VecPtr, Builder.getInt32(-Part-1)); + MaskPart = reverseVector(MaskPart); + } else { + VecPtr = Builder.CreateBitCast(Ptr, DataPtrTy); + VecPtr = Builder.CreateGEP(nullptr, VecPtr, Builder.getInt32(Part)); + } + + if (Legal->isMaskRequired(LI)) + NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment, MaskPart, + UndefValue::get(DataTy), + "wide.masked.load"); + else if (UsePredication) { + Value* P = Reverse ? reverseVector(Predicate[Part]) : Predicate[Part]; + NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment, P, + UndefValue::get(DataTy), + "wide.masked.load"); + } else + NewLI = Builder.CreateAlignedLoad(VecPtr, Alignment, "wide.load"); + Entry[Part] = Reverse ? reverseVector(NewLI) : NewLI; + } + addMetadata(NewLI, LI); + } +} + +/// Depending on the access pattern, either of three things happen with +/// the GetElementPtr instruction: +/// - GEP is loop invariant: +/// - GEP is not affine: +/// - GEP pointer is a vectorized GEP instruction: +/// GEP is replaced by a vector of pointers using arithmetic +/// - GEP is affine function of loop iteration counter: +/// GEP is replaced by a seriesvector(%ptr, %stride) +void InnerLoopVectorizer::vectorizeGEPInstruction(Instruction *Instr) { + GetElementPtrInst *Gep = cast(Instr); + + if (!isScalable()) { + scalarizeInstruction(Instr); + return; + } + + auto *SE = PSE.getSE(); + + // Handle all non loop invariant forms that are not affine, so that + // when used as address it can be transformed into a gather load/store, + // or when used as pointer arithmetic, it is just vectorized into + // arithmetic instructions. + auto *SAR = dyn_cast(SE->getSCEV(Gep)); + if (!SAR || !SAR->isAffine() || SE->isLoopInvariant(SAR, OrigLoop)) { + vectorizeArithmeticGEP(Gep); + return; + } + + // Create SCEV expander for Start- and StepValue + const DataLayout &DL = Instr->getModule()->getDataLayout(); + SCEVExpander Expander(*SE, DL, "seriesgep"); + + // Expand step and start value (the latter in preheader) + const SCEV *StepRec = SAR->getStepRecurrence(*SE); + + // If the step can't be divided by the type size of the GEP (for example if + // the type structure is { gep = { i64, i64 }, i64 }, then also use the + // pointer arithmetic vectorization. + if (auto *StepC = dyn_cast(StepRec)) { + if (StepC->getAPInt().getZExtValue() % + DL.getTypeAllocSize(Gep->getType()->getPointerElementType())) { + vectorizeArithmeticGEP(Gep); + return; + } + } + + Value *StepValue = Expander.expandCodeFor(StepRec, StepRec->getType(), + &*Builder.GetInsertPoint()); + + // Try to find a smaller type for StepValue + const SCEV *BETC = SE->getMaxBackedgeTakenCount(OrigLoop); + if (auto * MaxIters = dyn_cast(BETC)) { + if (auto * CI = dyn_cast(StepValue)) { + // RequiredBits = active_bits(max_iterations * step_value) + APInt MaxItersV = MaxIters->getValue()->getValue(); + if (CI->isNegative()) + MaxItersV = MaxItersV.sextOrSelf(CI->getValue().getBitWidth()); + else + MaxItersV = MaxItersV.zextOrSelf(CI->getValue().getBitWidth()); + + APInt MaxVal = MaxItersV * CI->getValue(); + + // Try to reduce this type from i64 to something smaller + unsigned RequiredBits = MaxVal.getActiveBits(); + unsigned StepBits = StepValue->getType()->getIntegerBitWidth(); + while (RequiredBits <= StepBits && StepBits >= 32) + StepBits = StepBits >> 1; + + // Truncate the step value + Type *NewStepType = IntegerType::get( + Instr->getParent()->getContext(), StepBits << 1); + StepValue = Builder.CreateTrunc(StepValue, NewStepType); + } + } + + const SCEV *StartRec = SAR->getStart(); + Value *StartValue = Expander.expandCodeFor( + StartRec, Gep->getType(), LoopVectorPreHeader->getTerminator()); + + // Normalize Start offset for first iteration in case the + // Induction variable does not start at 0. + IRBuilder<>::InsertPoint IP = Builder.saveIP(); + Builder.SetInsertPoint(&*LoopVectorPreHeader->getTerminator()); + + Value *Base = Gep->getPointerOperand(); + Value *Tmp2 = Builder.CreateBitCast(StartValue, + Builder.getInt8PtrTy(Base->getType()->getPointerAddressSpace())); + + // We can zero extend the incoming value, because Induction is + // the unsigned iteration counter. + Value *Tmp = Induction->getIncomingValueForBlock(LoopVectorPreHeader); + Tmp = Builder.CreateZExtOrTrunc(Tmp, StepValue->getType()); + Tmp = Builder.CreateMul(StepValue, Tmp); + Tmp = Builder.CreateSub(ConstantInt::get(StepValue->getType(), 0), Tmp); + Tmp = Builder.CreateGEP(Tmp2, Tmp); + StartValue = Builder.CreateBitCast(Tmp, StartValue->getType()); + Builder.restoreIP(IP); + + // Normalize to be in #elements, not bytes + Type *ElemTy = Instr->getType()->getPointerElementType(); + Tmp = ConstantInt::get(StepValue->getType(), DL.getTypeAllocSize(ElemTy)); + StepValue = Builder.CreateSDiv(StepValue, Tmp); + + // Get the dynamic VL + Value *RuntimeVF = getRuntimeVF(StepValue->getType()); + + // Create the series vector + VectorParts &Entry = WidenMap.get(Instr); + + // Induction is always the widest induction type in the loop, + // but if that is not enough for evaluating the step, zero extend is + // fine because Induction is the iteration counter, always unsigned. + Value *IterOffset = Builder.CreateZExtOrTrunc(Induction, StepValue->getType()); + IterOffset = Builder.CreateMul(IterOffset, StepValue); + for (unsigned Part = 0; Part < UF; ++Part) { + // Tmp = part * stride * VL + Value *UnrollOffset = ConstantInt::get(RuntimeVF->getType(), Part); + UnrollOffset = Builder.CreateMul(StepValue, UnrollOffset); + UnrollOffset = Builder.CreateMul(RuntimeVF, UnrollOffset); + + // Adjust offset for unrolled iteration + Value *Offset = Builder.CreateAdd(IterOffset, UnrollOffset); + Offset = Builder.CreateSeriesVector({VF,Scalable}, Offset, StepValue); + + // Address = getelementptr %scalarbase, seriesvector(0, step) + Entry[Part] = Builder.CreateGEP(StartValue, Offset); + } + + addMetadata(Entry, Instr); +} + +void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, + bool IfPredicateStore) { + assert(!Instr->getType()->isAggregateType() && "Can't handle vectors"); + assert(!isScalable() && + "Cannot scalarize instruction with scalable vectorization"); + + // Holds vector parameters or scalars, in case of uniform vals. + SmallVector Params; + + setDebugLocFromInst(Builder, Instr); + + // Find all of the vectorized parameters. + for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { + Value *SrcOp = Instr->getOperand(op); + + // If we are accessing the old induction variable, use the new one. + if (SrcOp == OldInduction) { + Params.push_back(getVectorValue(SrcOp)); + continue; + } + + // Try using previously calculated values. + Instruction *SrcInst = dyn_cast(SrcOp); + + // If the src is an instruction that appeared earlier in the basic block, + // then it should already be vectorized. + if (SrcInst && OrigLoop->contains(SrcInst)) { + assert(WidenMap.has(SrcInst) && "Source operand is unavailable"); + // The parameter is a vector value from earlier. + Params.push_back(WidenMap.get(SrcInst)); + } else { + // The parameter is a scalar from outside the loop. Maybe even a constant. + VectorParts Scalars; + Scalars.append(UF, SrcOp); + Params.push_back(Scalars); + } + } + + assert(Params.size() == Instr->getNumOperands() && + "Invalid number of operands"); + + // Does this instruction return a value ? + bool IsVoidRetTy = Instr->getType()->isVoidTy(); + + Value *UndefVec = + IsVoidRetTy ? nullptr + : UndefValue::get(VectorType::get(Instr->getType(), VF)); + // Create a new entry in the WidenMap and initialize it to Undef or Null. + VectorParts &VecResults = WidenMap.splat(Instr, UndefVec); + + VectorParts Cond; + if (IfPredicateStore) { + assert(Instr->getParent()->getSinglePredecessor() && + "Only support single predecessor blocks"); + Cond = createEdgeMask(Instr->getParent()->getSinglePredecessor(), + Instr->getParent()); + } + + // For each vector unroll 'part': + for (unsigned Part = 0; Part < UF; ++Part) { + // For each scalar that we create: + for (unsigned Width = 0; Width < VF; ++Width) { + + // Start if-block. + Value *Cmp = nullptr; + if (IfPredicateStore) { + Cmp = Builder.CreateExtractElement(Cond[Part], Builder.getInt32(Width)); + Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Cmp, + ConstantInt::get(Cmp->getType(), 1)); + } + + Instruction *Cloned = Instr->clone(); + if (!IsVoidRetTy) + Cloned->setName(Instr->getName() + ".cloned"); + // Replace the operands of the cloned instructions with extracted scalars. + for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { + Value *Op = Params[op][Part]; + // Param is a vector. Need to extract the right lane. + if (Op->getType()->isVectorTy()) + Op = Builder.CreateExtractElement(Op, Builder.getInt32(Width)); + Cloned->setOperand(op, Op); + } + addNewMetadata(Cloned, Instr); + + // Place the cloned scalar in the new loop. + Builder.Insert(Cloned); + + // If we just cloned a new assumption, add it the assumption cache. + if (auto *II = dyn_cast(Cloned)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC->registerAssumption(II); + + // If the original scalar returns a value we need to place it in a vector + // so that future users will be able to use it. + if (!IsVoidRetTy) + VecResults[Part] = Builder.CreateInsertElement(VecResults[Part], Cloned, + Builder.getInt32(Width)); + // End if-block. + if (IfPredicateStore) + PredicatedStores.push_back( + std::make_pair(cast(Cloned), Cmp)); + } + } +} + +static Instruction *getFirstInst(Instruction *FirstInst, Value *V, + Instruction *Loc) { + if (FirstInst) + return FirstInst; + if (Instruction *I = dyn_cast(V)) + return I->getParent() == Loc->getParent() ? I : nullptr; + return nullptr; +} + +std::pair +InnerLoopVectorizer::addStrideCheck(Instruction *Loc) { + Instruction *tnullptr = nullptr; + if (!Legal->mustCheckStrides()) + return std::pair(tnullptr, tnullptr); + + IRBuilder<> ChkBuilder(Loc); + + // Emit checks. + Value *Check = nullptr; + Instruction *FirstInst = nullptr; + for (SmallPtrSet::iterator SI = Legal->strides_begin(), + SE = Legal->strides_end(); + SI != SE; ++SI) { + Value *Ptr = stripIntegerCast(*SI); + Value *C = ChkBuilder.CreateICmpNE(Ptr, ConstantInt::get(Ptr->getType(), 1), + "stride.chk"); + // Store the first instruction we create. + FirstInst = getFirstInst(FirstInst, C, Loc); + if (Check) + Check = ChkBuilder.CreateOr(Check, C); + else + Check = C; + } + + // We have to do this trickery because the IRBuilder might fold the check to a + // constant expression in which case there is no Instruction anchored in a + // the block. + LLVMContext &Ctx = Loc->getContext(); + Instruction *TheCheck = + BinaryOperator::CreateAnd(Check, ConstantInt::getTrue(Ctx)); + ChkBuilder.Insert(TheCheck, "stride.not.one"); + FirstInst = getFirstInst(FirstInst, TheCheck, Loc); + + return std::make_pair(FirstInst, TheCheck); +} + +PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start, + Value *End, Value *Step, + Instruction *DL) { + BasicBlock *Header = L->getHeader(); + BasicBlock *Latch = L->getLoopLatch(); + // As we're just creating this loop, it's possible no latch exists + // yet. If so, use the header as this will be a single block loop. + if (!Latch) + Latch = Header; + + IRBuilder<> Builder(&*Header->getFirstInsertionPt()); + setDebugLocFromInst(Builder, getDebugLocFromInstOrOperands(OldInduction)); + + auto *PredTy = VectorType::get(Builder.getInt1Ty(), VF, Scalable); + auto *AllActive = ConstantInt::getTrue(PredTy); + + auto *Induction = Builder.CreatePHI(Start->getType(), 2, "index"); + for (unsigned i = 0; i < UF; ++i) + Predicate.push_back(Builder.CreatePHI(PredTy, 2, "predicate")); + + Builder.SetInsertPoint(Latch->getTerminator()); + + // Create i+1 and fill the PHINode. + Value *Next = Builder.CreateAdd(Induction, Step, "index.next"); + Induction->addIncoming(Start, L->getLoopPreheader()); + Induction->addIncoming(Next, Latch); + + // Even though all lanes are active some code paths require a predicate. + for (unsigned i = 0; i < UF; ++i) { + Predicate[i]->addIncoming(AllActive, L->getLoopPreheader()); + Predicate[i]->addIncoming(AllActive, Latch); + } + + // Create the compare. + Value *ICmp = Builder.CreateICmpEQ(Next, End); + Builder.CreateCondBr(ICmp, L->getExitBlock(), Header); + + // Now we have two terminators. Remove the old one from the block. + Latch->getTerminator()->eraseFromParent(); + + return Induction; +} + +Value *InnerLoopVectorizer::getOrCreateTripCount(Loop *L) { + if (TripCount) + return TripCount; + + IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); + // Find the loop boundaries. + ScalarEvolution *SE = PSE.getSE(); + const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount(); + assert(BackedgeTakenCount != SE->getCouldNotCompute() && + "Invalid loop count"); + + Type *IdxTy = Legal->getWidestInductionType(); + + // The exit count might have the type of i64 while the phi is i32. This can + // happen if we have an induction variable that is sign extended before the + // compare. The only way that we get a backedge taken count is that the + // induction variable was signed and as such will not overflow. In such a case + // truncation is legal. + if (BackedgeTakenCount->getType()->getPrimitiveSizeInBits() > + IdxTy->getPrimitiveSizeInBits()) + BackedgeTakenCount = SE->getTruncateOrNoop(BackedgeTakenCount, IdxTy); + BackedgeTakenCount = SE->getNoopOrZeroExtend(BackedgeTakenCount, IdxTy); + + // Get the total trip count from the count by adding 1. + const SCEV *ExitCount = SE->getAddExpr( + BackedgeTakenCount, SE->getOne(BackedgeTakenCount->getType())); + + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + + // Expand the trip count and place the new instructions in the preheader. + // Notice that the pre-header does not change, only the loop body. + SCEVExpander Exp(*SE, DL, "induction"); + + // Count holds the overall loop count (N). + TripCount = Exp.expandCodeFor(ExitCount, ExitCount->getType(), + L->getLoopPreheader()->getTerminator()); + + if (TripCount->getType()->isPointerTy()) + TripCount = + CastInst::CreatePointerCast(TripCount, IdxTy, "exitcount.ptrcnt.to.int", + L->getLoopPreheader()->getTerminator()); + + return TripCount; +} + +Value *InnerLoopVectorizer::getOrCreateVectorTripCount(Loop *L) { + if (VectorTripCount) + return VectorTripCount; + + Value *TC = getOrCreateTripCount(L); + if (UsePredication) { + // All iterations are done by the vector body so VectorTripCount==TripCount. + VectorTripCount = TC; + return VectorTripCount; + } + + IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); + Constant *InductionStep = getInductionStep(); + + // Now we need to generate the expression for N - (N % VF), which is + // the part that the vectorized body will execute. + // The loop step is equal to the vectorization factor (num of SIMD elements) + // times the unroll factor (num of SIMD instructions). + Value *R = Builder.CreateURem(TC, InductionStep, "n.mod.vf"); + + // If there is a non-reversed interleaved group that may speculatively access + // memory out-of-bounds, we need to ensure that there will be at least one + // iteration of the scalar epilogue loop. Thus, if the step evenly divides + // the trip count, we set the remainder to be equal to the step. If the step + // does not evenly divide the trip count, no adjustment is necessary since + // there will already be scalar iterations. Note that the minimum iterations + // check ensures that N >= Step. + if (VF > 1 && !Scalable && Legal->requiresScalarEpilogue()) { + auto *IsZero = Builder.CreateICmpEQ(R, ConstantInt::get(R->getType(), 0)); + R = Builder.CreateSelect(IsZero, InductionStep, R); + } + + VectorTripCount = Builder.CreateSub(TC, R, "n.vec"); + + return VectorTripCount; +} + +Constant *InnerLoopVectorizer::getInductionStep() { + Type *Ty = Legal->getWidestInductionType(); + return ConstantExpr::getMul(getRuntimeVF(Ty), ConstantInt::get(Ty, UF)); +} + +void InnerLoopVectorizer::emitMinimumIterationCountCheck(Loop *L, + Value *MinCount, + BasicBlock *Bypass) { + Value *Count = getOrCreateTripCount(L); + BasicBlock *BB = L->getLoopPreheader(); + IRBuilder<> Builder(BB->getTerminator()); + + // Generate code to check that the loop's trip count that we computed by + // adding one to the backedge-taken count will not overflow. + Value *CheckMinIters = Builder.CreateICmpULT(Count, MinCount, + "min.iters.check"); + + BasicBlock *NewBB = + BB->splitBasicBlock(BB->getTerminator(), "min.iters.checked"); + // Update dominator tree immediately if the generated block is a + // LoopBypassBlock because SCEV expansions to generate loop bypass + // checks may query it before the current function is finished. + DT->addNewBlock(NewBB, BB); + if (L->getParentLoop()) + L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); + ReplaceInstWithInst(BB->getTerminator(), + BranchInst::Create(Bypass, NewBB, CheckMinIters)); + LoopBypassBlocks.push_back(BB); +} + +// Protect against the vector loop's induction variable overflowing. +void InnerLoopVectorizer::emitIVOverflowCheck(Loop *L, BasicBlock *Bypass) { + // There can be no overflow when the induction step is a power of 2. + assert(isPowerOf2_64(VF*UF) && "Unexpected value for VF*UF!"); + if (!Scalable) + return; + + ScalarEvolution *SE = PSE.getSE(); + Value *Count = getOrCreateTripCount(L); + BasicBlock *BB = L->getLoopPreheader(); + IRBuilder<> Builder(BB->getTerminator()); + + // Generate code to test if "%index + (VF * UL * VScale)" will overflow. + auto Ty = Count->getType(); + auto MaxUInt = Constant::getAllOnesValue(Ty); + auto RedZone = ConstantExpr::getSub(MaxUInt, getInductionStep()); + + // NOTE: This assumes the calculation of Count will not overflow, as proven + // by emitMinimumIterationCountCheck. + if (!SE->isKnownPredicate(CmpInst::ICMP_ULE, + SE->getSCEV(Count), SE->getSCEV(RedZone))) { + Value *CheckOF = Builder.CreateICmpUGT(Count, RedZone, "overflow.check"); + + BasicBlock *NewBB = + BB->splitBasicBlock(BB->getTerminator(), "overflow.checked"); + // Update dominator tree immediately if the generated block is a + // LoopBypassBlock because SCEV expansions to generate loop bypass + // checks may query it before the current function is finished. + DT->addNewBlock(NewBB, BB); + if (L->getParentLoop()) + L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); + ReplaceInstWithInst(BB->getTerminator(), + BranchInst::Create(Bypass, NewBB, CheckOF)); + LoopBypassBlocks.push_back(BB); + } +} + +void InnerLoopVectorizer::emitVectorLoopEnteredCheck(Loop *L, + BasicBlock *Bypass) { + Value *TC = getOrCreateVectorTripCount(L); + BasicBlock *BB = L->getLoopPreheader(); + IRBuilder<> Builder(BB->getTerminator()); + + // Now, compare the new count to zero. If it is zero skip the vector loop and + // jump to the scalar loop. + Value *Cmp = Builder.CreateICmpEQ(TC, Constant::getNullValue(TC->getType()), + "cmp.zero"); + + // Generate code to check that the loop's trip count that we computed by + // adding one to the backedge-taken count will not overflow. + BasicBlock *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); + // Update dominator tree immediately if the generated block is a + // LoopBypassBlock because SCEV expansions to generate loop bypass + // checks may query it before the current function is finished. + DT->addNewBlock(NewBB, BB); + if (L->getParentLoop()) + L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); + ReplaceInstWithInst(BB->getTerminator(), + BranchInst::Create(Bypass, NewBB, Cmp)); + LoopBypassBlocks.push_back(BB); +} + +void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { + BasicBlock *BB = L->getLoopPreheader(); + + // Generate the code to check that the SCEV assumptions that we made. + // We want the new basic block to start at the first instruction in a + // sequence of instructions that form a check. + SCEVExpander Exp(*PSE.getSE(), Bypass->getModule()->getDataLayout(), + "scev.check"); + Value *SCEVCheck = + Exp.expandCodeForPredicate(&PSE.getUnionPredicate(), BB->getTerminator()); + + if (auto *C = dyn_cast(SCEVCheck)) + if (C->isZero()) + return; + + // Create a new block containing the stride check. + BB->setName("vector.scevcheck"); + auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); + // Update dominator tree immediately if the generated block is a + // LoopBypassBlock because SCEV expansions to generate loop bypass + // checks may query it before the current function is finished. + DT->addNewBlock(NewBB, BB); + if (L->getParentLoop()) + L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); + ReplaceInstWithInst(BB->getTerminator(), + BranchInst::Create(Bypass, NewBB, SCEVCheck)); + LoopBypassBlocks.push_back(BB); + AddedSafetyChecks = true; +} + +void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass) { + BasicBlock *BB = L->getLoopPreheader(); + + // Generate the code that checks in runtime if arrays overlap. We put the + // checks into a separate block to make the more common case of few elements + // faster. + Instruction *FirstCheckInst; + Instruction *MemRuntimeCheck; + std::tie(FirstCheckInst, MemRuntimeCheck) = + Legal->getLAI()->addRuntimeChecks(BB->getTerminator()); + if (!MemRuntimeCheck) + return; + + // Create a new block containing the memory check. + BB->setName("vector.memcheck"); + auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); + // Update dominator tree immediately if the generated block is a + // LoopBypassBlock because SCEV expansions to generate loop bypass + // checks may query it before the current function is finished. + DT->addNewBlock(NewBB, BB); + if (L->getParentLoop()) + L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); + ReplaceInstWithInst(BB->getTerminator(), + BranchInst::Create(Bypass, NewBB, MemRuntimeCheck)); + LoopBypassBlocks.push_back(BB); + AddedSafetyChecks = true; + + // We currently don't use LoopVersioning for the actual loop cloning but we + // still use it to add the noalias metadata. + LVer = llvm::make_unique(*Legal->getLAI(), OrigLoop, LI, DT, + PSE.getSE()); + LVer->prepareNoAliasMetadata(); +} + +void InnerLoopVectorizer::createEmptyLoop() { + /* + In this function we generate a new loop. The new loop will contain + the vectorized instructions while the old loop will continue to run the + scalar remainder. + + [ ] <-- loop iteration number check. + / | + / v + | [ ] <-- vector loop bypass (may consist of multiple blocks). + | / | + | / v + || [ ] <-- vector pre header. + || | + || v + || [ ] \ + || [ ]_| <-- vector loop. + || | + | \ v + | >[ ] <--- middle-block. + | / | + | / | + | / v + -|- >[ ] <--- new preheader. + | | + | v + | [ ] \ + | [ ]_| <-- old scalar loop to handle remainder. + \ | + \ v + >[ ] <-- exit block. + ... + */ + + BasicBlock *OldBasicBlock = OrigLoop->getHeader(); + BasicBlock *VectorPH = OrigLoop->getLoopPreheader(); + BasicBlock *ExitBlock = OrigLoop->getExitBlock(); + assert(VectorPH && "Invalid loop structure"); + assert(ExitBlock && "Must have an exit block"); + + // Some loops have a single integer induction variable, while other loops + // don't. One example is c++ iterators that often have multiple pointer + // induction variables. In the code below we also support a case where we + // don't have a single induction variable. + // + // We try to obtain an induction variable from the original loop as hard + // as possible. However if we don't find one that: + // - is an integer + // - counts from zero, stepping by one + // - is the size of the widest induction variable type + // then we create a new one. + OldInduction = Legal->getPrimaryInduction(); + Type *IdxTy = Legal->getWidestInductionType(); + + // Split the single block loop into the two loop structure described above. + BasicBlock *VecBody = + VectorPH->splitBasicBlock(VectorPH->getTerminator(), "vector.body"); + BasicBlock *MiddleBlock = + VecBody->splitBasicBlock(VecBody->getTerminator(), "middle.block"); + BasicBlock *ScalarPH = + MiddleBlock->splitBasicBlock(MiddleBlock->getTerminator(), "scalar.ph"); + + // Create and register the new vector loop. + Loop *Lp = LI->AllocateLoop(); + Loop *ParentLoop = OrigLoop->getParentLoop(); + + if (ParentLoop) { + ParentLoop->addChildLoop(Lp); + ParentLoop->addBasicBlockToLoop(ScalarPH, *LI); + ParentLoop->addBasicBlockToLoop(MiddleBlock, *LI); + } else { + LI->addTopLevelLoop(Lp); + } + Lp->addBasicBlockToLoop(VecBody, *LI); + + // Find the loop boundaries. + Value *Count = getOrCreateTripCount(Lp); + Value *StartIdx = ConstantInt::get(IdxTy, 0); + Constant *Step = getInductionStep(); + + // We need to test whether the backedge-taken count is uint##_max. Adding one + // to it will cause overflow and an incorrect loop trip count in the vector + // body. In case of overflow we want to directly jump to the scalar remainder + // loop. + emitMinimumIterationCountCheck(Lp, Step, ScalarPH); + // Now, compare the new count to zero. If it is zero skip the vector loop and + // jump to the scalar loop. + emitVectorLoopEnteredCheck(Lp, ScalarPH); + // Generate the code to check any assumptions that we've made for SCEV + // expressions. + emitSCEVChecks(Lp, ScalarPH); + // Generate the code that checks in runtime if arrays overlap. We put the + // checks into a separate block to make the more common case of few elements + // faster. + emitMemRuntimeChecks(Lp, ScalarPH); + + // Generate the induction variable. + // The loop step is equal to the vectorization factor (num of SIMD elements) + // times the unroll factor (num of SIMD instructions). + Value *CountRoundDown = getOrCreateVectorTripCount(Lp); + + Induction = + createInductionVariable(Lp, StartIdx, CountRoundDown, Step, + getDebugLocFromInstOrOperands(OldInduction)); + + // We are going to resume the execution of the scalar loop. + // Go over all of the induction variables that we found and fix the + // PHIs that are left in the scalar version of the loop. + // The starting values of PHI nodes depend on the counter of the last + // iteration in the vectorized loop. + // If we come from a bypass edge then we need to start from the original + // start value. + + // This variable saves the new starting index for the scalar loop. It is used + // to test if there are any tail iterations left once the vector loop has + // completed. + LoopVectorizationLegality::InductionList::iterator I, E; + LoopVectorizationLegality::InductionList *List = Legal->getInductionVars(); + for (I = List->begin(), E = List->end(); I != E; ++I) { + PHINode *OrigPhi = I->first; + InductionDescriptor II = I->second; + + // Create phi nodes to merge from the backedge-taken check block. + PHINode *BCResumeVal = PHINode::Create( + OrigPhi->getType(), 3, "bc.resume.val", ScalarPH->getTerminator()); + Value *&EndValue = IVEndValues[OrigPhi]; + if (OrigPhi == OldInduction) { + // We know what the end value is. + EndValue = CountRoundDown; + } else { + IRBuilder<> B(LoopBypassBlocks.back()->getTerminator()); + Type *StepType = II.getStep()->getType(); + Instruction::CastOps CastOp = + CastInst::getCastOpcode(CountRoundDown, true, StepType, true); + Value *CRD = B.CreateCast(CastOp, CountRoundDown, StepType, "cast.crd"); + const DataLayout &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); + EndValue = II.transform(B, CRD, PSE.getSE(), DL); + EndValue->setName("ind.end"); + } + + // The new PHI merges the original incoming value, in case of a bypass, + // or the value at the end of the vectorized loop. + BCResumeVal->addIncoming(EndValue, MiddleBlock); + + // Fix the scalar body counter (PHI node). + unsigned BlockIdx = OrigPhi->getBasicBlockIndex(ScalarPH); + + // The old induction's phi node in the scalar body needs the truncated + // value. + for (unsigned I = 0, E = LoopBypassBlocks.size(); I != E; ++I) + BCResumeVal->addIncoming(II.getStartValue(), LoopBypassBlocks[I]); + OrigPhi->setIncomingValue(BlockIdx, BCResumeVal); + } + + // Add a check in the middle block to see if we have completed + // all of the iterations in the first vector loop. + // If (N - N%VF) == N, then we *don't* need to run the remainder. + Value *CmpN = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, Count, + CountRoundDown, "cmp.n", MiddleBlock->getTerminator()); + ReplaceInstWithInst(MiddleBlock->getTerminator(), + BranchInst::Create(ExitBlock, ScalarPH, CmpN)); + + // Get ready to start creating new instructions into the vectorized body. + Builder.SetInsertPoint(&*VecBody->getFirstInsertionPt()); + + // Save the state. + LoopVectorPreHeader = Lp->getLoopPreheader(); + LoopScalarPreHeader = ScalarPH; + LoopMiddleBlock = MiddleBlock; + LoopExitBlock = ExitBlock; + LoopVectorBody.push_back(VecBody); + VecBodyPostDom = VecBody; + LoopScalarBody = OldBasicBlock; + + // Keep all loop hints from the original loop on the vector loop (we'll + // replace the vectorizer-specific hints below). + if (MDNode *LID = OrigLoop->getLoopID()) + Lp->setLoopID(LID); + + LoopVectorizeHints Hints(Lp, true, *ORE); + Hints.setAlreadyVectorized(); +} + +void InnerLoopVectorizer::createEmptyLoopWithPredication() { + /* + In this function we generate a new loop. The new loop will contain + the vectorized instructions while the old loop will continue to run the + scalar remainder. + + v + [ ] <-- Back-edge taken count overflow check. + / \ + | [ ] <-- vector loop bypass (may consist of multiple blocks). + | / \ + | [ ] \ <-- vector pre header. + | | | + | />[ ] | + | |_[ ] | <-- vector loop. + | | | + | [ ] | <-- middle-block. + | | | + | [ ] | <-- return point from predicated loop. + | | | + |<---/ | + | [ ] <-- scalar preheader. + | | + | [ ]<\ + | [ ]_| <-- old scalar loop to handle remainder. + | | + |<-------/ + v + [ ] <-- exit block. + ... + */ + + assert(UsePredication && "predication required for this layout"); + BasicBlock *OldBasicBlock = OrigLoop->getHeader(); + BasicBlock *VectorPH = OrigLoop->getLoopPreheader(); + BasicBlock *ExitBlock = OrigLoop->getExitBlock(); + assert(VectorPH && "Invalid loop structure"); + assert(ExitBlock && "Must have an exit block"); + + // Some loops have a single integer induction variable, while other loops + // don't. One example is c++ iterators that often have multiple pointer + // induction variables. In the code below we also support a case where we + // don't have a single induction variable. + // + // We try to obtain an induction variable from the original loop as hard + // as possible. However if we don't find one that: + // - is an integer + // - counts from zero, stepping by one + // - is the size of the widest induction variable type + // then we create a new one. + OldInduction = Legal->getPrimaryInduction(); + Type *IdxTy = Legal->getWidestInductionType(); + Type *PredTy = Builder.getInt1Ty(); + Type *PredVecTy = VectorType::get(PredTy, VF, Scalable); + + // Split the single block loop into the two loop structure described above. + BasicBlock *VecBody = + VectorPH->splitBasicBlock(VectorPH->getTerminator(), "vector.body"); + BasicBlock *MiddleBlock = + VecBody->splitBasicBlock(VecBody->getTerminator(), "middle.block"); + BasicBlock *ScalarPH = + MiddleBlock->splitBasicBlock(MiddleBlock->getTerminator(), "scalar.ph"); + + // Create and register the new vector loop. + Loop *Lp = LI->AllocateLoop(); + Loop *ParentLoop = OrigLoop->getParentLoop(); + + if (ParentLoop) { + ParentLoop->addChildLoop(Lp); + ParentLoop->addBasicBlockToLoop(ScalarPH, *LI); + ParentLoop->addBasicBlockToLoop(MiddleBlock, *LI); + + } else { + LI->addTopLevelLoop(Lp); + } + Lp->addBasicBlockToLoop(VecBody, *LI); + + // Find the loop boundaries. + Value *Count = getOrCreateTripCount(Lp); + Value *StartIdx = ConstantInt::get(IdxTy, 0); + IdxEnd = Count; + + // We need to test whether the backedge-taken count is uint##_max. Adding one + // to it will cause overflow and an incorrect loop trip count in the vector + // body. In case of overflow we want to directly jump to the scalar loop. + { + ScalarEvolution *SE = PSE.getSE(); + const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount(); + Type *IdxTy = Legal->getWidestInductionType(); + + // The exit count might have the type of i64 while the phi is i32. This can + // happen if we have an induction variable that is sign extended before the + // compare. The only way that we get a backedge taken count is that the + // induction variable was signed and as such will not overflow. In such a + // case truncation is legal. + if (BackedgeTakenCount->getType()->getPrimitiveSizeInBits() > + IdxTy->getPrimitiveSizeInBits()) + BackedgeTakenCount = SE->getTruncateOrNoop(BackedgeTakenCount, IdxTy); + BackedgeTakenCount = SE->getNoopOrZeroExtend(BackedgeTakenCount, IdxTy); + + // If we know ahead of time that overflow is not possible we still plant + // the check but in a manner that is easily removable by a later pass. + APInt MaxTakenCount = SE->getUnsignedRangeMax(BackedgeTakenCount); + Constant *MinCount = ConstantInt::get(IdxTy, (MaxTakenCount + 1) == 0); + emitMinimumIterationCountCheck(Lp, MinCount, ScalarPH); + } + // Generate the code to check %index update will not overflow, thus allowing + // us to branch back based purely on the result of %predicate.next. + emitIVOverflowCheck(Lp, ScalarPH); + // Generate the code to check any assumptions that we've made for SCEV + // expressions. + emitSCEVChecks(Lp, ScalarPH); + // Generate the code that checks in runtime if arrays overlap. We put the + // checks into a separate block to make the more common case of few elements + // faster. + emitMemRuntimeChecks(Lp, ScalarPH); + + // Record the exit value of induction variables for use by fixupIVUsers. + for (auto &Entry : *Legal->getInductionVars()) { + PHINode *OrigPhi = Entry.first; + InductionDescriptor II = Entry.second; + + IRBuilder<> B(LoopBypassBlocks.back()->getTerminator()); + auto StepType = II.getStep()->getType(); + auto CastOp = CastInst::getCastOpcode(IdxEnd, true, StepType, true); + auto CRD = B.CreateCast(CastOp, IdxEnd, StepType, "cast.crd"); + const DataLayout &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); + + Value *&EndValue = IVEndValues[OrigPhi]; + EndValue = II.transform(B, CRD, PSE.getSE(), DL); + EndValue->setName("ind.end"); + } + + // *************************************************************************** + // Start of vector.ph + // *************************************************************************** + + Builder.SetInsertPoint(&*Lp->getLoopPreheader()->getTerminator()); + setDebugLocFromInst(Builder, getDebugLocFromInstOrOperands(OldInduction)); + + VectorType::ElementCount IdxEltCnt(VF, Scalable); + IdxEndV = Builder.CreateVectorSplat(IdxEltCnt, Count, "wide.end.idx"); + + VectorParts EntryPreds; + Value *RuntimeVF = getRuntimeVF(IdxTy); + Constant *StepVec = StepVector::get(IdxEndV->getType()); + + // Compute the entry predicates. + // NOTE: The vector loop is guarded by a check to ensure the calculation of + // index.next will not overflow. The vector loop's latch is based on this + // value alone, thus we can safetly add NUW flags across all lanes as they're + // only be used within the vector loop and considered poisoned upon exit. + for (unsigned i = 0; i < UF; ++i) { + Value *Step = Builder.CreateMul(RuntimeVF, ConstantInt::get(IdxTy, i)); + Value *Idx = Builder.CreateNUWAdd(StartIdx, Step); + Value *IdxSplat = Builder.CreateVectorSplat(IdxEltCnt, Idx); + Value *SV = Builder.CreateNUWAdd(IdxSplat, StepVec); + Value *Cmp = Builder.CreateICmpULT(SV, IdxEndV, "predicate.entry"); + EntryPreds.push_back(Cmp); + } + + // *************************************************************************** + // End of vector.ph + // *************************************************************************** + + // *************************************************************************** + // Start of vector.body + // *************************************************************************** + + BasicBlock *Header = Lp->getHeader(); + BasicBlock *Latch = Lp->getLoopLatch(); + // As we're just creating this loop, it's possible no latch exists + // yet. If so, use the header as this will be a single block loop. + if (!Latch) + Latch = Header; + + Builder.SetInsertPoint(&*Header->getFirstInsertionPt()); + setDebugLocFromInst(Builder, getDebugLocFromInstOrOperands(OldInduction)); + + // Generate the induction variable. + Induction = Builder.CreatePHI(IdxTy, 2, "index"); + for (unsigned i = 0; i < UF; ++i) + Predicate.push_back(Builder.CreatePHI(PredVecTy, 2, "predicate")); + + // These Phis have two incoming values, but right now we only add the + // one coming from the preheader. The other (from the loop latch block) + // will be added in 'patchLatchBranch', after everything else has been + // vectorized. This allows predicates from first-faulting loads or other + // instructions to be added in before finalizing the phi. + Induction->addIncoming(StartIdx, Lp->getLoopPreheader()); + for (unsigned i = 0; i < UF; ++i) + Predicate[i]->addIncoming(EntryPreds[i], Lp->getLoopPreheader()); + + Builder.SetInsertPoint(Latch->getTerminator()); + + // We don't yet have a condition for the branch, since it may depend on + // instructions within the loop (beyond just the trip count, if any). + // As above, this will be added in 'patchLatchBranch'. + Value *ICmp = UndefValue::get(Builder.getInt1Ty()); + LatchBranch = Builder.CreateCondBr(ICmp, Header, Lp->getExitBlock()); + // Now we have two terminators. Remove the old one from the block. + Latch->getTerminator()->eraseFromParent(); + + // *************************************************************************** + // End of vector.body + // *************************************************************************** + + // *************************************************************************** + // Start of reduction.loop.ret + // *************************************************************************** + + // The vector body processes all elements so after the reduction we are done. + Instruction *OldTerm = MiddleBlock->getTerminator(); + BranchInst::Create(ExitBlock, OldTerm); + OldTerm->eraseFromParent(); + + // *************************************************************************** + // End of reduction.loop.ret + // *************************************************************************** + + // Get ready to start creating new instructions into the vectorized body. + Builder.SetInsertPoint(&*VecBody->getFirstInsertionPt()); + + // Save the state. + LoopVectorPreHeader = Lp->getLoopPreheader(); + LoopScalarPreHeader = ScalarPH; + LoopMiddleBlock = MiddleBlock; + LoopExitBlock = ExitBlock; + LoopVectorBody.push_back(VecBody); + VecBodyPostDom = VecBody; + LoopScalarBody = OldBasicBlock; + + LoopVectorizeHints Hints(Lp, true, *ORE); + Hints.setAlreadyVectorized(); +} + +// Fix up external users of the induction variable. At this point, we are +// in LCSSA form, with all external PHIs that use the IV having one input value, +// coming from the remainder loop. We need those PHIs to also have a correct +// value for the IV when arriving directly from the middle block. +void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, + const InductionDescriptor &II, + Value *CountRoundDown, Value *EndValue, + BasicBlock *MiddleBlock) { + // There are two kinds of external IV usages - those that use the value + // computed in the last iteration (the PHI) and those that use the penultimate + // value (the value that feeds into the phi from the loop latch). + // We allow both, but they, obviously, have different values. + + assert(OrigLoop->getExitBlock() && "Expected a single exit block"); + + DenseMap MissingVals; + + // An external user of the last iteration's value should see the value that + // the remainder loop uses to initialize its own IV. + Value *PostInc = OrigPhi->getIncomingValueForBlock(OrigLoop->getLoopLatch()); + for (User *U : PostInc->users()) { + Instruction *UI = cast(U); + if (!OrigLoop->contains(UI)) { + assert(isa(UI) && "Expected LCSSA form"); + MissingVals[UI] = EndValue; + } + } + + // An external user of the penultimate value need to see EndValue - Step. + // The simplest way to get this is to recompute it from the constituent SCEVs, + // that is Start + (Step * (CRD - 1)). + for (User *U : OrigPhi->users()) { + auto *UI = cast(U); + if (!OrigLoop->contains(UI)) { + const DataLayout &DL = + OrigLoop->getHeader()->getModule()->getDataLayout(); + assert(isa(UI) && "Expected LCSSA form"); + + IRBuilder<> B(MiddleBlock->getTerminator()); + Value *CountMinusOne = B.CreateSub( + CountRoundDown, ConstantInt::get(CountRoundDown->getType(), 1)); + Value *CMO = + !II.getStep()->getType()->isIntegerTy() + ? B.CreateCast(Instruction::SIToFP, CountMinusOne, + II.getStep()->getType()) + : B.CreateSExtOrTrunc(CountMinusOne, II.getStep()->getType()); + CMO->setName("cast.cmo"); + Value *Escape = II.transform(B, CMO, PSE.getSE(), DL); + Escape->setName("ind.escape"); + MissingVals[UI] = Escape; + } + } + + for (auto &I : MissingVals) { + PHINode *PHI = cast(I.first); + // One corner case we have to handle is two IVs "chasing" each-other, + // that is %IV2 = phi [...], [ %IV1, %latch ] + // In this case, if IV1 has an external use, we need to avoid adding both + // "last value of IV1" and "penultimate value of IV2". So, verify that we + // don't already have an incoming value for the middle block. + if (PHI->getBasicBlockIndex(MiddleBlock) == -1) + PHI->addIncoming(I.second, MiddleBlock); + } +} + +namespace { +struct CSEDenseMapInfo { + static bool canHandle(const Instruction *I) { + return isa(I) || isa(I) || + isa(I) || isa(I); + } + static inline Instruction *getEmptyKey() { + return DenseMapInfo::getEmptyKey(); + } + static inline Instruction *getTombstoneKey() { + return DenseMapInfo::getTombstoneKey(); + } + static unsigned getHashValue(const Instruction *I) { + assert(canHandle(I) && "Unknown instruction!"); + return hash_combine(I->getOpcode(), hash_combine_range(I->value_op_begin(), + I->value_op_end())); + } + static bool isEqual(const Instruction *LHS, const Instruction *RHS) { + if (LHS == getEmptyKey() || RHS == getEmptyKey() || + LHS == getTombstoneKey() || RHS == getTombstoneKey()) + return LHS == RHS; + return LHS->isIdenticalTo(RHS); + } +}; +} + +///\brief Perform cse of induction variable instructions. +void InnerLoopVectorizer::CSE(SmallVector &BBs, + SmallSet &PredBlocks) { + // Perform simple cse. + SmallDenseMap CSEMap; + for (unsigned i = 0, e = BBs.size(); i != e; ++i) { + BasicBlock *BB = BBs[i]; + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { + Instruction *In = &*I++; + + if (!CSEDenseMapInfo::canHandle(In)) + continue; + + // Check if we can replace this instruction with any of the + // visited instructions. + if (Instruction *V = CSEMap.lookup(In)) { + In->replaceAllUsesWith(V); + In->eraseFromParent(); + continue; + } + + // Ignore instructions in conditional blocks. We create "if (pred) a[i] = + // ...;" blocks for predicated stores. Every second block is a predicated + // block. + if (PredBlocks.count(BBs[i])) + continue; + + // Check if we can replace this instruction with any of the + // visited instructions. + if (Instruction *V = CSEMap.lookup(In)) { + In->replaceAllUsesWith(V); + In->eraseFromParent(); + continue; + } + + CSEMap[In] = In; + } + } +} + +/// Estimate the overhead of scalarizing a value. Insert and Extract are set if +/// the result needs to be inserted and/or extracted from vectors. +static unsigned getScalarizationOverhead(Instruction *I, VectorizationFactor VF, + const TargetTransformInfo &TTI) { + if (VF.Width == 1) + return 0; + + unsigned Cost = 0; + Type *RetTy = ToVectorTy(I->getType(), VF); + if (!RetTy->isVoidTy() && + (!isa(I) || !TTI.supportsEfficientVectorElementLoadStore())) + Cost += TTI.getScalarizationOverhead(RetTy, true, false); + + VectorType::ElementCount EC(VF.Width, !VF.isFixed); + if (CallInst *CI = dyn_cast(I)) { + SmallVector Operands(CI->arg_operands()); + Cost += TTI.getOperandsScalarizationOverhead(Operands, EC); + } else if (!isa(I) || + !TTI.supportsEfficientVectorElementLoadStore()) { + SmallVector Operands(I->operand_values()); + Cost += TTI.getOperandsScalarizationOverhead(Operands, EC); + } + + return Cost; +} + +// Estimate cost of a call instruction CI if it were vectorized with factor VF. +// Return the cost of the instruction, including scalarization overhead if it's +// needed. The flag NeedToScalarize shows if the call needs to be scalarized - +// i.e. either vector version isn't available, or is too expensive. +static unsigned getVectorCallCost(CallInst *CI, VectorizationFactor VF, + const TargetTransformInfo &TTI, + const TargetLibraryInfo *TLI, + LoopVectorizationLegality &Legal, + bool &NeedToScalarize) { + if (VectorizeMemset && isa(CI)) { + auto MSI = cast (CI); + const auto Length = MSI->getLength(); + const auto IsVolatile = MSI->isVolatile(); + const auto Alignment = MSI->getDestAlignment(); + auto CL = dyn_cast(Length); + auto CLength = CL->getZExtValue(); + assert (CL && ( CLength% Alignment == 0) + && ((CLength / Alignment) <= VectorizerMemSetThreshold) + && !IsVolatile && "Invalid memset call."); + return CLength / Alignment; + } + Function *F = CI->getCalledFunction(); + StringRef FnName = CI->getCalledFunction()->getName(); + Type *ScalarRetTy = CI->getType(); + SmallVector Tys, ScalarTys; + for (auto &ArgOp : CI->arg_operands()) + ScalarTys.push_back(ArgOp->getType()); + + // Estimate cost of scalarized vector call. The source operands are assumed + // to be vectors, so we need to extract individual elements from there, + // execute VF.Width scalar calls, and then gather the result into the vector return + // value. + const unsigned ScalarCallCost = TTI.getCallInstrCost(F, ScalarRetTy, ScalarTys); + if (VF.Width == 1) + return ScalarCallCost; + + // Compute corresponding vector type for return value and arguments. + Type *RetTy = ToVectorTy(ScalarRetTy, VF); + for (auto &Op : CI->arg_operands()) { + Type *ScalarTy = Op->getType(); + if (ScalarTy->isPointerTy() && Legal.isConsecutivePtr(Op)) + Tys.push_back(ScalarTy); + else + Tys.push_back(ToVectorTy(ScalarTy, VF)); + } + + if (!VF.isFixed) { + IRBuilder<> Builder(CI); + Type *PredTy = Builder.getInt1Ty(); + Tys.push_back(ToVectorTy(PredTy, VF)); + } + + // Compute costs of unpacking argument values for the scalar calls and + // packing the return values to a vector. + unsigned ScalarizationCost = getScalarizationOverhead(CI, VF, TTI); + for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) + ScalarizationCost += getScalarizationOverhead(CI, VF, TTI); + + unsigned Cost = ScalarCallCost * VF.Width + ScalarizationCost; + + // If we can't emit a vector call for this function, then the currently found + // cost is the cost we need to return. + NeedToScalarize = true; + FunctionType *FTy = FunctionType::get(RetTy, Tys, false); + + + // NOTE: Linking VF.isFixed to Masked is bogus, but at this point we + // don't have anymore information. + bool Masked = EnableVectorPredication && !VF.isFixed; + VectorType::ElementCount EC(VF.Width, !VF.isFixed); + if (TLI && (TLI->getVectorizedFunction(FnName, EC, Masked, FTy) != "")) + return ScalarCallCost; + + if (!TLI || !TLI->isFunctionVectorizable(FnName, EC, Masked, FTy) || + CI->isNoBuiltin()) + return Cost; + + // If the corresponding vector cost is cheaper, return its cost. + unsigned VectorCallCost = TTI.getCallInstrCost(nullptr, RetTy, Tys); + if (VectorCallCost < Cost) { + NeedToScalarize = false; + return VectorCallCost; + } + return Cost; +} + +// Estimate cost of an intrinsic call instruction CI if it were vectorized with +// factor VF. Return the cost of the instruction, including scalarization +// overhead if it's needed. +static unsigned getVectorIntrinsicCost(CallInst *CI, VectorizationFactor VF, + const TargetTransformInfo &TTI, + const TargetLibraryInfo *TLI) { + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); + assert(ID && "Expected intrinsic call!"); + + Type *RetTy = ToVectorTy(CI->getType(), VF); + SmallVector Tys; + for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) + Tys.push_back(ToVectorTy(CI->getArgOperand(i)->getType(), VF)); + + FastMathFlags FMF; + if (auto *FPMO = dyn_cast(CI)) + FMF = FPMO->getFastMathFlags(); + + return TTI.getIntrinsicInstrCost(ID, RetTy, Tys, FMF); +} + +static Type *smallestIntegerVectorType(Type *T1, Type *T2) { + IntegerType *I1 = cast(T1->getVectorElementType()); + IntegerType *I2 = cast(T2->getVectorElementType()); + return I1->getBitWidth() < I2->getBitWidth() ? T1 : T2; +} +static Type *largestIntegerVectorType(Type *T1, Type *T2) { + IntegerType *I1 = cast(T1->getVectorElementType()); + IntegerType *I2 = cast(T2->getVectorElementType()); + return I1->getBitWidth() > I2->getBitWidth() ? T1 : T2; +} + +void InnerLoopVectorizer::truncateToMinimalBitwidths() { + // For every instruction `I` in MinBWs, truncate the operands, create a + // truncated version of `I` and reextend its result. InstCombine runs + // later and will remove any ext/trunc pairs. + // + SmallPtrSet Erased; + for (auto &KV : MinBWs) { + VectorParts &Parts = WidenMap.get(KV.first); + for (Value *&I : Parts) { + if (Erased.count(I) || I->use_empty()) + continue; + auto *OriginalTy = cast(I->getType()); + Type *ScalarTruncatedTy = + IntegerType::get(OriginalTy->getContext(), KV.second); + Type *TruncatedTy = VectorType::get(ScalarTruncatedTy, + OriginalTy->getElementCount()); + if (TruncatedTy == OriginalTy) + continue; + + if (!isa(I)) + continue; + + IRBuilder<> B(cast(I)); + auto ShrinkOperand = [&](Value *V) -> Value * { + if (auto *ZI = dyn_cast(V)) + if (ZI->getSrcTy() == TruncatedTy) + return ZI->getOperand(0); + return B.CreateZExtOrTrunc(V, TruncatedTy); + }; + + // The actual instruction modification depends on the instruction type, + // unfortunately. + Value *NewI = nullptr; + if (BinaryOperator *BO = dyn_cast(I)) { + NewI = B.CreateBinOp(BO->getOpcode(), ShrinkOperand(BO->getOperand(0)), + ShrinkOperand(BO->getOperand(1))); + cast(NewI)->copyIRFlags(I); + } else if (ICmpInst *CI = dyn_cast(I)) { + NewI = + B.CreateICmp(CI->getPredicate(), ShrinkOperand(CI->getOperand(0)), + ShrinkOperand(CI->getOperand(1))); + } else if (SelectInst *SI = dyn_cast(I)) { + NewI = B.CreateSelect(SI->getCondition(), + ShrinkOperand(SI->getTrueValue()), + ShrinkOperand(SI->getFalseValue())); + } else if (CastInst *CI = dyn_cast(I)) { + switch (CI->getOpcode()) { + default: + llvm_unreachable("Unhandled cast!"); + case Instruction::Trunc: + NewI = ShrinkOperand(CI->getOperand(0)); + break; + case Instruction::SExt: + NewI = B.CreateSExtOrTrunc( + CI->getOperand(0), + smallestIntegerVectorType(OriginalTy, TruncatedTy)); + break; + case Instruction::ZExt: + NewI = B.CreateZExtOrTrunc( + CI->getOperand(0), + smallestIntegerVectorType(OriginalTy, TruncatedTy)); + break; + } + } else if (ShuffleVectorInst *SI = dyn_cast(I)) { + auto VTy0 = cast(SI->getOperand(0)->getType()); + auto Elements0 = VTy0->getElementCount(); + auto *O0 = B.CreateZExtOrTrunc( + SI->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements0)); + auto VTy1 = cast(SI->getOperand(1)->getType()); + auto Elements1 = VTy1->getElementCount(); + auto *O1 = B.CreateZExtOrTrunc( + SI->getOperand(1), VectorType::get(ScalarTruncatedTy, Elements1)); + + NewI = B.CreateShuffleVector(O0, O1, SI->getMask()); + } else if (isa(I) || isa(I)) { + // Don't do anything with the operands, just extend the result. + continue; + } else if (auto *IE = dyn_cast(I)) { + auto Elements = IE->getOperand(0)->getType()->getVectorNumElements(); + auto *O0 = B.CreateZExtOrTrunc( + IE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements)); + auto *O1 = B.CreateZExtOrTrunc(IE->getOperand(1), ScalarTruncatedTy); + NewI = B.CreateInsertElement(O0, O1, IE->getOperand(2)); + } else if (auto *EE = dyn_cast(I)) { + auto Elements = EE->getOperand(0)->getType()->getVectorNumElements(); + auto *O0 = B.CreateZExtOrTrunc( + EE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements)); + NewI = B.CreateExtractElement(O0, EE->getOperand(2)); + } else { + llvm_unreachable("Unhandled instruction type!"); + } + + // Lastly, extend the result. + NewI->takeName(cast(I)); + Value *Res = B.CreateZExtOrTrunc(NewI, OriginalTy); + I->replaceAllUsesWith(Res); + cast(I)->eraseFromParent(); + Erased.insert(I); + I = Res; + } + } + + // We'll have created a bunch of ZExts that are now parentless. Clean up. + for (auto &KV : MinBWs) { + VectorParts &Parts = WidenMap.get(KV.first); + for (Value *&I : Parts) { + ZExtInst *Inst = dyn_cast(I); + if (Inst && Inst->use_empty()) { + Value *NewI = Inst->getOperand(0); + Inst->eraseFromParent(); + I = NewI; + } + } + } +} + +void InnerLoopVectorizer::vectorizeLoop() { + //===------------------------------------------------===// + // + // Notice: any optimization or new instruction that go + // into the code below should be also be implemented in + // the cost-model. + // + //===------------------------------------------------===// + Constant *Zero = Builder.getInt32(0); + + // In order to support recurrences we need to be able to vectorize Phi nodes. + // Phi nodes have cycles, so we need to vectorize them in two stages. First, + // we create a new vector PHI node with no incoming edges. We use this value + // when we vectorize all of the instructions that use the PHI. Next, after + // all of the instructions in the block are complete we add the new incoming + // edges to the PHI. At this point all of the instructions in the basic block + // are vectorized, so we can use them to construct the PHI. + PhiVector PHIsToFix; + + // Move instructions to handle first-order recurrences. + DenseMap SinkAfter = Legal->getSinkAfter(); + for (auto &Entry : SinkAfter) { + Entry.first->removeFromParent(); + Entry.first->insertAfter(Entry.second); + LLVM_DEBUG(dbgs() << "Sinking" << *Entry.first << " after" << *Entry.second + << " to vectorize a 1st order recurrence.\n"); + } + + // Scan the loop in a topological order to ensure that defs are vectorized + // before users. + LoopBlocksDFS DFS(OrigLoop); + DFS.perform(LI); + + // Vectorize all of the blocks in the original loop. + for (LoopBlocksDFS::RPOIterator bb = DFS.beginRPO(), be = DFS.endRPO(); + bb != be; ++bb) + vectorizeBlockInLoop(*bb, &PHIsToFix); + + // When using predication not all elements will be modified during the current + // iteration and so we must iterate through the reduction variables selecting + // between the original and new values for each element. + if (UsePredication) { + for (auto *RdxPhi : PHIsToFix) { + assert(RdxPhi && "Unable to recover vectorized PHI"); + + if (Legal->isFirstOrderRecurrence(RdxPhi)) + continue; + + // Find the reduction variable descriptor. + assert(Legal->getReductionVars()->count(RdxPhi) && + "Unable to find the reduction variable"); + RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[RdxPhi]; + + VectorParts &VecRdxPhi = WidenMap.get(RdxPhi); + Value * LoopExitInstr = RdxDesc.getLoopExitInstr(); + VectorParts &VectorExit = getVectorValue(LoopExitInstr); + + for (unsigned Part = 0; Part < UF; ++Part) { + if (!RdxDesc.isOrdered()) { + Instruction *Merge = SelectInst::Create(Predicate[Part], + VectorExit[Part], + VecRdxPhi[Part]); + Merge->insertAfter(cast(VectorExit[Part])); + VectorExit[Part] = Merge; + } + } + } + } + + // Insert truncates and extends for any truncated instructions as hints to + // InstCombine. + if (VF > 1) + truncateToMinimalBitwidths(); + + // At this point every instruction in the original loop is widened to a + // vector form. Now we need to fix the recurrences in PHIsToFix. These PHI + // nodes are currently empty because we did not want to introduce cycles. + // This is the second stage of vectorizing recurrences. + for (PHINode *Phi : PHIsToFix) { + assert(Phi && "Unable to recover vectorized PHI"); + + // Handle first-order recurrences that need to be fixed. + if (Legal->isFirstOrderRecurrence(Phi)) { + fixFirstOrderRecurrence(Phi); + continue; + } + + // If the phi node is not a first-order recurrence, it must be a reduction. + // Get it's reduction variable descriptor. + assert(Legal->isReductionVariable(Phi) && + "Unable to find the reduction variable"); + RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[Phi]; + + RecurrenceDescriptor::RecurrenceKind RK = RdxDesc.getRecurrenceKind(); + TrackingVH ReductionStartValue = RdxDesc.getRecurrenceStartValue(); + Instruction *LoopExitInst = RdxDesc.getLoopExitInstr(); + RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind = + RdxDesc.getMinMaxRecurrenceKind(); + setDebugLocFromInst(Builder, ReductionStartValue); + + // We need to generate a reduction vector from the incoming scalar. + // To do so, we need to generate the 'identity' vector and override + // one of the elements with the incoming scalar reduction. We need + // to do it in the vector-loop preheader. + if (UsePredication) + Builder.SetInsertPoint(LoopBypassBlocks[0]->getTerminator()); + else + Builder.SetInsertPoint(LoopBypassBlocks[1]->getTerminator()); + + // This is the vector-clone of the value that leaves the loop. + VectorParts &VectorExit = getVectorValue(LoopExitInst); + Type *VecTy = VectorExit[0]->getType(); + + // Find the reduction identity variable. Zero for addition, or, xor, + // one for multiplication, -1 for And. + Value *Identity; + Value *VectorStart; + if (RK == RecurrenceDescriptor::RK_IntegerMinMax || + RK == RecurrenceDescriptor::RK_FloatMinMax || + RK == RecurrenceDescriptor::RK_ConstSelectICmp || + RK == RecurrenceDescriptor::RK_ConstSelectFCmp) { + // MinMax reduction have the start value as their identify. + if (VF == 1) { + VectorStart = Identity = ReductionStartValue; + } else { + const char *Ident = (RK == RecurrenceDescriptor::RK_ConstSelectICmp || + RK == RecurrenceDescriptor::RK_ConstSelectFCmp) ? + "intcond.ident" : "minmax.ident"; + VectorStart = Identity = + Builder.CreateVectorSplat({VF, Scalable}, + RdxDesc.getRecurrenceStartValue(), + Ident); + } + } else { + // Handle other reduction kinds: + Constant *Iden = RecurrenceDescriptor::getRecurrenceIdentity( + RK, VecTy->getScalarType()); + if (VF == 1) { + Identity = Iden; + // This vector is the Identity vector where the first element is the + // incoming scalar reduction. + VectorStart = ReductionStartValue; + } else { + Identity = ConstantVector::getSplat({VF, Scalable}, Iden); + + // This vector is the Identity vector where the first element is the + // incoming scalar reduction. + VectorStart = + Builder.CreateInsertElement(Identity, ReductionStartValue, Zero); + } + } + + // Fix the vector-loop phi. + + // Reductions do not have to start at zero. They can start with + // any loop invariant values. + VectorParts &VecRdxPhi = WidenMap.get(Phi); + BasicBlock *Latch = OrigLoop->getLoopLatch(); + Value *LoopVal = Phi->getIncomingValueForBlock(Latch); + VectorParts &Val = getVectorValue(LoopVal); + + for (unsigned part = 0; part < UF; ++part) { + // Only add the reduction start value to the first unroll part. + Value *StartVal = (part == 0) ? VectorStart : Identity; + + Value *NewVal = Val[part]; + if (RdxDesc.isOrdered() && VF > 1) { + StartVal = Builder.CreateExtractElement(StartVal, + Builder.getInt32(0)); + NewVal = Val[UF-1]; + } + + cast(VecRdxPhi[part]) + ->addIncoming(StartVal, LoopVectorPreHeader); + cast(VecRdxPhi[part]) + ->addIncoming(NewVal, LoopVectorBody.back()); + } + + // Before each round, move the insertion point right between + // the PHIs and the values we are going to write. + // This allows us to write both PHINodes and the extractelement + // instructions. + Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); + + VectorParts RdxParts = getVectorValue(LoopExitInst); + setDebugLocFromInst(Builder, LoopExitInst); + + // If the vector reduction can be performed in a smaller type, we truncate + // then extend the loop exit value to enable InstCombine to evaluate the + // entire expression in the smaller type. + if (VF > 1 && Phi->getType() != RdxDesc.getRecurrenceType()) { + Type *RdxVecTy = + VectorType::get(RdxDesc.getRecurrenceType(), VF, Scalable); + Builder.SetInsertPoint(LoopVectorBody.back()->getTerminator()); + for (unsigned part = 0; part < UF; ++part) { + Value *Trunc = Builder.CreateTrunc(RdxParts[part], RdxVecTy); + Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy) + : Builder.CreateZExt(Trunc, VecTy); + for (Value::user_iterator UI = RdxParts[part]->user_begin(); + UI != RdxParts[part]->user_end();) + if (*UI != Trunc) { + (*UI++)->replaceUsesOfWith(RdxParts[part], Extnd); + RdxParts[part] = Extnd; + } else { + ++UI; + } + } + Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); + for (unsigned part = 0; part < UF; ++part) + RdxParts[part] = Builder.CreateTrunc(RdxParts[part], RdxVecTy); + } + + // Reduce all of the unrolled parts into a single vector. + Value *ReducedPartRdx = RdxParts[0]; + unsigned Op = RecurrenceDescriptor::getRecurrenceBinOp(RK); + setDebugLocFromInst(Builder, ReducedPartRdx); + if (!RdxDesc.isOrdered()) { + for (unsigned part = 1; part < UF; ++part) { + if (Op != Instruction::ICmp && Op != Instruction::FCmp) + // Floating point operations had to be 'fast' to enable the reduction. + ReducedPartRdx = addFastMathFlag( + Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxParts[part], + ReducedPartRdx, "bin.rdx")); + else + ReducedPartRdx = RecurrenceDescriptor::createMinMaxOp( + Builder, MinMaxKind, ReducedPartRdx, RdxParts[part]); + } + } else { + // for ordered reduction get the result of the last unrolled + // instruction + ReducedPartRdx=RdxParts[UF-1]; + } + + if ((VF > 1) && !isScalable() && !RdxDesc.isOrdered()) { + // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles + // and vector ops, reducing the set of values being computed by half each + // round. + assert(isPowerOf2_32(VF) && + "Reduction emission only supported for pow2 vectors!"); + Value *TmpVec = ReducedPartRdx; + SmallVector ShuffleMask(VF, nullptr); + for (unsigned i = VF; i != 1; i >>= 1) { + // Move the upper half of the vector to the lower half. + for (unsigned j = 0; j != i / 2; ++j) + ShuffleMask[j] = Builder.getInt32(i / 2 + j); + + // Fill the rest of the mask with undef. + std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), + UndefValue::get(Builder.getInt32Ty())); + + Value *Shuf = Builder.CreateShuffleVector( + TmpVec, UndefValue::get(TmpVec->getType()), + ConstantVector::get(ShuffleMask), "rdx.shuf"); + + if (Op != Instruction::ICmp && Op != Instruction::FCmp) { + // Floating point operations had to be 'fast' to enable the reduction. + TmpVec = addFastMathFlag(Builder.CreateBinOp( + (Instruction::BinaryOps)Op, TmpVec, Shuf, "bin.rdx")); + } + else + TmpVec = RecurrenceDescriptor::createMinMaxOp(Builder, MinMaxKind, + TmpVec, Shuf); + } + + // The result is in the first element of the vector. + ReducedPartRdx = + Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); + } + + // Compute vector reduction for scalable vectors + if ((VF > 1) && isScalable() && !RdxDesc.isOrdered()) { + bool NoNaN = Legal->hasNoNaNAttr(); + ReducedPartRdx = + createTargetReduction(Builder, TTI, RdxDesc, ReducedPartRdx, NoNaN); + } + + if (VF > 1) { + // If the reduction can be performed in a smaller type, we need to extend + // the reduction to the wider type before we branch to the original loop. + if (Phi->getType() != RdxDesc.getRecurrenceType()) + ReducedPartRdx = + RdxDesc.isSigned() + ? Builder.CreateSExt(ReducedPartRdx, Phi->getType()) + : Builder.CreateZExt(ReducedPartRdx, Phi->getType()); + } + + // Create a phi node that merges control-flow from the backedge-taken check + // block and the middle block. + PHINode *BCBlockPhi = PHINode::Create(Phi->getType(), 2, "bc.merge.rdx", + LoopScalarPreHeader->getTerminator()); + for (unsigned I = 0, E = LoopBypassBlocks.size(); I != E; ++I) + BCBlockPhi->addIncoming(ReductionStartValue, LoopBypassBlocks[I]); + + // When using predication the vector loop performs all iterations. + if (!UsePredication) + BCBlockPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock); + + // If there were stores of the reduction value to a uniform memory address + // inside the loop, create the final store here. + if (StoreInst *SI = RdxDesc.IntermediateStore) { + Builder.SetInsertPoint(LoopMiddleBlock->getTerminator()); + StoreInst *NewSI = Builder.CreateStore(ReducedPartRdx, + SI->getPointerOperand()); + propagateMetadata(NewSI, SI); + + // If the reduction value is used in other places, + // then let the code below create PHI's for that. + } + + // Now, we need to fix the users of the reduction variable + // inside and outside of the scalar remainder loop. + // We know that the loop is in LCSSA form. We need to update the + // PHI nodes in the exit blocks. + for (BasicBlock::iterator LEI = LoopExitBlock->begin(), + LEE = LoopExitBlock->end(); + LEI != LEE; ++LEI) { + PHINode *LCSSAPhi = dyn_cast(LEI); + if (!LCSSAPhi) + break; + + // All PHINodes need to have a single entry edge, or two if + // we already fixed them. + assert(LCSSAPhi->getNumIncomingValues() < 3 && "Invalid LCSSA PHI"); + + // We found our reduction value exit-PHI. Update it with the + // incoming bypass edge. + if (LCSSAPhi->getIncomingValue(0) == LoopExitInst) { + // Add an edge coming from the bypass. + LCSSAPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock); + } + } // end of the LCSSA phi scan. + + // Fix the scalar loop reduction variable with the incoming reduction sum + // from the vector body and from the backedge value. + int IncomingEdgeBlockIdx = + Phi->getBasicBlockIndex(OrigLoop->getLoopLatch()); + assert(IncomingEdgeBlockIdx >= 0 && "Invalid block index"); + // Pick the other block. + int SelfEdgeBlockIdx = (IncomingEdgeBlockIdx ? 0 : 1); + Phi->setIncomingValue(SelfEdgeBlockIdx, BCBlockPhi); + Phi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst); + } // end of for each Phi in PHIsToFix. + + // Make sure DomTree is updated. + updateAnalysis(); + + // Fix-up external users of the induction variables. + for (auto &Entry : *Legal->getInductionVars()) + fixupIVUsers(Entry.first, Entry.second, + getOrCreateVectorTripCount(LI->getLoopFor(LoopVectorBody.front())), + IVEndValues[Entry.first], LoopMiddleBlock); + + fixLCSSAPHIs(); + + // Predicate any stores. + for (auto KV : PredicatedStores) { + BasicBlock::iterator I(KV.first); + auto *BB = SplitBlock(I->getParent(), &*std::next(I), DT, LI); + auto *T = SplitBlockAndInsertIfThen(KV.second, &*I, /*Unreachable=*/false, + /*BranchWeights=*/nullptr, DT, LI); + I->moveBefore(T); + I->getParent()->setName("pred.store.if"); + BB->setName("pred.store.continue"); + } + LLVM_DEBUG(assert(DT->verify(DominatorTree::VerificationLevel::Fast))); + // Remove redundant induction instructions. + CSE(LoopVectorBody, PredicatedBlocks); +} + +void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) { + + // This is the second phase of vectorizing first-order recurrences. An + // overview of the transformation is described below. Suppose we have the + // following loop. + // + // for (int i = 0; i < n; ++i) + // b[i] = a[i] - a[i - 1]; + // + // There is a first-order recurrence on "a". For this loop, the shorthand + // scalar IR looks like: + // + // scalar.ph: + // s_init = a[-1] + // br scalar.body + // + // scalar.body: + // i = phi [0, scalar.ph], [i+1, scalar.body] + // s1 = phi [s_init, scalar.ph], [s2, scalar.body] + // s2 = a[i] + // b[i] = s2 - s1 + // br cond, scalar.body, ... + // + // In this example, s1 is a recurrence because it's value depends on the + // previous iteration. In the first phase of vectorization, we created a + // temporary value for s1. We now complete the vectorization and produce the + // shorthand vector IR shown below (for VF = 4, UF = 1). + // + // vector.ph: + // v_init = vector(..., ..., ..., a[-1]) + // br vector.body + // + // vector.body + // i = phi [0, vector.ph], [i+4, vector.body] + // v1 = phi [v_init, vector.ph], [v2, vector.body] + // v2 = a[i, i+1, i+2, i+3]; + // v3 = vector(v1(3), v2(0, 1, 2)) + // b[i, i+1, i+2, i+3] = v2 - v3 + // br cond, vector.body, middle.block + // + // middle.block: + // x = v2(3) + // br scalar.ph + // + // scalar.ph: + // s_init = phi [x, middle.block], [a[-1], otherwise] + // br scalar.body + // + // After execution completes the vector loop, we extract the next value of + // the recurrence (x) to use as the initial value in the scalar loop. + + // Get the original loop preheader and single loop latch. + auto *Preheader = OrigLoop->getLoopPreheader(); + auto *Latch = OrigLoop->getLoopLatch(); + + // Get the initial and previous values of the scalar recurrence. + auto *ScalarInit = Phi->getIncomingValueForBlock(Preheader); + auto *Previous = Phi->getIncomingValueForBlock(Latch); + + auto *IdxTy = Builder.getInt32Ty(); + auto *RuntimeVF = getRuntimeVF(IdxTy); + auto *One = ConstantInt::get(IdxTy, 1); + auto *LastIdx = Builder.CreateBinOp(Instruction::Sub, RuntimeVF, One); + + // Create a vector from the initial value. + auto *VectorInit = ScalarInit; + if (VF > 1) { + Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); + VectorInit = Builder.CreateInsertElement( + UndefValue::get(VectorType::get(VectorInit->getType(), VF, Scalable)), + VectorInit, LastIdx, "vector.recur.init"); + } + + // We constructed a temporary phi node in the first phase of vectorization. + // This phi node will eventually be deleted. + auto &PhiParts = getVectorValue(Phi); + Builder.SetInsertPoint(cast(PhiParts[0])); + + // Create a phi node for the new recurrence. The current value will either be + // the initial value inserted into a vector or loop-varying vector value. + auto *VecPhi = Builder.CreatePHI(VectorInit->getType(), 2, "vector.recur"); + VecPhi->addIncoming(VectorInit, LoopVectorPreHeader); + + // Get the vectorized previous value. We ensured the previous values was an + // instruction when detecting the recurrence. + auto &PreviousParts = getVectorValue(Previous); + + // Set the insertion point to be after this instruction. We ensured the + // previous value dominated all uses of the phi when detecting the + // recurrence. + Builder.SetInsertPoint( + &*++BasicBlock::iterator(cast(PreviousParts[UF - 1]))); + + // We will construct a vector for the recurrence by combining the values for + // the current and previous iterations. This is the required shuffle mask. + auto *ShuffleMask = Builder.CreateSeriesVector({VF, Scalable}, + LastIdx, One); + + // The vector from which to take the initial value for the current iteration + // (actual or unrolled). Initially, this is the vector phi node. + Value *Incoming = VecPhi; + + // Shuffle the current and previous vector and update the vector parts. + for (unsigned Part = 0; Part < UF; ++Part) { + auto *Shuffle = + VF > 1 + ? Builder.CreateShuffleVector(Incoming, PreviousParts[Part], + ShuffleMask) + : Incoming; + PhiParts[Part]->replaceAllUsesWith(Shuffle); + cast(PhiParts[Part])->eraseFromParent(); + PhiParts[Part] = Shuffle; + Incoming = PreviousParts[Part]; + } + + // Fix the latch value of the new recurrence in the vector loop. + VecPhi->addIncoming(Incoming, + LI->getLoopFor(LoopVectorBody[0])->getLoopLatch()); + + // Extract the last vector element in the middle block. This will be the + // initial value for the recurrence when jumping to the scalar loop. + auto *Extract = Incoming; + if (VF > 1) { + Builder.SetInsertPoint(LoopMiddleBlock->getTerminator()); + Extract = Builder.CreateExtractElement(Extract, LastIdx, + "vector.recur.extract"); + } + + // Fix the initial value of the original recurrence in the scalar loop. + Builder.SetInsertPoint(&*LoopScalarPreHeader->begin()); + auto *Start = Builder.CreatePHI(Phi->getType(), 2, "scalar.recur.init"); + for (auto *BB : predecessors(LoopScalarPreHeader)) { + auto *Incoming = BB == LoopMiddleBlock ? Extract : ScalarInit; + Start->addIncoming(Incoming, BB); + } + + Phi->setIncomingValue(Phi->getBasicBlockIndex(LoopScalarPreHeader), Start); + Phi->setName("scalar.recur"); + + // Finally, fix users of the recurrence outside the loop. The users will need + // either the last value of the scalar recurrence or the last value of the + // vector recurrence we extracted in the middle block. Since the loop is in + // LCSSA form, we just need to find the phi node for the original scalar + // recurrence in the exit block, and then add an edge for the middle block. + for (auto &I : *LoopExitBlock) { + auto *LCSSAPhi = dyn_cast(&I); + if (!LCSSAPhi) + break; + if (LCSSAPhi->getIncomingValue(0) == Phi) { + LCSSAPhi->addIncoming(Extract, LoopMiddleBlock); + break; + } + } +} + +void InnerLoopVectorizer::fixLCSSAPHIs() { + for (BasicBlock::iterator LEI = LoopExitBlock->begin(), + LEE = LoopExitBlock->end(); + LEI != LEE; ++LEI) { + PHINode *LCSSAPhi = dyn_cast(LEI); + if (!LCSSAPhi) + break; + if (LCSSAPhi->getNumIncomingValues() == 1) + LCSSAPhi->addIncoming(UndefValue::get(LCSSAPhi->getType()), + LoopMiddleBlock); + } +} + +InnerLoopVectorizer::VectorParts +InnerLoopVectorizer::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) { + assert(std::find(pred_begin(Dst), pred_end(Dst), Src) != pred_end(Dst) && + "Invalid edge"); + + // Look for cached value. + std::pair Edge(Src, Dst); + EdgeMaskCache::iterator ECEntryIt = MaskCache.find(Edge); + if (ECEntryIt != MaskCache.end()) + return ECEntryIt->second; + + VectorParts SrcMask = createBlockInMask(Src); + + // The terminator has to be a branch inst! + BranchInst *BI = dyn_cast(Src->getTerminator()); + assert(BI && "Unexpected terminator found"); + + if (BI->isConditional()) { + VectorParts EdgeMask = getVectorValue(BI->getCondition()); + + if (BI->getSuccessor(0) != Dst) + for (unsigned part = 0; part < UF; ++part) + EdgeMask[part] = Builder.CreateNot(EdgeMask[part]); + + for (unsigned part = 0; part < UF; ++part) + EdgeMask[part] = Builder.CreateAnd(EdgeMask[part], SrcMask[part]); + + MaskCache[Edge] = EdgeMask; + return EdgeMask; + } + + MaskCache[Edge] = SrcMask; + return SrcMask; +} + +InnerLoopVectorizer::VectorParts +InnerLoopVectorizer::createBlockInMask(BasicBlock *BB) { + assert(OrigLoop->contains(BB) && "Block is not a part of a loop"); + + // Loop incoming mask is all-one. + if (OrigLoop->getHeader() == BB) { + Value *C = ConstantInt::get(IntegerType::getInt1Ty(BB->getContext()), 1); + return getVectorValue(C); + } + + // This is the block mask. We OR all incoming edges, and with zero. + Value *Zero = ConstantInt::get(IntegerType::getInt1Ty(BB->getContext()), 0); + VectorParts BlockMask = getVectorValue(Zero); + + // For each pred: + for (pred_iterator it = pred_begin(BB), e = pred_end(BB); it != e; ++it) { + VectorParts EM = createEdgeMask(*it, BB); + for (unsigned part = 0; part < UF; ++part) + BlockMask[part] = Builder.CreateOr(BlockMask[part], EM[part]); + } + + return BlockMask; +} + +void InnerLoopVectorizer::widenPHIInstruction( + Instruction *PN, InnerLoopVectorizer::VectorParts &Entry, unsigned UF, + unsigned VF, PhiVector *PV) { + PHINode *P = cast(PN); + // Handle recurrences. + if (Legal->isReductionVariable(P) || Legal->isFirstOrderRecurrence(P)) { + for (unsigned part = 0; part < UF; ++part) { + // This is phase one of vectorizing PHIs. + RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[P]; + + Type *VecTy = (VF == 1 || RdxDesc.isOrdered()) + ? PN->getType() + : VectorType::get(PN->getType(), VF, Scalable); + Entry[part] = PHINode::Create(VecTy, 2, "vec.phi", + &*LoopVectorBody[0]->getFirstInsertionPt()); + } + PV->push_back(P); + return; + } + + setDebugLocFromInst(Builder, P); + // Check for PHI nodes that are lowered to vector selects. + if (P->getParent() != OrigLoop->getHeader()) { + // We know that all PHIs in non-header blocks are converted into + // selects, so we don't have to worry about the insertion order and we + // can just use the builder. + // At this point we generate the predication tree. There may be + // duplications since this is a simple recursive scan, but future + // optimizations will clean it up. + + unsigned NumIncoming = P->getNumIncomingValues(); + + // If the value is an exit value of a strictly ordered reduction, + // skip this PHI node since the inputs to the reduction, as well as + // the reduction itself, will already have been predicated. + for (auto &Reduction : *Legal->getReductionVars()) { + RecurrenceDescriptor DS = Reduction.second; + auto LoopExitInstr = DS.getLoopExitInstr(); + if (LoopExitInstr == PN && DS.isOrdered()) { + Value *V = DS.getUnsafeAlgebraInst(); + for (unsigned part = 0; part < UF; ++part) + Entry[part] = getVectorValue(V)[part]; + return; + } + } + + // Generate a sequence of selects of the form: + // SELECT(Mask3, In3, + // SELECT(Mask2, In2, + // ( ...))) + for (unsigned In = 0; In < NumIncoming; In++) { + VectorParts Cond = + createEdgeMask(P->getIncomingBlock(In), P->getParent()); + VectorParts &In0 = getVectorValue(P->getIncomingValue(In)); + + for (unsigned part = 0; part < UF; ++part) { + // We might have single edge PHIs (blocks) - use an identity + // 'select' for the first PHI operand. + if (In == 0) + Entry[part] = Builder.CreateSelect(Cond[part], In0[part], In0[part]); + else + // Select between the current value and the previous incoming edge + // based on the incoming mask. + Entry[part] = Builder.CreateSelect(Cond[part], In0[part], Entry[part], + "predphi"); + } + } + return; + } + + // This PHINode must be an induction variable. + // Make sure that we know about it. + assert(Legal->getInductionVars()->count(P) && "Not an induction variable"); + + InductionDescriptor II = Legal->getInductionVars()->lookup(P); + const DataLayout &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); + + // FIXME: The newly created binary instructions should contain nsw/nuw flags, + // which can be found from the original scalar operations. + switch (II.getKind()) { + case InductionDescriptor::IK_NoInduction: + llvm_unreachable("Unknown induction"); + case InductionDescriptor::IK_IntInduction: { + Type *PhiTy = P->getType(); + assert(P->getType() == II.getStartValue()->getType() && "Types must match"); + // Handle other induction variables that are now based on the + // canonical one. + Value *V = Induction; + if (P != OldInduction || VF == 1) { + // Handle other induction variables that are now based on the + // canonical one. + if (P != OldInduction) { + V = Builder.CreateSExtOrTrunc(Induction, PhiTy); + V = II.transform(Builder, V, PSE.getSE(), DL); + V->setName("offset.idx"); + } + Value *Broadcasted = getBroadcastInstrs(V); + Value *RuntimeVF = getRuntimeVF(PhiTy); + // After broadcasting the induction variable we need to make the vector + // consecutive by adding 0, 1, 2, etc. + for (unsigned part = 0; part < UF; ++part) { + Value *Part = ConstantInt::get(PhiTy, part); + Value *StartIdx = Builder.CreateMul(RuntimeVF, Part); + Entry[part] = getStepVector(Broadcasted, StartIdx, II.getStep()); + } + } else { + // Instead of re-creating the vector IV by splatting the scalar IV + // in each iteration, we can make a new independent vector IV. + widenInductionVariable(II, Entry); + } + return; + } + case InductionDescriptor::IK_FpInduction: { + Value *V = Builder.CreateCast(Instruction::SIToFP, Induction, P->getType()); + V = II.transform(Builder, V, PSE.getSE(), DL); + V->setName("fp.offset.idx"); + Value *Broadcasted = getBroadcastInstrs(V); + Value *RuntimeVF = getRuntimeVF(Builder.getInt32Ty()); + for (unsigned part = 0; part < UF; ++part) { + Value *Part = Builder.getInt32(part); + Value *StartIdx = Builder.CreateMul(RuntimeVF, Part); + auto *StepVal = cast(II.getStep())->getValue(); + Entry[part] = getStepVector(Broadcasted, StartIdx, StepVal, + II.getInductionOpcode()); + } + return; + } + case InductionDescriptor::IK_PtrInduction: { + // Handle the pointer induction variable case. + assert(P->getType()->isPointerTy() && "Unexpected type."); + // This is the normalized GEP that starts counting at zero. + Value *PtrInd = Induction; + PtrInd = Builder.CreateSExtOrTrunc(PtrInd, II.getStep()->getType()); + + if (!isScalable()) { + // This is the vector of results. Notice that we don't generate + // vector geps because scalar geps result in better code. + for (unsigned part = 0; part < UF; ++part) { + if (VF == 1) { + int EltIndex = part; + Constant *Idx = ConstantInt::get(PtrInd->getType(),EltIndex); + Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); + Value *SclrGep = II.transform(Builder, GlobalIdx, PSE.getSE(), DL); + SclrGep->setName("next.gep"); + Entry[part] = SclrGep; + continue; + } + + Value *VecVal = UndefValue::get(VectorType::get(P->getType(), VF)); + for (unsigned int i = 0; i < VF; ++i) { + int EltIndex = i + part * VF; + Constant *Idx = ConstantInt::get(PtrInd->getType(),EltIndex); + Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); + Value *SclrGep = II.transform(Builder, GlobalIdx, PSE.getSE(), DL); + SclrGep->setName("next.gep"); + VecVal = Builder.CreateInsertElement(VecVal, SclrGep, + Builder.getInt32(i), + "insert.gep"); + } + Entry[part] = VecVal; + } + } else { + Type *PhiTy = PtrInd->getType(); + Value *RuntimeVF = getRuntimeVF(PhiTy); + + Value *StepValue; + ScalarEvolution *SE = PSE.getSE(); + const DataLayout &DL = PN->getModule()->getDataLayout(); + SCEVExpander Expander(*SE, DL, "seriesgep"); + + if (Legal->getInductionVars()->count(P)) { + const SCEV *Step = Legal->getInductionVars()->lookup(P).getStep(); + StepValue = Expander.expandCodeFor(Step, Step->getType(), + &*Builder.GetInsertPoint()); + } else { + auto *SAR = dyn_cast(SE->getSCEV(PN)); + assert(SAR && SAR->isAffine() && "Pointer induction not loop affine"); + + // Expand step and start value (the latter in preheader) + const SCEV *StepRec = SAR->getStepRecurrence(*SE); + StepValue = Expander.expandCodeFor(StepRec, StepRec->getType(), + &*Builder.GetInsertPoint()); + // Normalize step to be in #elements, not bytes + Type *ElemTy = PN->getType()->getPointerElementType(); + Value *Tmp = ConstantInt::get(StepValue->getType(), + DL.getTypeAllocSize(ElemTy)); + StepValue = Builder.CreateSDiv(StepValue, Tmp); + } + + for (unsigned part = 0; part < UF; ++part) { + Value *Part = ConstantInt::get(PhiTy, part); + Value *Idx = Builder.CreateMul(RuntimeVF, Part); + Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); + Value *SclrGep = II.transform(Builder, GlobalIdx, SE, DL); + SclrGep->setName("next.gep"); + Value *Offs = Builder.CreateSeriesVector({VF,Scalable}, + ConstantInt::get(StepValue->getType(), 0), StepValue); + Entry[part] = Builder.CreateGEP(SclrGep, Offs); + Entry[part]->setName("vector.gep"); + } + } + return; + } + } +} + +// Vectorize GEP as arithmetic instructions. +// +// This is required when a given GEP is not used for a load/store operation, +// but rather to implement pointer arithmetic. In this case, the pointer may +// be a vector of pointers (e.g. resulting from a load). +// +// This function makes a ptrtoint->arith->inttoptr transformation. +// +// extern char * reg_names[]; +// void foo(void) { +// for (int i = 0; i < K; i++) +// reg_names[i]--; +// } +// +// %1 = getelementptr inbounds [0 x i8*]* @reg_names, i64 0, i64 %0 +// %2 = bitcast i8** %1 to * +// %wide.load = load * %2, align 8, !tbaa !1 +// %3 = ptrtoint %wide.load to +// %4 = add %3, seriesvector (i64 -1, i64 0) +// %5 = inttoptr %4 to +// %6 = bitcast i8** %1 to * +// store %5, * %6, align 8, !tbaa !1 +void InnerLoopVectorizer::vectorizeArithmeticGEP(Instruction *Instr) { + assert(isa(Instr) && "Instr is not a GEP"); + GetElementPtrInst *GEP = static_cast(Instr); + + // Used types for inttoptr/ptrtoint transform + Type *OrigPtrType = GEP->getType(); + const DataLayout &DL = GEP->getModule()->getDataLayout(); + Type *IntPtrType = DL.getIntPtrType(GEP->getType()); + + // Constant and Variable elements are kept separate to allow IRBuilder + // to fold the constant before widening it to a vector. + VectorParts &Base = getVectorValue(GEP->getPointerOperand()); + VectorParts &Res = WidenMap.get(Instr); + + for (unsigned Part = 0; Part < UF; ++Part) { + // Pointer To Int (pointer operand) + Res[Part] = Builder.CreatePtrToInt( + Base[Part], VectorType::get(IntPtrType, VF, Scalable)); + + // Collect constants and split up the GEP expression into an arithmetic one. + Value *Cst = ConstantInt::get(IntPtrType, 0, false); + gep_type_iterator GTI = gep_type_begin(*GEP); + for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { + // V is still scalar + Value *V = GEP->getOperand(I); + + if (StructType *STy = GTI.getStructTypeOrNull()) { + // Struct type, get field offset in bytes. Result is always a constant. + assert(isa(V) && "Field offset must be constant"); + + ConstantInt *CI = static_cast(V); + unsigned ByteOffset = + DL.getStructLayout(STy)->getElementOffset(CI->getLimitedValue()); + V = ConstantInt::get(IntPtrType, ByteOffset, false); + } else { + // First transform index to pointer-type + if (V->getType() != IntPtrType) + V = Builder.CreateIntCast(V, IntPtrType, true, "idxprom"); + + Value *TypeAllocSize = ConstantInt::get( + V->getType(), DL.getTypeAllocSize(GTI.getIndexedType()), true); + // Only widen non-constant offsets + if (isa(V)) + V = Builder.CreateMul(V, TypeAllocSize); + else + V = Builder.CreateMul(getVectorValue(V)[Part], + getVectorValue(TypeAllocSize)[Part]); + } + + if (isa(V)) + Cst = Builder.CreateAdd(Cst, V); + else + Res[Part] = Builder.CreateAdd(Res[Part], V); + } + + // Add constant part and create final conversion to original type + Res[Part] = Builder.CreateAdd(Res[Part], getVectorValue(Cst)[Part]); + Res[Part] = Builder.CreateIntToPtr( + Res[Part], VectorType::get(OrigPtrType, VF, Scalable)); + } +} + +void InnerLoopVectorizer::patchLatchBranch(BranchInst *Br) { + assert(UsePredication && "Expect predicate to drive loop termination."); + assert(Br->getParent() == OrigLoop->getLoopLatch() && + "Non-latch branch cannot be patched"); + + BasicBlock *LastBB = LoopVectorBody.back(); + Type *IdxTy = Legal->getWidestInductionType(); + Value *RuntimeVF = getRuntimeVF(IdxTy); + Constant *StepVec = StepVector::get(IdxEndV->getType()); + + Value *InductionStep = getInductionStep(); + Value *NextIdx = Builder.CreateNUWAdd(Induction, InductionStep, "index.next"); + + // Create the predicates for the next iteration. + // NOTE: The vector loop is guarded by a check to ensure the calculation of + // index.next will not overflow. The vector loop's latch is based on this + // value alone, thus we can safetly add NUW flags across all lanes as they're + // only be used within the vector loop and considered poisoned upon exit. + for (unsigned i = 0; i < UF; ++i) { + Value *Step = Builder.CreateMul(RuntimeVF, ConstantInt::get(IdxTy, i)); + Value *Idx = Builder.CreateNUWAdd(NextIdx, Step); + Value *IdxSplat = Builder.CreateVectorSplat({VF,Scalable}, Idx); + + Value *NextPred = Builder.CreateNUWAdd(IdxSplat, StepVec); + NextPred = Builder.CreateICmpULT(NextPred, IdxEndV, "predicate.next"); + Predicate[i]->addIncoming(NextPred, LastBB); + } + + // An active first element means we have more work to do. + PHINode *FirstPredPhi = Predicate.front(); + Value *FinalTest = FirstPredPhi->getIncomingValueForBlock(LastBB); + // Test for first lane active. + Value *Done = Builder.CreateExtractElement(FinalTest, Builder.getInt64(0)); + Induction->addIncoming(NextIdx, LoopVectorBody.back()); + LatchBranch->setCondition(Done); + + // ---------------------------------------------------------------------- + // Generate an llvm.assume() intrinsic about the bounds of this loopvar + // if the chosen Index value replaces induction variables with smaller + // type and/or range. This can be used in InstCombine for better folding + // of some cases. + // ---------------------------------------------------------------------- + + // Get the range of the induction value. + ScalarEvolution *SE = PSE.getSE(); + auto IndTy = cast(Induction->getType()); + APInt MinRange = SE->getUnsignedRangeMax(SE->getSCEV(IdxEnd)) - 1; + + // Find a loop variable with the same start/step value + // and reduce the range if possible. + LoopVectorizationLegality::InductionList::iterator I, E; + LoopVectorizationLegality::InductionList *List = Legal->getInductionVars(); + for (I = List->begin(), E = List->end(); I != E; ++I) { + // Ignore FP inductions. + if (!I->first->getType()->isIntegerTy()) + continue; + // Check it is a non-negative, non-wrapping AddRec. + auto *Phi = dyn_cast(SE->getSCEV(I->first)); + if (!Phi) + continue; + + // Must have same step value. + if (!Phi->getStepRecurrence(*SE)->isOne()) + continue; + + // If range is smaller, reduce. + APInt RRange = SE->getUnsignedRange(Phi).getSetSize(); + if (MinRange.getBitWidth() > RRange.getBitWidth()) + RRange = RRange.zext(MinRange.getBitWidth()); + else if (MinRange.getBitWidth() < RRange.getBitWidth()) + MinRange = MinRange.zext(RRange.getBitWidth()); + if (RRange.ult(MinRange)) + MinRange = RRange; + } + + ConstantInt *MaxInd = ConstantInt::get(IndTy, MinRange.getLimitedValue()); + if (MaxInd->isMaxValue(false /* unsigned */)) + return; + + // Create the assume intrinsic in the preheader (llvm.assume must + // dominate use in order to be effective in InstCombine) + BasicBlock::iterator IP = Builder.GetInsertPoint(); + Builder.SetInsertPoint(IP->getParent()->getFirstNonPHI()); + + // Induction < minrange.Upper + CallInst *Assumption = Builder.CreateAssumption( + Builder.CreateICmpULT(Induction, MaxInd)); + AC->registerAssumption(Assumption); + + // Restore insertion point + Builder.SetInsertPoint(IP->getParent(), IP); +} + +bool +InnerLoopVectorizer::testHorizontalReductionExitInst(Instruction *I, + RecurrenceDescriptor &RD) { + auto Redux = Legal->getReductionVars(); + + bool Found = false; + for (auto Red : *Redux) { + auto RedDesc = Red.second; + if (!RedDesc.isOrdered()) + continue; + + if (RedDesc.getLoopExitInstr() == I) { + Found = true; + RD = RedDesc; + break; + } + + // Test if this is a PHI with an input from a horizontal ordered reduction. + auto P = dyn_cast(RedDesc.getLoopExitInstr()); + if (P && P->getNumIncomingValues() == 2 && + ((P->getIncomingValue(0) == I) || (P->getIncomingValue(1) == I))) { + Found = true; + RD = RedDesc; + break; + } + } + + if (!Found) + return false; + + LLVM_DEBUG(dbgs() << "LV: found an ordered horizontal reduction: "; + I->print(dbgs()); + dbgs()<< "\n"); + + return true; +} + +void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { + // For each instruction in the old loop. + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + VectorParts &Entry = WidenMap.get(&*it); + + switch (it->getOpcode()) { + case Instruction::Br: + if (UsePredication && BB == OrigLoop->getLoopLatch()) + patchLatchBranch(cast(it)); + continue; + case Instruction::PHI: { + // Vectorize PHINodes. + widenPHIInstruction(&*it, Entry, UF, VF, PV); + continue; + } // End of PHI. + + case Instruction::FAdd: { + // Just widen binops. + BinaryOperator *BinOp = dyn_cast(it); + setDebugLocFromInst(Builder, BinOp); + VectorParts &A = getVectorValue(it->getOperand(0)); + VectorParts &B = getVectorValue(it->getOperand(1)); + + // Use this vector value for all users of the original instruction. + RecurrenceDescriptor RD; + const bool isHorizontalReduction = + testHorizontalReductionExitInst(&*it, RD); + VectorParts Mask = createBlockInMask(it->getParent()); + for (unsigned Part = 0; Part < UF; ++Part) { + Value *V = nullptr; + + if (!isHorizontalReduction || VF == 1) { + V = Builder.CreateBinOp(BinOp->getOpcode(), A[Part], B[Part]); + if (BinaryOperator *VecOp = dyn_cast(V)) + VecOp->copyIRFlags(BinOp); + } else { + auto X = A[Part]; + auto Y = B[Part]; + auto XTy = X->getType(); + auto YTy = Y->getType(); + + if (YTy->isVectorTy() && !XTy->isVectorTy()) { + if (Part > 0) + X = Entry[Part-1]; + auto P = Builder.CreateAnd(Mask[Part], Predicate[Part]); + V = createOrderedReduction(Builder, RD, Y, X, P); + } + if (XTy->isVectorTy() && !YTy->isVectorTy()) { + if (Part > 0) + Y = Entry[Part-1]; + auto P = Builder.CreateAnd(Mask[Part], Predicate[Part]); + V = createOrderedReduction(Builder, RD, X, Y, P); + } + assert(V && "cannot find the reduction intrinsic"); + } + Entry[Part] = V; + } + + addMetadata(Entry, &*it); + break; + } + case Instruction::Add: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + // Just widen binops. + BinaryOperator *BinOp = dyn_cast(it); + setDebugLocFromInst(Builder, BinOp); + VectorParts &A = getVectorValue(it->getOperand(0)); + VectorParts &B = getVectorValue(it->getOperand(1)); + + // Use this vector value for all users of the original instruction. + for (unsigned Part = 0; Part < UF; ++Part) { + Value *V = Builder.CreateBinOp(BinOp->getOpcode(), A[Part], B[Part]); + if (BinaryOperator *VecOp = dyn_cast(V)) + VecOp->copyIRFlags(BinOp); + + Entry[Part] = V; + } + + addMetadata(Entry, &*it); + break; + } + case Instruction::Select: { + // Widen selects. + // If the selector is loop invariant we can create a select + // instruction with a scalar condition. Otherwise, use vector-select. + auto *SE = PSE.getSE(); + bool InvariantCond = + SE->isLoopInvariant(PSE.getSCEV(it->getOperand(0)), OrigLoop); + setDebugLocFromInst(Builder, &*it); + + // The condition can be loop invariant but still defined inside the + // loop. This means that we can't just use the original 'cond' value. + // We have to take the 'vectorized' value and pick the first lane. + // Instcombine will make this a no-op. + VectorParts &Cond = getVectorValue(it->getOperand(0)); + VectorParts &Op0 = getVectorValue(it->getOperand(1)); + VectorParts &Op1 = getVectorValue(it->getOperand(2)); + + Value *ScalarCond = + (VF == 1) + ? Cond[0] + : Builder.CreateExtractElement(Cond[0], Builder.getInt32(0)); + + for (unsigned Part = 0; Part < UF; ++Part) { + Entry[Part] = Builder.CreateSelect( + InvariantCond ? ScalarCond : Cond[Part], Op0[Part], Op1[Part]); + } + + addMetadata(Entry, &*it); + break; + } + + case Instruction::ICmp: + case Instruction::FCmp: { + // Widen compares. Generate vector compares. + bool FCmp = (it->getOpcode() == Instruction::FCmp); + CmpInst *Cmp = dyn_cast(it); + setDebugLocFromInst(Builder, &*it); + VectorParts &A = getVectorValue(it->getOperand(0)); + VectorParts &B = getVectorValue(it->getOperand(1)); + for (unsigned Part = 0; Part < UF; ++Part) { + Value *C = nullptr; + if (FCmp) { + C = Builder.CreateFCmp(Cmp->getPredicate(), A[Part], B[Part]); + cast(C)->copyFastMathFlags(&*it); + } else { + C = Builder.CreateICmp(Cmp->getPredicate(), A[Part], B[Part]); + } + Entry[Part] = C; + } + + addMetadata(Entry, &*it); + break; + } + + case Instruction::Store: + case Instruction::Load: + vectorizeMemoryInstruction(&*it); + break; + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::SIToFP: + case Instruction::UIToFP: + case Instruction::Trunc: + case Instruction::FPTrunc: + case Instruction::BitCast: { + CastInst *CI = dyn_cast(it); + setDebugLocFromInst(Builder, &*it); + /// Optimize the special case where the source is a constant integer + /// induction variable. Notice that we can only optimize the 'trunc' case + /// because: a. FP conversions lose precision, b. sext/zext may wrap, + /// c. other casts depend on pointer size. + + if (CI->getOperand(0) == OldInduction && + it->getOpcode() == Instruction::Trunc) { + InductionDescriptor II = + Legal->getInductionVars()->lookup(OldInduction); + if (auto StepValue = II.getConstIntStepValue()) { + IntegerType *TruncType = cast(CI->getType()); + if (VF == 1) { + StepValue = + ConstantInt::getSigned(TruncType, StepValue->getSExtValue()); + Value *ScalarCast = + Builder.CreateCast(CI->getOpcode(), Induction, CI->getType()); + Value *Broadcasted = getBroadcastInstrs(ScalarCast); + Type* ElemTy = Broadcasted->getType()->getScalarType(); + Value* RuntimeVF = getRuntimeVF(ElemTy); + for (unsigned Part = 0; Part < UF; ++Part) { + Value *Start = + Builder.CreateMul(RuntimeVF, ConstantInt::get(ElemTy, Part)); + Entry[Part] = getStepVector(Broadcasted, Start, StepValue); + } + } else { + // Truncating a vector induction variable on each iteration + // may be expensive. Instead, truncate the initial value, and create + // a new, truncated, vector IV based on that. + widenInductionVariable(II, Entry, TruncType); + } + addMetadata(Entry, &*it); + break; + } + } + /// Vectorize casts. + Type *DestTy = + (VF == 1) ? CI->getType() : + VectorType::get(CI->getType(), VF, Scalable); + + VectorParts &A = getVectorValue(it->getOperand(0)); + for (unsigned Part = 0; Part < UF; ++Part) + Entry[Part] = Builder.CreateCast(CI->getOpcode(), A[Part], DestTy); + addMetadata(Entry, &*it); + break; + } + + case Instruction::FRem: { + BinaryOperator *BinOp = cast(it); + setDebugLocFromInst(Builder, BinOp); + VectorParts &A = getVectorValue(it->getOperand(0)); + VectorParts &B = getVectorValue(it->getOperand(1)); + Type *RetTy = ToVectorTy(BinOp->getType(), VF, Scalable); + Module *M = BB->getParent()->getParent(); + VectorType::ElementCount EC(VF, Scalable); + Type *Tys[3] = {RetTy, RetTy, + VectorType::getBool(cast(RetTy))}; + + LibFunc F = LibFunc_fmod; + if (BinOp->getType()->isFloatTy()) + F = LibFunc_fmodf; + + StringRef FnName = TLI->getName(F); + FunctionType *FTy = FunctionType::get(RetTy, Tys, false); + const std::string VFnName = + TLI->getVectorizedFunction(FnName, EC, true, FTy); + + Function *VectorF; + if (VFnName != "") { + VectorF = Function::Create(FTy, Function::ExternalLinkage, VFnName, M); + } else { + Intrinsic::ID ID = Intrinsic::masked_fmod; + VectorF = Intrinsic::getDeclaration(M, ID, RetTy); + } + + for (unsigned Part = 0; Part < UF; ++Part) { + Value *Ops[3] = {A[Part], B[Part], Predicate[Part]}; + CallInst *FMod = Builder.CreateCall(VectorF, Ops); + if (isa(FMod)) + FMod->copyFastMathFlags(BinOp); + + Entry[Part] = FMod; + } + addMetadata(Entry, &*it); + break; + } + + case Instruction::Call: { + if (auto MSI = dyn_cast(it)) { + vectorizeMemsetInstruction(MSI); + break; + } + // Ignore dbg intrinsics. + if (isa(it)) + break; + setDebugLocFromInst(Builder, &*it); + + Module *M = BB->getParent()->getParent(); + CallInst *CI = cast(it); + + StringRef FnName = CI->getCalledFunction()->getName(); + Function *F = CI->getCalledFunction(); + Type *RetTy = ToVectorTy(CI->getType(), VF, Scalable); + + Intrinsic::ID ID = + getVectorIntrinsicIDForCall(CI, TLI, UsePredication && VF > 1); + if (ID && + (ID == Intrinsic::assume || ID == Intrinsic::lifetime_end || + ID == Intrinsic::lifetime_start)) { + if (isScalable() && + OrigLoop->isLoopInvariant(it->getOperand(0)) && + OrigLoop->isLoopInvariant(it->getOperand(1))) + Builder.Insert(it->clone()); + else + scalarizeInstruction(&*it); + break; + } + // The flag shows whether we use Intrinsic or a usual Call for vectorized + // version of the instruction. + // Is it beneficial to perform intrinsic call compared to lib call? + bool NeedToScalarize; + VectorizationFactor VecFactor = { VF, 1, !isScalable() }; + unsigned CallCost = getVectorCallCost(CI, VecFactor, *TTI, + TLI, *Legal, NeedToScalarize); + NeedToScalarize = NeedToScalarize && (!isScalable()); + bool UseVectorIntrinsic = + ID && getVectorIntrinsicCost(CI, VecFactor, *TTI, TLI) <= CallCost; + if (!UseVectorIntrinsic && NeedToScalarize) { + scalarizeInstruction(&*it); + break; + } + + const auto IsThereAMaskParam = isMaskedVectorIntrinsic(ID); + const bool IsPredicated = IsThereAMaskParam.first ; + const unsigned MaskPosition = IsThereAMaskParam.second ; + const bool CallNeedsPredication = IsPredicated || + (UsePredication && TLI->isFunctionVectorizable(FnName)); + for (unsigned Part = 0; Part < UF; ++Part) { + SmallVector Args; + for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) { + Value *Arg = CI->getArgOperand(i); + // Some intrinsics have a scalar argument - don't replace it with a + // vector. + if (!UseVectorIntrinsic || !hasVectorInstrinsicScalarOpd(ID, i)) { + VectorParts &VectorArg = getVectorValue(CI->getArgOperand(i)); + Arg = VectorArg[Part]; + } + Args.push_back(Arg); + } + + if (CallNeedsPredication) { + // If the intrinsic or function is maskable, then we need to pass in + // the loop predicate. + const SmallVectorImpl::iterator Insert = UseVectorIntrinsic ? + Args.begin() + MaskPosition : Args.end(); + Args.insert(Insert, Predicate[Part]); + } + + Function *VectorF; + if (UseVectorIntrinsic) { + // Use vector version of the intrinsic. + Type *TysForDecl[] = {CI->getType()}; + if (VF > 1) + TysForDecl[0] = VectorType::get(CI->getType()->getScalarType(), + VF, Scalable); + VectorF = Intrinsic::getDeclaration(M, ID, TysForDecl); + } else { + SmallVector Tys; + for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) { + Value *Arg = CI->getArgOperand(i); + // Check if the argument `x` is a pointer marked by an + // OpenMP clause `linear(x:1)`. + if (Arg->getType()->isPointerTy() && + (Legal->isConsecutivePtr(Arg) == 1) && + isa(Args[i]->getType())) { + LLVM_DEBUG(dbgs() << "LV: vectorizing " << *Arg + << " as a linear pointer with step 1"); + Args[i] = + Builder.CreateExtractElement(Args[i], Builder.getInt32(0)); + Tys.push_back(Arg->getType()); + } else + Tys.push_back(ToVectorTy(Arg->getType(), VF, Scalable)); + } + + if (CallNeedsPredication) { + // If the intrinsic or function is maskable, then we need to pass in + // the loop predicate type in the correct place of the signature. + const SmallVectorImpl::iterator Insert = + UseVectorIntrinsic ? Tys.begin() + MaskPosition : Tys.end(); + Tys.insert(Insert, Predicate[0]->getType()); + } + + // Use vector version of the library call. + FunctionType *FTy = FunctionType::get(RetTy, Tys, false); + LLVM_DEBUG(dbgs() << "SVE LV: Looking for a signature" << *FTy << + "\n"); + VectorType::ElementCount EC(VF, Scalable); + const std::string VFnName = + TLI->getVectorizedFunction(FnName, EC, CallNeedsPredication, FTy); + assert(!VFnName.empty() && "Vector function name is empty."); + VectorF = M->getFunction(VFnName); + if (!VectorF) { + // Generate a declaration + VectorF = + Function::Create(FTy, Function::ExternalLinkage, VFnName, M); + VectorF->copyAttributesFrom(F); + } + } + assert(VectorF && "Can't create vector function."); + + SmallVector OpBundles; + CI->getOperandBundlesAsDefs(OpBundles); + CallInst *V = Builder.CreateCall(VectorF, Args, OpBundles); + + if (isa(V)) + V->copyFastMathFlags(CI); + + Entry[Part] = V; + } + + addMetadata(Entry, &*it); + break; + } + + case Instruction::GetElementPtr: + vectorizeGEPInstruction(&*it); + break; + + default: + // All other instructions are unsupported. Scalarize them. + scalarizeInstruction(&*it); + break; + } // end of switch. + } // end of for_each instr. +} + +void InnerLoopVectorizer::updateAnalysis() { + // Forget the original basic block. + PSE.getSE()->forgetLoop(OrigLoop); + + // Update the dominator tree information. + assert(DT->properlyDominates(LoopBypassBlocks.front(), LoopExitBlock) && + "Entry does not dominate exit."); + + /* + for (unsigned I = 1, E = LoopBypassBlocks.size(); I != E; ++I) + DT->addNewBlock(LoopBypassBlocks[I], LoopBypassBlocks[I-1]); + DT->addNewBlock(LoopVectorPreHeader, LoopBypassBlocks.back()); + */ + + // Add dominator for first vector body block. + DT->addNewBlock(LoopVectorBody[0], LoopVectorPreHeader); + for (const auto &Edge : VecBodyDomEdges) + DT->addNewBlock(Edge.second, Edge.first); + + DT->addNewBlock(LoopMiddleBlock, LoopVectorBody.back()); + DT->addNewBlock(LoopScalarPreHeader, LoopBypassBlocks[0]); + DT->changeImmediateDominator(LoopScalarBody, LoopScalarPreHeader); + DT->changeImmediateDominator(LoopExitBlock, LoopBypassBlocks[0]); + + LLVM_DEBUG(assert(DT->verify(DominatorTree::VerificationLevel::Fast))); +} + +/// \brief Check whether it is safe to if-convert this phi node. +/// +/// Phi nodes with constant expressions that can trap are not safe to if +/// convert. +static bool canIfConvertPHINodes(BasicBlock *BB) { + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + PHINode *Phi = dyn_cast(I); + if (!Phi) + return true; + for (unsigned p = 0, e = Phi->getNumIncomingValues(); p != e; ++p) + if (Constant *C = dyn_cast(Phi->getIncomingValue(p))) + if (C->canTrap()) + return false; + } + return true; +} + +bool LoopVectorizationLegality::canVectorizeWithIfConvert() { + bool CanIfConvert = true; + + if (!EnableIfConversion) { + ORE->emit(createMissedAnalysis("IfConversionDisabled") + << "if-conversion is disabled"); + LLVM_DEBUG(dbgs() << "LV: Not vectorizing - if-conversion is disabled.\n"); + CanIfConvert = false; + NODEBUG_EARLY_BAILOUT(); + } + + assert(TheLoop->getNumBlocks() > 1 && "Single block loops are vectorizable"); + + // A list of pointers that we can safely read and write to. + SmallPtrSet SafePointes; + + // Collect safe addresses. + for (BasicBlock *BB : TheLoop->blocks()) { + if (blockNeedsPredication(BB)) + continue; + + for (Instruction &I : *BB) + if (auto *Ptr = getPointerOperand(&I)) + SafePointes.insert(Ptr); + } + + // Collect the blocks that need predication. + BasicBlock *Header = TheLoop->getHeader(); + for (BasicBlock *BB : TheLoop->blocks()) { + // We don't support switch statements inside loops. + if (!isa(BB->getTerminator())) { + ORE->emit(createMissedAnalysis("LoopContainsSwitch", BB->getTerminator()) + << "loop contains a switch statement"); + LLVM_DEBUG(dbgs() << + "LV: Not vectorizing - loop contains a switch statement.\n"); + CanIfConvert = false; + NODEBUG_EARLY_BAILOUT(); + } + + // We must be able to predicate all blocks that need to be predicated. + BasicBlock *PredB = BB->getSinglePredecessor(); + DebugLoc CmpLoc = DebugLoc(); + if (PredB && PredB->getTerminator()) + CmpLoc = PredB->getTerminator()->getDebugLoc(); + + if (blockNeedsPredication(BB)) { + if (!blockCanBePredicated(BB, SafePointes)) { + auto R = createMissedAnalysis("NoCFGForSelect"); + R << "control flow cannot be substituted for a select"; + if (auto *PredB = BB->getSinglePredecessor()) + R << ore::setExtraArgs() << ore::NV("Cmp", PredB->getTerminator()); + ORE->emit(R); + + LLVM_DEBUG(dbgs() << + "LV: Not vectorizing - cannot predicate all blocks for if-conversion.\n"); + CanIfConvert = false; + NODEBUG_EARLY_BAILOUT(); + } + } else if (BB != Header && !canIfConvertPHINodes(BB)) { + auto R = createMissedAnalysis("NoCFGForSelect"); + R << "control flow cannot be substituted for a select"; + if (auto *PredB = BB->getSinglePredecessor()) + R << ore::setExtraArgs() << ore::NV("Cmp", PredB->getTerminator()); + ORE->emit(R); + + LLVM_DEBUG(dbgs() << + "LV: Not vectorizing - phi nodes cannot be if converted.\n"); + CanIfConvert = false; + NODEBUG_EARLY_BAILOUT(); + } + } + + // We can if-convert this loop. + return CanIfConvert; +} + +bool LoopVectorizationLegality::canVectorize() { + bool CanVectorize = true; + + // We must have a loop in canonical form. Loops with indirectbr in them cannot + // be canonicalized. + if (!TheLoop->getLoopPreheader()) { + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer"); + return false; + } + + // We can only vectorize innermost loops. + if (!TheLoop->empty()) { + ORE->emit(createMissedAnalysis("NotInnermostLoop") + << "loop is not the innermost loop"); + LLVM_DEBUG(dbgs() << "LV: Not vectorizing - not the innermost loop.\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + } + + // We must have a single backedge. + if (TheLoop->getNumBackEdges() != 1) { + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer" + << ore::setExtraArgs() + << " (Reason = multiple backedges)"); + return false; + } + + // We must have a single exiting block. + if (!TheLoop->getExitingBlock()) { + std::vector EarlyExitLocations; + TheLoop->getEarlyExitLocations(EarlyExitLocations); + for (const auto &Loc: EarlyExitLocations) { + ORE->emit( + OptimizationRemarkAnalysis(Hints->vectorizeAnalysisPassName(), + "EarlyExit", + { Loc.getStart() }, + TheLoop->getHeader()) + << "loop not vectorized: " + << "Early exit prevented vectorization"); + } + + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer" + << ore::setExtraArgs() + << " (Reason = early exits)"); + + return false; + } + + // We only handle bottom-tested loops, i.e. loop in which the condition is + // checked at the end of each iteration. With that we can assume that all + // instructions in the loop are executed the same number of times. + if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) { + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer"); + return false; + } + + // We need to have a loop header. + LLVM_DEBUG(dbgs() << "LV: Found a loop: " << TheLoop->getHeader()->getName() + << '\n'); + + // Check if we can if-convert non-single-bb loops. + unsigned NumBlocks = TheLoop->getNumBlocks(); + if (NumBlocks != 1 && !canVectorizeWithIfConvert()) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing - can't if-convert the loop.\n"); + return false; + } + + // ScalarEvolution needs to be able to find the exit count. + auto *SE = PSE.getSE(); + const SCEV *ExitCount = PSE.getBackedgeTakenCount(); + if (ExitCount == SE->getCouldNotCompute()) { + ORE->emit(createMissedAnalysis("CantComputeNumberOfIterations") + << "could not determine number of loop iterations"); + LLVM_DEBUG(dbgs() << + "LV: Not vectorizing - SCEV could not compute the loop exit count.\n"); + + return false; + } + + // Check if we can vectorize the instructions and CFG in this loop. + if (!canVectorizeInstrs()) { + LLVM_DEBUG(dbgs() << + "LV: Not vectorizing - can't vectorize the instructions or CFG.\n"); + return false; + } + + // Go over each instruction and look at memory deps. + if (!canVectorizeMemory()) { + LLVM_DEBUG(dbgs() << + "LV: Can't vectorize due to memory conflicts.\n"); + return false; + } + + if (CanVectorize) { + // Collect all of the variables that remain uniform after vectorization. + collectLoopUniforms(); + LLVM_DEBUG(dbgs() << "LV: We can vectorize this loop" + << (LAI->getRuntimePointerChecking()->Need + ? " (with a runtime bound check)" + : "") + << "!\n"); + } + + bool UseInterleaved = TTI->enableInterleavedAccessVectorization(); + + // If an override option has been passed in for interleaved accesses, use it. + if (EnableInterleavedMemAccesses.getNumOccurrences() > 0) + UseInterleaved = EnableInterleavedMemAccesses; + + // Analyze interleaved memory accesses. + if (UseInterleaved) + InterleaveInfo.analyzeInterleaving(Strides); + + unsigned SCEVThreshold = VectorizeSCEVCheckThreshold; + if (Hints->getForce() == LoopVectorizeHints::FK_Enabled) + SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; + + if (PSE.getUnionPredicate().getComplexity() > SCEVThreshold) { + ORE->emit(createMissedAnalysis("TooManySCEVRunTimeChecks") + << "Too many SCEV assumptions need to be made and checked " + << "at runtime"); + LLVM_DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n"); + return false; + } + + // Okay! We can vectorize. At this point we don't have any other mem analysis + // which may limit our maximum vectorization factor, so just return true with + // no restrictions. + return CanVectorize; +} + +static Type *convertPointerToIntegerType(const DataLayout &DL, Type *Ty) { + if (Ty->isPointerTy()) + return DL.getIntPtrType(Ty); + + // It is possible that char's or short's overflow when we ask for the loop's + // trip count, work around this by changing the type size. + if (Ty->getScalarSizeInBits() < 32) + return Type::getInt32Ty(Ty->getContext()); + + return Ty; +} + +static Type *getWiderType(const DataLayout &DL, Type *Ty0, Type *Ty1) { + Ty0 = convertPointerToIntegerType(DL, Ty0); + Ty1 = convertPointerToIntegerType(DL, Ty1); + if (Ty0->getScalarSizeInBits() > Ty1->getScalarSizeInBits()) + return Ty0; + return Ty1; +} + +/// \brief Check that the instruction has outside loop users and is not an +/// identified reduction variable. +static bool hasOutsideLoopUser(const Loop *TheLoop, Instruction *Inst, + SmallPtrSetImpl &AllowedExit) { + // Reduction and Induction instructions are allowed to have exit users. All + // other instructions must not have external users. + if (!AllowedExit.count(Inst)) + // Check that all of the users of the loop are inside the BB. + for (User *U : Inst->users()) { + Instruction *UI = cast(U); + // This user may be a reduction exit value. + if (!TheLoop->contains(UI)) { + LLVM_DEBUG(dbgs() << "LV: Found an outside user " << *UI << " for : " + << *Inst << "\n"); + return true; + } + } + return false; +} + +void LoopVectorizationLegality::addInductionPhi( + PHINode *Phi, const InductionDescriptor &ID, + SmallPtrSetImpl &AllowedExit) { + Inductions[Phi] = ID; + Type *PhiTy = Phi->getType(); + const DataLayout &DL = Phi->getModule()->getDataLayout(); + + // Get the widest type. + if (!PhiTy->isFloatingPointTy()) { + if (!WidestIndTy) + WidestIndTy = convertPointerToIntegerType(DL, PhiTy); + else + WidestIndTy = getWiderType(DL, PhiTy, WidestIndTy); + } + + // Int inductions are special because we only allow one IV. + if (ID.getKind() == InductionDescriptor::IK_IntInduction && + ID.getConstIntStepValue() && + ID.getConstIntStepValue()->isOne() && + isa(ID.getStartValue()) && + cast(ID.getStartValue())->isNullValue()) { + + // Use the phi node with the widest type as induction. Use the last + // one if there are multiple (no good reason for doing this other + // than it is expedient). We've checked that it begins at zero and + // steps by one, so this is a canonical induction variable. + if (!PrimaryInduction || PhiTy == WidestIndTy) + PrimaryInduction = Phi; + } + + // Both the PHI node itself, and the "post-increment" value feeding + // back into the PHI node may have external users. + AllowedExit.insert(Phi); + AllowedExit.insert(Phi->getIncomingValueForBlock(TheLoop->getLoopLatch())); + + LLVM_DEBUG(dbgs() << "LV: Found an induction variable.\n"); + return; +} + +bool LoopVectorizationLegality::canVectorizeInstrs() { + BasicBlock *Header = TheLoop->getHeader(); + LAI = &(*GetLAA)(*TheLoop); + + bool CanVectorize = true; + + // Look for the attribute signaling the absence of NaNs. + Function &F = *Header->getParent(); + HasFunNoNaNAttr = + F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; + + // For each block in the loop. + for (Loop::block_iterator bb = TheLoop->block_begin(), + be = TheLoop->block_end(); + bb != be; ++bb) { + + // Scan the instructions in the block and look for hazards. + for (BasicBlock::iterator it = (*bb)->begin(), e = (*bb)->end(); it != e; + ++it) { + + if (PHINode *Phi = dyn_cast(it)) { + Type *PhiTy = Phi->getType(); + // Check that this PHI type is allowed. + if (!PhiTy->isIntegerTy() && !PhiTy->isFloatingPointTy() && + !PhiTy->isPointerTy()) { + ORE->emit(createMissedAnalysis("CFGNotUnderstood", Phi) + << "loop control flow is not understood by vectorizer"); + LLVM_DEBUG(dbgs() << + "LV: Not vectorizing - Found an non-int non-pointer PHI.\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + } + + // If this PHINode is not in the header block, then we know that we + // can convert it to select during if-conversion. No need to check if + // the PHIs in this block are induction or reduction variables. + if (*bb != Header) { + // Check that this instruction has no outside users or is an + // identified reduction value with an outside user. + // TODO: For now, we ignore this case with uncounted loops and just + // focus on phis created in the header block. + if (!hasOutsideLoopUser(TheLoop, &*it, AllowedExit)) + continue; + ORE->emit(createMissedAnalysis("NeitherInductionNorReduction", Phi) + << "value could not be identified as " + "an induction or reduction variable"); + return false; + } + + // We only allow if-converted PHIs with exactly two incoming values. + if (Phi->getNumIncomingValues() != 2) { + ORE->emit(createMissedAnalysis("CFGNotUnderstood", Phi) + << "control flow not understood by vectorizer"); + LLVM_DEBUG(dbgs() << + "LV: Not vectorizing - Phi with more than two incoming values.\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + continue; + } + + RecurrenceDescriptor RedDes; + if (RecurrenceDescriptor::isReductionPHI(Phi, TheLoop, PSE.getSE(), + RedDes, false)) { + if (RedDes.hasUnsafeAlgebra()) + Requirements->addUnsafeAlgebraInst(RedDes.getUnsafeAlgebraInst()); + AllowedExit.insert(RedDes.getLoopExitInstr()); + Reductions[Phi] = RedDes; + LLVM_DEBUG(dbgs() << "LV: Found a reduction variable " << *Phi << "\n"); + continue; + } + + InductionDescriptor ID; + RecurrenceDescriptor RecTmp; + // First we do a check to see if the phi is a recognizable reduction, + // if not we try to handle it as an induction variable if possible. + if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID)) { + addInductionPhi(Phi, ID, AllowedExit); + if (ID.hasUnsafeAlgebra() && !HasFunNoNaNAttr) + Requirements->addUnsafeAlgebraInst(ID.getUnsafeAlgebraInst()); + continue; + } + + if (RecurrenceDescriptor::isFirstOrderRecurrence(Phi, TheLoop, + SinkAfter, DT)) { + FirstOrderRecurrences.insert(Phi); + continue; + } + + // As a last resort, coerce the PHI to a AddRec expression + // and re-try classifying it a an induction PHI. + if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID, true)) { + addInductionPhi(Phi, ID, AllowedExit); + continue; + } + + ORE->emit(createMissedAnalysis("NonReductionValueUsedOutsideLoop", Phi) + << "value that could not be identified as " + "reduction is used outside the loop"); + LLVM_DEBUG(dbgs() << "LV: Found an unidentified PHI." << *Phi << "\n"); + return false; + } // end of PHI handling + + // We handle calls that: + // * Are debug info intrinsics. + // * Have a mapping to an IR intrinsic. + // * Have a vector version available. + CallInst *CI = dyn_cast(it); + if (CI && !getVectorIntrinsicIDForCall(CI, TLI) && + !isa(CI) && + !(CI->getCalledFunction() && TLI && + TLI->isFunctionVectorizable(CI->getCalledFunction()->getName()))) { + if (auto MSI = dyn_cast(CI)) { + if (VectorizeMemset && EnableScalableVectorisation && + isLegalMaskedScatter(MSI->getValue()->getType())) { + const auto Length = MSI->getLength(); + const auto IsVolatile = MSI->isVolatile(); + // Alignment is clamped to yield an acceptable vector element type. + const auto Alignment = + std::min((uint64_t)MSI->getDestAlignment(), (uint64_t) 8); + auto CL = dyn_cast(Length); + if (CL && (CL->getZExtValue() % Alignment == 0) + && ((CL->getZExtValue() / Alignment) <= + VectorizerMemSetThreshold) + && !IsVolatile) { + LLVM_DEBUG(dbgs() << "LV: Found a vectorizable 'memset':\n" << *MSI); + continue; + } + } + } + if (CI->isInlineAsm()) + ORE->emit(createMissedAnalysis("CantVectorizeCall") + << "inline assembly call cannot be vectorized" + << ore::setExtraArgs() + << " (Location = " << ore::NV("Location", CI) << ")"); + else { + auto *Callee = CI->getCalledFunction(); + std::string CalleeName = Callee ? Callee->getName() : "[?]"; + ORE->emit(createMissedAnalysis("CantVectorizeCall") + << "call instruction cannot be vectorized" + << ore::setExtraArgs() + << " (Callee = " << CalleeName + << ", Location = " << ore::NV("Location", CI) << ")"); + } + + LLVM_DEBUG(dbgs() << + "LV: Not vectorizing - found a non-intrinsic, non-libfunc callsite " << + *CI << "\n"); + CanVectorize = false; + + NODEBUG_EARLY_BAILOUT(); + continue; + } + + // Intrinsics such as powi,cttz and ctlz are legal to vectorize if the + // second argument is the same (i.e. loop invariant) + if (CI && hasVectorInstrinsicScalarOpd( + getVectorIntrinsicIDForCall(CI, TLI), 1)) { + auto *SE = PSE.getSE(); + if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(1)), TheLoop)) { + ORE->emit(createMissedAnalysis("CantVectorizeIntrinsic", CI) + << "intrinsic instruction cannot be vectorized"); + LLVM_DEBUG(dbgs() << + "LV: Not vectorizing - found unvectorizable intrinsic " << *CI << "\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + continue; + } + } + + // Check that the instruction return type is vectorizable. + // Also, we can't vectorize extractelement instructions. + if ((!VectorType::isValidElementType(it->getType()) && + !it->getType()->isVoidTy()) || + it->getType()->isFP128Ty() || + it->getType()->isIntegerTy(128) || + isa(it)) { + ORE->emit(createMissedAnalysis("CantVectorizeInstructionReturnType", &*it) + << "instruction return type cannot be vectorized"); + LLVM_DEBUG(dbgs() << "LV: Found unvectorizable type.\n"); + LLVM_DEBUG(dbgs() << + "LV: Not vectorizing - found unvectorizable type " << + *(it->getType()) << "\n"); + CanVectorize = false; + NODEBUG_EARLY_BAILOUT(); + continue; + } + + if (StoreInst *SI = dyn_cast(it)) { + Value *Ptr = SI->getPointerOperand(); + auto *Ty = cast(Ptr->getType())->getElementType(); + if (std::abs(isConsecutivePtr(Ptr)) != 1 && !LAI->isUniform(Ptr) && + !isLegalMaskedScatter(Ty)) { + ORE->emit(createMissedAnalysis("CantVectorizeNonUnitStride", &*it) + << "non consecutive store instructions cannot be " + << "vectorized"); + return false; + } + } + if (LoadInst *LI = dyn_cast(it)) { + Value *Ptr = LI->getPointerOperand(); + auto *Ty = cast(Ptr->getType())->getElementType(); + if (std::abs(isConsecutivePtr(Ptr)) != 1 && !LAI->isUniform(Ptr) && + !isLegalMaskedGather(Ty)) { + ORE->emit(createMissedAnalysis("CantVectorizeNonUnitStride", &*it) + << "non consecutive load instructions cannot be " + << "vectorized"); + return false; + } + if (Ty->isIntegerTy()) { + if (!cast(Ty)->isPowerOf2ByteWidth()) { + ORE->emit(createMissedAnalysis("CantVectorizeLoad", &*it) + << "load instruction cannot be vectorized, " + "invalid type"); + return false; + } + } + } + + // Check that the stored type is vectorizable. + if (StoreInst *ST = dyn_cast(it)) { + Type *T = ST->getValueOperand()->getType(); + if (T->isIntegerTy()) { + if (!cast(T)->isPowerOf2ByteWidth()) { + ORE->emit(createMissedAnalysis("CantVectorizeStore", ST) + << "store instruction cannot be vectorized, " + "invalid type"); + return false; + } + } + if (!VectorType::isValidElementType(T) || it->getType()->isFP128Ty()) { + ORE->emit(createMissedAnalysis("CantVectorizeStore", ST) + << "store instruction cannot be vectorized"); + return false; + } + + // FP instructions can allow unsafe algebra, thus vectorizable by + // non-IEEE-754 compliant SIMD units. + // This applies to floating-point math operations and calls, not memory + // operations, shuffles, or casts, as they don't change precision or + // semantics. + } else if (it->getType()->isFloatingPointTy() && + (CI || it->isBinaryOp()) && !it->isFast()) { + LLVM_DEBUG(dbgs() << "LV: Found FP op with unsafe algebra.\n"); + Hints->setPotentiallyUnsafe(); + } + + // Reduction instructions are allowed to have exit users. + // All other instructions must not have external users. + // + // For uncounted loops we do allow induction variable + // escapees. + if (hasOutsideLoopUser(TheLoop, &*it, AllowedExit)) { + ORE->emit(createMissedAnalysis("ValueUsedOutsideLoop", &*it) + << "value cannot be used outside the loop"); + return false; + } + } // next instr. + } + + if (!PrimaryInduction) { + LLVM_DEBUG(dbgs() << "LV: Did not find one integer induction var.\n"); + if (Inductions.empty()) { + ORE->emit(createMissedAnalysis("NoInductionVariable") + << "loop induction variable could not be identified"); + LLVM_DEBUG(dbgs() << + "LV: Not vectorizing - unable to identify loop induction variable.\n"); + CanVectorize = false; + } + } + + // Now we know the widest induction type, check if our found induction + // is the same size. If it's not, unset it here and InnerLoopVectorizer + // will create another. + if (PrimaryInduction && WidestIndTy != PrimaryInduction->getType()) + PrimaryInduction = nullptr; + + return CanVectorize; +} + +void LoopVectorizationLegality::collectStridedAccess(Value *MemAccess) { + Value *Ptr = nullptr; + if (LoadInst *LI = dyn_cast(MemAccess)) + Ptr = LI->getPointerOperand(); + else if (StoreInst *SI = dyn_cast(MemAccess)) + Ptr = SI->getPointerOperand(); + else + return; + + Value *Stride = getStrideFromPointer(Ptr, PSE.getSE(), TheLoop); + if (!Stride) + return; + + LLVM_DEBUG(dbgs() << "LV: Found a strided access that we can version"); + LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *Stride << "\n"); + Strides[Ptr] = Stride; + StrideSet.insert(Stride); +} + +void LoopVectorizationLegality::collectLoopUniforms() { + // We now know that the loop is vectorizable! + // Collect variables that will remain uniform after vectorization. + std::vector Worklist; + BasicBlock *Latch = TheLoop->getLoopLatch(); + + // Start with the conditional branch and walk up the block. + Worklist.push_back(Latch->getTerminator()->getOperand(0)); + + // Also add all consecutive pointer values; these values will be uniform + // after vectorization (and subsequent cleanup) and, until revectorization is + // supported, all dependencies must also be uniform. + for (Loop::block_iterator B = TheLoop->block_begin(), + BE = TheLoop->block_end(); + B != BE; ++B) + for (BasicBlock::iterator I = (*B)->begin(), IE = (*B)->end(); I != IE; ++I) + if (I->getType()->isPointerTy() && isConsecutivePtr(&*I)) + Worklist.insert(Worklist.end(), I->op_begin(), I->op_end()); + + while (!Worklist.empty()) { + Instruction *I = dyn_cast(Worklist.back()); + Worklist.pop_back(); + + // Look at instructions inside this loop. + // Stop when reaching PHI nodes. + // TODO: we need to follow values all over the loop, not only in this block. + if (!I || !TheLoop->contains(I) || isa(I)) + continue; + + // This is a known uniform. + Uniforms.insert(I); + + // Insert all operands. + Worklist.insert(Worklist.end(), I->op_begin(), I->op_end()); + } +} + +/// Add memory access related remarks for TheLoop. +void LoopVectorizationLegality::elaborateMemoryReport() { + switch (LAI->getFailureReason()) { + case LoopAccessInfo::FailureReason::UnsafeDataDependence: { + const auto &UnsafeDependences = + LAI->getDepChecker().getUnsafeDependences(); + unsigned NumUnsafeDeps = UnsafeDependences.size(); + assert(NumUnsafeDeps > 0 && "expected unsafe dependencies but found none"); + + // Emit detailed remarks for each unsafe dependence + for (const auto &Dep : UnsafeDependences) { + auto *Source = getPointerOperand(Dep.getSource(*LAI)); + + switch (Dep.Type) { + case MemoryDepChecker::Dependence::NoDep: + case MemoryDepChecker::Dependence::Forward: + case MemoryDepChecker::Dependence::BackwardVectorizable: + // Don't emit a remark for dependences that don't block vectorization. + continue; + default: + break; + } + + + DebugLoc Destination; + if (auto *D = dyn_cast(Dep.getDestination(*LAI))) { + Destination = D->getDebugLoc(); + if (auto *DD = dyn_cast(getPointerOperand(D))) + Destination = DD->getDebugLoc(); + } + + OptimizationRemarkAnalysis R(Hints->vectorizeAnalysisPassName(), + "UnsafeDep", Destination, + TheLoop->getHeader()); + R << "loop not vectorized: "; + switch (Dep.Type) { + case MemoryDepChecker::Dependence::NoDep: + case MemoryDepChecker::Dependence::Forward: + case MemoryDepChecker::Dependence::BackwardVectorizable: + llvm_unreachable("Unexpected dependency"); + case MemoryDepChecker::Dependence::Backward: + ORE->emit(R << "Backward loop carried dependence on " + << ore::NV("Source", Source)); + break; + case MemoryDepChecker::Dependence::ForwardButPreventsForwarding: + ORE->emit(R << "Loop carried forward dependence that prevents " + "load/store forwarding on " + << ore::NV("Source", Source)); + break; + case MemoryDepChecker::Dependence:: + BackwardVectorizableButPreventsForwarding: + ORE->emit(R << "Loop carried backward dependence that prevents " + "load/store forwarding on " + << ore::NV("Source", Source)); + break; + case MemoryDepChecker::Dependence::Unknown: + ORE->emit(R << "Unknown dependence on " + << ore::NV("Source", Source)); + break; + } + } + break; + } + case LoopAccessInfo::FailureReason::UnknownArrayBounds: { + // add detailed remarks at locations of pointers where bound cannot + // be computed + for (Value *Ptr : LAI->getUncomputablePtrs()) + if (auto *I = dyn_cast(Ptr)) + ORE->emit(createMissedAnalysis("UnknownArrayBounds", I) + << "Unknown array bounds"); + break; + } + case LoopAccessInfo::FailureReason::Unknown: + case LoopAccessInfo::FailureReason::UnsafeDataDependenceTriedRT: + break; + } +} + +bool LoopVectorizationLegality::canVectorizeMemory() { + LAI = &(*GetLAA)(*TheLoop); + InterleaveInfo.setLAI(LAI); + const OptimizationRemarkAnalysis *LAR = LAI->getReport(); + if (LAR) { + OptimizationRemarkAnalysis VR(Hints->vectorizeAnalysisPassName(), + "loop not vectorized: ", *LAR); + ORE->emit(VR); + } + + if (!LAI->canVectorizeMemory()) { + elaborateMemoryReport(); + return false; + } + + if (LAI->hasStoreToLoopInvariantAddress()) { + ScalarEvolution *SE = PSE.getSE(); + std::list UnhandledStores; + + // For each invariant address, check its last stored value is the result + // of one of our reductions and is unconditional. + for (StoreInst *SI : LAI->getInvariantStores()) { + bool FoundMatchingRecurrence = false; + for (auto &II : Reductions) { + RecurrenceDescriptor DS = II.second; + StoreInst *DSI = DS.IntermediateStore; + if (DSI && (DSI == SI) && !blockNeedsPredication(DSI->getParent())) { + FoundMatchingRecurrence = true; + break; + } + } + + if (FoundMatchingRecurrence) + // Earlier stores to this address are effectively deadcode. + UnhandledStores.remove_if([SE, SI](StoreInst *I) { + return storeToSameAddress(SE, SI, I); + }); + else + UnhandledStores.push_back(SI); + } + + bool IsOK = UnhandledStores.empty(); + // TODO: we should also validate against InvariantMemSets. + if (!IsOK) { + ORE->emit(createMissedAnalysis("CantVectorizeStoreToLoopInvariantAddress") + << "write to a loop invariant address could not be vectorized"); + LLVM_DEBUG(dbgs() << "LV: We don't allow storing to uniform addresses\n"); + return false; + } + } + + Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); + PSE.addPredicate(LAI->getPSE().getUnionPredicate()); + + return true; +} + +bool LoopVectorizationLegality::isInductionVariable(const Value *V) { + Value *In0 = const_cast(V); + + if (EnableScalableVectorisation) { + // TODO: Need to handle other arithmetic/logical instructions + Instruction *Inst = dyn_cast(In0); + if (Inst && Inst->getOpcode() == Instruction::Shl) { + Value *ShiftVal = Inst->getOperand(1); + if (!dyn_cast(ShiftVal)) + return false; + In0 = Inst->getOperand(0); + } + } + + PHINode *PN = dyn_cast_or_null(In0); + if (!PN) + return false; + + return Inductions.count(PN); +} + +bool LoopVectorizationLegality::isFirstOrderRecurrence(const PHINode *Phi) { + return FirstOrderRecurrences.count(Phi); +} + +bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) { + return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); +} + +bool LoopVectorizationLegality::blockCanBePredicated( + BasicBlock *BB, SmallPtrSetImpl &SafePtrs) { + const bool IsAnnotatedParallel = TheLoop->isAnnotatedParallel(); + + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + // Check that we don't have a constant expression that can trap as operand. + for (Instruction::op_iterator OI = it->op_begin(), OE = it->op_end(); + OI != OE; ++OI) { + if (Constant *C = dyn_cast(*OI)) + if (C->canTrap()) + return false; + } + // We might be able to hoist the load. + if (it->mayReadFromMemory()) { + LoadInst *LI = dyn_cast(it); + if (!LI) + return false; + if (!SafePtrs.count(LI->getPointerOperand())) { + if (isLegalMaskedLoad(LI->getType(), LI->getPointerOperand()) || + isLegalMaskedGather(LI->getType())) { + MaskedOp.insert(LI); + continue; + } + // !llvm.mem.parallel_loop_access implies if-conversion safety. + if (IsAnnotatedParallel) + continue; + return false; + } + } + + // We don't predicate stores at the moment. + if (it->mayWriteToMemory()) { + StoreInst *SI = dyn_cast(it); + // We only support predication of stores in basic blocks with one + // predecessor. + if (!SI) + return false; + + // Build a masked store if it is legal for the target. + if (isLegalMaskedStore(SI->getValueOperand()->getType(), + SI->getPointerOperand()) || + isLegalMaskedScatter(SI->getValueOperand()->getType())) { + MaskedOp.insert(SI); + continue; + } + + bool isSafePtr = (SafePtrs.count(SI->getPointerOperand()) != 0); + bool isSinglePredecessor = SI->getParent()->getSinglePredecessor(); + + if (++NumPredStores > NumberOfStoresToPredicate || !isSafePtr || + !isSinglePredecessor) + return false; + } + if (it->mayThrow()) + return false; + + // The instructions below can trap. + switch (it->getOpcode()) { + default: + continue; + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + return false; + } + } + + return true; +} + +void InterleavedAccessInfo::collectConstStrideAccesses( + MapVector &AccessStrideInfo, + const ValueToValueMap &Strides) { + + auto &DL = TheLoop->getHeader()->getModule()->getDataLayout(); + + // Since it's desired that the load/store instructions be maintained in + // "program order" for the interleaved access analysis, we have to visit the + // blocks in the loop in reverse postorder (i.e., in a topological order). + // Such an ordering will ensure that any load/store that may be executed + // before a second load/store will precede the second load/store in + // AccessStrideInfo. + LoopBlocksDFS DFS(TheLoop); + DFS.perform(LI); + for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) + for (auto &I : *BB) { + auto *LI = dyn_cast(&I); + auto *SI = dyn_cast(&I); + if (!LI && !SI) + continue; + + Value *Ptr = getPointerOperand(&I); + // We don't check wrapping here because we don't know yet if Ptr will be + // part of a full group or a group with gaps. Checking wrapping for all + // pointers (even those that end up in groups with no gaps) will be overly + // conservative. For full groups, wrapping should be ok since if we would + // wrap around the address space we would do a memory access at nullptr + // even without the transformation. The wrapping checks are therefore + // deferred until after we've formed the interleaved groups. + int64_t Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, + /*Assume=*/true, /*ShouldCheckWrap=*/false); + + const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); + PointerType *PtrTy = dyn_cast(Ptr->getType()); + uint64_t Size = DL.getTypeAllocSize(PtrTy->getElementType()); + + // An alignment of 0 means target ABI alignment. + unsigned Align = getMemInstAlignment(&I); + if (!Align) + Align = DL.getABITypeAlignment(PtrTy->getElementType()); + + AccessStrideInfo[&I] = StrideDescriptor(Stride, Scev, Size, Align); + } +} + +// Analyze interleaved accesses and collect them into interleaved load and +// store groups. +// +// When generating code for an interleaved load group, we effectively hoist all +// loads in the group to the location of the first load in program order. When +// generating code for an interleaved store group, we sink all stores to the +// location of the last store. This code motion can change the order of load +// and store instructions and may break dependences. +// +// The code generation strategy mentioned above ensures that we won't violate +// any write-after-read (WAR) dependences. +// +// E.g., for the WAR dependence: a = A[i]; // (1) +// A[i] = b; // (2) +// +// The store group of (2) is always inserted at or below (2), and the load +// group of (1) is always inserted at or above (1). Thus, the instructions will +// never be reordered. All other dependences are checked to ensure the +// correctness of the instruction reordering. +// +// The algorithm visits all memory accesses in the loop in bottom-up program +// order. Program order is established by traversing the blocks in the loop in +// reverse postorder when collecting the accesses. +// +// We visit the memory accesses in bottom-up order because it can simplify the +// construction of store groups in the presence of write-after-write (WAW) +// dependences. +// +// E.g., for the WAW dependence: A[i] = a; // (1) +// A[i] = b; // (2) +// A[i + 1] = c; // (3) +// +// We will first create a store group with (3) and (2). (1) can't be added to +// this group because it and (2) are dependent. However, (1) can be grouped +// with other accesses that may precede it in program order. Note that a +// bottom-up order does not imply that WAW dependences should not be checked. +void InterleavedAccessInfo::analyzeInterleaving( + const ValueToValueMap &Strides) { + LLVM_DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n"); + + // Holds all accesses with a constant stride. + MapVector AccessStrideInfo; + collectConstStrideAccesses(AccessStrideInfo, Strides); + + if (AccessStrideInfo.empty()) + return; + + // Collect the dependences in the loop. + collectDependences(); + + // Holds all interleaved store groups temporarily. + SmallSetVector StoreGroups; + // Holds all interleaved load groups temporarily. + SmallSetVector LoadGroups; + + // Search in bottom-up program order for pairs of accesses (A and B) that can + // form interleaved load or store groups. In the algorithm below, access A + // precedes access B in program order. We initialize a group for B in the + // outer loop of the algorithm, and then in the inner loop, we attempt to + // insert each A into B's group if: + // + // 1. A and B have the same stride, + // 2. A and B have the same memory object size, and + // 3. A belongs in B's group according to its distance from B. + // + // Special care is taken to ensure group formation will not break any + // dependences. + for (auto BI = AccessStrideInfo.rbegin(), E = AccessStrideInfo.rend(); + BI != E; ++BI) { + Instruction *B = BI->first; + StrideDescriptor DesB = BI->second; + + // Initialize a group for B if it has an allowable stride. Even if we don't + // create a group for B, we continue with the bottom-up algorithm to ensure + // we don't break any of B's dependences. + InterleaveGroup *Group = nullptr; + if (isStrided(DesB.Stride)) { + Group = getInterleaveGroup(B); + if (!Group) { + LLVM_DEBUG(dbgs() << "LV: Creating an interleave group with:" << *B << '\n'); + Group = createInterleaveGroup(B, DesB.Stride, DesB.Align); + } + if (B->mayWriteToMemory()) + StoreGroups.insert(Group); + else + LoadGroups.insert(Group); + } + + for (auto AI = std::next(BI); AI != E; ++AI) { + Instruction *A = AI->first; + StrideDescriptor DesA = AI->second; + + // Our code motion strategy implies that we can't have dependences + // between accesses in an interleaved group and other accesses located + // between the first and last member of the group. Note that this also + // means that a group can't have more than one member at a given offset. + // The accesses in a group can have dependences with other accesses, but + // we must ensure we don't extend the boundaries of the group such that + // we encompass those dependent accesses. + // + // For example, assume we have the sequence of accesses shown below in a + // stride-2 loop: + // + // (1, 2) is a group | A[i] = a; // (1) + // | A[i-1] = b; // (2) | + // A[i-3] = c; // (3) + // A[i] = d; // (4) | (2, 4) is not a group + // + // Because accesses (2) and (3) are dependent, we can group (2) with (1) + // but not with (4). If we did, the dependent access (3) would be within + // the boundaries of the (2, 4) group. + if (!canReorderMemAccessesForInterleavedGroups(&*AI, &*BI)) { + + // If a dependence exists and A is already in a group, we know that A + // must be a store since A precedes B and WAR dependences are allowed. + // Thus, A would be sunk below B. We release A's group to prevent this + // illegal code motion. A will then be free to form another group with + // instructions that precede it. + if (isInterleaved(A)) { + InterleaveGroup *StoreGroup = getInterleaveGroup(A); + StoreGroups.remove(StoreGroup); + releaseGroup(StoreGroup); + } + + // If a dependence exists and A is not already in a group (or it was + // and we just released it), B might be hoisted above A (if B is a + // load) or another store might be sunk below A (if B is a store). In + // either case, we can't add additional instructions to B's group. B + // will only form a group with instructions that it precedes. + break; + } + + // At this point, we've checked for illegal code motion. If either A or B + // isn't strided, there's nothing left to do. + if (!isStrided(DesA.Stride) || !isStrided(DesB.Stride)) + continue; + + // Ignore A if it's already in a group or isn't the same kind of memory + // operation as B. + if (isInterleaved(A) || A->mayReadFromMemory() != B->mayReadFromMemory()) + continue; + + // Check rules 1 and 2. Ignore A if its stride or size is different from + // that of B. + if (DesA.Stride != DesB.Stride || DesA.Size != DesB.Size) + continue; + + // Ignore A if the memory object of A and B don't belong to the same + // address space + if (getMemInstAddressSpace(A) != getMemInstAddressSpace(B)) + continue; + + // Calculate the distance from A to B. + const SCEVConstant *DistToB = dyn_cast( + PSE.getSE()->getMinusSCEV(DesA.Scev, DesB.Scev)); + if (!DistToB) + continue; + int64_t DistanceToB = DistToB->getAPInt().getSExtValue(); + + // Check rule 3. Ignore A if its distance to B is not a multiple of the + // size. + if (DistanceToB % static_cast(DesB.Size)) + continue; + + // Ignore A if either A or B is in a predicated block. Although we + // currently prevent group formation for predicated accesses, we may be + // able to relax this limitation in the future once we handle more + // complicated blocks. + if (isPredicated(A->getParent()) || isPredicated(B->getParent())) + continue; + + // The index of A is the index of B plus A's distance to B in multiples + // of the size. + int IndexA = + Group->getIndex(B) + DistanceToB / static_cast(DesB.Size); + + // Try to insert A into B's group. + if (Group->insertMember(A, IndexA, DesA.Align)) { + LLVM_DEBUG(dbgs() << "LV: Inserted:" << *A << '\n' + << " into the interleave group with" << *B << '\n'); + InterleaveGroupMap[A] = Group; + + // Set the first load in program order as the insert position. + if (A->mayReadFromMemory()) + Group->setInsertPos(A); + } + } // Iteration over A accesses. + } // Iteration over B accesses. + + // Remove interleaved store groups with gaps. + for (InterleaveGroup *Group : StoreGroups) + if (Group->getNumMembers() != Group->getFactor()) + releaseGroup(Group); + + // Remove interleaved groups with gaps (currently only loads) whose memory + // accesses may wrap around. We have to revisit the getPtrStride analysis, + // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does + // not check wrapping (see documentation there). + // FORNOW we use Assume=false; + // TODO: Change to Assume=true but making sure we don't exceed the threshold + // of runtime SCEV assumptions checks (thereby potentially failing to + // vectorize altogether). + // Additional optional optimizations: + // TODO: If we are peeling the loop and we know that the first pointer doesn't + // wrap then we can deduce that all pointers in the group don't wrap. + // This means that we can forcefully peel the loop in order to only have to + // check the first pointer for no-wrap. When we'll change to use Assume=true + // we'll only need at most one runtime check per interleaved group. + // + for (InterleaveGroup *Group : LoadGroups) { + + // Case 1: A full group. Can Skip the checks; For full groups, if the wide + // load would wrap around the address space we would do a memory access at + // nullptr even without the transformation. + if (Group->getNumMembers() == Group->getFactor()) + continue; + + // Case 2: If first and last members of the group don't wrap this implies + // that all the pointers in the group don't wrap. + // So we check only group member 0 (which is always guaranteed to exist), + // and group member Factor - 1; If the latter doesn't exist we rely on + // peeling (if it is a non-reveresed accsess -- see Case 3). + Value *FirstMemberPtr = getPointerOperand(Group->getMember(0)); + if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false, + /*ShouldCheckWrap=*/true)) { + LLVM_DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to " + "first group member potentially pointer-wrapping.\n"); + releaseGroup(Group); + continue; + } + Instruction *LastMember = Group->getMember(Group->getFactor() - 1); + if (LastMember) { + Value *LastMemberPtr = getPointerOperand(LastMember); + if (!getPtrStride(PSE, LastMemberPtr, TheLoop, Strides, /*Assume=*/false, + /*ShouldCheckWrap=*/true)) { + LLVM_DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to " + "last group member potentially pointer-wrapping.\n"); + releaseGroup(Group); + } + } else { + // Case 3: A non-reversed interleaved load group with gaps: We need + // to execute at least one scalar epilogue iteration. This will ensure + // we don't speculatively access memory out-of-bounds. We only need + // to look for a member at index factor - 1, since every group must have + // a member at index zero. + if (Group->isReverse()) { + releaseGroup(Group); + continue; + } + LLVM_DEBUG(dbgs() << "LV: Interleaved group requires epilogue iteration.\n"); + RequiresScalarEpilogue = true; + } + } +} + +static TargetTransformInfo::ReductionFlags +getReductionFlagsFromDesc(RecurrenceDescriptor Rdx) { + using RD = RecurrenceDescriptor; + RD::RecurrenceKind RecKind = Rdx.getRecurrenceKind(); + TargetTransformInfo::ReductionFlags Flags; + Flags.IsOrdered = Rdx.isOrdered(); + if (RecKind == RD::RK_IntegerMinMax || RecKind == RD::RK_FloatMinMax) { + auto MMKind = Rdx.getMinMaxRecurrenceKind(); + Flags.IsSigned = MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_SIntMin; + Flags.IsMaxOp = MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_FloatMax; + } else if (RecKind == RD::RK_ConstSelectICmp || + RecKind == RD::RK_ConstSelectFCmp) { + auto MMKind = Rdx.getMinMaxRecurrenceKind(); + Flags.IsSigned = true; + Flags.IsMaxOp = MMKind == RD::MRK_SIntMax; + } + return Flags; +} + +VectorizationFactor +LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { + bool FixedWidth = !EnableScalableVectorisation; + + int UserVStyle = Hints->getStyle(); + if (UserVStyle != LoopVectorizeHints::SK_Unspecified) { + FixedWidth = UserVStyle == LoopVectorizeHints::SK_Fixed; + LLVM_DEBUG(dbgs() << "LV: Using user vectorization style of " + << (FixedWidth ? "fixed" : "scaled") << " width.\n"); + } + + // Width 1 means no vectorize + VectorizationFactor Factor = { 1U, 0U, FixedWidth }; + + if (OptForSize && Legal->getRuntimePointerChecking()->Need) { + ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize") + << "runtime pointer checks needed. Enable vectorization of this " + "loop with '#pragma clang loop vectorize(enable)' when " + "compiling with -Os/-Oz"); + LLVM_DEBUG(dbgs() << + "LV: Aborting. Runtime ptr check is required with -Os/-Oz.\n"); + Factor.isFixed = true; + return Factor; + } + + if (!EnableCondStoresVectorization && Legal->getNumPredStores()) { + ORE->emit(createMissedAnalysis("ConditionalStore") + << "store that is conditionally executed prevents vectorization"); + LLVM_DEBUG(dbgs() << "LV: No vectorization. There are conditional stores.\n"); + Factor.isFixed = true; + return Factor; + } + + // Find the trip count. + unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); + LLVM_DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); + + MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI); + unsigned SmallestType, WidestType; + std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes(); + unsigned WidestRegister = TTI.getRegisterBitWidth(true); + unsigned MaxSafeDepDist = -1U; + + // Get the maximum safe dependence distance in bits computed by LAA. If the + // loop contains any interleaved accesses, we divide the dependence distance + // by the maximum interleave factor of all interleaved groups. Note that + // although the division ensures correctness, this is a fairly conservative + // computation because the maximum distance computed by LAA may not involve + // any of the interleaved accesses. + if (Legal->getMaxSafeDepDistBytes() != -1U) + MaxSafeDepDist = + Legal->getMaxSafeDepDistBytes() * 8 / Legal->getMaxInterleaveFactor(); + + // For the case when the register size is unknown we cannot vectorise loops + // with data dependencies in a scalable manner. However, when the + // architecture provides an upper bound, we can query that before reverting + // to fixed width vectors. + if (MaxSafeDepDist < TTI.getRegisterBitWidthUpperBound(true)) { + Factor.isFixed = true; + // LAA may have assumed we can do strided during analysis + if (Legal->getRuntimePointerChecking()->Strided && + TTI.canVectorizeNonUnitStrides(true)) { + LLVM_DEBUG(dbgs() << + "LV: Not vectorizing, can't do strided accesses on target.\n"); + ORE->emit(createMissedAnalysis("StridedAccess") + << "Target doesn't support vectorizing strided accesses."); + Factor.Width = 1; + return Factor; + } + } + + WidestRegister = + ((WidestRegister < MaxSafeDepDist) ? WidestRegister : MaxSafeDepDist); + unsigned MaxVectorSize = WidestRegister / WidestType; + + LLVM_DEBUG(dbgs() << "LV: The Smallest and Widest types: " << SmallestType << " / " + << WidestType << " bits.\n"); + LLVM_DEBUG(dbgs() << "LV: The Widest register is: " << WidestRegister + << " bits.\n"); + + if (MaxVectorSize == 0) { + LLVM_DEBUG(dbgs() << "LV: The target has no vector registers.\n"); + MaxVectorSize = 1; + } + + assert(MaxVectorSize <= 64 && "Did not expect to pack so many elements" + " into one vector!"); + + unsigned VF = MaxVectorSize; + if (MaximizeBandwidth && !OptForSize) { + // Collect all viable vectorization factors. + SmallVector VFs; + unsigned NewMaxVectorSize = WidestRegister / SmallestType; + for (unsigned VS = MaxVectorSize; VS <= NewMaxVectorSize; VS *= 2) + VFs.push_back(VS); + + // For each VF calculate its register usage. + auto RUs = calculateRegisterUsage(VFs); + + // Select the largest VF which doesn't require more registers than existing + // ones. + unsigned TargetNumRegisters = TTI.getNumberOfRegisters(true); + for (int i = RUs.size() - 1; i >= 0; --i) { + if (RUs[i].MaxLocalUsers <= TargetNumRegisters) { + VF = VFs[i]; + break; + } + } + } + + // If we optimize the program for size, avoid creating the tail loop. + if (OptForSize) { + // If we are unable to calculate the trip count then don't try to vectorize. + if (TC < 2) { + ORE->emit( + createMissedAnalysis("UnknownLoopCountComplexCFG") + << "unable to calculate the loop count due to complex control flow"); + LLVM_DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); + if (Factor.Width < 2) + Factor.isFixed = true; + return Factor; + } + + // Find the maximum SIMD width that can fit within the trip count. + VF = TC % MaxVectorSize; + + if (VF == 0) + VF = MaxVectorSize; + else { + // If the trip count that we found modulo the vectorization factor is not + // zero then we require a tail. + ORE->emit(createMissedAnalysis("NoTailLoopWithOptForSize") + << "cannot optimize for size and vectorize at the " + "same time. Enable vectorization of this loop " + "with '#pragma clang loop vectorize(enable)' " + "when compiling with -Os/-Oz"); + LLVM_DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); + Factor.isFixed = true; + return Factor; + } + } + + int UserVF = Hints->getWidth(); + if (UserVF != 0) { + assert(isPowerOf2_32(UserVF) && "VF needs to be a power of two"); + LLVM_DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); + + Factor.Width = UserVF; + if (Factor.Width < 2) + Factor.isFixed = true; + return Factor; + } + + float Cost = expectedCost({/*Width=*/1, 0, /*isFixed=*/true}).first; +#ifndef NDEBUG + const float ScalarCost = Cost; +#endif /* NDEBUG */ + Factor.Width = 1; + LLVM_DEBUG(dbgs() << "LV: Scalar loop costs: " << (int)ScalarCost << ".\n"); + + bool ForceVectorization = Hints->getForce() == LoopVectorizeHints::FK_Enabled; + // Ignore scalar width, because the user explicitly wants vectorization. + if (ForceVectorization && VF > 1) { + Factor.Width = 2; + Cost = expectedCost(Factor).first / (float)Factor.Width; + } + + VectorizationFactor PotentialFactor = Factor; + for (unsigned i = 2; i <= VF; i *= 2) { + // Notice that the vector loop needs to be executed less times, so + // we need to divide the cost of the vector loops by the width of + // the vector elements. + PotentialFactor.Width = i; + VectorizationCostTy C = expectedCost(PotentialFactor); + float VectorCost = C.first / (float)i; + LLVM_DEBUG(dbgs() << "LV: Vector loop of width " << i + << " costs: " << (int)VectorCost << ".\n"); + if (!C.second && !ForceVectorization) { + LLVM_DEBUG( + dbgs() << "LV: Not considering vector loop of width " << i + << " because it will not generate any vector instructions.\n"); + continue; + } + // Vectorize if the cost is less than or equal to the scalar cost. If the + // cost is equal, it may still be beneficial to vectorize the loop as + // the cost model will be basing its estimate on an VL of 128, where + // for scalable vectorization it may actually execute using more lanes. + if (VectorCost <= Cost) { + Cost = VectorCost; + Factor = PotentialFactor; + } + } + + LLVM_DEBUG(if (ForceVectorization && Factor.Width > 1 && Cost > ScalarCost) dbgs() + << "LV: Vectorization seems to be not beneficial, " + << "but was forced by a user.\n"); + Factor.Cost = Factor.Width * Cost; + if (Factor.Width < 2) + Factor.isFixed = true; + LLVM_DEBUG(dbgs() << "LV: Selecting VF: " << (Factor.isFixed ? "" : "n x ") << + Factor.Width << ".\n"); + return Factor; +} + +std::pair +LoopVectorizationCostModel::getSmallestAndWidestTypes() { + unsigned MinWidth = -1U; + unsigned MaxWidth = 8; + const DataLayout &DL = TheFunction->getParent()->getDataLayout(); + + // For each block. + for (Loop::block_iterator bb = TheLoop->block_begin(), + be = TheLoop->block_end(); + bb != be; ++bb) { + BasicBlock *BB = *bb; + + // For each instruction in the loop. + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + Type *T = it->getType(); + + // Skip ignored values. + if (ValuesToIgnore.count(&*it)) + continue; + + // Only examine Loads, Stores and PHINodes. + if (!isa(it) && !isa(it) && !isa(it)) + continue; + + // Examine PHI nodes that are reduction variables. Update the type to + // account for the recurrence type. + if (PHINode *PN = dyn_cast(it)) { + if (!Legal->isReductionVariable(PN)) + continue; + RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[PN]; + T = RdxDesc.getRecurrenceType(); + } + + // Examine the stored values. + if (StoreInst *ST = dyn_cast(it)) + T = ST->getValueOperand()->getType(); + + // Ignore loaded pointer types and stored pointer types that are not + // vectorizable. + // + // FIXME: The check here attempts to predict whether a load or store will + // be vectorized. We only know this for certain after a VF has + // been selected. Here, we assume that if an access can be + // vectorized, it will be. We should also look at extending this + // optimization to non-pointer types. + // + if (T->isPointerTy() && !isConsecutiveLoadOrStore(&*it) && + !Legal->isAccessInterleaved(&*it) && !Legal->isLegalGatherOrScatter(&*it)) + continue; + + MinWidth = std::min(MinWidth, + (unsigned)DL.getTypeSizeInBits(T->getScalarType())); + MaxWidth = std::max(MaxWidth, + (unsigned)DL.getTypeSizeInBits(T->getScalarType())); + } + } + + return {MinWidth, MaxWidth}; +} + +unsigned +LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, + VectorizationFactor VF, + unsigned LoopCost) { + + // -- The interleave heuristics -- + // We interleave the loop in order to expose ILP and reduce the loop overhead. + // There are many micro-architectural considerations that we can't predict + // at this level. For example, frontend pressure (on decode or fetch) due to + // code size, or the number and capabilities of the execution ports. + // + // We use the following heuristics to select the interleave count: + // 1. If the code has reductions, then we interleave to break the cross + // iteration dependency. + // 2. If the loop is really small, then we interleave to reduce the loop + // overhead. + // 3. We don't interleave if we think that we will spill registers to memory + // due to the increased register pressure. + + // TODO: revisit this decision but for now it is not worth considering + if (EnableVectorPredication && !VF.isFixed) + return 1; + + // When we optimize for size, we don't interleave. + if (OptForSize) + return 1; + + // We used the distance for the interleave count. + if (Legal->getMaxSafeDepDistBytes() != -1U) + return 1; + + // Do not interleave loops with a relatively small trip count. + unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); + if (TC > 1 && TC < TinyTripCountInterleaveThreshold) + return 1; + + // Ordered reductions can't be used with interleaving. + for (auto &Rdx : *Legal->getReductionVars()) + if (Rdx.second.isOrdered()) + return 1; + + unsigned TargetNumRegisters = TTI.getNumberOfRegisters(VF.Width > 1); + LLVM_DEBUG(dbgs() << "LV: The target has " << TargetNumRegisters + << " registers\n"); + + if (VF.Width == 1) { + if (ForceTargetNumScalarRegs.getNumOccurrences() > 0) + TargetNumRegisters = ForceTargetNumScalarRegs; + } else { + if (ForceTargetNumVectorRegs.getNumOccurrences() > 0) + TargetNumRegisters = ForceTargetNumVectorRegs; + } + + RegisterUsage R = calculateRegisterUsage({VF.Width})[0]; + // We divide by these constants so assume that we have at least one + // instruction that uses at least one register. + R.MaxLocalUsers = std::max(R.MaxLocalUsers, 1U); + R.NumInstructions = std::max(R.NumInstructions, 1U); + + // We calculate the interleave count using the following formula. + // Subtract the number of loop invariants from the number of available + // registers. These registers are used by all of the interleaved instances. + // Next, divide the remaining registers by the number of registers that is + // required by the loop, in order to estimate how many parallel instances + // fit without causing spills. All of this is rounded down if necessary to be + // a power of two. We want power of two interleave count to simplify any + // addressing operations or alignment considerations. + unsigned IC = PowerOf2Floor((TargetNumRegisters - R.LoopInvariantRegs) / + R.MaxLocalUsers); + + // Don't count the induction variable as interleaved. + if (EnableIndVarRegisterHeur) + IC = PowerOf2Floor((TargetNumRegisters - R.LoopInvariantRegs - 1) / + std::max(1U, (R.MaxLocalUsers - 1))); + + // Clamp the interleave ranges to reasonable counts. + unsigned MaxInterleaveCount = TTI.getMaxInterleaveFactor(VF.Width); + + // Check if the user has overridden the max. + if (VF.Width == 1) { + if (ForceTargetMaxScalarInterleaveFactor.getNumOccurrences() > 0) + MaxInterleaveCount = ForceTargetMaxScalarInterleaveFactor; + } else { + if (ForceTargetMaxVectorInterleaveFactor.getNumOccurrences() > 0) + MaxInterleaveCount = ForceTargetMaxVectorInterleaveFactor; + } + + // If we did not calculate the cost for VF (because the user selected the VF) + // then we calculate the cost of VF here. + if (LoopCost == 0) + LoopCost = expectedCost(VF).first; + + // Clamp the calculated IC to be between the 1 and the max interleave count + // that the target allows. + if (IC > MaxInterleaveCount) + IC = MaxInterleaveCount; + else if (IC < 1) + IC = 1; + + // Interleave if we vectorized this loop and there is a reduction that could + // benefit from interleaving. + if (VF.Width > 1 && Legal->getReductionVars()->size()) { + LLVM_DEBUG(dbgs() << "LV: Interleaving because of reductions.\n"); + return IC; + } + + // Note that if we've already vectorized the loop we will have done the + // runtime check and so interleaving won't require further checks. + bool InterleavingRequiresRuntimePointerCheck = + (VF.Width == 1 && Legal->getRuntimePointerChecking()->Need); + + // We want to interleave small loops in order to reduce the loop overhead and + // potentially expose ILP opportunities. + LLVM_DEBUG(dbgs() << "LV: Loop cost is " << LoopCost << '\n'); + if (!InterleavingRequiresRuntimePointerCheck && LoopCost < SmallLoopCost) { + // We assume that the cost overhead is 1 and we use the cost model + // to estimate the cost of the loop and interleave until the cost of the + // loop overhead is about 5% of the cost of the loop. + unsigned SmallIC = + std::min(IC, (unsigned)PowerOf2Floor(SmallLoopCost / LoopCost)); + + // Interleave until store/load ports (estimated by max interleave count) are + // saturated. + unsigned NumStores = Legal->getNumStores(); + unsigned NumLoads = Legal->getNumLoads(); + unsigned StoresIC = IC / (NumStores ? NumStores : 1); + unsigned LoadsIC = IC / (NumLoads ? NumLoads : 1); + + // If we have a scalar reduction (vector reductions are already dealt with + // by this point), we can increase the critical path length if the loop + // we're interleaving is inside another loop. Limit, by default to 2, so the + // critical path only gets increased by one reduction operation. + if (Legal->getReductionVars()->size() && TheLoop->getLoopDepth() > 1) { + unsigned F = static_cast(MaxNestedScalarReductionIC); + SmallIC = std::min(SmallIC, F); + StoresIC = std::min(StoresIC, F); + LoadsIC = std::min(LoadsIC, F); + } + + if (EnableLoadStoreRuntimeInterleave && + std::max(StoresIC, LoadsIC) > SmallIC) { + LLVM_DEBUG(dbgs() << "LV: Interleaving to saturate store or load ports.\n"); + return std::max(StoresIC, LoadsIC); + } + + LLVM_DEBUG(dbgs() << "LV: Interleaving to reduce branch cost.\n"); + return SmallIC; + } + + // Interleave if this is a large loop (small loops are already dealt with by + // this point) that could benefit from interleaving. + bool HasReductions = (Legal->getReductionVars()->size() > 0); + if (TTI.enableAggressiveInterleaving(HasReductions)) { + LLVM_DEBUG(dbgs() << "LV: Interleaving to expose ILP.\n"); + return IC; + } + + LLVM_DEBUG(dbgs() << "LV: Not Interleaving.\n"); + return 1; +} + +SmallVector +LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef VFs) { + // This function calculates the register usage by measuring the highest number + // of values that are alive at a single location. Obviously, this is a very + // rough estimation. We scan the loop in a topological order in order and + // assign a number to each instruction. We use RPO to ensure that defs are + // met before their users. We assume that each instruction that has in-loop + // users starts an interval. We record every time that an in-loop value is + // used, so we have a list of the first and last occurrences of each + // instruction. Next, we transpose this data structure into a multi map that + // holds the list of intervals that *end* at a specific location. This multi + // map allows us to perform a linear search. We scan the instructions linearly + // and record each time that a new interval starts, by placing it in a set. + // If we find this value in the multi-map then we remove it from the set. + // The max register usage is the maximum size of the set. + // We also search for instructions that are defined outside the loop, but are + // used inside the loop. We need this number separately from the max-interval + // usage number because when we unroll, loop-invariant values do not take + // more register. + LoopBlocksDFS DFS(TheLoop); + DFS.perform(LI); + + RegisterUsage RU; + RU.NumInstructions = 0; + + // Each 'key' in the map opens a new interval. The values + // of the map are the index of the 'last seen' usage of the + // instruction that is the key. + typedef DenseMap IntervalMap; + // Maps instruction to its index. + DenseMap IdxToInstr; + // Marks the end of each interval. + IntervalMap EndPoint; + // Saves the list of instruction indices that are used in the loop. + SmallSet Ends; + // Saves the list of values that are used in the loop but are + // defined outside the loop, such as arguments and constants. + SmallPtrSet LoopInvariants; + + unsigned Index = 0; + for (LoopBlocksDFS::RPOIterator bb = DFS.beginRPO(), be = DFS.endRPO(); + bb != be; ++bb) { + RU.NumInstructions += (*bb)->size(); + for (Instruction &I : **bb) { + IdxToInstr[Index++] = &I; + + // Save the end location of each USE. + for (unsigned i = 0; i < I.getNumOperands(); ++i) { + Value *U = I.getOperand(i); + Instruction *Instr = dyn_cast(U); + + // Ignore non-instruction values such as arguments, constants, etc. + if (!Instr) + continue; + + // If this instruction is outside the loop then record it and continue. + if (!TheLoop->contains(Instr)) { + LoopInvariants.insert(Instr); + continue; + } + + // Overwrite previous end points. + EndPoint[Instr] = Index; + Ends.insert(Instr); + } + } + } + + // Saves the list of intervals that end with the index in 'key'. + typedef SmallVector InstrList; + DenseMap TransposeEnds; + + // Transpose the EndPoints to a list of values that end at each index. + for (IntervalMap::iterator it = EndPoint.begin(), e = EndPoint.end(); it != e; + ++it) + TransposeEnds[it->second].push_back(it->first); + + SmallSet OpenIntervals; + + // Get the size of the widest register. + unsigned MaxSafeDepDist = -1U; + if (Legal->getMaxSafeDepDistBytes() != -1U) + MaxSafeDepDist = Legal->getMaxSafeDepDistBytes() * 8; + unsigned WidestRegister = + std::min(TTI.getRegisterBitWidth(true), MaxSafeDepDist); + const DataLayout &DL = TheFunction->getParent()->getDataLayout(); + + SmallVector RUs(VFs.size()); + SmallVector MaxUsages(VFs.size(), 0); + + LLVM_DEBUG(dbgs() << "LV(REG): Calculating max register usage:\n"); + + // A lambda that gets the register usage for the given type and VF. + auto GetRegUsage = [&DL, WidestRegister](Type *Ty, unsigned VF) { + if (Ty->isTokenTy()) + return 0U; + unsigned TypeSize = DL.getTypeSizeInBits(Ty->getScalarType()); + return std::max(1, VF * TypeSize / WidestRegister); + }; + + for (unsigned int i = 0; i < Index; ++i) { + Instruction *I = IdxToInstr[i]; + // Ignore instructions that are never used within the loop. + if (!Ends.count(I)) + continue; + + // Remove all of the instructions that end at this location. + InstrList &List = TransposeEnds[i]; + for (unsigned int j = 0, e = List.size(); j < e; ++j) + OpenIntervals.erase(List[j]); + + // Skip ignored values. + if (ValuesToIgnore.count(I)) + continue; + + // For each VF find the maximum usage of registers. + for (unsigned j = 0, e = VFs.size(); j < e; ++j) { + if (VFs[j] == 1) { + MaxUsages[j] = std::max(MaxUsages[j], OpenIntervals.size()); + continue; + } + + // Count the number of live intervals. + unsigned RegUsage = 0; + for (auto Inst : OpenIntervals) { + // Skip ignored values for VF > 1. + if (VecValuesToIgnore.count(Inst)) + continue; + RegUsage += GetRegUsage(Inst->getType(), VFs[j]); + } + MaxUsages[j] = std::max(MaxUsages[j], RegUsage); + } + + LLVM_DEBUG(dbgs() << "LV(REG): At #" << i << " Interval # " + << OpenIntervals.size() << '\n'); + + // Add the current instruction to the list of open intervals. + OpenIntervals.insert(I); + } + + for (unsigned i = 0, e = VFs.size(); i < e; ++i) { + unsigned Invariant = 0; + if (VFs[i] == 1) + Invariant = LoopInvariants.size(); + else { + for (auto Inst : LoopInvariants) + Invariant += GetRegUsage(Inst->getType(), VFs[i]); + } + + LLVM_DEBUG(dbgs() << "LV(REG): VF = " << VFs[i] << '\n'); + LLVM_DEBUG(dbgs() << "LV(REG): Found max usage: " << MaxUsages[i] << '\n'); + LLVM_DEBUG(dbgs() << "LV(REG): Found invariant usage: " << Invariant << '\n'); + LLVM_DEBUG(dbgs() << "LV(REG): LoopSize: " << RU.NumInstructions << '\n'); + + RU.LoopInvariantRegs = Invariant; + RU.MaxLocalUsers = MaxUsages[i]; + RUs[i] = RU; + } + + return RUs; +} + +LoopVectorizationCostModel::VectorizationCostTy +LoopVectorizationCostModel::expectedCost(VectorizationFactor VF) { + VectorizationCostTy Cost; + + // For each block. + for (Loop::block_iterator bb = TheLoop->block_begin(), + be = TheLoop->block_end(); + bb != be; ++bb) { + VectorizationCostTy BlockCost; + BasicBlock *BB = *bb; + + // For each instruction in the old loop. + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + // Skip dbg intrinsics. + if (isa(it)) + continue; + + // Skip ignored values. + if (ValuesToIgnore.count(&*it)) + continue; + + VectorizationCostTy C = getInstructionCost(&*it, VF); + + // Check if we should override the cost. + if (ForceTargetInstructionCost.getNumOccurrences() > 0) + C.first = ForceTargetInstructionCost; + + BlockCost.first += C.first; + BlockCost.second |= C.second; + LLVM_DEBUG(dbgs() << "LV: Found an estimated cost of " << C.first + << " for VF " << (VF.isFixed ? "" : "n x ") << VF.Width + << " For instruction: " << *it << '\n'); + } + + // We assume that if-converted blocks have a 50% chance of being executed. + // When the code is scalar then some of the blocks are avoided due to CF. + // When the code is vectorized we execute all code paths. + if (VF.Width == 1 && Legal->blockNeedsPredication(*bb)) + BlockCost.first /= 2; + + Cost.first += BlockCost.first; + Cost.second |= BlockCost.second; + } + + return Cost; +} + +/// \brief Check if the load/store instruction \p I may be translated into +/// gather/scatter during vectorization. +/// +/// Pointer \p Ptr specifies address in memory for the given scalar memory +/// instruction. We need it to retrieve data type. +/// Using gather/scatter is possible when it is supported by target. +/* +static bool isGatherOrScatterLegal(Instruction *I, Value *Ptr, + LoopVectorizationLegality *Legal) { + Type *DataTy = cast(Ptr->getType())->getElementType(); + return (isa(I) && Legal->isLegalMaskedGather(DataTy)) || + (isa(I) && Legal->isLegalMaskedScatter(DataTy)); +} + */ + +/// \brief Check whether the address computation for a non-consecutive memory +/// access looks like an unlikely candidate for being merged into the indexing +/// mode. +/// +/// We look for a GEP which has one index that is an induction variable and all +/// other indices are loop invariant. If the stride of this access is also +/// within a small bound we decide that this address computation can likely be +/// merged into the addressing mode. +/// In all other cases, we identify the address computation as complex. +/* +static bool isLikelyComplexAddressComputation(Value *Ptr, + LoopVectorizationLegality *Legal, + ScalarEvolution *SE, + const Loop *TheLoop) { + GetElementPtrInst *Gep = dyn_cast(Ptr); + if (!Gep) + return true; + + // We are looking for a gep with all loop invariant indices except for one + // which should be an induction variable. + unsigned NumOperands = Gep->getNumOperands(); + for (unsigned i = 1; i < NumOperands; ++i) { + Value *Opd = Gep->getOperand(i); + if (!SE->isLoopInvariant(SE->getSCEV(Opd), TheLoop) && + !Legal->isInductionVariable(Opd)) + return true; + } + + // Now we know we have a GEP ptr, %inv, %ind, %inv. Make sure that the step + // can likely be merged into the address computation. + unsigned MaxMergeDistance = 64; + + const SCEVAddRecExpr *AddRec = dyn_cast(SE->getSCEV(Ptr)); + if (!AddRec) + return true; + + // Check the step is constant. + const SCEV *Step = AddRec->getStepRecurrence(*SE); + // Calculate the pointer stride and check if it is consecutive. + const SCEVConstant *C = dyn_cast(Step); + if (!C) + return true; + + const APInt &APStepVal = C->getAPInt(); + + // Huge step value - give up. + if (APStepVal.getBitWidth() > 64) + return true; + + int64_t StepVal = APStepVal.getSExtValue(); + + return StepVal > MaxMergeDistance; +} + */ + +static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) { + return Legal->hasStride(I->getOperand(0)) || + Legal->hasStride(I->getOperand(1)); +} + +// Given a Chain +// A -> B -> Z, +// where: +// A = s/zext +// B = add +// C = trunc +// Check this is one of +// s/zext(i32) -> add -> trunc(valtype) +static bool isPartOfPromotedAdd(Instruction *I, Type **OrigType) { + Instruction *TruncOp = I; + + // If I is one of step A, find step C + if ((I->getOpcode() == Instruction::ZExt || + I->getOpcode() == Instruction::SExt)) { + // Confirm that s/zext is *only* used for the add + for(int K=0; K<2; ++K) { + if (!TruncOp->hasOneUse()) + return false; + TruncOp = dyn_cast(TruncOp->user_back()); + } + } + // If I is one of step B, find step C + else if ((I->getOpcode() == Instruction::Add)) { + if (!I->hasOneUse()) + return false; + TruncOp = I->user_back(); + } + + // Check if I is one of step C + if (TruncOp->getOpcode() != Instruction::Trunc) + return false; + + if (Instruction *Opnd = dyn_cast(TruncOp->getOperand(0))) { + if (TruncOp->getOpcode() != Instruction::Trunc || + Opnd->getOpcode() != Instruction::Add || !Opnd->hasNUses(1)) + return false; + + // Check each operand to the 'add' + unsigned cnt = 0; + for (Value *V : Opnd->operands()) { + if (const Instruction *AddOpnd = dyn_cast(V)) { + if (AddOpnd->getOpcode() != Instruction::ZExt && + AddOpnd->getOpcode() != Instruction::SExt) + break; + + if (!AddOpnd->getType()->isIntegerTy(32)) + break; + + if ( AddOpnd->getOperand(0)->getType() != TruncOp->getType() || + !AddOpnd->hasNUses(1)) + break; + } + cnt++; + } + + if (cnt == Opnd->getNumOperands()) { + if (OrigType) + *OrigType = TruncOp->getType(); + return true; + } + } + + return false; +} + +static MemAccessInfo calculateMemAccessInfo(Instruction *I, + Type *VectorTy, + LoopVectorizationLegality *Legal, + ScalarEvolution *SE) { + const DataLayout &DL = I->getModule()->getDataLayout(); + + // Get pointer operand + Value *Ptr = nullptr; + if (auto *LI = dyn_cast(I)) + Ptr = LI->getPointerOperand(); + if (auto *SI = dyn_cast(I)) + Ptr = SI->getPointerOperand(); + + assert (Ptr && "Could not get pointer operand from instruction"); + + // Check for uniform access (scalar load + splat) + if (Legal->isUniform(Ptr)) + return MemAccessInfo::getUniformInfo(); + + // Get whether it is a predicated memory operation + bool IsMasked = Legal->isMaskRequired(I); + + // Try to find the stride of the pointer expression + if (auto *SAR = dyn_cast(SE->getSCEV(Ptr))) { + const SCEV *StepRecurrence = SAR->getStepRecurrence(*SE); + if (auto *StrideV = dyn_cast(StepRecurrence)) { + // Get the element size + unsigned VectorElementSize = + DL.getTypeStoreSize(VectorTy) / VectorTy->getVectorNumElements(); + + // Normalize Stride from bytes to number of elements + int Stride = + StrideV->getValue()->getSExtValue() / ((int64_t)VectorElementSize); + return MemAccessInfo::getStridedInfo(Stride, Stride < 0, IsMasked); + } else { + // Unknown stride is a subset of gather/scatter + return MemAccessInfo::getNonStridedInfo(StepRecurrence->getType(), + IsMasked); + } + } + + // If this is a scatter operation try to find the type of the offset, + // if applicable, e.g. A[i] = B[C[i]] + // ^^^^ get type of C[i] + Type *IdxTy = nullptr; + bool IsSigned = true; + if (auto *Gep = dyn_cast(Ptr)) { + for (unsigned Op=0; Op < Gep->getNumOperands(); ++Op) { + Value *Opnd = Gep->getOperand(Op); + if (Legal->isUniform(Opnd)) { + continue; + } + + // If there are multiple non-loop invariant indices + // in this GEP, fall back to the worst case below. + if (IdxTy != nullptr) { + IdxTy = nullptr; + break; + } + + // If type is promoted, see if we can use smaller type + IdxTy = Opnd->getType(); + if (auto *Ext = dyn_cast(Opnd)) { + if (Ext->isIntegerCast()) + IdxTy = Ext->getSrcTy(); + if (isa(Ext)) + IsSigned = false; + } + } + } + + // Worst case scenario, assume pointer size + if (!IdxTy) + IdxTy = DL.getIntPtrType(Ptr->getType()); + + return MemAccessInfo::getNonStridedInfo(IdxTy, IsMasked, IsSigned); +} + +LoopVectorizationCostModel::VectorizationCostTy +LoopVectorizationCostModel::getInstructionCost(Instruction *I, + VectorizationFactor VF) { + // If we know that this instruction will remain uniform, check the cost of + // the scalar version. + if (Legal->isUniformAfterVectorization(I)) + VF.Width = 1; + + Type *VectorTy; + unsigned C = getInstructionCost(I, VF, VectorTy); + + bool TypeNotScalarized = + VF.Width > 1 && !VectorTy->isVoidTy() && + TTI.getNumberOfParts(VectorTy) < VF.Width; + return VectorizationCostTy(C, TypeNotScalarized); +} + +unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, + VectorizationFactor VF, + Type *&VectorTy) { + Type *RetTy = I->getType(); + if (VF.Width > 1 && MinBWs.count(I)) + RetTy = IntegerType::get(RetTy->getContext(), MinBWs[I]); + VectorTy = ToVectorTy(RetTy, VF); + auto SE = PSE.getSE(); + + // TODO: We need to estimate the cost of intrinsic calls. + switch (I->getOpcode()) { + case Instruction::GetElementPtr: + // We mark this instruction as zero-cost because the cost of GEPs in + // vectorized code depends on whether the corresponding memory instruction + // is scalarized or not. Therefore, we handle GEPs with the memory + // instruction cost. + return 0; + case Instruction::Br: { + return TTI.getCFInstrCost(I->getOpcode()); + } + case Instruction::PHI: { + auto *Phi = cast(I); + + // First-order recurrences are replaced by vector shuffles inside the loop. + // TODO: Does getShuffleCost need special handling for scalable vectors? + if (VF.Width > 1 && Legal->isFirstOrderRecurrence(Phi)) + return TTI.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, + VectorTy, VF.Width - 1, VectorTy); + + // TODO: IF-converted IFs become selects. + return 0; + } + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + // Since we will replace the stride by 1 the multiplication should go away. + if (I->getOpcode() == Instruction::Mul && isStrideMul(I, Legal)) + return 0; + // Certain instructions can be cheaper to vectorize if they have a constant + // second vector operand. One example of this are shifts on x86. + TargetTransformInfo::OperandValueKind Op1VK = + TargetTransformInfo::OK_AnyValue; + TargetTransformInfo::OperandValueKind Op2VK = + TargetTransformInfo::OK_AnyValue; + TargetTransformInfo::OperandValueProperties Op1VP = + TargetTransformInfo::OP_None; + TargetTransformInfo::OperandValueProperties Op2VP = + TargetTransformInfo::OP_None; + Value *Op2 = I->getOperand(1); + + // Check for a splat of a constant or for a non uniform vector of constants. + if (isa(Op2)) { + ConstantInt *CInt = cast(Op2); + if (CInt && CInt->getValue().isPowerOf2()) + Op2VP = TargetTransformInfo::OP_PowerOf2; + Op2VK = TargetTransformInfo::OK_UniformConstantValue; + } else if (isa(Op2) || isa(Op2)) { + Op2VK = TargetTransformInfo::OK_NonUniformConstantValue; + Constant *SplatValue = cast(Op2)->getSplatValue(); + if (SplatValue) { + ConstantInt *CInt = dyn_cast(SplatValue); + if (CInt && CInt->getValue().isPowerOf2()) + Op2VP = TargetTransformInfo::OP_PowerOf2; + Op2VK = TargetTransformInfo::OK_UniformConstantValue; + } + } + + // Note: When we find a s/zext_to_i32->add->trunc_to_origtype + // chain, we ask the target if it has an add for the original + // type. This is not allowed in C, so the target should ensure + // that the instruction does the sign/zero conversion in 'int'. + Type *OrigType = nullptr; + if (isPartOfPromotedAdd(I, &OrigType)) + VectorTy = VectorType::get(OrigType, VF.Width, !VF.isFixed); + + return TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, Op1VK, Op2VK, + Op1VP, Op2VP); + } + case Instruction::Select: { + SelectInst *SI = cast(I); + const SCEV *CondSCEV = SE->getSCEV(SI->getCondition()); + bool ScalarCond = (SE->isLoopInvariant(CondSCEV, TheLoop)); + Type *CondTy = SI->getCondition()->getType(); + if (!ScalarCond) + CondTy = VectorType::get(CondTy, VF.Width, !VF.isFixed); + + return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy, CondTy); + } + case Instruction::ICmp: + case Instruction::FCmp: { + Type *ValTy = I->getOperand(0)->getType(); + Instruction *Op0AsInstruction = dyn_cast(I->getOperand(0)); + auto It = MinBWs.find(Op0AsInstruction); + if (VF.Width > 1 && It != MinBWs.end()) + ValTy = IntegerType::get(ValTy->getContext(), It->second); + VectorTy = ToVectorTy(ValTy, VF); + return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy); + } + case Instruction::Store: + case Instruction::Load: { + StoreInst *SI = dyn_cast(I); + LoadInst *LI = dyn_cast(I); + Type *ValTy = (SI ? SI->getValueOperand()->getType() : LI->getType()); + VectorTy = ToVectorTy(ValTy, VF); + + unsigned Alignment = SI ? SI->getAlignment() : LI->getAlignment(); + unsigned AS = + SI ? SI->getPointerAddressSpace() : LI->getPointerAddressSpace(); + Value *Ptr = SI ? SI->getPointerOperand() : LI->getPointerOperand(); + // We add the cost of address computation here instead of with the gep + // instruction because only here we know whether the operation is + // scalarized. + if (VF.Width == 1) + return TTI.getAddressComputationCost(VectorTy) + + TTI.getMemoryOpCost(I->getOpcode(), VectorTy, Alignment, AS); + + if (LI && Legal->isUniform(Ptr)) { + // Scalar load + broadcast + unsigned Cost = TTI.getAddressComputationCost(ValTy->getScalarType()); + Cost += TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), + Alignment, AS); + return Cost + + TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, ValTy); + } + + // For an interleaved access, calculate the total cost of the whole + // interleave group. + if (Legal->isAccessInterleaved(I)) { + auto Group = Legal->getInterleavedAccessGroup(I); + assert(Group && "Fail to get an interleaved access group."); + + // Only calculate the cost once at the insert position. + if (Group->getInsertPos() != I) + return 0; + + unsigned InterleaveFactor = Group->getFactor(); + Type *WideVecTy = + VectorType::get(VectorTy->getVectorElementType(), + VectorTy->getVectorNumElements() * InterleaveFactor, + !VF.isFixed); + + // Holds the indices of existing members in an interleaved load group. + // An interleaved store group doesn't need this as it doesn't allow gaps. + SmallVector Indices; + if (LI) { + for (unsigned i = 0; i < InterleaveFactor; i++) + if (Group->getMember(i)) + Indices.push_back(i); + } + + // Calculate the cost of the whole interleaved group. + unsigned Cost = TTI.getInterleavedMemoryOpCost( + I->getOpcode(), WideVecTy, Group->getFactor(), Indices, + Group->getAlignment(), AS); + + if (Group->isReverse()) + Cost += + Group->getNumMembers() * + TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0); + + // FIXME: The interleaved load group with a huge gap could be even more + // expensive than scalar operations. Then we could ignore such group and + // use scalar operations instead. + return Cost; + } + + const DataLayout &DL = I->getModule()->getDataLayout(); + unsigned ScalarAllocatedSize = DL.getTypeAllocSize(ValTy); + unsigned VectorElementSize = DL.getTypeStoreSize(VectorTy) / VF.Width; + + // Get information about vector memory access + MemAccessInfo MAI = calculateMemAccessInfo(I, VectorTy, Legal, SE); + + // If there are no vector memory operations to support the stride, + // get the cost for scalarizing the operation. + if (!TTI.hasVectorMemoryOp(I->getOpcode(), VectorTy, MAI) || + ScalarAllocatedSize != VectorElementSize) { + // Get cost of scalarizing +// bool IsComplexComputation = +// isLikelyComplexAddressComputation(Ptr, Legal, SE, TheLoop); + unsigned Cost = 0; + // The cost of extracting from the value vector and pointer vector. + Type *PtrTy = ToVectorTy(Ptr->getType(), VF); + for (unsigned i = 0; i < VF.Width; ++i) { + // The cost of extracting the pointer operand. + Cost += TTI.getVectorInstrCost(Instruction::ExtractElement, PtrTy, i); + // In case of STORE, the cost of ExtractElement from the vector. + // In case of LOAD, the cost of InsertElement into the returned + // vector. + Cost += TTI.getVectorInstrCost(SI ? Instruction::ExtractElement + : Instruction::InsertElement, + VectorTy, i); + } + + // The cost of the scalar loads/stores. + /* TODO: Replace this with community code? + Cost += VF.Width * + TTI.getAddressComputationCost(PtrTy, IsComplexComputation); + */ + Cost += VF.Width * + TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), + Alignment, AS); + return Cost; + } + + unsigned Cost = TTI.getAddressComputationCost(VectorTy); + Cost += TTI.getVectorMemoryOpCost(I->getOpcode(), VectorTy, Ptr, + Alignment, AS, MAI, I); + + if (MAI.isStrided() && MAI.isReversed()) + Cost += + TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0); + else if (MAI.isUniform()) + Cost += TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, + VectorTy, 0); + return Cost; + } + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::SIToFP: + case Instruction::UIToFP: + case Instruction::Trunc: + case Instruction::FPTrunc: + case Instruction::BitCast: { + // We optimize the truncation of induction variable. + // The cost of these is the same as the scalar operation. + if (I->getOpcode() == Instruction::Trunc && + Legal->isInductionVariable(I->getOperand(0))) + return TTI.getCastInstrCost(I->getOpcode(), I->getType(), + I->getOperand(0)->getType()); +// TODO: determine if still useful, deleting isPartOfPromotedAdd if not +// // Don't count these +// if (isPartOfPromotedAdd(I, nullptr)) +// return 0; +// +// Type *SrcVecTy = ToVectorTy(I->getOperand(0)->getType(), VF); + + Type *SrcScalarTy = I->getOperand(0)->getType(); + Type *SrcVecTy = ToVectorTy(SrcScalarTy, VF); + if (VF.Width > 1 && MinBWs.count(I)) { + // This cast is going to be shrunk. This may remove the cast or it might + // turn it into slightly different cast. For example, if MinBW == 16, + // "zext i8 %1 to i32" becomes "zext i8 %1 to i16". + // + // Calculate the modified src and dest types. + Type *MinVecTy = VectorTy; + if (I->getOpcode() == Instruction::Trunc) { + SrcVecTy = smallestIntegerVectorType(SrcVecTy, MinVecTy); + VectorTy = + largestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy); + } else if (I->getOpcode() == Instruction::ZExt || + I->getOpcode() == Instruction::SExt) { + SrcVecTy = largestIntegerVectorType(SrcVecTy, MinVecTy); + VectorTy = + smallestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy); + } + } + + return TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy); + } + case Instruction::Call: { + bool NeedToScalarize; + CallInst *CI = cast(I); + unsigned CallCost = + getVectorCallCost(CI, VF, TTI, TLI, *Legal, NeedToScalarize); + if (getVectorIntrinsicIDForCall(CI, TLI)) + return std::min(CallCost, getVectorIntrinsicCost(CI, VF, TTI, TLI)); + return CallCost; + } + default: { + // We are scalarizing the instruction. Return the cost of the scalar + // instruction, plus the cost of insert and extract into vector + // elements, times the vector width. + unsigned Cost = 0; + + if (!RetTy->isVoidTy() && VF.Width != 1) { + unsigned InsCost = + TTI.getVectorInstrCost(Instruction::InsertElement, + VectorTy); + unsigned ExtCost = + TTI.getVectorInstrCost(Instruction::ExtractElement, + VectorTy); + + // The cost of inserting the results plus extracting each one of the + // operands. + Cost += VF.Width * (InsCost + ExtCost * I->getNumOperands()); + } + + // The cost of executing VF copies of the scalar instruction. This opcode + // is unknown. Assume that it is the same as 'mul'. + Cost += VF.Width * TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy); + return Cost; + } + } // end of switch. +} + +char SVELoopVectorize::ID = 0; +static const char lv_name[] = "SVE Loop Vectorization"; +INITIALIZE_PASS_BEGIN(SVELoopVectorize, LV_NAME, lv_name, false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) +INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_END(SVELoopVectorize, LV_NAME, lv_name, false, false) + +namespace llvm { +Pass *createSVELoopVectorizePass(bool NoUnrolling, bool AlwaysVectorize) { + return new SVELoopVectorize(NoUnrolling, AlwaysVectorize); +} +} + +bool LoopVectorizationCostModel::isConsecutiveLoadOrStore(Instruction *Inst) { + // Check for a store. + if (StoreInst *ST = dyn_cast(Inst)) + return Legal->isConsecutivePtr(ST->getPointerOperand()) != 0; + + // Check for a load. + if (LoadInst *LI = dyn_cast(Inst)) + return Legal->isConsecutivePtr(LI->getPointerOperand()) != 0; + + return false; +} + +void LoopVectorizationCostModel::collectValuesToIgnore() { + // Ignore ephemeral values. + CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore); + + // Ignore type-promoting instructions we identified during reduction + // detection. + for (auto &Reduction : *Legal->getReductionVars()) { + RecurrenceDescriptor &RedDes = Reduction.second; + SmallPtrSetImpl &Casts = RedDes.getCastInsts(); + VecValuesToIgnore.insert(Casts.begin(), Casts.end()); + } + + // Ignore induction phis that are only used in either GetElementPtr or ICmp + // instruction to exit loop. Induction variables usually have large types and + // can have big impact when estimating register usage. + // This is for when VF > 1. + for (auto &Induction : *Legal->getInductionVars()) { + auto *PN = Induction.first; + auto *UpdateV = PN->getIncomingValueForBlock(TheLoop->getLoopLatch()); + + // Check that the PHI is only used by the induction increment (UpdateV) or + // by GEPs. Then check that UpdateV is only used by a compare instruction or + // the loop header PHI. + // FIXME: Need precise def-use analysis to determine if this instruction + // variable will be vectorized. + if (std::all_of(PN->user_begin(), PN->user_end(), + [&](const User *U) -> bool { + return U == UpdateV || isa(U); + }) && + std::all_of(UpdateV->user_begin(), UpdateV->user_end(), + [&](const User *U) -> bool { + return U == PN || isa(U); + })) { + VecValuesToIgnore.insert(PN); + VecValuesToIgnore.insert(UpdateV); + } + } + + // Ignore instructions that will not be vectorized. + // This is for when VF > 1. + for (auto bb = TheLoop->block_begin(), be = TheLoop->block_end(); bb != be; + ++bb) { + for (auto &Inst : **bb) { + switch (Inst.getOpcode()) + case Instruction::GetElementPtr: { + // Ignore GEP if its last operand is an induction variable so that it is + // a consecutive load/store and won't be vectorized as scatter/gather + // pattern. + + GetElementPtrInst *Gep = cast(&Inst); + unsigned NumOperands = Gep->getNumOperands(); + unsigned InductionOperand = getGEPInductionOperand(Gep); + bool GepToIgnore = true; + + // Check that all of the gep indices are uniform except for the + // induction operand. + for (unsigned i = 0; i != NumOperands; ++i) { + if (i != InductionOperand && + !PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), + TheLoop)) { + GepToIgnore = false; + break; + } + } + + if (GepToIgnore) + VecValuesToIgnore.insert(&Inst); + break; + } + } + } +} + +void InnerLoopUnroller::scalarizeInstruction(Instruction *Instr, + bool IfPredicateStore) { + assert(!Instr->getType()->isAggregateType() && "Can't handle vectors"); + // Holds vector parameters or scalars, in case of uniform vals. + SmallVector Params; + + setDebugLocFromInst(Builder, Instr); + + // Find all of the vectorized parameters. + for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { + Value *SrcOp = Instr->getOperand(op); + + // If we are accessing the old induction variable, use the new one. + if (SrcOp == OldInduction) { + Params.push_back(getVectorValue(SrcOp)); + continue; + } + + // Try using previously calculated values. + Instruction *SrcInst = dyn_cast(SrcOp); + + // If the src is an instruction that appeared earlier in the basic block + // then it should already be vectorized. + if (SrcInst && OrigLoop->contains(SrcInst)) { + assert(WidenMap.has(SrcInst) && "Source operand is unavailable"); + // The parameter is a vector value from earlier. + Params.push_back(WidenMap.get(SrcInst)); + } else { + // The parameter is a scalar from outside the loop. Maybe even a constant. + VectorParts Scalars; + Scalars.append(UF, SrcOp); + Params.push_back(Scalars); + } + } + + assert(Params.size() == Instr->getNumOperands() && + "Invalid number of operands"); + + // Does this instruction return a value ? + bool IsVoidRetTy = Instr->getType()->isVoidTy(); + + Value *UndefVec = IsVoidRetTy ? nullptr : UndefValue::get(Instr->getType()); + // Create a new entry in the WidenMap and initialize it to Undef or Null. + VectorParts &VecResults = WidenMap.splat(Instr, UndefVec); + + VectorParts Cond; + if (IfPredicateStore) { + assert(Instr->getParent()->getSinglePredecessor() && + "Only support single predecessor blocks"); + Cond = createEdgeMask(Instr->getParent()->getSinglePredecessor(), + Instr->getParent()); + } + + // For each vector unroll 'part': + for (unsigned Part = 0; Part < UF; ++Part) { + // For each scalar that we create: + + // Start an "if (pred) a[i] = ..." block. + Value *Cmp = nullptr; + if (IfPredicateStore) { + if (Cond[Part]->getType()->isVectorTy()) + Cond[Part] = + Builder.CreateExtractElement(Cond[Part], Builder.getInt32(0)); + Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Cond[Part], + ConstantInt::get(Cond[Part]->getType(), 1)); + } + + Instruction *Cloned = Instr->clone(); + if (!IsVoidRetTy) + Cloned->setName(Instr->getName() + ".cloned"); + // Replace the operands of the cloned instructions with extracted scalars. + for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { + Value *Op = Params[op][Part]; + Cloned->setOperand(op, Op); + } + + // Place the cloned scalar in the new loop. + Builder.Insert(Cloned); + + // If we just cloned a new assumption, add it the assumption cache. + if (auto *II = dyn_cast(Cloned)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC->registerAssumption(II); + + // If the original scalar returns a value we need to place it in a vector + // so that future users will be able to use it. + if (!IsVoidRetTy) + VecResults[Part] = Cloned; + + // End if-block. + if (IfPredicateStore) + PredicatedStores.push_back(std::make_pair(cast(Cloned), Cmp)); + } +} + +void InnerLoopUnroller::vectorizeMemoryInstruction(Instruction *Instr) { + assert(!Legal->isMaskRequired(Instr) && + "Unroller does not support masked operations!"); + StoreInst *SI = dyn_cast(Instr); + bool IfPredicateStore = (SI && Legal->blockNeedsPredication(SI->getParent())); + + return scalarizeInstruction(Instr, IfPredicateStore); +} + +Value *InnerLoopUnroller::reverseVector(Value *Vec) { return Vec; } + +Value *InnerLoopUnroller::getBroadcastInstrs(Value *V) { return V; } + +Value *InnerLoopUnroller::getStepVector(Value *Val, Value *Start, + const SCEV *StepSCEV, + Instruction::BinaryOps BinOp) { + const DataLayout &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); + SCEVExpander Exp(*PSE.getSE(), DL, "induction"); + Value *StepValue = Exp.expandCodeFor(StepSCEV, StepSCEV->getType(), + &*Builder.GetInsertPoint()); + return getStepVector(Val, Start, StepValue, BinOp); +} + +Value *InnerLoopUnroller::getStepVector(Value *Val, int Start, Value *Step, + Instruction::BinaryOps BinOp) { + // When unrolling and the VF is 1, we only need to add a simple scalar. + Type *Ty = Val->getType()->getScalarType(); + return getStepVector(Val, ConstantInt::get(Ty, Start), Step, BinOp); +} + +Value *InnerLoopUnroller::getStepVector(Value *Val, Value* Start, Value *Step, + Instruction::BinaryOps BinOp) { + // When unrolling and the VF is 1, we only need to add a simple scalar. + assert(!Val->getType()->isVectorTy() && "Val must be a scalar"); + if (Val->getType()->isFloatingPointTy()) { + if (Start->getType()->isIntegerTy()) + Start = Builder.CreateUIToFP(Start, Val->getType()); + Step = addFastMathFlag(Builder.CreateFMul(Start, Step)); + return addFastMathFlag(Builder.CreateBinOp(BinOp, Val, Step, "fpinduction")); + } + return Builder.CreateAdd(Val, Builder.CreateMul(Start, Step), "induction"); +} Index: lib/Transforms/Vectorize/SearchLoopVectorize.cpp =================================================================== --- /dev/null +++ lib/Transforms/Vectorize/SearchLoopVectorize.cpp @@ -0,0 +1,4480 @@ +//===- SearchLoopVectorize.cpp - A Search Loop Vectorizer -----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// TODO: Fix up description +// +// This is the LLVM loop vectorizer. This pass modifies 'vectorizable' loops +// and generates target-independent LLVM-IR. +// The vectorizer uses the TargetTransformInfo analysis to estimate the costs +// of instructions in order to estimate the profitability of vectorization. +// +// The loop vectorizer combines consecutive loop iterations into a single +// 'wide' iteration. After this transformation the index is incremented +// by the SIMD vector width, and not by one. +// +// This pass has three parts: +// 1. The main loop pass that drives the different parts. +// 2. SLVLoopVectorizationLegality - A unit that checks for the legality +// of the vectorization. +// 3. SearchLoopVectorizer - A unit that performs the actual +// widening of instructions. +// 4. SLVLoopVectorizationCostModel - A unit that checks for the profitability +// of vectorization. It decides on the optimal vector width, which +// can be one, if vectorization is not profitable. +// +//===----------------------------------------------------------------------===// +// +// The reduction-variable vectorization is based on the paper: +// D. Nuzman and R. Henderson. Multi-platform Auto-vectorization. +// +// Variable uniformity checks are inspired by: +// Karrenberg, R. and Hack, S. Whole Function Vectorization. +// +// The interleaved access vectorization is based on the paper: +// Dorit Nuzman, Ira Rosen and Ayal Zaks. Auto-Vectorization of Interleaved +// Data for SIMD +// +// Other ideas/concepts are from: +// A. Zaks and D. Nuzman. Autovectorization in GCC-two years later. +// +// S. Maleki, Y. Gao, M. Garzaran, T. Wong and D. Padua. An Evaluation of +// Vectorizing Compilers. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Vectorize.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DemandedBits.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopVectorizationAnalysis.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Pass.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include +#include +#include +#include +#include + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define SLV_NAME "search-loop-vectorize" +#define DEBUG_TYPE SLV_NAME +#ifndef NDEBUG +#define NODEBUG_EARLY_BAILOUT() \ + do { if (!::llvm::DebugFlag || !::llvm::isCurrentDebugType(DEBUG_TYPE)) \ + { return false; } } while (0) +#else +#define NODEBUG_EARLY_BAILOUT() { return false; } +#endif + +STATISTIC(SearchLoopsVectorized, "Number of loops vectorized"); +STATISTIC(SearchLoopsAnalyzed, "Number of loops analyzed for vectorization"); + +static cl::opt DisableReductionIntrinsics( + "sl-disable-reduction-intrinsics", cl::init(false), cl::Hidden, + cl::desc("Disable the loop vectoriser's use of reduction intrinsics.")); + +static cl::opt AnnotateWidenedInstrs( + "sl-annotate-widened-instrs", cl::init(false), cl::Hidden, + cl::desc("Annotate vector instructions with the scalar instruction they represent")); + + +// The following two options are mutually exclusive. +static cl::list + FuncsBlackList("sl-blacklist-funcs", cl::value_desc("function names"), + cl::desc("Skip search loop vectorization for functions whose" + " name matches one in this list"), + cl::CommaSeparated); + +static cl::list + FuncsWhiteList("sl-whitelist-funcs", cl::value_desc("function names"), + cl::desc("Only use search loop vectorization for functions whose" + " name matches on in this list"), + cl::CommaSeparated); + +//////////////////////////////////////////////////////////////////////////////// +// Helpers +//////////////////////////////////////////////////////////////////////////////// + +/// \brief Look for a meaningful debug location on the instruction or it's +/// operands. +static Instruction *getDebugLocFromInstOrOperands(Instruction *I) { + if (!I) + return I; + + DebugLoc Empty; + if (I->getDebugLoc() != Empty) + return I; + + for (User::op_iterator OI = I->op_begin(), OE = I->op_end(); OI != OE; ++OI) { + if (Instruction *OpInst = dyn_cast(*OI)) + if (OpInst->getDebugLoc() != Empty) + return OpInst; + } + + return I; +} + +/// \brief Set the debug location in the builder using the debug location in the +/// instruction. +static void setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr) { + if (const Instruction *Inst = dyn_cast_or_null(Ptr)) + B.SetCurrentDebugLocation(Inst->getDebugLoc()); + else + B.SetCurrentDebugLocation(DebugLoc()); +} + +#ifndef NDEBUG +/// \return string containing a file name and a line # for the given loop. +static std::string getDebugLocString(const Loop *L) { + std::string Result; + if (L) { + raw_string_ostream OS(Result); + if (const DebugLoc LoopDbgLoc = L->getStartLoc()) + LoopDbgLoc.print(OS); + else + // Just print the module name. + OS << L->getHeader()->getParent()->getParent()->getModuleIdentifier(); + OS.flush(); + } + return Result; +} +#endif + +/// \brief Propagate known metadata from one instruction to another. +static void propagateMetadata(Instruction *To, const Instruction *From) { + SmallVector, 4> Metadata; + From->getAllMetadataOtherThanDebugLoc(Metadata); + + for (auto M : Metadata) { + unsigned Kind = M.first; + + // These are safe to transfer (this is safe for TBAA, even when we + // if-convert, because should that metadata have had a control dependency + // on the condition, and thus actually aliased with some other + // non-speculated memory access when the condition was false, this would be + // caught by the runtime overlap checks). + if (Kind != LLVMContext::MD_tbaa && + Kind != LLVMContext::MD_alias_scope && + Kind != LLVMContext::MD_noalias && + Kind != LLVMContext::MD_fpmath && + Kind != LLVMContext::MD_nontemporal) + continue; + + To->setMetadata(Kind, M.second); + } +} + +/// \brief Propagate known metadata from one instruction to a vector of others. +static void propagateMetadata(SmallVectorImpl &To, const Instruction *From) { + for (Value *V : To) + if (Instruction *I = dyn_cast(V)) + propagateMetadata(I, From); +} + +static void emitMissedWarning(Function *F, Loop *L, + const SLVLoopVectorizeHints &LH) { +//emitOptimizationRemarkMissed(F->getContext(), SLV_NAME, *F, L->getStartLoc(), +// LH.emitRemark()); + + /* TODO: Reenable + if (LH.getForce() == SLVLoopVectorizeHints::FK_Enabled) { + if (LH.getWidth() != 1) + emitLoopVectorizeWarning( + F->getContext(), *F, L->getStartLoc(), + "failed explicitly specified loop vectorization"); + else if (LH.getInterleave() != 1) + emitLoopInterleaveWarning( + F->getContext(), *F, L->getStartLoc(), + "failed explicitly specified loop interleaving"); + } + */ +} + +static void addInnerLoop(Loop &L, SmallVectorImpl &V) { + if (L.empty()) + return V.push_back(&L); + + for (Loop *InnerL : L) + addInnerLoop(*InnerL, V); +} + +static Instruction *getFirstInst(Instruction *FirstInst, Value *V, + Instruction *Loc) { + if (FirstInst) + return FirstInst; + if (Instruction *I = dyn_cast(V)) + return I->getParent() == Loc->getParent() ? I : nullptr; + return nullptr; +} + + +/// \brief Adds a 'fast' flag to floating point operations. +static Value *addFastMathFlag(Value *V) { + if (isa(V)){ + FastMathFlags Flags; + Flags.setFast(true); + cast(V)->setFastMathFlags(Flags); + } + return V; +} + +//////////////////////////////////////////////////////////////////////////////// +// SearchLoopVectorizer +//////////////////////////////////////////////////////////////////////////////// + +namespace { + +/// TODO: Correct comment. Very much outdated, even compared to the +/// 'normal' vectorizer. + +/// SearchLoopVectorizer vectorizes loops which contain only one basic +/// block to a specified vectorization factor (VF). +/// This class performs the widening of scalars into vectors, or multiple +/// scalars. This class also implements the following features: +/// * It inserts an epilogue loop for handling loops that don't have iteration +/// counts that are known to be a multiple of the vectorization factor. +/// * It handles the code generation for reduction variables. +/// * Scalarization (implementation using scalars) of un-vectorizable +/// instructions. +/// SearchLoopVectorizer does not perform any vectorization-legality +/// checks, and relies on the caller to check for the different legality +/// aspects. The SearchLoopVectorizer relies on the +/// SLVLoopVectorizationLegality class to provide information about the induction +/// and reduction variables that were found to a given vectorization factor. +class SearchLoopVectorizer { +public: + SearchLoopVectorizer(Loop *OrigLoop, PredicatedScalarEvolution &PSE, + LoopInfo *LI, DominatorTree *DT, + const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI, OptimizationRemarkEmitter *ORE, + AssumptionCache *AC, unsigned VecWidth, unsigned UnrollFactor, + bool VecWidthIsFixed) + : OrigLoop(OrigLoop), PSE(PSE), LI(LI), DT(DT), TLI(TLI), TTI(TTI), AC(AC), + ORE(ORE), VF(VecWidth), Scalable(!VecWidthIsFixed), + Builder(PSE.getSE()->getContext()), + Induction(nullptr), OldInduction(nullptr), InductionStep(nullptr), PHMap(1), + NextBodyWMap(1), BodyWMap(1), VTailWMap(1), TmpMakeIntoPHIsMap(1), + VecBodyPostDom(nullptr), TripCount(nullptr), VectorTripCount(nullptr), + Legal(nullptr), AddedSafetyChecks(false), LatchBranch(nullptr), + IdxEnd(nullptr), IdxEndV(nullptr), BranchCounter(0) {} + // Perform the actual loop widening (vectorization). + // MinimumBitWidths maps scalar integer values to the smallest bitwidth they + // can be validly truncated to. The cost model has assumed this truncation + // will happen when vectorizing. + // TODO: Don't need MinimumBitWidths if we're passing the whole cost model in. + void vectorize(SLVLoopVectorizationLegality *L, SLVLoopVectorizationCostModel *C, + MapVector MinimumBitWidths) { + LLVM_DEBUG(dbgs() << "SLV vectorizing loop: " << getDebugLocString(OrigLoop) + << "\n"); + MinBWs = MinimumBitWidths; + Legal = L; + Costs = C; + // Create a new empty loop. Unlink the old loop and connect the new one. + createEmptyLoopWithPredication(); + // Widen each instruction in the old loop to a new one in the new loop. + // Use the Legality module to find the induction and reduction variables. + + // TODO: Before this, preload BodyWMap with loop-invariant vals? + vectorizeLoop(); + } + + // Return true if any runtime check is added. + bool IsSafetyChecksAdded() { + return AddedSafetyChecks; + } + + virtual ~SearchLoopVectorizer() {} + +protected: + bool isScalable() { + return (VF > 1) && Scalable; + } + + /// A small list of PHINodes. + typedef SmallVector PhiVector; + /// When we unroll loops we have multiple vector values for each scalar. + /// This data structure holds the unrolled and vectorized values that + /// originated from one scalar instruction. + typedef SmallVector VectorParts; + + // When we if-convert we need to create edge masks. We have to cache values + // so that we don't end up with exponential recursion/IR. + typedef DenseMap, + VectorParts> EdgeMaskCache; + + /// This is a helper class that holds the vectorizer state. It maps scalar + /// instructions to vector instructions. When the code is 'unrolled' then + /// then a single scalar value is mapped to multiple vector parts. The parts + /// are stored in the VectorPart type. + struct ValueMap { + /// C'tor. UnrollFactor controls the number of vectors ('parts') that + /// are mapped. + ValueMap(unsigned UnrollFactor) : UF(UnrollFactor) {} + + /// \return True if 'Key' is saved in the Value Map. + bool has(Value *Key) const { return MapStorage.count(Key); } + + /// Initializes a new entry in the map. Sets all of the vector parts to the + /// save value in 'Val'. + /// \return A reference to a vector with splat values. + VectorParts &splat(Value *Key, Value *Val) { + VectorParts &Entry = MapStorage[Key]; + Entry.assign(UF, Val); + return Entry; + } + + ///\return A reference to the value that is stored at 'Key'. + VectorParts &get(Value *Key) { + VectorParts &Entry = MapStorage[Key]; + if (Entry.empty()) + Entry.resize(UF); + assert(Entry.size() == UF); + return Entry; + } + + void clearValue(Value *Key){ + MapStorage.erase(Key); + } + + void clear() { + MapStorage.clear(); + } + + private: + /// The unroll factor. Each entry in the map stores this number of vector + /// elements. + unsigned UF; + + /// Map storage. We use std::map and not DenseMap because insertions to a + /// dense map invalidates its iterators. + std::map MapStorage; + }; + + /// \brief Add checks for strides that were assumed to be 1. + /// + /// Returns the last check instruction and the first check instruction in the + /// pair as (first, last). + std::pair addStrideCheck(Instruction *Loc); + + /// Create an empty loop, using per-element predication to control termination + void createEmptyLoopWithPredication(); + /// Create a new induction variable inside L. + PHINode *createInductionVariable(Loop *L, Value *Start, Value *End, + Value *Step, Instruction *DL); + + bool isIndvarPHIOrUpdate(Value *V, InductionDescriptor &II, bool &IsUpdate); + bool allIndvarPHIsOrAllUpdates(Escapee *E, InductionDescriptor &II, bool &IsUpdate); + + /// Copy and widen the instructions from the old loop. + virtual void vectorizeLoop(); + virtual void insertVectorTail(LoopBlocksDFS& DFS); + PHINode *createTailPhiFromPhi(PHINode *PN, const Twine &Name=""); + PHINode *createTailPhiFromValues(Value *ValPH, Value *ValVB, const Twine &Name=""); + + /// \brief The Loop exit block may have single value PHI nodes where the + /// incoming value is 'Undef'. While vectorizing we only handled real values + /// that were defined inside the loop. Here we fix the 'undef case'. + /// See PR14725. + void fixLCSSAPHIs(); + + /// Shrinks vector element sizes based on information in "MinBWs". + void truncateToMinimalBitwidths(ValueMap &WidenMap); + + /// A helper function that computes the predicate of the block BB, assuming + /// that the header block of the loop is set to True. It returns the *entry* + /// mask for the block BB. + VectorParts createBlockInMask(BasicBlock *BB, ValueMap &WidenMap); + /// A helper function that computes the predicate of the edge between SRC + /// and DST. + VectorParts createEdgeMask(BasicBlock *Src, BasicBlock *Dst, + ValueMap &WidenMap); + + /// A helper function to vectorize a single BB within the innermost loop. + void vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV, ValueMap &WidenMap); + + void widenInstruction(Instruction *it, BasicBlock *BB, PhiVector *PV, + ValueMap &WidenMap); + + /// Vectorize a single PHINode in a block. This method handles the induction + /// variable canonicalization. It supports both VF = 1 for unrolled loops and + /// arbitrary length vectors. + void widenPHIInstruction(Instruction *PN, VectorParts &Entry, + unsigned VF, PhiVector *PV, + ValueMap &WidenMap); + + /// Insert early break checks using Map as the ValueMap, and Pred + /// as starting Predicate condition, and Induction as the induction + /// variable to use. + /// Returns the 'and'ed value for all conditions. + /// Each (vectorised) condition is separately stored in 'Conditions'. + Value *insertEarlyBreaks(ValueMap &Map, Value *Induction, + SmallVectorImpl &Conditions, Value *Pred); + + // TODO: Reword. + // Patch up the condition for a branch instruction after the block has been + // vectorized; only used with predication for now. + void vectorizeExits(); + + /// Insert the new loop to the loop hierarchy and pass manager + /// and update the analysis passes. + void updateAnalysis(); + + /// This instruction is un-vectorizable. Implement it as a sequence + /// of scalars. If \p IfPredicateStore is true we need to 'hide' each + /// scalarized instruction behind an if block predicated on the control + /// dependence of the instruction. + virtual void scalarizeInstruction(Instruction *Instr, + ValueMap &WidenMap, + bool IfPredicateStore=false); + + /// Scalarize instruction using a sub-loop within the vector body. + virtual void scalarizeInstructionWithSubloop(Instruction *Instr, + ValueMap &WidenMap, + bool IfPredicateStore = false); + + /// Vectorize Load and Store instructions, + virtual void vectorizeMemoryInstruction(Instruction *Instr, ValueMap &WidenMap); + virtual void vectorizeArithmeticGEP(Instruction *Instr, ValueMap &WidenMap); + virtual void vectorizeGEPInstruction(Instruction *Instr, ValueMap &WidenMap); + + /// Create a broadcast instruction. This method generates a broadcast + /// instruction (shuffle) for loop invariant values and for the induction + /// value. If this is the induction variable then we extend it to N, N+1, ... + /// this is needed because each iteration in the loop corresponds to a SIMD + /// element. + virtual Value *getBroadcastInstrs(Value *V); + + /// This function adds (Start, Start + Step, Start + 2*Step, ...) + /// to each vector element of Val. The sequence starts at StartIndex. + virtual Value *getStepVector(Value *Val, int Start, Value *Step); + virtual Value *getStepVector(Value *Val, Value* Start, Value *Step); + + /// This function adds (StartIdx, StartIdx + Step, StartIdx + 2*Step, ...) + /// to each vector element of Val. The sequence starts at StartIndex. + /// Step is a SCEV. In order to get StepValue it takes the existing value + /// from SCEV or creates a new using SCEVExpander. + virtual Value *getStepVector(Value *Val, Value *Start, const SCEV *Step); + + virtual Constant *getRuntimeVF(Type *Ty); + + Value *getExclusivePartition(Value *Pred); + Value *getInclusivePartition(Value *Pred); + + /// When we go over instructions in the basic block we rely on previous + /// values within the current basic block or on loop invariant values. + /// When we widen (vectorize) values we place them in the map. If the values + /// are not within the map, they have to be loop invariant, so we simply + /// broadcast them into a vector. + VectorParts &getVectorValue(Value *V, ValueMap &WidenMap); + + /// Generate a shuffle sequence that will reverse the vector Vec. + virtual Value *reverseVector(Value *Vec); + + /// Returns (and creates if needed) the original loop trip count. + Value *getOrCreateTripCount(Loop *NewLoop); + + /// Returns (and creates if needed) the trip count of the widened loop. + Value *getOrCreateVectorTripCount(Loop *NewLoop); + + /// Emit a bypass check to see if the trip count would overflow, or we + /// wouldn't have enough iterations to execute one vector loop. + void emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass); + /// Emit a bypass check to see if the vector trip count is nonzero. + void emitVectorLoopEnteredCheck(Loop *L, BasicBlock *Bypass); + /// Emit a bypass check to see if all of the SCEV assumptions we've + /// had to make are correct. + void emitSCEVChecks(Loop *L, BasicBlock *Bypass); + /// Emit bypass checks to check any memory assumptions we may have made. + void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass); + + ///\brief Perform CSE of induction variable instructions. + void CSE(SmallVector &BBs, SmallSet &Preds); + + /// The original loop. + Loop *OrigLoop; + /// Scev analysis to use. + PredicatedScalarEvolution &PSE; + /// Loop Info. + LoopInfo *LI; + /// Dominator Tree. + DominatorTree *DT; + /// Target Library Info. + const TargetLibraryInfo *TLI; + /// Target Transform Info. + const TargetTransformInfo *TTI; + /// Assumption Cache. + AssumptionCache *AC; + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; + /// Alias Analysis. + AliasAnalysis *AA; + + /// The vectorization SIMD factor to use. Each vector will have this many + /// vector elements. + unsigned VF; + bool Scalable; + +protected: + /// The builder that we use + IRBuilder<> Builder; + + // --- Vectorization state --- + /// The new inner loop + Loop *SLVLoop; + /// The vector-loop preheader. + BasicBlock *LoopVectorPreHeader; + /// The scalar-loop preheader. + BasicBlock *LoopScalarPreHeader; + /// Vector tail following vector body. + BasicBlock *LoopVectorTail; + ///The ExitBlock of the scalar loop. + BasicBlock *LoopExitBlock; + ///The vector loop body. + SmallVector LoopVectorBody; + ///The scalar loop body. + BasicBlock *LoopScalarBody; + /// A list of all bypass blocks. The first block is the entry of the loop. + SmallVector LoopBypassBlocks; + /// Blocks needed for performing scalable reduction across a vector + /// @{ + BasicBlock *ReductionLoop; + BasicBlock *ReductionLoopRet; + /// }@ + /// The new Induction variable which was added to the new block. + Value *Induction; + /// PHI node only available in vector tail + Value *PrevInduction; + Value *InductionStartIdx; + /// The induction variable of the old basic block. + PHINode *OldInduction; + /// Value of the induction var for the next loop trip + Value *InductionStep; + /// Holds the entry predicates for the current iteration of the vector body. + Value *Predicate; + MapVector SpeculativePredicates; + PHINode *TailSpeculativeLanes; + /// Incoming predicate for vector tail, coming from either preheader + /// or vector body. + Value *PreHeaderOutPred; + Value *VecBodyOutPred; + /// Holds the extended (to the widest induction type) start index. + Value *ExtendedIdx; + /// Maps scalars to widened vectors. + ValueMap PHMap; + ValueMap NextBodyWMap; + ValueMap BodyWMap; + ValueMap VTailWMap; + ValueMap TmpMakeIntoPHIsMap; + SmallSet ConditionNodes; + /// Store instructions that should be predicated, as a pair + /// + SmallVector, 4> PredicatedStores; + EdgeMaskCache MaskCache; + + // Loop vector body current post-dominator block. + BasicBlock *VecBodyPostDom; + typedef std::pair DomEdge; + SmallVector VecBodyDomEdges; + + // Conditional blocks due to if-conversion. + SmallSet PredicatedBlocks; + /// Trip count of the original loop. + Value *TripCount; + /// Trip count of the widened loop (TripCount - TripCount % (VF*UF)) + Value *VectorTripCount; + + /// Map of scalar integer values to the smallest bitwidth they can be legally + /// represented as. The vector equivalents of these values should be truncated + /// to this type. + MapVector MinBWs; + SLVLoopVectorizationLegality *Legal; + SLVLoopVectorizationCostModel *Costs; + + // Record whether runtime check is added. + bool AddedSafetyChecks; + + /// Stores new branch for vectorized latch block so it + /// can be patched up after vectorization + BranchInst *LatchBranch; + + /// TODO -- rename this and InductionStep? + /// TODO -- move both to exit info descriptor? + Value *IdxEnd; + Value *IdxEndV; + + /// Does the induction variable expression have nsw/nuw flags? + bool IndVarNoWrap; + + // Set of nodes that should be made into PHIs in the original loop + // when we hoist the early-exit condition out of the loop, but one + // of the condition's subexpressions is reused somewhere in the loop + // body after the condition. + SmallSet MakeTheseIntoPHIs; + + /// The BranchCounter tells the vectorizer which Escapee values to select + /// using the current predicate. + unsigned BranchCounter; + + /// The Green vector lanes are always executed + Value *GreenLanes; + /// The Yellow vector lanes are executed before the mid-body exit + /// and are defined as the GreenLanes + 1 lane. + Value *YellowLanes; + + /// TODO: Set to false above? + bool HasSpeculativeLoads; + +}; + + // TODO: Move this struct? + +namespace { +struct CSEDenseMapInfo { + static bool canHandle(const Instruction *I) { + return isa(I) || isa(I) || + isa(I) || isa(I); + } + static inline Instruction *getEmptyKey() { + return DenseMapInfo::getEmptyKey(); + } + static inline Instruction *getTombstoneKey() { + return DenseMapInfo::getTombstoneKey(); + } + static unsigned getHashValue(const Instruction *I) { + assert(canHandle(I) && "Unknown instruction!"); + return hash_combine(I->getOpcode(), hash_combine_range(I->value_op_begin(), + I->value_op_end())); + } + static bool isEqual(const Instruction *LHS, const Instruction *RHS) { + if (LHS == getEmptyKey() || RHS == getEmptyKey() || + LHS == getTombstoneKey() || RHS == getTombstoneKey()) + return LHS == RHS; + return LHS->isIdenticalTo(RHS); + } +}; +} + +////===----------------------------------------------------------------------===// +//// Implementation of SearchLoopVectorizer +////===----------------------------------------------------------------------===// + +Value *SearchLoopVectorizer::getBroadcastInstrs(Value *V) { + // We need to place the broadcast of invariant variables outside the loop. + Instruction *Instr = dyn_cast(V); + bool NewInstr = + (Instr && std::find(LoopVectorBody.begin(), LoopVectorBody.end(), + Instr->getParent()) != LoopVectorBody.end()); + bool Invariant = OrigLoop->isLoopInvariant(V) && !NewInstr; + + // Place the code for broadcasting invariant variables in the new preheader. + // IRBuilder<>::InsertPointGuard Guard(Builder); + IRBuilder<>::InsertPoint IP = Builder.saveIP(); + + // Broadcast the scalar into all locations in the vector. + Value *Shuf = Builder.CreateVectorSplat({VF, Scalable}, V, "broadcast"); + + if (Invariant) + Builder.restoreIP(IP); + + return Shuf; +} + +Value *SearchLoopVectorizer::getStepVector(Value *Val, Value *Start, + const SCEV *StepSCEV) { + const DataLayout &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); + SCEVExpander Exp(*PSE.getSE(), DL, "induction"); + Value *StepValue = Exp.expandCodeFor(StepSCEV, StepSCEV->getType(), + &*Builder.GetInsertPoint()); + return getStepVector(Val, Start, StepValue); +} + +Value *SearchLoopVectorizer::getStepVector(Value *Val, int Start, Value *Step) { + Type *Ty = Val->getType()->getScalarType(); + return getStepVector(Val, ConstantInt::get(Ty, Start), Step); +} + +Value *SearchLoopVectorizer::getStepVector(Value *Val, Value *Start, + Value *Step) { + assert(Val->getType()->isVectorTy() && "Must be a vector"); + assert(Val->getType()->getScalarType()->isIntegerTy() && + "Elem must be an integer"); + assert(Step->getType() == Val->getType()->getScalarType() && + "Step has wrong type"); + + VectorType *Ty = cast(Val->getType()); + Value *One = ConstantInt::get(Start->getType(), 1); + + // Create a vector of consecutive numbers from Start to Start+VF + Value *Cv = Builder.CreateSeriesVector(Ty->getElementCount(), Start, One); + + // Add the consecutive indices to the vector value. + assert(Cv->getType() == Val->getType() && "Invalid consecutive vec"); + Step = Builder.CreateVectorSplat(Ty->getElementCount(), Step); + // FIXME: The newly created binary instructions should contain nsw/nuw flags, + // which can be found from the original scalar operations. + Step = Builder.CreateMul(Cv, Step); + return Builder.CreateAdd(Val, Step, "induction"); +} + +Constant *SearchLoopVectorizer::getRuntimeVF(Type *Ty) { + Constant *EC = ConstantInt::get(Ty, VF); + if (Scalable) + EC = ConstantExpr::getMul(VScale::get(Ty), EC); + + return EC; +} + +Value *SearchLoopVectorizer::getExclusivePartition(Value *Pred) { + Type *Ty = Pred->getType(); + Value *PTrue = ConstantInt::getTrue(Ty); + Value *InvertedPred = Builder.CreateNot(Pred); + Module *M = Builder.GetInsertBlock()->getParent()->getParent(); + Function *F = Intrinsic::getDeclaration(M, Intrinsic::aarch64_sve_brkb_z, Ty); + return Builder.CreateCall(F, { PTrue, InvertedPred }, "brkb.z"); +} + +Value *SearchLoopVectorizer::getInclusivePartition(Value *Pred) { + Type *Ty = Pred->getType(); + Value *PTrue = ConstantInt::getTrue(Ty); + Value *InvertedPred = Builder.CreateNot(Pred); + Module *M = Builder.GetInsertBlock()->getParent()->getParent(); + Function *F = Intrinsic::getDeclaration(M, Intrinsic::aarch64_sve_brka_z, Ty); + return Builder.CreateCall(F, { PTrue, InvertedPred }, "brka.z"); +} + +SearchLoopVectorizer::VectorParts& +SearchLoopVectorizer::getVectorValue(Value *V, ValueMap& WidenMap) { + assert(!V->getType()->isVectorTy() && "Can't widen a vector"); + + // If we have a stride that is replaced by one, do it here. + if (Legal->hasStride(V)) + V = ConstantInt::get(V->getType(), 1); + + // If we have this scalar in the map, return it. + if (WidenMap.has(V)) + return WidenMap.get(V); + + // If this scalar is unknown, assume that it is a constant or that it is + // loop invariant. Broadcast V and save the value for future uses. + Value *B = getBroadcastInstrs(V); + return WidenMap.splat(V, B); +} + +Value *SearchLoopVectorizer::reverseVector(Value *Vec) { + assert(Vec->getType()->isVectorTy() && "Invalid type"); + VectorType *Ty = cast(Vec->getType()); + + // i32 reverse_mask[n] = { n-1, n-2...1, 0 } + Value *RuntimeVF = getRuntimeVF(Builder.getInt32Ty()); + Value *Start = Builder.CreateSub(RuntimeVF, Builder.getInt32(1)); + Value *Step = ConstantInt::get(Start->getType(), -1, true); + Value *Mask = Builder.CreateSeriesVector({VF,Scalable}, Start, Step); + + return Builder.CreateShuffleVector(Vec, UndefValue::get(Ty), Mask, "reverse"); +} + +void SearchLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, + ValueMap &WidenMap) { + // Attempt to issue a wide load. + LoadInst *LI = dyn_cast(Instr); + StoreInst *SI = dyn_cast(Instr); + + assert((LI || SI) && "Invalid Load/Store instruction"); + + Type *ScalarDataTy = LI ? LI->getType() : SI->getValueOperand()->getType(); + Type *DataTy = VectorType::get(ScalarDataTy, VF, Scalable); + Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); + unsigned Alignment = LI ? LI->getAlignment() : SI->getAlignment(); + // An alignment of 0 means target abi alignment. We need to use the scalar's + // target abi alignment in such a case. + const DataLayout &DL = Instr->getModule()->getDataLayout(); + if (!Alignment) + Alignment = DL.getABITypeAlignment(ScalarDataTy); + unsigned AddressSpace = Ptr->getType()->getPointerAddressSpace(); + unsigned ScalarAllocatedSize = DL.getTypeAllocSize(ScalarDataTy); + unsigned VectorElementSize = DL.getTypeStoreSize(DataTy) / VF; + + if (SI && Legal->blockNeedsPredication(SI->getParent()) && + !Legal->isMaskRequired(SI)) + return scalarizeInstruction(Instr, WidenMap, true); + + if (ScalarAllocatedSize != VectorElementSize) + return scalarizeInstruction(Instr, WidenMap); + + // If the pointer is loop invariant or if it is non-consecutive, + // scalarize the load. + int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); + bool Reverse = ConsecutiveStride < 0; + bool UniformLoad = LI && Legal->isUniform(Ptr); + + // TODO: Check for S/G capability instead of scalable? + if (!isScalable() && (!ConsecutiveStride || UniformLoad)) + return scalarizeInstruction(Instr, WidenMap); + + Constant *Zero = Builder.getInt32(0); + VectorParts &Entry = WidenMap.get(Instr); + + // TODO: Is there a bit of metadata to check for possible aliasing? + if (UniformLoad) { + assert(isScalable() && "non-WA Uniform loads should have been scalarized"); + + // Generate a scalar load... + Instruction *NewLI = Builder.CreateLoad(Ptr); + propagateMetadata(NewLI, LI); + + // ... and splat it. + Entry[0] = Builder.CreateVectorSplat({VF,Scalable}, NewLI, "uniform_load"); + + return; + } + + // Handle scatter/gather loads/stores... + if (std::abs(ConsecutiveStride) != 1) { + VectorParts Mask = createBlockInMask(Instr->getParent(), WidenMap); + + if (LI) { + assert(&WidenMap != &NextBodyWMap && + "First faulting gather not yet supported"); + + VectorParts &Ptrs = getVectorValue(Ptr, WidenMap); + Value *P = Predicate; + if (Legal->isMaskRequired(LI)) + P = Builder.CreateAnd(P, Mask[0]); + + auto *NewLI = Builder.CreateMaskedGather(Ptrs[0], Alignment, P); + propagateMetadata(NewLI, LI); + Entry[0] = NewLI; + } + + if (SI) { + VectorParts &Ptrs = getVectorValue(Ptr, WidenMap); + VectorParts &Vals = getVectorValue(SI->getValueOperand(), WidenMap); + Value *P = Predicate; + if (Legal->isMaskRequired(SI)) + P = Builder.CreateAnd(P, Mask[0]); + + auto *NewSI = Builder.CreateMaskedScatter(Vals[0], Ptrs[0], + Alignment, P); + propagateMetadata(NewSI, SI); + } + + return; + } else { + // Handle consecutive loads/stores. + assert(ConsecutiveStride && "Consecutive load/store expected."); + + GetElementPtrInst *Gep = getGEPInstruction(Ptr); + if (Gep && Legal->isInductionVariable(Gep->getPointerOperand())) { + setDebugLocFromInst(Builder, Gep); + Value *PtrOperand = Gep->getPointerOperand(); + Value *FirstBasePtr = getVectorValue(PtrOperand, WidenMap)[0]; + FirstBasePtr = Builder.CreateExtractElement(FirstBasePtr, Zero); + + // Create the new GEP with the new induction variable. + GetElementPtrInst *Gep2 = cast(Gep->clone()); + Gep2->setOperand(0, FirstBasePtr); + Gep2->setName("gep.indvar.base"); + Ptr = Builder.Insert(Gep2); + } else if (Gep) { + setDebugLocFromInst(Builder, Gep); + ScalarEvolution *SE = PSE.getSE(); + assert(SE->isLoopInvariant(SE->getSCEV(Gep->getPointerOperand()), + OrigLoop) && "Base ptr must be invariant"); + + // The last index does not have to be the induction. It can be + // consecutive and be a function of the index. For example A[I+1]; + unsigned NumOperands = Gep->getNumOperands(); + unsigned InductionOperand = getGEPInductionOperand(Gep); + // Create the new GEP with the new induction variable. + GetElementPtrInst *Gep2 = cast(Gep->clone()); + + for (unsigned i = 0; i < NumOperands; ++i) { + Value *GepOperand = Gep->getOperand(i); + Instruction *GepOperandInst = dyn_cast(GepOperand); + + // Update last index or loop invariant instruction anchored in loop. + if (i == InductionOperand || + (GepOperandInst && OrigLoop->contains(GepOperandInst))) { + assert((i == InductionOperand || + SE->isLoopInvariant(SE->getSCEV(GepOperandInst), OrigLoop)) && + "Must be last index or loop invariant"); + + VectorParts &GEPParts = getVectorValue(GepOperand, WidenMap); + Value *Index = GEPParts[0]; + Index = Builder.CreateExtractElement(Index, Zero); + Gep2->setOperand(i, Index); + Gep2->setName("gep.indvar.idx"); + } + } + Ptr = Builder.Insert(Gep2); + } else { + // TODO: Ideally this should be used for all contiguous accesses. + setDebugLocFromInst(Builder, Ptr); + VectorParts &PtrVal = getVectorValue(Ptr, WidenMap); + Ptr = Builder.CreateExtractElement(PtrVal[0], Zero); + } + } + + Type *DataPtrTy = DataTy->getPointerTo(AddressSpace); + Value *Mask = Predicate; + if (Legal->isMaskRequired(Instr)) + Mask = createBlockInMask(Instr->getParent(), WidenMap)[0]; + + // Handle Stores: + if (SI) { + assert(!Legal->isUniform(SI->getPointerOperand()) && + "We do not allow storing to uniform addresses"); + setDebugLocFromInst(Builder, SI); + // We don't want to update the value in the map as it might be used in + // another expression. So don't use a reference type for "StoredVal". + VectorParts StoredVal = getVectorValue(SI->getValueOperand(), WidenMap); + + // Calculate the index for the specific unroll-part. + Value *VecPtr; + if (Reverse) { + // If the address is consecutive but reversed, then the + // wide store needs to start at the last vector element. + VecPtr = Builder.CreateGEP(Ptr, Builder.getInt32(1)); + VecPtr = Builder.CreateBitCast(VecPtr, DataPtrTy); + VecPtr = Builder.CreateGEP(VecPtr, Builder.getInt32(-1)); + } else { + VecPtr = Builder.CreateBitCast(Ptr, DataPtrTy); + VecPtr = Builder.CreateGEP(VecPtr, Builder.getInt32(0)); + } + + Value *Data = StoredVal[0]; + if (Reverse) + Data = reverseVector(Data); + + Instruction* NewSI; + if (Legal->isMaskRequired(SI)) { + Mask = Builder.CreateAnd(Mask, Predicate); + + if (Reverse) + Mask = reverseVector(Mask); + + NewSI = Builder.CreateMaskedStore(Data, VecPtr, Alignment, Mask); + } else { + Value* P = Reverse ? reverseVector(Predicate) : Predicate; + NewSI = Builder.CreateMaskedStore(Data, VecPtr, Alignment, P); + } + + propagateMetadata(NewSI, SI); + return; + } + + // Handle loads. + assert(LI && "Must have a load instruction"); + setDebugLocFromInst(Builder, LI); + // Calculate the pointer for the specific unroll-part. + Value *VecPtr; + if (Reverse) { + // If the address is consecutive but reversed, then the + // wide load needs to start at the last vector element. + VecPtr = Builder.CreateGEP(Ptr, Builder.getInt32(1)); + VecPtr = Builder.CreateBitCast(VecPtr, DataPtrTy); + VecPtr = Builder.CreateGEP(VecPtr, Builder.getInt32(-1)); + } else { + VecPtr = Builder.CreateBitCast(Ptr, DataPtrTy); + VecPtr = Builder.CreateGEP(VecPtr, Builder.getInt32(0)); + } + + Instruction *NewLI, *NewLIExtr = nullptr; + unsigned Part = 0; + if (Legal->isMaskRequired(LI)) { + Mask = Builder.CreateAnd(Mask, Predicate); + + if (Reverse) + Mask = reverseVector(Mask); + + // Every speculative load in BodyMap and VTailWMap is guaranteed + // to be non-faulting, as it cannot be a load for one of the loop + // exit conditions. + // TODO: Also support first faulting gather loads. + if (Legal->isUncountedLoop() && + &WidenMap != &BodyWMap && + &WidenMap != &VTailWMap) { + NewLI = Builder.CreateMaskedSpecLoad(VecPtr, Alignment, Mask, + UndefValue::get(DataTy), + "wide.masked.specload"); + + auto *SpeculativeLanes = Builder.CreateExtractValue(NewLI, 1); + SpeculativePredicates.insert(std::make_pair(LI, SpeculativeLanes)); + NewLIExtr = cast(Builder.CreateExtractValue(NewLI, 0)); + } else { + NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment, Mask, + UndefValue::get(DataTy), + "wide.masked.load"); + } + } else { + Value* P = Reverse ? reverseVector(Predicate) : Predicate; + if (Legal->isUncountedLoop() && + &WidenMap != &BodyWMap && + &WidenMap != &VTailWMap) { + NewLI = Builder.CreateMaskedSpecLoad(VecPtr, Alignment, P, + UndefValue::get(DataTy), + "wide.masked.specload"); + + auto *SpeculativeLanes = Builder.CreateExtractValue(NewLI, 1); + SpeculativePredicates.insert(std::make_pair(LI, SpeculativeLanes)); + NewLIExtr = cast(Builder.CreateExtractValue(NewLI, 0)); + } else { + NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment, P, + UndefValue::get(DataTy), + "wide.masked.load"); + } + } + + propagateMetadata(NewLI, LI); + NewLI = NewLIExtr ? NewLIExtr : NewLI; + Entry[Part] = Reverse ? reverseVector(NewLI) : NewLI; +} + +/// Depending on the access pattern, either of three things happen with +/// the GetElementPtr instruction: +/// - GEP is loop invariant: +/// Nothing +/// - GEP is affine function of loop iteration counter: +/// GEP is replaced by a seriesvector(%ptr, %stride) +/// - GEP is not affine: +/// - GEP pointer is a vectorized GEP instruction:: +/// GEP is replaced by a vector of pointers using arithmetic +void SearchLoopVectorizer::vectorizeGEPInstruction(Instruction *Instr, + ValueMap &WidenMap) { + GetElementPtrInst *Gep = cast(Instr); + + if (!isScalable()) { + scalarizeInstruction(Instr, WidenMap); + return; + } + + ScalarEvolution *SE = PSE.getSE(); + // Must be uniform, handled by vectorizeMemoryInstruction() + if (SE->isLoopInvariant(SE->getSCEV(Gep), OrigLoop)) + return; + + // Handle all non loop invariant forms that are not affine, so that + // when used as address it can be transformed into a gather load/store, + // or when used as pointer arithmetic, it is just vectorized into + // arithmetic instructions. + auto *SAR = dyn_cast(SE->getSCEV(Gep)); + if (!SAR || !SAR->isAffine()) { + vectorizeArithmeticGEP(Gep, WidenMap); + return; + } + + // Create SCEV expander for Start- and StepValue + const DataLayout &DL = Instr->getModule()->getDataLayout(); + SCEVExpander Expander(*SE, DL, "seriesgep"); + + // Expand step and start value (the latter in preheader) + const SCEV *StepRec = SAR->getStepRecurrence(*SE); + // TODO: Make sure of the insert point!!! + Value *StepValue = Expander.expandCodeFor(StepRec, StepRec->getType(), + &*Builder.GetInsertPoint()); + + // Try to find a smaller type for StepValue + const SCEV *BETC = SE->getMaxBackedgeTakenCount(OrigLoop); + if (auto * MaxIters = dyn_cast(BETC)) { + if (auto * CI = dyn_cast(StepValue)) { + // RequiredBits = active_bits(max_iterations * step_value) + APInt MaxItersV = MaxIters->getValue()->getValue(); + if (CI->isNegative()) + MaxItersV = MaxItersV.sextOrSelf(CI->getValue().getBitWidth()); + else + MaxItersV = MaxItersV.zextOrSelf(CI->getValue().getBitWidth()); + + APInt MaxVal = MaxItersV * CI->getValue(); + + // Try to reduce this type from i64 to something smaller + unsigned RequiredBits = MaxVal.getActiveBits(); + unsigned StepBits = StepValue->getType()->getIntegerBitWidth(); + while (RequiredBits <= StepBits && StepBits >= 32) + StepBits = StepBits >> 1; + + // Truncate the step value + Type *NewStepType = IntegerType::get( + Instr->getParent()->getContext(), StepBits << 1); + StepValue = Builder.CreateTrunc(StepValue, NewStepType); + } + } + + const SCEV *StartRec = SAR->getStart(); + Value *StartValue = nullptr; + StartValue = Expander.expandCodeFor(StartRec, Gep->getType(), + LoopVectorPreHeader->getTerminator()); + + Value *Base = Gep->getPointerOperand(); + Value *Tmp2 = Builder.CreateBitCast(StartValue, + Builder.getInt8PtrTy(Base->getType()->getPointerAddressSpace())); + + // We can zero extend the incoming value, because Induction is + // the unsigned iteration counter. + // TODO: Is this correct? (Preheader value, might not be defined?) + Value *Tmp = InductionStartIdx; + Tmp = Builder.CreateZExtOrTrunc(Tmp, StepValue->getType()); + Tmp = Builder.CreateMul(StepValue, Tmp); + Tmp = Builder.CreateSub(ConstantInt::get(StepValue->getType(), 0), Tmp); + Tmp = Builder.CreateGEP(Tmp2, Tmp); + StartValue = Builder.CreateBitCast(Tmp, StartValue->getType()); + + // Normalize to be in #elements, not bytes + Type *ElemTy = Instr->getType()->getPointerElementType(); + Tmp = ConstantInt::get(StepValue->getType(), DL.getTypeAllocSize(ElemTy)); + StepValue = Builder.CreateSDiv(StepValue, Tmp); + + // Get the dynamic VL + Value *RuntimeVF = getRuntimeVF(StepValue->getType()); + + // Create the series vector + VectorParts &Entry = WidenMap.get(Instr); + + // Induction is always the widest induction type in the loop, + // but if that is not enough for evaluating the step, zero extend is + // fine because Induction is the iteration counter, always unsigned. + Value *IterOffset = Builder.CreateZExtOrTrunc(Induction, StepValue->getType()); + IterOffset = Builder.CreateMul(IterOffset, StepValue); + unsigned Part = 0; + { + // Tmp = part * stride * VL + Value *UnrollOffset = ConstantInt::get(RuntimeVF->getType(), Part); + UnrollOffset = Builder.CreateMul(StepValue, UnrollOffset); + UnrollOffset = Builder.CreateMul(RuntimeVF, UnrollOffset); + + // Adjust offset for unrolled iteration + Value *Offset = Builder.CreateAdd(IterOffset, UnrollOffset); + Offset = Builder.CreateSeriesVector({VF,Scalable}, Offset, StepValue); + + // Address = getelementptr %scalarbase, seriesvector(0, step) + Entry[Part] = Builder.CreateGEP(StartValue, Offset); + } + + propagateMetadata(Entry, Instr); +} + +// Vectorize GEP as arithmetic instructions. +// +// This is required when a given GEP is not used for a load/store operation, +// but rather to implement pointer arithmetic. In this case, the pointer may +// be a vector of pointers (e.g. resulting from a load). +// +// This function makes a ptrtoint->arith->inttoptr transformation. +// +// extern char * reg_names[]; +// void foo(void) { +// for (int i = 0; i < K; i++) +// reg_names[i]--; +// } +// +// %1 = getelementptr inbounds [0 x i8*]* @reg_names, i64 0, i64 %0 +// %2 = bitcast i8** %1 to * +// %wide.load = load * %2, align 8, !tbaa !1 +// %3 = ptrtoint %wide.load to +// %4 = add %3, seriesvector (i64 -1, i64 0) +// %5 = inttoptr %4 to +// %6 = bitcast i8** %1 to * +// store %5, * %6, align 8, !tbaa !1 +void SearchLoopVectorizer::vectorizeArithmeticGEP(Instruction *Instr, + ValueMap &WidenMap) { + assert(isa(Instr) && "Instr is not a GEP"); + GetElementPtrInst *GEP = static_cast(Instr); + + // Used types for inttoptr/ptrtoint transform + Type *OrigPtrType = GEP->getType(); + const DataLayout &DL = GEP->getModule()->getDataLayout(); + Type *IntPtrType = DL.getIntPtrType(GEP->getType()); + + // Constant and Variable elements are kept separate to allow IRBuilder + // to fold the constant before widening it to a vector. + VectorParts &Base = getVectorValue(GEP->getPointerOperand(), WidenMap); + VectorParts &Res = WidenMap.get(Instr); + + unsigned Part = 0; + { + // Pointer To Int (pointer operand) + Res[Part] = Builder.CreatePtrToInt( + Base[Part], VectorType::get(IntPtrType, VF, Scalable)); + + // Collect constants and split up the GEP expression into an arithmetic one. + Value *Cst = ConstantInt::get(IntPtrType, 0, false); + gep_type_iterator GTI = gep_type_begin(*GEP); + for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { + // V is still scalar + Value *V = GEP->getOperand(I); + + if (StructType *STy = GTI.getStructTypeOrNull()) { + // Struct type, get field offset in bytes. Result is always a constant. + assert(isa(V) && "Field offset must be constant"); + + ConstantInt *CI = static_cast(V); + unsigned ByteOffset = + DL.getStructLayout(STy)->getElementOffset(CI->getLimitedValue()); + V = ConstantInt::get(IntPtrType, ByteOffset, false); + } else { + // First transform index to pointer-type + if (V->getType() != IntPtrType) + V = Builder.CreateIntCast(V, IntPtrType, true, "idxprom"); + + Value *TypeAllocSize = ConstantInt::get( + V->getType(), DL.getTypeAllocSize(GTI.getIndexedType()), true); + // Only widen non-constant offsets + if (isa(V)) + V = Builder.CreateMul(V, TypeAllocSize); + else + V = Builder.CreateMul(getVectorValue(V, WidenMap)[Part], + getVectorValue(TypeAllocSize, WidenMap)[Part]); + } + + if (isa(V) && !V->getType()->isVectorTy()) + Cst = Builder.CreateAdd(Cst, V); + else + Res[Part] = Builder.CreateAdd(Res[Part], V); + } + + // Add constant part and create final conversion to original type + Res[Part] = Builder.CreateAdd(Res[Part], + getVectorValue(Cst, WidenMap)[Part]); + Res[Part] = Builder.CreateIntToPtr( + Res[Part], VectorType::get(OrigPtrType, VF, Scalable)); + } +} + +void +SearchLoopVectorizer::scalarizeInstructionWithSubloop(Instruction *Instr, + ValueMap &WidenMap, + bool IfPredicateStore) { + assert(!IfPredicateStore && + "Can't handle scalarizing sub-loop with ifcvt store"); + // Holds vector parameters or scalars, in case of uniform vals. + SmallVector Params; + + setDebugLocFromInst(Builder, Instr); + + // Find all of the vectorized parameters. + for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { + Value *SrcOp = Instr->getOperand(op); + + // If we are accessing the old induction variable, use the new one. + if (SrcOp == OldInduction) { + Params.push_back(getVectorValue(SrcOp, WidenMap)); + continue; + } + + // Try using previously calculated values. + Instruction *SrcInst = dyn_cast(SrcOp); + + // If the src is an instruction that appeared earlier in the basic block + // then it should already be vectorized. + if (SrcInst && OrigLoop->contains(SrcInst)) { + assert(WidenMap.has(SrcInst) && "Source operand is unavailable"); + // The parameter is a vector value from earlier. + Params.push_back(WidenMap.get(SrcInst)); + } else { + // The parameter is a scalar from outside the loop. Maybe even a constant. + VectorParts Scalars; + Scalars.append(1, SrcOp); + Params.push_back(Scalars); + } + } + + assert(Params.size() == Instr->getNumOperands() && + "Invalid number of operands"); + + // Does this instruction return a value ? + bool IsVoidRetTy = Instr->getType()->isVoidTy(); + + Value *UndefVec = + IsVoidRetTy ? nullptr + : UndefValue::get(VectorType::get(Instr->getType(), VF, + isScalable() ? 0 : 1)); + + Instruction *InsertPt = &*Builder.GetInsertPoint(); + BasicBlock *IfBlock = Builder.GetInsertBlock(); + + VectorParts Cond; + Loop *VectorLp = LI->getLoopFor(IfBlock); + + Type *IdxType = Legal->getWidestInductionType(); + VectorParts &VecResults = WidenMap.splat(Instr, UndefVec); + + BasicBlock *CurrentBlock = IfBlock; + // For each unroll 'part': + unsigned Part = 0; + { + // Generate a scalar sub-loop. + BasicBlock *ScalarBlock = + CurrentBlock->splitBasicBlock(InsertPt, "subloop"); + LoopVectorBody.push_back(ScalarBlock); + VectorLp->addBasicBlockToLoop(ScalarBlock, *LI); + Builder.SetInsertPoint(ScalarBlock); + BasicBlock *ResumeBlock = + ScalarBlock->splitBasicBlock(&ScalarBlock->front(), "subloop.resume"); + LoopVectorBody.push_back(ResumeBlock); + VectorLp->addBasicBlockToLoop(ResumeBlock, *LI); + // Remove newly created uncond br. + ScalarBlock->getTerminator()->eraseFromParent(); + + PHINode *ScalarIdx = Builder.CreatePHI(IdxType, 2); + ScalarIdx->setName("scalar.idx"); + ScalarIdx->addIncoming(ConstantInt::get(IdxType, 0), CurrentBlock); + + // Create a new entry in the WidenMap and initialize it to Undef or Null. + // The undef is wrapped in a phi since we need to insert into it on every + // iteration of the subloop. + PHINode *UndefPN = Builder.CreatePHI(UndefVec->getType(), 2); + UndefPN->setName("loopvec"); + UndefPN->addIncoming(UndefVec, CurrentBlock); + + Instruction *ClonedInst = Instr->clone(); + if (!IsVoidRetTy) + ClonedInst->setName(Instr->getName() + ".cloned"); + + // For each operand in the original instruction: + for (unsigned ParamIdx = 0; ParamIdx < Params.size(); ParamIdx++) { + Value *Opnd = Params[ParamIdx][Part]; + Type *OpndVecType = dyn_cast(Opnd->getType()); + // We have a vector operand, so extract the element. + if (OpndVecType) + Opnd = Builder.CreateExtractElement(Opnd, ScalarIdx); + // Replace the operand of the cloned instruction with the scalar. + ClonedInst->setOperand(ParamIdx, Opnd); + } + + Builder.Insert(ClonedInst); + if (!IsVoidRetTy) { + auto VecInsert = + Builder.CreateInsertElement(UndefPN, ClonedInst, ScalarIdx); + VecResults[Part] = VecInsert; + UndefPN->addIncoming(VecInsert, ScalarBlock); + } + + Value *NextScalIdx = + Builder.CreateAdd(ScalarIdx, ConstantInt::get(IdxType, 1)); + ScalarIdx->addIncoming(NextScalIdx, ScalarBlock); + Value *Cmp = + Builder.CreateICmp(ICmpInst::ICMP_ULT, NextScalIdx, + getRuntimeVF(IdxType)); + + Builder.CreateCondBr(Cmp, ScalarBlock, ResumeBlock); + CurrentBlock = ResumeBlock; + + // Record the current dominator information for the vector body blocks. + // TODO: If we need to support if-cvtd blocks then this will need to be + // adapted. + VecBodyDomEdges.push_back(DomEdge(VecBodyPostDom, ScalarBlock)); + VecBodyDomEdges.push_back(DomEdge(ScalarBlock, ResumeBlock)); + VecBodyPostDom = ResumeBlock; + } + Builder.SetInsertPoint(InsertPt); +} + +void SearchLoopVectorizer::scalarizeInstruction(Instruction *Instr, + ValueMap &WidenMap, + bool IfPredicateStore) { + assert(!Instr->getType()->isAggregateType() && "Can't handle vectors"); + if (isScalable()) { + assert(!IfPredicateStore && + "Can't handle WA predicating store scalarization"); + scalarizeInstructionWithSubloop(Instr, WidenMap, IfPredicateStore); + return; + } + + // Holds vector parameters or scalars, in case of uniform vals. + SmallVector Params; + + setDebugLocFromInst(Builder, Instr); + + // Find all of the vectorized parameters. + for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { + Value *SrcOp = Instr->getOperand(op); + + // If we are accessing the old induction variable, use the new one. + if (SrcOp == OldInduction) { + Params.push_back(getVectorValue(SrcOp, WidenMap)); + continue; + } + + // Try using previously calculated values. + Instruction *SrcInst = dyn_cast(SrcOp); + + // If the src is an instruction that appeared earlier in the basic block, + // then it should already be vectorized. + if (SrcInst && OrigLoop->contains(SrcInst)) { + assert(WidenMap.has(SrcInst) && "Source operand is unavailable"); + // The parameter is a vector value from earlier. + Params.push_back(WidenMap.get(SrcInst)); + } else { + // The parameter is a scalar from outside the loop. Maybe even a constant. + VectorParts Scalars; + Scalars.append(1, SrcOp); + Params.push_back(Scalars); + } + } + + assert(Params.size() == Instr->getNumOperands() && + "Invalid number of operands"); + + // Does this instruction return a value ? + bool IsVoidRetTy = Instr->getType()->isVoidTy(); + + Value *UndefVec = IsVoidRetTy ? nullptr : + UndefValue::get(VectorType::get(Instr->getType(), VF)); + // Create a new entry in the WidenMap and initialize it to Undef or Null. + VectorParts &VecResults = WidenMap.splat(Instr, UndefVec); + + VectorParts Cond; + if (IfPredicateStore) { + assert(Instr->getParent()->getSinglePredecessor() && + "Only support single predecessor blocks"); + Cond = createEdgeMask(Instr->getParent()->getSinglePredecessor(), + Instr->getParent(), WidenMap); + } + + // For each vector unroll 'part': + unsigned Part = 0; + { + // For each scalar that we create: + for (unsigned Width = 0; Width < VF; ++Width) { + + // Start if-block. + Value *Cmp = nullptr; + if (IfPredicateStore) { + Cmp = Builder.CreateExtractElement(Cond[Part], Builder.getInt32(Width)); + Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Cmp, ConstantInt::get(Cmp->getType(), 1)); + } + + Instruction *Cloned = Instr->clone(); + if (!IsVoidRetTy) + Cloned->setName(Instr->getName() + ".cloned"); + // Replace the operands of the cloned instructions with extracted scalars. + for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { + Value *Op = Params[op][Part]; + // Param is a vector. Need to extract the right lane. + if (Op->getType()->isVectorTy()) + Op = Builder.CreateExtractElement(Op, Builder.getInt32(Width)); + Cloned->setOperand(op, Op); + } + + // Place the cloned scalar in the new loop. + Builder.Insert(Cloned); + + // If the original scalar returns a value we need to place it in a vector + // so that future users will be able to use it. + if (!IsVoidRetTy) + VecResults[Part] = Builder.CreateInsertElement(VecResults[Part], Cloned, + Builder.getInt32(Width)); + // End if-block. + if (IfPredicateStore) + PredicatedStores.push_back(std::make_pair(cast(Cloned), + Cmp)); + } + } +} + +std::pair +SearchLoopVectorizer::addStrideCheck(Instruction *Loc) { + Instruction *tnullptr = nullptr; + if (!Legal->mustCheckStrides()) + return std::pair(tnullptr, tnullptr); + + IRBuilder<> ChkBuilder(Loc); + + // Emit checks. + Value *Check = nullptr; + Instruction *FirstInst = nullptr; + for (SmallPtrSet::iterator SI = Legal->strides_begin(), + SE = Legal->strides_end(); + SI != SE; ++SI) { + Value *Ptr = stripIntegerCast(*SI); + Value *C = ChkBuilder.CreateICmpNE(Ptr, ConstantInt::get(Ptr->getType(), 1), + "stride.chk"); + // Store the first instruction we create. + FirstInst = getFirstInst(FirstInst, C, Loc); + if (Check) + Check = ChkBuilder.CreateOr(Check, C); + else + Check = C; + } + + // We have to do this trickery because the IRBuilder might fold the check to a + // constant expression in which case there is no Instruction anchored in a + // the block. + LLVMContext &Ctx = Loc->getContext(); + Instruction *TheCheck = + BinaryOperator::CreateAnd(Check, ConstantInt::getTrue(Ctx)); + ChkBuilder.Insert(TheCheck, "stride.not.one"); + FirstInst = getFirstInst(FirstInst, TheCheck, Loc); + + return std::make_pair(FirstInst, TheCheck); +} + +PHINode *SearchLoopVectorizer::createInductionVariable(Loop *L, Value *Start, + Value *End, Value *Step, + Instruction *DL) { + BasicBlock *Header = L->getHeader(); + BasicBlock *Latch = L->getLoopLatch(); + // As we're just creating this loop, it's possible no latch exists + // yet. If so, use the header as this will be a single block loop. + if (!Latch) + Latch = Header; + + IRBuilder<> Builder(&*Header->getFirstInsertionPt()); + setDebugLocFromInst(Builder, getDebugLocFromInstOrOperands(OldInduction)); + + auto *PredTy = VectorType::get(Builder.getInt1Ty(), VF, Scalable); + auto *AllActive = ConstantInt::getTrue(PredTy); + + auto *Induction = Builder.CreatePHI(Start->getType(), 2, "index"); + Predicate = Builder.CreatePHI(PredTy, 2, "predicate"); + + Builder.SetInsertPoint(Latch->getTerminator()); + + // Create i+1 and fill the PHINode. + Value *Next = Builder.CreateAdd(Induction, Step, "index.next"); + Induction->addIncoming(Start, L->getLoopPreheader()); + Induction->addIncoming(Next, Latch); + + // Even though all lanes are active some code paths require a predicate. + dyn_cast(Predicate)->addIncoming(AllActive, L->getLoopPreheader()); + dyn_cast(Predicate)->addIncoming(AllActive, Latch); + + // Create the compare. + Value *ICmp = Builder.CreateICmpEQ(Next, End); + Builder.CreateCondBr(ICmp, L->getExitBlock(), Header); + + // Now we have two terminators. Remove the old one from the block. + Latch->getTerminator()->eraseFromParent(); + + return Induction; +} + +Value *SearchLoopVectorizer::getOrCreateTripCount(Loop *L) { + if (TripCount) + return TripCount; + + IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); + // Find the loop boundaries. + ScalarEvolution *SE = PSE.getSE(); + const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(OrigLoop); + assert(BackedgeTakenCount != SE->getCouldNotCompute() && "Invalid loop count"); + + Type *IdxTy = Legal->getWidestInductionType(); + + // The exit count might have the type of i64 while the phi is i32. This can + // happen if we have an induction variable that is sign extended before the + // compare. The only way that we get a backedge taken count is that the + // induction variable was signed and as such will not overflow. In such a case + // truncation is legal. + if (BackedgeTakenCount->getType()->getPrimitiveSizeInBits() > + IdxTy->getPrimitiveSizeInBits()) + BackedgeTakenCount = SE->getTruncateOrNoop(BackedgeTakenCount, IdxTy); + BackedgeTakenCount = SE->getNoopOrZeroExtend(BackedgeTakenCount, IdxTy); + + // Get the total trip count from the count by adding 1. + const SCEV *ExitCount = SE->getAddExpr( + BackedgeTakenCount, SE->getOne(BackedgeTakenCount->getType())); + + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + + // Expand the trip count and place the new instructions in the preheader. + // Notice that the pre-header does not change, only the loop body. + SCEVExpander Exp(*SE, DL, "induction"); + + // Count holds the overall loop count (N). + TripCount = Exp.expandCodeFor(ExitCount, ExitCount->getType(), + L->getLoopPreheader()->getTerminator()); + + if (TripCount->getType()->isPointerTy()) + TripCount = + CastInst::CreatePointerCast(TripCount, IdxTy, + "exitcount.ptrcnt.to.int", + L->getLoopPreheader()->getTerminator()); + + return TripCount; +} + +Value *SearchLoopVectorizer::getOrCreateVectorTripCount(Loop *L) { + if (VectorTripCount) + return VectorTripCount; + + Value *TC = getOrCreateTripCount(L); + IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); + + // Now we need to generate the expression for N - (N % VF), which is + // the part that the vectorized body will execute. + // The loop step is equal to the vectorization factor (num of SIMD elements) + // times the unroll factor (num of SIMD instructions). + Value *R = Builder.CreateURem(TC, InductionStep, "n.mod.vf"); + VectorTripCount = Builder.CreateSub(TC, R, "n.vec"); + + return VectorTripCount; +} + +void SearchLoopVectorizer::emitMinimumIterationCountCheck(Loop *L, + BasicBlock *Bypass) { + Value *Count = getOrCreateTripCount(L); + BasicBlock *BB = L->getLoopPreheader(); + IRBuilder<> Builder(BB->getTerminator()); + + // Generate code to check that the loop's trip count that we computed by + // adding one to the backedge-taken count will not overflow. + Value *CheckMinIters = Builder.CreateICmpULT(Count, InductionStep, + "min.iters.check"); + + BasicBlock *NewBB = BB->splitBasicBlock(BB->getTerminator(), + "min.iters.checked"); + if (L->getParentLoop()) + L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); + ReplaceInstWithInst(BB->getTerminator(), + BranchInst::Create(Bypass, NewBB, CheckMinIters)); + LoopBypassBlocks.push_back(BB); +} + +void SearchLoopVectorizer::emitVectorLoopEnteredCheck(Loop *L, + BasicBlock *Bypass) { + Value *TC = getOrCreateVectorTripCount(L); + BasicBlock *BB = L->getLoopPreheader(); + IRBuilder<> Builder(BB->getTerminator()); + + // Now, compare the new count to zero. If it is zero skip the vector loop and + // jump to the scalar loop. + Value *Cmp = Builder.CreateICmpEQ(TC, Constant::getNullValue(TC->getType()), + "cmp.zero"); + + // Generate code to check that the loop's trip count that we computed by + // adding one to the backedge-taken count will not overflow. + BasicBlock *NewBB = BB->splitBasicBlock(BB->getTerminator(), + "vector.ph"); + if (L->getParentLoop()) + L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); + ReplaceInstWithInst(BB->getTerminator(), + BranchInst::Create(Bypass, NewBB, Cmp)); + LoopBypassBlocks.push_back(BB); +} + +void SearchLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { + BasicBlock *BB = L->getLoopPreheader(); + + // Generate the code to check that the SCEV assumptions that we made. + // We want the new basic block to start at the first instruction in a + // sequence of instructions that form a check. + SCEVExpander Exp(*PSE.getSE(), Bypass->getModule()->getDataLayout(), "scev.check"); + Value *SCEVCheck = Exp.expandCodeForPredicate(&PSE.getUnionPredicate(), + BB->getTerminator()); + + if (auto *C = dyn_cast(SCEVCheck)) + if (C->isZero()) + return; + + // Create a new block containing the stride check. + BB->setName("vector.scevcheck"); + auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); + if (L->getParentLoop()) + L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); + ReplaceInstWithInst(BB->getTerminator(), + BranchInst::Create(Bypass, NewBB, SCEVCheck)); + LoopBypassBlocks.push_back(BB); + AddedSafetyChecks = true; +} + +// TODO: Are all these checks required if we have FF/NF/predication? +void SearchLoopVectorizer::emitMemRuntimeChecks(Loop *L, + BasicBlock *Bypass) { + BasicBlock *BB = L->getLoopPreheader(); + + // Generate the code that checks in runtime if arrays overlap. We put the + // checks into a separate block to make the more common case of few elements + // faster. + Instruction *FirstCheckInst; + Instruction *MemRuntimeCheck; + std::tie(FirstCheckInst, MemRuntimeCheck) = + Legal->getLAI()->addRuntimeChecks(BB->getTerminator()); + if (!MemRuntimeCheck) + return; + + // Create a new block containing the memory check. + BB->setName("vector.memcheck"); + auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); + if (L->getParentLoop()) + L->getParentLoop()->addBasicBlockToLoop(NewBB, *LI); + ReplaceInstWithInst(BB->getTerminator(), + BranchInst::Create(Bypass, NewBB, MemRuntimeCheck)); + LoopBypassBlocks.push_back(BB); + AddedSafetyChecks = true; +} + +// TODO: Pointer chase pre-vec scalar loop? +// TODO: Different 'empty' loop structures based on loop type? +// Plain EE, ptr chase, multi-backedge? Or just the +// single creator func, but calling off to appropriate +// helper funcs based on type? +void SearchLoopVectorizer::createEmptyLoopWithPredication() { + /* + In this function we generate a new loop. The new loop will contain + the vectorized instructions while the old loop will continue to run the + scalar remainder. + + v + [ ] <-- Back-edge taken count overflow check. + / \ + | [ ] <-- vector loop bypass (may consist of multiple blocks). + | / \ + | [ ] \ <-- vector pre header. + | | \ | + | />[ ] || + | |_[ ] || <-- vector loop. + | | || + | [ ][ ] | <-- scalable reduction loop. + | |_[ ] | + | | | + | [ ] | <-- return point from scalable reduction loop. + | | | + |<---/ | + | [ ] <-- scalar preheader. + | | + | [ ]<\ + | [ ]_| <-- old scalar loop to handle remainder. + | | + |<-------/ + v + [ ] <-- exit block. + ... + */ + + BasicBlock *OldBasicBlock = OrigLoop->getHeader(); + BasicBlock *BypassBlock = OrigLoop->getLoopPreheader(); + + ScalarEvolution *SE = PSE.getSE(); + const SCEV *ExitCount = nullptr; + auto &Exit = (*(Legal->getLoopExits()))[0]; + // If the latch condition is counted, we can do some special stuff + if (Exit.ExitingBlock == OrigLoop->getLoopLatch() && Exit.Kind == EK_Counted) + ExitCount = SE->getExitCount(OrigLoop, Exit.ExitingBlock); + BasicBlock *ExitBlock = Exit.ExitBlock; + + assert(BypassBlock && "Invalid loop structure"); + assert(ExitBlock && "Must have an exit block"); + + // Some loops have a single integer induction variable, while other loops + // don't. One example is c++ iterators that often have multiple pointer + // induction variables. In the code below we also support a case where we + // don't have a single induction variable. + // TODO: Move types to SearchLoopVectorizer? + OldInduction = Legal->getInduction(); + Type *IdxTy = Legal->getWidestInductionType(); + Type *PredTy = Builder.getInt1Ty(); + Type *PredVecTy = VectorType::get(PredTy, VF, Scalable); + + // Find the loop boundaries. + assert(ExitCount != SE->getCouldNotCompute() && "Invalid loop count"); + + const SCEV *BackedgeTakeCount = nullptr; + if (ExitCount && ExitCount != SE->getCouldNotCompute()) { + // The exit count might have the type of i64 while the phi is i32. This can + // happen if we have an induction variable that is sign extended before the + // compare. The only way that we get a backedge taken count is that the + // induction variable was signed and as such will not overflow. In such a case + // truncation is legal. + if (ExitCount->getType()->getPrimitiveSizeInBits() > + IdxTy->getPrimitiveSizeInBits()) + ExitCount = SE->getTruncateOrNoop(ExitCount, IdxTy); + + BackedgeTakeCount = SE->getNoopOrZeroExtend(ExitCount, IdxTy); + // Get the total trip count from the count by adding 1. + ExitCount = SE->getAddExpr(BackedgeTakeCount, + SE->getConstant(BackedgeTakeCount->getType(), 1)); + } + + const DataLayout &DL = OldBasicBlock->getModule()->getDataLayout(); + + // Extract the nsw/nuw flags from the induction variable expression if + // possible; loops without an integer induction variable won't be able + // to set these flags, but if the loop predicate is constructed from + // constant values we will usually be able to optimize later anyway. + if (OldInduction) { + const SCEVNAryExpr *IVSCEV = dyn_cast(SE->getSCEV(OldInduction)); + IndVarNoWrap = IVSCEV && (IVSCEV->getNoWrapFlags() & SCEV::FlagNW); + } + + // Split the single block loop into the two loop structure described above. + LoopVectorPreHeader = + BypassBlock->splitBasicBlock(BypassBlock->getTerminator(), "vector.ph"); + BasicBlock *VecBody = + LoopVectorPreHeader->splitBasicBlock(LoopVectorPreHeader->getTerminator(), + "vector.body.unpred"); + BasicBlock *VectorTail = + VecBody->splitBasicBlock(VecBody->getTerminator(), "vector.tail"); + BasicBlock *ScalarPH = + VectorTail->splitBasicBlock(VectorTail->getTerminator(), "scalar.ph"); + BasicBlock *ReductionLoopLocal = + VectorTail->splitBasicBlock(VectorTail->getTerminator(), + "reduction.loop"); + BasicBlock *ReductionLoopRetLocal = + ReductionLoopLocal->splitBasicBlock(ReductionLoopLocal->getTerminator(), + "reduction.loop.ret"); + + // Create and register the new vector loop. + SLVLoop = LI->AllocateLoop(); + Loop *ParentLoop = OrigLoop->getParentLoop(); + + // Insert the new loop into the loop nest and register the new basic blocks + // before calling any utilities such as SCEV that require valid LoopInfo. + bool HasReductionLoop = Legal->hasScalarizedReduction() && isScalable(); + + if (ParentLoop) { + ParentLoop->addChildLoop(SLVLoop); + ParentLoop->addBasicBlockToLoop(ScalarPH, *LI); + ParentLoop->addBasicBlockToLoop(LoopVectorPreHeader, *LI); + ParentLoop->addBasicBlockToLoop(VectorTail, *LI); + + if(!HasReductionLoop) + ParentLoop->addBasicBlockToLoop(ReductionLoopLocal, *LI); + + ParentLoop->addBasicBlockToLoop(ReductionLoopRetLocal, *LI); + } else { + LI->addTopLevelLoop(SLVLoop); + } + SLVLoop->addBasicBlockToLoop(VecBody, *LI); + + // *************************************************************************** + // Start of BypassBlock (NOTE: may contain multiple blocks) + // *************************************************************************** + + // Expand the trip count and place the new instructions in the preheader. + // Notice that the pre-header does not change, only the loop body. + SCEVExpander Exp(*SE, DL, "induction"); + + // We need to test whether the backedge-taken count is uint##_max. Adding one + // to it will cause overflow and an incorrect loop trip count in the vector + // body. In case of overflow we want to directly jump to the scalar remainder + // loop. + Instruction *CheckBCOverflow = nullptr; + if (BackedgeTakeCount) { + Value *BackedgeCount = + Exp.expandCodeFor(BackedgeTakeCount, BackedgeTakeCount->getType(), + BypassBlock->getTerminator()); + if (BackedgeCount->getType()->isPointerTy()) + BackedgeCount = CastInst::CreatePointerCast(BackedgeCount, IdxTy, + "backedge.ptrcnt.to.int", + BypassBlock->getTerminator()); + CheckBCOverflow = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, BackedgeCount, + Constant::getAllOnesValue(BackedgeCount->getType()), + "backedge.overflow", BypassBlock->getTerminator()); + } + + // The loop index does not have to start at Zero. Find the original start + // value from the induction PHI node. If we don't have an induction variable + // then we know that it starts at zero. + Builder.SetInsertPoint(BypassBlock->getTerminator()); + Value *StartIdx = ExtendedIdx = OldInduction ? + Builder.CreateZExt(OldInduction->getIncomingValueForBlock(ScalarPH), + IdxTy) : + ConstantInt::get(IdxTy, 0); + + // We need an instruction to anchor the first bypass check on. + Instruction *BypassAnchor = BinaryOperator::CreateAdd( + StartIdx, ConstantInt::get(IdxTy, 0), "bypass.anchor", + BypassBlock->getTerminator()); + + // Count holds the overall loop count (N). + Value *Count = nullptr; + if (ExitCount) { + Count = Exp.expandCodeFor(ExitCount, ExitCount->getType(), + BypassBlock->getTerminator()); + } + + LoopBypassBlocks.push_back(BypassBlock); + + // Update LoopInfo analysis to include the loop used by reduction variables. + if (HasReductionLoop) { + Loop *ReductionLp = LI->AllocateLoop();; + + if (ParentLoop) + ParentLoop->addChildLoop(ReductionLp); + else + LI->addTopLevelLoop(ReductionLp); + + ReductionLp->addBasicBlockToLoop(ReductionLoopLocal, *LI); + } + + // This is the IR builder that we use to add all of the logic for bypassing + // the new vector loop. + IRBuilder<> BypassBuilder(BypassBlock->getTerminator()); + setDebugLocFromInst(BypassBuilder, + getDebugLocFromInstOrOperands(OldInduction)); + + // The loop step is equal to the vectorization factor (num of SIMD elements) + // times the unroll factor (num of SIMD instructions). + // TODO: Unrolling is current disabled. + InductionStep = getRuntimeVF(IdxTy); + + // We may need to extend the index in case there is a type mismatch. + // We know that the count starts at zero and does not overflow. + IdxEnd = nullptr; + if (Count) + if (Count->getType() != IdxTy) { + // The exit count can be of pointer type. Convert it to the correct + // integer type. + if (ExitCount->getType()->isPointerTy()) + Count = BypassBuilder.CreatePointerCast(Count, IdxTy, "ptrcnt.to.int"); + else + Count = BypassBuilder.CreateZExtOrTrunc(Count, IdxTy, "cnt.cast"); + + // Add the start index to the loop count to get the new end index. + IdxEnd = BypassBuilder.CreateAdd(Count, StartIdx, "end.idx"); + } + + Value *Cmp = nullptr; + BasicBlock *LastBypassBlock = BypassBlock; + + // Generate code to check that the loops trip count that we computed by adding + // one to the backedge-taken count will not overflow. + if (CheckBCOverflow) { + auto PastOverflowCheck = std::next(BasicBlock::iterator(BypassAnchor)); + BasicBlock *CheckBlock = + LastBypassBlock->splitBasicBlock(PastOverflowCheck, "overflow.checked"); + DT->addNewBlock(CheckBlock, LastBypassBlock); + if (ParentLoop) + ParentLoop->addBasicBlockToLoop(CheckBlock, *LI); + LoopBypassBlocks.push_back(CheckBlock); + + Instruction *OldTerm = LastBypassBlock->getTerminator(); + BranchInst::Create(ScalarPH, CheckBlock, CheckBCOverflow, OldTerm); + OldTerm->eraseFromParent(); + + Cmp = CheckBCOverflow; + LastBypassBlock = CheckBlock; + } + + // Generate the code to check that the strides we assumed to be one are really + // one. We want the new basic block to start at the first instruction in a + // sequence of instructions that form a check. + Instruction *StrideCheck; + Instruction *FirstCheckInst; + std::tie(FirstCheckInst, StrideCheck) = + addStrideCheck(LastBypassBlock->getTerminator()); + if (StrideCheck) { + AddedSafetyChecks = true; + // Create a new block containing the stride check. + BasicBlock *CheckBlock = + LastBypassBlock->splitBasicBlock(FirstCheckInst, "vector.stridecheck"); + DT->addNewBlock(CheckBlock, LastBypassBlock); + if (ParentLoop) + ParentLoop->addBasicBlockToLoop(CheckBlock, *LI); + LoopBypassBlocks.push_back(CheckBlock); + + // Replace the branch into the memory check block with a conditional branch + // for the "few elements case". + Instruction *OldTerm = LastBypassBlock->getTerminator(); + if (Cmp) + BranchInst::Create(ScalarPH, CheckBlock, Cmp, OldTerm); + else + BranchInst::Create(CheckBlock, OldTerm); + OldTerm->eraseFromParent(); + + Cmp = StrideCheck; + LastBypassBlock = CheckBlock; + } + + // Generate the code that checks in runtime if arrays overlap. We put the + // checks into a separate block to make the more common case of few elements + // faster. + Instruction *MemRuntimeCheck; + std::tie(FirstCheckInst, MemRuntimeCheck) = + Legal->getLAI()->addRuntimeChecks(LastBypassBlock->getTerminator()); + if (MemRuntimeCheck) { + AddedSafetyChecks = true; + // Create a new block containing the memory check. + BasicBlock *CheckBlock = + LastBypassBlock->splitBasicBlock(FirstCheckInst, "vector.memcheck"); + DT->addNewBlock(CheckBlock, LastBypassBlock); + if (ParentLoop) + ParentLoop->addBasicBlockToLoop(CheckBlock, *LI); + LoopBypassBlocks.push_back(CheckBlock); + + // Replace the branch into the memory check block with a conditional branch + // for the "few elements case". + Instruction *OldTerm = LastBypassBlock->getTerminator(); + if (Cmp) + BranchInst::Create(ScalarPH, CheckBlock, Cmp, OldTerm); + else + BranchInst::Create(CheckBlock, OldTerm); + OldTerm->eraseFromParent(); + + Cmp = MemRuntimeCheck; + LastBypassBlock = CheckBlock; + } + + if (Cmp) { + Instruction *OldTerm = LastBypassBlock->getTerminator(); + BranchInst::Create(ScalarPH, LoopVectorPreHeader, Cmp, OldTerm); + OldTerm->eraseFromParent(); + } + + // BLARF. Move this. Needed for vectorizeGEPInstruction within an uncounted + // loop pre-vec-loadcmp check. + + //TODO: Grabbed from below. (Start of Vecbody) + // Use Builder to create the loop instructions (Phi, Br, Cmp) inside the loop. + Builder.SetInsertPoint(VecBody->getFirstNonPHI()); + setDebugLocFromInst(Builder, getDebugLocFromInstOrOperands(OldInduction)); + // Generate the induction variable. + Induction = Builder.CreatePHI(IdxTy, 2, "index"); + + // These Phis have two incoming values, but right now we only add the + // one coming from the preheader. The other (from the loop latch block) + // will be added in 'vectorizeExits', after everything else has been + // vectorized. This allows predicates from first-faulting loads or other + // instructions to be added in before finalizing the phi. + cast(Induction)->addIncoming(StartIdx, LoopVectorPreHeader); + InductionStartIdx = StartIdx; + // End of reshuffled vecbody + // TODO: Use Builder insert guard RAII within if scope? + + // *************************************************************************** + // End of BypassBlock (NOTE: may contain multiple blocks) + // *************************************************************************** + + // *************************************************************************** + // Start of vector.ph + // *************************************************************************** + + // Set builder insertion points + Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); + setDebugLocFromInst(Builder, getDebugLocFromInstOrOperands(OldInduction)); + + // Set to false, then check. + HasSpeculativeLoads = false; + + // Collect all nodes (instructions) used in each of the conditions + ConditionNodes.clear(); + for (auto &Exit : Legal->exits()) + ConditionNodes.insert(Exit.Nodes.begin(), Exit.Nodes.end()); + + // If a sub-expression of a condition is reused in the loop + // (after the early exit condition), then we need to create a + // PHI node to get the value either from the preheader or from + // the vector loop body. The same also holds for the vector tail. + for (auto &Exit : Legal->exits()) { + // For each sub expression of condition + for (Value *V : Exit.Nodes) { + // It only makes sense for instructions + if (!isa(V)) + continue; + + // TODO: Better way of doing this? + if (isa(V) && !OrigLoop->isLoopInvariant(V)) + HasSpeculativeLoads = true; + + // Induction + Reduction vars are exempt + if (auto *VP = dyn_cast(V)) { + if (Legal->getInductionVars()->count(VP)) + continue; + if (Legal->getReductionVars()->count(VP)) + continue; + } + + // Induction var updates are also exempt + bool IsIndUpdate = false; + for (auto *User : V->users()) { + if (auto *VP = dyn_cast(User)) { + if (Legal->getInductionVars()->count(VP)) { + IsIndUpdate = true; + break; + } + } + } + + if (IsIndUpdate) + continue; + + // Test if subexpression is used in the loop. + for (auto *User : V->users()) { + if (ConditionNodes.count(User)) + continue; + // There is always one branch that uses V + if (isa(User)) + continue; + // If this user is in the loop, but not part of the + // condition, we need to create the PHI for this Value. + if (auto *UI = dyn_cast(User)) + if (OrigLoop->contains(UI)) + MakeTheseIntoPHIs.insert(V); + } + } + } + + // If we have speculative loads, we need to determine whether a psuedo-fault + // occcurred and potentially continue running in the scalar loop. + if (HasSpeculativeLoads) + TailSpeculativeLanes = PHINode::Create(PredVecTy, 2, + "active.speculative.lanes", + VectorTail->getFirstNonPHI()); + + // Calculate early exit conditions + SmallVector Conds; + Value *PTrue = ConstantInt::getTrue(PredVecTy); + PreHeaderOutPred = insertEarlyBreaks(PHMap, StartIdx, Conds, PTrue); + + // Check if it needs to skip over body to vector tail + Value *Done = + getAllTrueReduction(Builder, PreHeaderOutPred, "fullvectorbody"); + Builder.CreateCondBr(Done, VecBody, VectorTail); + + // Cleanup the unconditional branch to vector body. + Instruction *Term = LoopVectorPreHeader->getTerminator(); + Term->eraseFromParent(); + + // *************************************************************************** + // End of vector.ph + // *************************************************************************** + + // *************************************************************************** + // Start of vector.tail (part 1) + // *************************************************************************** + Builder.SetInsertPoint(VectorTail->getFirstNonPHI()); + + // For each exit condition, make a PHI node that either gets the + // calculated value from either the loop, or from the preheader + for (auto *Cond : Conds) { + // Create a PHI node and fill in preheader (other values we do in + // vectorizeExits(), similar to Induction PHI). + Value *VCond = PHMap.get(Cond)[0]; + auto *VCondPHI = Builder.CreatePHI(VCond->getType(), 2); + VCondPHI->addIncoming(VCond, LoopVectorPreHeader); + + // Store vectorized (condition) PHIs in the vector-tail map + VectorParts &VectorParts = VTailWMap.get(Cond); + VectorParts[0] = VCondPHI; + } + + // *************************************************************************** + // End of vector.tail + // *************************************************************************** + + // *************************************************************************** + // Start of vector.body + // *************************************************************************** + Builder.SetInsertPoint(VecBody->getFirstNonPHI()); + setDebugLocFromInst(Builder, getDebugLocFromInstOrOperands(OldInduction)); + + // We don't yet have a condition for the branch, since it may depend on + // instructions within the loop (beyond just the trip count, if any). + // As above, this will be added in 'vectorizeExits'. + LatchBranch = Builder.CreateCondBr(UndefValue::get(Builder.getInt1Ty()), + VecBody, VectorTail); + // Now we have two terminators. Remove the old one from the block. + VecBody->getTerminator()->eraseFromParent(); + + // *************************************************************************** + // End of vector.body + // *************************************************************************** + + // *************************************************************************** + // Start of reduction.loop.ret + // *************************************************************************** + + // The vector body processes all elements so after the reduction we are done. + Instruction *OldTerm = ReductionLoopRetLocal->getTerminator(); + // TODO: When HasSpeculativeLoads is false, we should only plant a branch to + // the exitblock since we won't need the scalar loop to recover from a + // pseudofault. However, until our DomTree is fixed up properly, we resort + // to always planting this conditional branch and just hardcoding a 'true' + // for the condition if we don't need the scalar loop. + BranchInst::Create(ExitBlock, ScalarPH, + UndefValue::get(Builder.getInt1Ty()), OldTerm); + OldTerm->eraseFromParent(); + + // *************************************************************************** + // End of reduction.loop.ret + // *************************************************************************** + + // Get ready to start creating new instructions into the vectorized body. + Builder.SetInsertPoint(&*VecBody->getFirstInsertionPt()); + + // Save the state. + LoopScalarPreHeader = ScalarPH; + LoopVectorTail = VectorTail; + LoopExitBlock = ExitBlock; + LoopVectorBody.push_back(VecBody); + VecBodyPostDom = VecBody; + LoopScalarBody = OldBasicBlock; + ReductionLoop = ReductionLoopLocal; + ReductionLoopRet = ReductionLoopRetLocal; + + SLVLoopVectorizeHints Hints(SLVLoop, true, *ORE); + // TODO: Make sure normal vectorizer doesn't nuke us by setting this for unrolled? + Hints.setAlreadyVectorized(); + + // Predicate for vector body starts out with a PTRUE + Predicate = ConstantInt::getTrue(PredVecTy); + + // Set insertion point for unpredicated vector body + Builder.SetInsertPoint(&*VecBody->getFirstInsertionPt()); +} + +///\brief Perform cse of induction variable instructions. +void SearchLoopVectorizer::CSE(SmallVector &BBs, + SmallSet &PredBlocks) { + // Perform simple cse. + SmallDenseMap CSEMap; + for (unsigned i = 0, e = BBs.size(); i != e; ++i) { + BasicBlock *BB = BBs[i]; + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { + Instruction *In = &*I++; + + if (!CSEDenseMapInfo::canHandle(In)) + continue; + + // Check if we can replace this instruction with any of the + // visited instructions. + if (Instruction *V = CSEMap.lookup(In)) { + In->replaceAllUsesWith(V); + In->eraseFromParent(); + continue; + } + // Ignore instructions in conditional blocks. We create "if (pred) a[i] = + // ...;" blocks for predicated stores. Every second block is a predicated + // block. + if (PredBlocks.count(BBs[i])) + continue; + + CSEMap[In] = In; + } + } +} + + +void SearchLoopVectorizer::truncateToMinimalBitwidths(ValueMap &WidenMap) { + // For every instruction `I` in MinBWs, truncate the operands, create a + // truncated version of `I` and reextend its result. InstCombine runs + // later and will remove any ext/trunc pairs. + // + for (auto &KV : MinBWs) { + VectorParts &Parts = WidenMap.get(KV.first); + for (Value *&I : Parts) { + if (I->use_empty()) + continue; + + // Not every value in the widenmap is an instruction + if (!isa(I)) + continue; + + auto *OriginalTy = dyn_cast(I->getType()); + Type *ScalarTruncatedTy = IntegerType::get(OriginalTy->getContext(), + KV.second); + Type *TruncatedTy = VectorType::get(ScalarTruncatedTy, + OriginalTy->getElementCount()); + if (TruncatedTy == OriginalTy) + continue; + + IRBuilder<> B(cast(I)); + auto ShrinkOperand = [&](Value *V) -> Value* { + if (auto *ZI = dyn_cast(V)) + if (ZI->getSrcTy() == TruncatedTy) + return ZI->getOperand(0); + return B.CreateZExtOrTrunc(V, TruncatedTy); + }; + + // The actual instruction modification depends on the instruction type, + // unfortunately. + Value *NewI = nullptr; + if (BinaryOperator *BO = dyn_cast(I)) { + NewI = B.CreateBinOp(BO->getOpcode(), + ShrinkOperand(BO->getOperand(0)), + ShrinkOperand(BO->getOperand(1))); + cast(NewI)->copyIRFlags(I); + } else if (ICmpInst *CI = dyn_cast(I)) { + NewI = B.CreateICmp(CI->getPredicate(), + ShrinkOperand(CI->getOperand(0)), + ShrinkOperand(CI->getOperand(1))); + } else if (SelectInst *SI = dyn_cast(I)) { + NewI = B.CreateSelect(SI->getCondition(), + ShrinkOperand(SI->getTrueValue()), + ShrinkOperand(SI->getFalseValue())); + } else if (CastInst *CI = dyn_cast(I)) { + switch (CI->getOpcode()) { + default: llvm_unreachable("Unhandled cast!"); + case Instruction::Trunc: + NewI = ShrinkOperand(CI->getOperand(0)); + break; + case Instruction::SExt: + NewI = B.CreateSExtOrTrunc(CI->getOperand(0), + smallestIntegerVectorType(OriginalTy, + TruncatedTy)); + break; + case Instruction::ZExt: + NewI = B.CreateZExtOrTrunc(CI->getOperand(0), + smallestIntegerVectorType(OriginalTy, + TruncatedTy)); + break; + } + } else if (ShuffleVectorInst *SI = dyn_cast(I)) { + auto VTy0 = cast(SI->getOperand(0)->getType()); + auto Elements0 = VTy0->getElementCount(); + auto *O0 = + B.CreateZExtOrTrunc(SI->getOperand(0), + VectorType::get(ScalarTruncatedTy, Elements0)); + auto VTy1 = cast(SI->getOperand(1)->getType()); + auto Elements1 = VTy1->getElementCount(); + auto *O1 = + B.CreateZExtOrTrunc(SI->getOperand(1), + VectorType::get(ScalarTruncatedTy, Elements1)); + + NewI = B.CreateShuffleVector(O0, O1, SI->getMask()); + } else if (isa(I) || isa(I)) { + // Don't do anything with the operands, just extend the result. + continue; + } else { + llvm_unreachable("Unhandled instruction type!"); + } + + // Lastly, extend the result. + NewI->takeName(cast(I)); + Value *Res = B.CreateZExtOrTrunc(NewI, OriginalTy); + I->replaceAllUsesWith(Res); + cast(I)->eraseFromParent(); + I = Res; + } + } + + // We'll have created a bunch of ZExts that are now parentless. Clean up. + for (auto &KV : MinBWs) { + VectorParts &Parts = WidenMap.get(KV.first); + for (Value *&I : Parts) { + ZExtInst *Inst = dyn_cast(I); + if (Inst && Inst->use_empty()) { + Value *NewI = Inst->getOperand(0); + Inst->eraseFromParent(); + I = NewI; + } + } + } +} + + +PHINode *SearchLoopVectorizer::createTailPhiFromPhi(PHINode *PN, + const Twine &Name) { + // Get incoming values from PH and Vector Body + BasicBlock *VecBody = LoopVectorBody.back(); + Value *ValPH = PN->getIncomingValueForBlock(LoopVectorPreHeader); + Value *ValVB = PN->getIncomingValueForBlock(VecBody); + + // Create Tail PHI + return createTailPhiFromValues(ValPH, ValVB, Name); +} + +PHINode *SearchLoopVectorizer::createTailPhiFromValues( + Value *ValPH, Value *ValVB, const Twine &Name) { + BasicBlock *VecBody = LoopVectorBody.back(); + + // Store the Insertion Point + auto IP = Builder.saveIP(); + Builder.SetInsertPoint(LoopVectorTail->getFirstNonPHI()); + + // Create the new PHI in vector tail + auto *Res = Builder.CreatePHI(ValPH->getType(), 2, Name); + Res->addIncoming(ValPH, LoopVectorPreHeader); + Res->addIncoming(ValVB, VecBody); + + // Restore Insertion Point + Builder.restoreIP(IP); + + return Res; +} + +// Returns whether V is an Induction PHI or an Induction PHI update value. +bool SearchLoopVectorizer::isIndvarPHIOrUpdate(Value *V, + InductionDescriptor &II, bool &IsUpdate) { + if (!isa(V)) + return false; + + if (auto *PHI = dyn_cast(V)) { + IsUpdate = false; + II = Legal->getInductionVars()->lookup(PHI); + return II.getKind() != InductionDescriptor::IK_NoInduction; + } + + // Check if this is an Indvar update by looking if any of its users + // is an induction PHI and the result comes from this loop. + auto *VI = cast(V); + for (Value *User : VI->users()) { + auto *PHI = dyn_cast(User); + if (!PHI || !OrigLoop->contains(PHI)) + continue; + + II = Legal->getInductionVars()->lookup(PHI); + if (II.getKind() != InductionDescriptor::IK_NoInduction) { + IsUpdate = true; + return true; + } + } + + return false; +} + +// Returns whether all values are the same and are either an Induction PHI +// or an Induction PHI update value. +bool SearchLoopVectorizer::allIndvarPHIsOrAllUpdates(Escapee *E, + InductionDescriptor &II, bool &IsUpdate) { + Value *PrevVal = nullptr; + for (auto *V : E->getValues()) { + if (PrevVal != nullptr && PrevVal != V) + return false; + + if (!isIndvarPHIOrUpdate(V, II, IsUpdate)) + return false; + + PrevVal = V; + } + + return true; +} + +// TODO: We previously bailed out here if we weren't an uncounted +// loop; however, twolf had a loop which should be vectorizeable +// with the regular vectorizer, only the induction variable was +// stored outside the loop. It should be possible to tell that +// the indvar must be the max value and store that. For now, +// though, we assume that any loops being vectorized by the slv +// have something interesting in them and therefore require a tail. +void SearchLoopVectorizer::insertVectorTail(LoopBlocksDFS &DFS) { + // Create a PHI for the final predicate + auto *Pg = createTailPhiFromValues(PreHeaderOutPred, VecBodyOutPred, + "vtail.pg"); + + // Create a PHI for the Induction variable (note that other induction + // variables use the scalar induction var to create a seriesvector) + Induction = createTailPhiFromPhi(cast(Induction), "vtail.ind"); + + // Create PHIs for each induction escapee's 'last' value from the + // vector body. For induction variables, it can be 'undef' from vector + // preheader, since we're guaranteed at least one iteration will be + // executed, this is only for case that this iteration contains + // no active lanes. For reductions, the value has been created + // above with a 'select' with its identity value. + for (auto E : Legal->getEF()->getEscapees()) { + VectorParts &Parts = VTailWMap.get(E.first); + + if (E.second->isReduction()) { + Value *LastVal = E.second->getValue(E.second->getNumValues()-1); + + // Get the PHI node that describes this reduction, + // could be obtained by finding the PHI that uses LastVal + PHINode *RdxPhi = nullptr; + PHINode *PHLastVal = dyn_cast(LastVal); + + // The last value can already be the reduction PHI, otherwise + // search the users + if (PHLastVal && Legal->getReductionVars()->count(PHLastVal)) + RdxPhi = PHLastVal; + else { + for (auto *U : LastVal->users()) { + auto *RdxPhi2 = dyn_cast(U); + if (!RdxPhi2 || !Legal->getReductionVars()->count(RdxPhi2)) + continue; + RdxPhi = RdxPhi2; + } + } + + assert (RdxPhi && "Escapee not a valid Reduction Escapee"); + + // From the (vectorised) PHI, get the Incoming edge from PreHeader + PHINode *VRdxPhi = cast(getVectorValue(RdxPhi, BodyWMap)[0]); + Value *ValPH = VRdxPhi->getIncomingValueForBlock(LoopVectorPreHeader); + + // The Last Value will be the one from the sorted list + Value *VLastVal = getVectorValue(LastVal, BodyWMap)[0]; + + // Create a PHI node in the vector tail + Value *VPHi = createTailPhiFromValues(ValPH, VLastVal, "rdx.vtail"); + VTailWMap.get(RdxPhi)[0] = Parts[0] = VPHi; + + continue; + } + + // Inductions need to start with the 'last' value from previous + // full iteration. + Value *LastVal = E.second->getValue(E.second->getNumValues()-1); + + // If LastVal is part of the condition, we can reuse it in + // this iteration. + if (ConditionNodes.count(LastVal)) { + VectorParts &Parts2 = VTailWMap.get(LastVal); + Value *A = getVectorValue(LastVal, PHMap)[0]; + Value *B = getVectorValue(LastVal, NextBodyWMap)[0]; + Parts2[0] = createTailPhiFromValues(A, B, "reuse"); + } + + LastVal = getVectorValue(LastVal, BodyWMap)[0]; + Parts[0] = UndefValue::get(LastVal->getType()); + } + + // Save IP + auto IP = Builder.saveIP(); + Builder.SetInsertPoint(LoopVectorTail->getFirstNonPHI()); + + // Find predicates for reductions and indvar extraction + GreenLanes = getExclusivePartition(Pg); + YellowLanes = getInclusivePartition(Pg); + + // We start with Yellow Lanes... + Predicate = YellowLanes; + + // TODO: I think we want to merge this with the extraction below, but not all + // indvars are escapees -- can we just iterate over all and use additional + // logic for escapees? + // TODO: Move to scalar preheader? + Value *FFIncrVal = Builder.CreateCntVPop(GreenLanes, "speculative.increment"); + // TODO: ZExtOrTrunc when we use an unsigned indvar? + FFIncrVal = Builder.CreateSExtOrTrunc(FFIncrVal, Induction->getType()); + Value *FFInductionVal = Builder.CreateAdd(Induction, FFIncrVal, + "speculative.indval"); + + // Provide the correct primary indvar value to the scalar loop in case we had + // a pseudo-fault we need to recover from. + SLVLoopVectorizationLegality::InductionList::iterator I, E; + // TODO: Convenience methods... + // for (auto &IndVar : Legal->getInductionVars()) + SLVLoopVectorizationLegality::InductionList * List = Legal->getInductionVars(); + for (I = List->begin(), E = List->end(); I != E; ++I) { + PHINode *OrigPhi = I->first; + InductionDescriptor II = I->second; + + PHINode *ResumeVal = PHINode::Create(OrigPhi->getType(), 3, + "speculative.resume.val", + LoopScalarPreHeader->getFirstNonPHI()); + + Value *EndVal; + // TODO: Share with bcresumeval logic from createEmptyLoopWithPredication? + // Replace, maybe, since we don't form this with known exit iterations. + if (OrigPhi == OldInduction) { + EndVal = FFInductionVal; + } else { + EndVal = Builder.CreateSExtOrTrunc(FFInductionVal, + II.getStep()->getType(), + "speculative.cast.endval"); + const DataLayout &DL = OrigPhi->getModule()->getDataLayout(); + EndVal = II.transform(Builder, EndVal, PSE.getSE(), DL); + EndVal->setName("speculative.ind.end"); + } + + ResumeVal->addIncoming(EndVal, ReductionLoopRet); + + for (auto *BB : LoopBypassBlocks) + for (auto *Succ : successors(BB)) + if (Succ == LoopScalarPreHeader) + ResumeVal->addIncoming(II.getStartValue(), BB); + + auto BIdx = OrigPhi->getBasicBlockIndex(LoopScalarPreHeader); + OrigPhi->setIncomingValue(BIdx, ResumeVal); + } + + + // TODO: How to block vectorization of things which shouldn't be in the tail? + // Comparing map pointers is ugly; consider a single map pointer but switching + // it at various points, and using a phase variable to determine how to handle + // things? + // Generate the vector tail instructions + PhiVector Dummy; + MaskCache.clear(); + for (LoopBlocksDFS::RPOIterator bb = DFS.beginRPO(), + be = DFS.endRPO(); bb != be; ++bb) + vectorizeBlockInLoop(*bb, &Dummy, VTailWMap); + MaskCache.clear(); + + Value *GrnCount = Builder.CreateCntVPop(GreenLanes, "grn.cnt"); + + for (auto E : Legal->getEF()->getEscapees()) { + // Reductions are handled separately + if (E.second->isReduction()) + continue; + + auto Phi = dyn_cast(E.first); + + // Either extract an element from the vector value itself, + // or optimise the extraction of purely induction vars + // by calculating it as: + // II.transform(#scalar_induction + #active_grn_lanes) + Value *Res; + bool IsUpdate = false; + InductionDescriptor II; + if (allIndvarPHIsOrAllUpdates(E.second, II, IsUpdate)) { + Type *StartTy = II.getStartValue()->getType(); + Value *One = ConstantInt::get(GrnCount->getType(), 1); + + // Number of active lanes (possibly +1 if its an update) + Value *EltCnt = IsUpdate ? Builder.CreateAdd(GrnCount, One) : GrnCount; + EltCnt = Builder.CreateZExtOrTrunc(EltCnt, StartTy); + Value *Ind = (II.getKind() == InductionDescriptor::IK_IntInduction) ? + Builder.CreateZExtOrTrunc(Induction, StartTy) : Induction; + + // Add scalar induction var + handled lanes + Res = Builder.CreateAdd(Ind, EltCnt); + + const DataLayout &DL = Phi->getModule()->getDataLayout(); + // Transform into induction variable + Res = II.transform(Builder, Res, PSE.getSE(), DL); + + // Possibly trunc + if (Phi->getType()->isIntegerTy()) + Res = Builder.CreateZExtOrTrunc(Res, Phi->getType()); + } else { + Value *VPhi = getVectorValue(Phi, VTailWMap)[0]; + Res = Builder.CreateExtractElement(VPhi, GrnCount); + } + + Phi->addIncoming(Res, ReductionLoopRet); + } + + ////////////////////////////////////////////////////////////////////////////// + // Speculative Recovery Branch // + ////////////////////////////////////////////////////////////////////////////// + // + // There are three possibilities for loop state in the vector tail: + // + // 1. All work completed in the vector body + // + // 2. Some work remaining, but we have reached a valid exit condition + // + // 3. Some work remaining, but a speculative load faulted before an + // exit condition was reached. + // + // For option 1, Pg will be all false, and we can just test for that when + // determining whether we should skip the scalar loop. + // + // Options 2 and 3 require a bit more work. We have a phi which provides + // us with the active speculative lanes from either the preheader or the + // body, and we have the partitioned general predicate Pg for the tail. + // + // If we negate Pg, the first active lane is the one which exited. 'And'ing + // that with the active speculative lanes tells us whether or not the exiting + // lane was active or if the speculative load faulted before we reached any + // valid exit condition. If all lanes are false there was no overlap, so + // recovery via the scalar loop is required. + // + // + // We expect hardware implementations to provide best-effort speculative loads + // so any fault should be a real fault instead of a pseudo fault, and if an + // exit condition was not reached then the program *should* fault at that + // address. If the hardware implementation doesn't implement all boundary + // cases, then it can just recover slowly via the scalar loop. + Value *FaultCond = nullptr; + if (HasSpeculativeLoads) { + Builder.SetInsertPoint(LoopVectorTail->getTerminator()); + Value *InactiveLanes = Builder.CreateNot(Pg, "inactive.lanes"); + Value *OverlappedLanes = Builder.CreateAnd(InactiveLanes, + TailSpeculativeLanes, + "speculative.overlapped"); + Value *CondBeforeFault = + getAnyTrueReduction(Builder, OverlappedLanes, "cond.before.fault"); + Value *NoActive = getAllFalseReduction(Builder, Pg, "no.active.lanes"); + FaultCond = Builder.CreateOr(CondBeforeFault, NoActive, "skip.scalar.cond"); + } else + // TODO: Remove branch modifications for non-speculative case once + // the DomTree is sorted out. + FaultCond = ConstantInt::getTrue(Builder.getInt1Ty()); + + BranchInst *FallbackBr = cast(ReductionLoopRet->getTerminator()); + FallbackBr->setCondition(FaultCond); + + // Restore Insertion point + Builder.restoreIP(IP); +} + + +void SearchLoopVectorizer::vectorizeLoop() { + // Only print 'function-name (loop-bb-name)' + LLVM_DEBUG(dbgs() << "SLV: Transforming " << + OrigLoop->getHeader()->getParent()->getName() << + "\t(" << OrigLoop->getHeader()->getName() << ")" << "\n"); + + //===------------------------------------------------===// + // + // Notice: any optimization or new instruction that go + // into the code below should be also be implemented in + // the cost-model. + // + //===------------------------------------------------===// + Constant *Zero = Builder.getInt32(0); + + // In order to support reduction variables we need to be able to vectorize + // Phi nodes. Phi nodes have cycles, so we need to vectorize them in two + // stages. First, we create a new vector PHI node with no incoming edges. + // We use this value when we vectorize all of the instructions that use the + // PHI. Next, after all of the instructions in the block are complete we + // add the new incoming edges to the PHI. At this point all of the + // instructions in the basic block are vectorized, so we can use them to + // construct the PHI. + PhiVector RdxPHIsToFix; + PhiVector Dummy; + + // Scan the loop in a topological order to ensure that defs are vectorized + // before users. + LoopBlocksDFS DFS(OrigLoop); + DFS.perform(LI); + + // Create PHIs for each condition subexpression that needs to become a PHI + auto IP = Builder.saveIP(); + for (auto *V : MakeTheseIntoPHIs) { + // Set insert point to preheader, if the value is constant, + // it should be expanded in the preheader. + Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); + Value *PHValue = PHMap.get(V)[0]; + + // Create two PHIs, one in body... + Builder.SetInsertPoint(LoopVectorBody.front()->getFirstNonPHI()); + PHINode *BodyPhi = Builder.CreatePHI(PHValue->getType(), 2); + BodyPhi->addIncoming(PHValue, LoopVectorPreHeader); + BodyWMap.get(V)[0] = BodyPhi; + + // ... and one in vtail + Builder.SetInsertPoint(LoopVectorTail->getFirstNonPHI()); + PHINode *TailPhi = Builder.CreatePHI(PHValue->getType(), 2); + TailPhi->addIncoming(PHValue, LoopVectorPreHeader); + VTailWMap.get(V)[0] = TailPhi; + } + Builder.restoreIP(IP); + + // Vectorize all of the blocks in the original loop. + Builder.SetInsertPoint(&*(LoopVectorBody[0]->getFirstInsertionPt())); + for (LoopBlocksDFS::RPOIterator bb = DFS.beginRPO(), + be = DFS.endRPO(); bb != be; ++bb) + vectorizeBlockInLoop(*bb, &RdxPHIsToFix, BodyWMap); + MaskCache.clear(); + + // vectorize exit by creating the next predicate, + // next induction value and a test for all conditions. + vectorizeExits(); + + // Patch up PHIs from vector body + BasicBlock *LastBB = LoopVectorBody.back(); + for (auto *V : MakeTheseIntoPHIs) { + Value *PHValue = NextBodyWMap.get(V)[0]; + + // Create two PHIs, one in body, one in vtail + auto *BodyPhi = dyn_cast(BodyWMap.get(V)[0]); + BodyPhi->addIncoming(PHValue, LastBB); + + auto *TailPhi = dyn_cast(VTailWMap.get(V)[0]); + TailPhi->addIncoming(PHValue, LastBB); + } + + // Find the reduction identity variable. Zero for addition, or, xor, + // one for multiplication, -1 for And. + for (auto &Reduction : *Legal->getReductionVars()) { + auto IP = Builder.saveIP(); + Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); + + PHINode *PN = Reduction.first; + RecurrenceDescriptor RdxDesc = Reduction.second; + RecurrenceDescriptor::RecurrenceKind RK = RdxDesc.getRecurrenceKind(); + TrackingVH ReductionStartValue = RdxDesc.getRecurrenceStartValue(); + + // Reductions do not have to start at zero. They can start with + // any loop invariant values. + VectorParts &VecRdxPhi = BodyWMap.get(PN); + BasicBlock *Latch = OrigLoop->getLoopLatch(); + Value *LoopVal = PN->getIncomingValueForBlock(Latch); + VectorParts &Val = getVectorValue(LoopVal, BodyWMap); + + // Create the VectorStart value + Value *Identity, *VectorStart; + if (RK == RecurrenceDescriptor::RK_IntegerMinMax || + RK == RecurrenceDescriptor::RK_FloatMinMax || + RK == RecurrenceDescriptor::RK_ConstSelectICmp || + RK == RecurrenceDescriptor::RK_ConstSelectFCmp) { + // MinMax reduction have the start value as their identify. + if (VF == 1) { + VectorStart = Identity = ReductionStartValue; + } else { + const char *Ident = (RK == RecurrenceDescriptor::RK_ConstSelectICmp || + RK == RecurrenceDescriptor::RK_ConstSelectFCmp) ? + "intcond.ident" : "minmax.ident"; + VectorStart = Identity = + Builder.CreateVectorSplat({VF, Scalable}, + RdxDesc.getRecurrenceStartValue(), + Ident); + } + } else { + // Handle other reduction kinds: + Constant *Iden = RecurrenceDescriptor::getRecurrenceIdentity( + RK, VecRdxPhi[0]->getType()->getScalarType()); + if (VF == 1) { + Identity = Iden; + // This vector is the Identity vector where the first element is the + // incoming scalar reduction. + VectorStart = ReductionStartValue; + } else { + Identity = ConstantVector::getSplat({VF, Scalable}, Iden); + + // This vector is the Identity vector where the first element is the + // incoming scalar reduction. + VectorStart = + Builder.CreateInsertElement(Identity, ReductionStartValue, Zero); + } + } + + // Fix the vector-loop phi. + + // Only add the reduction start value to the first unroll part. + cast(VecRdxPhi[0])->addIncoming(VectorStart, + LoopVectorPreHeader); + cast(VecRdxPhi[0])->addIncoming(Val[0], + LoopVectorBody.back()); + + Builder.restoreIP(IP); + } + + // Insert the vector tail and all PHI node edges from preheader/vector body + insertVectorTail(DFS); + + // Insert truncates and extends for any truncated instructions as hints to + // InstCombine. + if (VF > 1) { + //FIXME: Enabling this gives a segv with bzip2. + //truncateToMinimalBitwidths(PHMap); + //truncateToMinimalBitwidths(BodyWMap); + //truncateToMinimalBitwidths(NextBodyWMap); + //truncateToMinimalBitwidths(VTailWMap); + } + // At this point every instruction in the original loop is widened to + // a vector form. We are almost done. Now, we need to fix the PHI nodes + // that we vectorized. The PHI nodes are currently empty because we did + // not want to introduce cycles. Notice that the remaining PHI nodes + // that we need to fix are reduction variables. + + // Create the 'reduced' values for each of the induction vars. + // The reduced values are the vector values that we scalarize and combine + // after the loop is finished. + for (auto *RdxPhi : RdxPHIsToFix) { + assert(RdxPhi && "Unable to recover vectorized PHI"); + + // Find the reduction variable descriptor. + assert(Legal->getReductionVars()->count(RdxPhi) && + "Unable to find the reduction variable"); + RecurrenceDescriptor RdxDesc = (*Legal->getReductionVars())[RdxPhi]; + + RecurrenceDescriptor::RecurrenceKind RK = RdxDesc.getRecurrenceKind(); + TrackingVH ReductionStartValue = RdxDesc.getRecurrenceStartValue(); + // TODO: Actually make this function aware of multiple exit vals. + Instruction *LoopExitInst = RdxDesc.getLoopExitInstrs()->back(); + RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind = + RdxDesc.getMinMaxRecurrenceKind(); + setDebugLocFromInst(Builder, ReductionStartValue); + + // We need to generate a reduction vector from the incoming scalar. + // To do so, we need to generate the 'identity' vector and override + // one of the elements with the incoming scalar reduction. We need + // to do it in the vector-loop preheader. + Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); + + // This is the vector-clone of the value that leaves the loop. + VectorParts &VectorExit = getVectorValue(LoopExitInst, BodyWMap); + Type *VecTy = VectorExit[0]->getType(); + + // Before each round, move the insertion point right between + // the PHIs and the values we are going to write. + // This allows us to write both PHINodes and the extractelement + // instructions. + Builder.SetInsertPoint(&*LoopVectorTail->getTerminator()); + + // Get the last reduction value from the vector tail, rather + // than the vector body. + VectorParts RdxParts = getVectorValue(LoopExitInst, VTailWMap); + setDebugLocFromInst(Builder, LoopExitInst); + + // If the vector reduction can be performed in a smaller type, we truncate + // then extend the loop exit value to enable InstCombine to evaluate the + // entire expression in the smaller type. + if (VF > 1 && RdxPhi->getType() != RdxDesc.getRecurrenceType()) { + Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF, + Scalable); + Builder.SetInsertPoint(LoopVectorBody.back()->getTerminator()); + unsigned part = 0; + { + Value *Trunc = Builder.CreateTrunc(RdxParts[part], RdxVecTy); + Value *Extnd = RdxDesc.isSigned() ? Builder.CreateSExt(Trunc, VecTy) + : Builder.CreateZExt(Trunc, VecTy); + for (Value::user_iterator UI = RdxParts[part]->user_begin(); + UI != RdxParts[part]->user_end();) + if (*UI != Trunc) { + (*UI++)->replaceUsesOfWith(RdxParts[part], Extnd); + RdxParts[part] = Extnd; + } else { + ++UI; + } + } + Builder.SetInsertPoint(&*ReductionLoop->getFirstInsertionPt()); + RdxParts[0] = Builder.CreateTrunc(RdxParts[0], RdxVecTy); + } + + // Reduce all of the unrolled parts into a single vector. + Value *ReducedPartRdx = RdxParts[0]; + unsigned Op = RecurrenceDescriptor::getRecurrenceBinOp(RK); + setDebugLocFromInst(Builder, ReducedPartRdx); + /* TODO: Remove? Only deals with unrolled code, which we won't use... + for (unsigned part = 1; part < UF; ++part) { + if (Op != Instruction::ICmp && Op != Instruction::FCmp) + // Floating point operations had to be 'fast' to enable the reduction. + ReducedPartRdx = addFastMathFlag( + Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxParts[part], + ReducedPartRdx, "bin.rdx")); + else + ReducedPartRdx = RecurrenceDescriptor::createMinMaxOp( + Builder, MinMaxKind, ReducedPartRdx, RdxParts[part]); + } + */ + + if ((VF > 1) && !isScalable()) { + // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles + // and vector ops, reducing the set of values being computed by half each + // round. + assert(isPowerOf2_32(VF) && + "Reduction emission only supported for pow2 vectors!"); + Value *TmpVec = ReducedPartRdx; + SmallVector ShuffleMask(VF, nullptr); + for (unsigned i = VF; i != 1; i >>= 1) { + // Move the upper half of the vector to the lower half. + for (unsigned j = 0; j != i/2; ++j) + ShuffleMask[j] = Builder.getInt32(i/2 + j); + + // Fill the rest of the mask with undef. + std::fill(&ShuffleMask[i/2], ShuffleMask.end(), + UndefValue::get(Builder.getInt32Ty())); + + Value *Shuf = + Builder.CreateShuffleVector(TmpVec, + UndefValue::get(TmpVec->getType()), + ConstantVector::get(ShuffleMask), + "rdx.shuf"); + + if (Op != Instruction::ICmp && Op != Instruction::FCmp) { + // Floating point operations had to be 'fast' to enable the reduction. + TmpVec = addFastMathFlag(Builder.CreateBinOp( + (Instruction::BinaryOps)Op, TmpVec, Shuf, "bin.rdx")); + } + else + TmpVec = RecurrenceDescriptor::createMinMaxOp(Builder, MinMaxKind, + TmpVec, Shuf); + } + + // The result is in the first element of the vector. + ReducedPartRdx = Builder.CreateExtractElement(TmpVec, + Builder.getInt32(0)); + } + + // Compute vector reduction for scalable vectors + if ((VF > 1) && isScalable()) { + bool NoNaN = Legal->hasNoNaNAttr(); + + // Before scalarizing, check if the target has an intrinsic to do the job. + if (!DisableReductionIntrinsics) { + ReducedPartRdx = + createTargetReduction(Builder, TTI, RdxDesc, ReducedPartRdx, NoNaN); + assert(ReducedPartRdx != nullptr); + } else { + Constant *Zero = Builder.getInt32(0); + Value *StartAcc = Builder.CreateExtractElement(ReducedPartRdx, Zero); + Value *RuntimeVF = getRuntimeVF(Builder.getInt32Ty()); + + // ********************************************************************* + // Start of reduction loop + // ********************************************************************* + + Builder.SetInsertPoint(ReductionLoop->getFirstNonPHI()); + PHINode *Idx = Builder.CreatePHI(Builder.getInt32Ty(), 2, "rdx.idx"); + PHINode *Acc = Builder.CreatePHI(StartAcc->getType(), 2, "rdx.acc"); + + // ReducedPartRdx[Idx] + Value *Lane = Builder.CreateExtractElement(ReducedPartRdx, Idx); + + // Acc = Acc ReducedPartRdx[Idx] + Value *NextAcc; + if (Op != Instruction::ICmp && Op != Instruction::FCmp) + NextAcc = Builder.CreateBinOp((Instruction::BinaryOps)Op, Acc, Lane); + else + NextAcc = RdxDesc.createMinMaxOp(Builder, + RdxDesc.getMinMaxRecurrenceKind(), + Acc, Lane); + + Acc->addIncoming(StartAcc, LoopVectorTail); + Acc->addIncoming(NextAcc, ReductionLoop); + + Value *NextIdx = Builder.CreateAdd(Idx, Builder.getInt32(1)); + Idx->addIncoming(Builder.getInt32(1), LoopVectorTail); + Idx->addIncoming(NextIdx, ReductionLoop); + + Value *LoopCond = Builder.CreateICmpULT(NextIdx, RuntimeVF); + Instruction *OldTerm = ReductionLoop->getTerminator(); + BranchInst::Create(ReductionLoop, ReductionLoopRet, LoopCond, OldTerm); + OldTerm->eraseFromParent(); + + // ********************************************************************* + // End of reduction loop + // ********************************************************************* + + ReducedPartRdx = NextAcc; + } + } + + + if (VF > 1) { + // If the reduction can be performed in a smaller type, we need to extend + // the reduction to the wider type before we branch to the original loop. + if (RdxPhi->getType() != RdxDesc.getRecurrenceType()) + ReducedPartRdx = + RdxDesc.isSigned() + ? Builder.CreateSExt(ReducedPartRdx, RdxPhi->getType()) + : Builder.CreateZExt(ReducedPartRdx, RdxPhi->getType()); + } + + // Create a phi node that merges control-flow from the backedge-taken check + // block and the middle block. + PHINode *BCBlockPhi = PHINode::Create(RdxPhi->getType(), 2, "bc.merge.rdx", + LoopScalarPreHeader->getTerminator()); + BCBlockPhi->addIncoming(ReducedPartRdx, ReductionLoopRet); + for (auto *BB : LoopBypassBlocks) + for (auto *Succ : successors(BB)) + if (Succ == LoopScalarPreHeader) + BCBlockPhi->addIncoming(ReductionStartValue, BB); + + // If there were stores of the reduction value to a uniform memory address + // inside the loop, create the final store here. + if (StoreInst *SI = RdxDesc.IntermediateStore) { + StoreInst *NewSI = Builder.CreateStore(ReducedPartRdx, + SI->getPointerOperand()); + propagateMetadata(NewSI, SI); + // If the reduction value is used in other places, + // then let the code below create PHI's for that. + } + + // Now, we need to fix the users of the reduction variable + // inside and outside of the scalar remainder loop. + // We know that the loop is in LCSSA form. We need to update the + // PHI nodes in the exit blocks. + for (BasicBlock::iterator LEI = LoopExitBlock->begin(), + LEE = LoopExitBlock->end(); LEI != LEE; ++LEI) { + PHINode *LCSSAPhi = dyn_cast(LEI); + if (!LCSSAPhi) break; + + // All PHINodes need to have a single entry edge, or two if + // we already fixed them. + assert(LCSSAPhi->getNumIncomingValues() < 3 && "Invalid LCSSA PHI"); + + // We found our reduction value exit-PHI. Update it with the + // incoming bypass edge. + for (Value *Incoming : LCSSAPhi->incoming_values()) { + if (Incoming == LoopExitInst) { + // Add an edge coming from the bypass. + LCSSAPhi->addIncoming(ReducedPartRdx, ReductionLoopRet); + break; + } + } + }// end of the LCSSA phi scan. + + // Fix the scalar loop reduction variable with the incoming reduction sum + // from the vector body and from the backedge value. + int IncomingEdgeBlockIdx = + RdxPhi->getBasicBlockIndex(OrigLoop->getLoopLatch()); + assert(IncomingEdgeBlockIdx >= 0 && "Invalid block index"); + // Pick the other block. + int SelfEdgeBlockIdx = (IncomingEdgeBlockIdx ? 0 : 1); + RdxPhi->setIncomingValue(SelfEdgeBlockIdx, BCBlockPhi); + RdxPhi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst); + }// end of for each redux variable. + + fixLCSSAPHIs(); + + // Make sure DomTree is updated. + updateAnalysis(); + + // Predicate any stores. + for (auto KV : PredicatedStores) { + BasicBlock::iterator I(KV.first); + auto *BB = SplitBlock(I->getParent(), &*std::next(I), DT, LI); + auto *T = SplitBlockAndInsertIfThen(KV.second, &*I, /*Unreachable=*/false, + /*BranchWeights=*/nullptr, DT); + I->moveBefore(T); + I->getParent()->setName("pred.store.if"); + BB->setName("pred.store.continue"); + } + LLVM_DEBUG(DT->verify()); + // Remove redundant induction instructions. + CSE(LoopVectorBody, PredicatedBlocks); + LLVM_DEBUG(verifyFunction(*(ReductionLoop->getParent()), &dbgs())); +} + +void SearchLoopVectorizer::fixLCSSAPHIs() { + for (BasicBlock::iterator LEI = LoopExitBlock->begin(), + LEE = LoopExitBlock->end(); LEI != LEE; ++LEI) { + PHINode *LCSSAPhi = dyn_cast(LEI); + if (!LCSSAPhi) break; + if (LCSSAPhi->getNumIncomingValues() == 1) + LCSSAPhi->addIncoming(UndefValue::get(LCSSAPhi->getType()), + ReductionLoopRet); + } +} + +SearchLoopVectorizer::VectorParts +SearchLoopVectorizer::createEdgeMask(BasicBlock *Src, BasicBlock *Dst, + ValueMap &WidenMap) { + assert(std::find(pred_begin(Dst), pred_end(Dst), Src) != pred_end(Dst) && + "Invalid edge"); + + // Look for cached value. + std::pair Edge(Src, Dst); + EdgeMaskCache::iterator ECEntryIt = MaskCache.find(Edge); + if (ECEntryIt != MaskCache.end()) + return ECEntryIt->second; + + VectorParts SrcMask = createBlockInMask(Src, WidenMap); + + // The terminator has to be a branch inst! + BranchInst *BI = dyn_cast(Src->getTerminator()); + assert(BI && "Unexpected terminator found"); + + if (BI->isConditional()) { + VectorParts EdgeMask = getVectorValue(BI->getCondition(), WidenMap); + + if (BI->getSuccessor(0) != Dst) + EdgeMask[0] = Builder.CreateNot(EdgeMask[0]); + + EdgeMask[0] = Builder.CreateAnd(EdgeMask[0], SrcMask[0]); + + MaskCache[Edge] = EdgeMask; + return EdgeMask; + } + + MaskCache[Edge] = SrcMask; + return SrcMask; +} + +SearchLoopVectorizer::VectorParts +SearchLoopVectorizer::createBlockInMask(BasicBlock *BB, ValueMap &WidenMap) { + assert(OrigLoop->contains(BB) && "Block is not a part of a loop"); + + // Loop incoming mask is all-one. + if (OrigLoop->getHeader() == BB) { + Value *C = ConstantInt::get(IntegerType::getInt1Ty(BB->getContext()), 1); + return getVectorValue(C, WidenMap); + } + + // This is the block mask. We OR all incoming edges, and with zero. + Value *Zero = ConstantInt::get(IntegerType::getInt1Ty(BB->getContext()), 0); + VectorParts BlockMask = getVectorValue(Zero, WidenMap); + + // For each pred: + for (pred_iterator it = pred_begin(BB), e = pred_end(BB); it != e; ++it) { + VectorParts EM = createEdgeMask(*it, BB, WidenMap); + BlockMask[0] = Builder.CreateOr(BlockMask[0], EM[0]); + } + + return BlockMask; +} + +// TODO: No need to pass 'Entry'? +void SearchLoopVectorizer::widenPHIInstruction(Instruction *PN, + SearchLoopVectorizer::VectorParts &Entry, + unsigned VF, PhiVector *PV, + ValueMap &WidenMap) { + PHINode* P = cast(PN); + // Handle reduction variables: + if (Legal->getReductionVars()->count(P)) { + // This is phase one of vectorizing PHIs. + Type *VecTy = (VF == 1) ? PN->getType() : + VectorType::get(PN->getType(), VF, Scalable); + Entry[0] = PHINode::Create( + VecTy, 2, "vec.phi", &*LoopVectorBody.back()->getFirstInsertionPt()); + PV->push_back(P); + return; + } + + setDebugLocFromInst(Builder, P); + // Check for PHI nodes that are lowered to vector selects. + if (P->getParent() != OrigLoop->getHeader()) { + // We know that all PHIs in non-header blocks are converted into + // selects, so we don't have to worry about the insertion order and we + // can just use the builder. + // At this point we generate the predication tree. There may be + // duplications since this is a simple recursive scan, but future + // optimizations will clean it up. + + unsigned NumIncoming = P->getNumIncomingValues(); + + // Generate a sequence of selects of the form: + // SELECT(Mask3, In3, + // SELECT(Mask2, In2, + // ( ...))) + for (unsigned In = 0; In < NumIncoming; In++) { + VectorParts Cond = createEdgeMask(P->getIncomingBlock(In), + P->getParent(), WidenMap); + VectorParts &In0 = getVectorValue(P->getIncomingValue(In), WidenMap); + + unsigned part = 0; + { + // We might have single edge PHIs (blocks) - use an identity + // 'select' for the first PHI operand. + if (In == 0) + Entry[part] = Builder.CreateSelect(Cond[part], In0[part], + In0[part]); + else + // Select between the current value and the previous incoming edge + // based on the incoming mask. + Entry[part] = Builder.CreateSelect(Cond[part], In0[part], + Entry[part], "predphi"); + } + } + return; + } + + // This PHINode must be an induction variable. + // Make sure that we know about it. + assert(Legal->getInductionVars()->count(P) && + "Not an induction variable"); + + InductionDescriptor II = Legal->getInductionVars()->lookup(P); + + // FIXME: The newly created binary instructions should contain nsw/nuw flags, + // which can be found from the original scalar operations. + switch (II.getKind()) { + case InductionDescriptor::IK_NoInduction: + case InductionDescriptor::IK_FpInduction: + llvm_unreachable("Unknown induction"); + case InductionDescriptor::IK_IntInduction: { + Type *PhiTy = P->getType(); + assert(PhiTy == II.getStartValue()->getType() && "Types must match"); + // Handle other induction variables that are now based on the + // canonical one. + Value *V = Induction; + if (P != OldInduction) { + V = Builder.CreateSExtOrTrunc(Induction, PhiTy); + const DataLayout &DL = P->getModule()->getDataLayout(); + V = II.transform(Builder, V, PSE.getSE(), DL); + V->setName("offset.idx"); + } + Value *Broadcasted = getBroadcastInstrs(V); + Value *RuntimeVF = getRuntimeVF(PhiTy); + // After broadcasting the induction variable we need to make the vector + // consecutive by adding 0, 1, 2, etc. + Value *Part = ConstantInt::get(PhiTy, 0); + Value *StartIdx = Builder.CreateMul(RuntimeVF, Part); + Entry[0] = getStepVector(Broadcasted, StartIdx, II.getStep()); + return; + } + case InductionDescriptor::IK_PtrInduction: { + // Handle the pointer induction variable case. + assert(P->getType()->isPointerTy() && "Unexpected type."); + // This is the normalized GEP that starts counting at zero. + Value *PtrInd = Induction; + PtrInd = Builder.CreateSExtOrTrunc(PtrInd, II.getStep()->getType()); + + if (!isScalable()) { + // This is the vector of results. Notice that we don't generate + // vector geps because scalar geps result in better code. + unsigned part = 0; + { + /* + if (VF == 1) { + int EltIndex = part; + Constant *Idx = ConstantInt::get(PtrInd->getType(),EltIndex); + Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); + Value *SclrGep = II.transform(Builder, GlobalIdx); + SclrGep->setName("next.gep"); + Entry[part] = SclrGep; + continue; + } + */ + + Value *VecVal = UndefValue::get(VectorType::get(P->getType(), VF)); + for (unsigned int i = 0; i < VF; ++i) { + int EltIndex = i + part * VF; + Constant *Idx = ConstantInt::get(PtrInd->getType(),EltIndex); + Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); + const DataLayout &DL = P->getModule()->getDataLayout(); + Value *SclrGep = II.transform(Builder, GlobalIdx, PSE.getSE(), DL); + SclrGep->setName("next.gep"); + VecVal = Builder.CreateInsertElement(VecVal, SclrGep, + Builder.getInt32(i), + "insert.gep"); + } + Entry[part] = VecVal; + } + } else { + Type *PhiTy = PtrInd->getType(); + Value *RuntimeVF = getRuntimeVF(PhiTy); + + Value *StepValue; + ScalarEvolution *SE = PSE.getSE(); + const DataLayout &DL = PN->getModule()->getDataLayout(); + SCEVExpander Expander(*SE, DL, "seriesgep"); + if (Legal->getInductionVars()->count(P)) { + const SCEV *Step = Legal->getInductionVars()->lookup(P).getStep(); + StepValue = Expander.expandCodeFor(Step, Step->getType(), + &*Builder.GetInsertPoint()); + } else { + auto *SAR = dyn_cast(PSE.getSE()->getSCEV(PN)); + assert(SAR && SAR->isAffine() && "Pointer induction not loop affine"); + + // Create SCEV expander for Start- and StepValue + const DataLayout &DL = PN->getModule()->getDataLayout(); + SCEVExpander Expander(*PSE.getSE(), DL, "seriesgep"); + + // Expand step and start value (the latter in preheader) + const SCEV *StepRec = SAR->getStepRecurrence(*PSE.getSE()); + StepValue = Expander.expandCodeFor(StepRec, StepRec->getType(), + &*Builder.GetInsertPoint()); + // Normalize step to be in #elements, not bytes + Type *ElemTy = PN->getType()->getPointerElementType(); + Value *Tmp = ConstantInt::get(StepValue->getType(), + DL.getTypeAllocSize(ElemTy)); + StepValue = Builder.CreateSDiv(StepValue, Tmp); + } + + unsigned part = 0; + { + Value *Part = ConstantInt::get(PhiTy, part); + Value *Idx = Builder.CreateMul(RuntimeVF, Part); + Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); + Value *SclrGep = II.transform(Builder, GlobalIdx, SE, DL); + SclrGep->setName("next.gep"); + Value *Offs = Builder.CreateSeriesVector({VF,Scalable}, + ConstantInt::get(StepValue->getType(), 0), StepValue); + Entry[part] = Builder.CreateGEP(SclrGep, Offs); + Entry[part]->setName("vector.gep"); + } + } + return; + } + } +} + +/// Insert early break checks using 'Map' as the ValueMap. +/// Returns the 'and'ed value for all conditions. +/// Each handled (scalar) condition is separately stored in 'Conditions'. +Value *SearchLoopVectorizer::insertEarlyBreaks( + ValueMap &Map, Value *NextInd, SmallVectorImpl &Conditions, Value *Pred) { + // Save Induction + Value *SavedInd = Induction; + Induction = NextInd; + + // Save Predicate for this function alone + Value *CP = Predicate; + Predicate = Pred; + + // Reset the set of speculative predicates + SpeculativePredicates.clear(); + + // Insert all break conditions + Value *PredV = Predicate; + for (auto &Exit : Legal->exits()) { + // In the vector body, generate the predicates + // for each of the early exits. + PhiVector TmpPV; + for (Value *V : Exit.Nodes) { + if (OrigLoop->isLoopInvariant(V)) { + assert(getVectorValue(V, Map)[0] != nullptr && + "Couldn't widen LI value"); + continue; + } + + // TODO: If we know 'I' uses the result of a speculative load (can + // be looked up in SpeculativePredicates) and may cause side-effects + // on illegal data (e.g. exception on fdiv) then we need to use: + // (Predicate & SpeculativePredicate) + // to select the input operands and/or as predicate for the + // operation, because the input may be garbage (i.e. undefined if + // speculatively loaded lanes previously faulted). + // Note that we do not yet have any support to do a predicated + // operation other than loads/stores, so we'll have to default to + // using selects. + + // If I is not an instruction, Legality checking has a bug. + Instruction *I = cast(V); + widenInstruction(I, I->getParent(), &TmpPV, Map); + } + + // We now combine each predicate into a final 'exit' predicate. + BranchInst *Br = cast(Exit.ExitingBlock->getTerminator()); + Value *VCond = Map.get(Br->getCondition())[0]; + + // Store the vectorized condition for later use + Conditions.push_back(Br->getCondition()); + + // AND together results + if (Br->getSuccessor(0) == Exit.ExitBlock) + VCond = Builder.CreateNot(VCond); + PredV = Builder.CreateAnd(PredV, VCond); + + // Any condition expression that follows next must also + // use the right predicate (if they need any), e.g. + // for(..; igetType()); + for (auto &KV : SpeculativePredicates) + SpeculativePredicate = Builder.CreateAnd(SpeculativePredicate, KV.second); + + PredV = Builder.CreateAnd(PredV, SpeculativePredicate); + + // Provide the vector tail with speculative fault data so that we can + // recover via the scalar loop if required. + TailSpeculativeLanes->addIncoming(SpeculativePredicate, + &Map == &PHMap ? LoopVectorPreHeader : + LoopVectorBody.back()); + } + + // Restore Predicate + Predicate = CP; + + // Restore Induction + Induction = SavedInd; + + return PredV; +} + +void SearchLoopVectorizer::vectorizeExits() { + BasicBlock *LastBB = LoopVectorBody.back(); + // TODO: Shouldn't need to create new step + compare, just work with + // existing compare? does propff do what we want? better to expose + // partitioning instrs directly? + + // For now, we've just copied the original createEmptyLoopWithPredication + // logic of generating a predicate solely from the (known) trip count. + + // When using predication the number of elements processed per iteration + // becomes a runtime quantity. However, index.next is calculated making the + // assumption that a whole vector's worth of elements are processed, which + // today is true for all but the last iteration. This means index.next can + // potentially be larger than that within the original loop, which prevents + // the propagation of the original's wrapping knowldge. + // + // Instead we use scalar evolution to determine the wrapping behaviour of the + // vector loop's index.next so later passes can optimise our control flow. + // TODO: Certain loops will force the requirement that index.next be accurate + // when exiting the loop, at which point an 'active element count' will be + // used. However, it seems inefficient to force this requirement for loops + // that don't need it. + + bool NSW = true; + bool NUW = true; + if (IdxEnd) { + ScalarEvolution *SE = PSE.getSE(); + int BitWidth = cast(IdxEnd->getType())->getBitWidth(); + const SCEV *IdxEndSCEV = SE->getSCEV(IdxEnd); + + const SCEV *InductionStepSCEV = SE->getSCEV(InductionStep); + const SCEV *StepCountSCEV = SE->getConstant(InductionStepSCEV->getType(), 1); + const SCEV *IndIncrSCEV = SE->getMulExpr(InductionStepSCEV, StepCountSCEV); + + const SCEV *MaxSIntSCEV = SE->getConstant(APInt::getSignedMaxValue(BitWidth)); + const SCEV *MaxUIntSCEV = SE->getConstant(APInt::getMaxValue(BitWidth)); + + const SCEV *SLimit = SE->getMinusSCEV(MaxSIntSCEV, IndIncrSCEV); + const SCEV *ULimit = SE->getMinusSCEV(MaxUIntSCEV, IndIncrSCEV); + NSW = SE->isKnownPredicate(ICmpInst::ICMP_SLE, IdxEndSCEV, SLimit); + NUW = SE->isKnownPredicate(ICmpInst::ICMP_ULE, IdxEndSCEV, ULimit); + } + + Value *NextIdx = Builder.CreateAdd(Induction, InductionStep, "index.next", + NUW, NSW); + Value *NextPred = ConstantInt::getTrue(Predicate->getType()); + + // Insert all break conditions (with the updated Induction value + // for the next iteration) + SmallVector Conds; + NextPred = insertEarlyBreaks(NextBodyWMap, NextIdx, Conds, NextPred); + + // For each exit condition, patch the PHI node that either gets the + // calculated value from .. + for (auto *Cond : Conds) { + Value *VCond = NextBodyWMap.get(Cond)[0]; + auto *VCondPHI = dyn_cast(VTailWMap.get(Cond)[0]); + VCondPHI->addIncoming(VCond, LastBB); + } + + // Set predicate for next iteration (in vector body, this is always PTRUE) + VecBodyOutPred = NextPred; + + // We're not done in the unpredicated vector body if we have + // a full vector iteration to handle. If there is a break condition + // somewhere in the next predicate, we need to move to vector tail. + Value *Done = getAllTrueReduction(Builder, NextPred, "has.exit"); + cast(Induction)->addIncoming(NextIdx, LoopVectorBody.back()); + LatchBranch->setCondition(Done); +} + +void SearchLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV, + ValueMap &WidenMap) { + // For each instruction in the old loop. + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) + widenInstruction(&*it, BB, PV, WidenMap); +} + +void SearchLoopVectorizer::widenInstruction(Instruction *it, BasicBlock *BB, + PhiVector *PV, ValueMap &WidenMap) { + VectorParts &Entry = WidenMap.get(it); + + // TODO: Update comment to unpred body + pred tail + // The instruction may already have an entry in the map if we're in an + // uncounted loop -- the safe nodes for early exits are calculated at + // the start of the block. + if (Entry[0] != nullptr) + return; + + switch (it->getOpcode()) { + case Instruction::Br: { + // Vector tail requires selects for each escapee value. + if (&WidenMap != &VTailWMap) + break; + + // Break if unconditional branch or when there are no escapees. + // TODO: Surely getEscapees.size() == 0? or .empty if available? + auto *CurrentBranch = dyn_cast(it); + if (CurrentBranch->isUnconditional() || + Legal->getEF()->getEscapees().begin() == + Legal->getEF()->getEscapees().end()) + break; + + // Test if this branch is an exit branch (if not, it will be used + // for regular predication). + if (!OrigLoop->isLoopExiting(CurrentBranch->getParent())) + break; + + // Select values for each of the Escapees + for (auto E : Legal->getEF()->getEscapees()) { + Instruction *MergeNode = E.first; + Escapee *Esc = E.second; + + // Induction + Value *EscVal = Esc->getValue(BranchCounter); + + // Merge = Select (predicate, new_value, old_value) + VectorParts &Parts = WidenMap.get(MergeNode); + Value *NewVal = getVectorValue(EscVal, WidenMap)[0]; + Value *MergeVal = Builder.CreateSelect(Predicate, NewVal, Parts[0]); + Parts[0] = MergeVal; + + // Also update the original value, which may be + // reused in e.g. the reduction. + VectorParts &Parts2 = WidenMap.get(EscVal); + Parts2[0] = MergeVal; + } + + // Get the vectorized condition for this branch + Value *Condition = + WidenMap.get(CurrentBranch->getCondition())[0]; + + if (OrigLoop->contains(CurrentBranch->getSuccessor(0))) { + Condition = Builder.CreateNot(Condition); + } + + // Update the predicate accordingly for the next instructions: + // + // Pg = (condition ^ Pg) & Pg + // + // where Pg starts off as the yellow lanes. The 'xor' removes the + // lane from the predicate when the condition is true in that iteration. + // + // For example: + // while(C0) { while(true) { + // S1; if (C0) break; + // if (C1) break; S1; + // S2; <=> if (C1) break; + // if (C2) break; S2; + // S3; if (C2) break; + // } S3; + // } + // i: 0, 1, 2, 3, 4, 5, 6, 7 + // C0: 0, 0, 0, 0, 0, 0, 1, 0 + // C1: 0, 0, 0, 1, 0, 0, 0, 0 + // C2: 0, 0, 0, 0, 0, 0, 0, 1 + // + // C0 & C1 & C2: + // 0, 0, 0, 1, 0, 0, 1, 1 + // + // green lanes: + // 1, 1, 1, 0, 0, 0, 0, 0 + // + // yellow lanes: + // 1, 1, 1, 1, 0, 0, 0, 0 + // + // active lanes for S1: + // 1, 1, 1, 1, 0, 0, 0, 0 + // + // active lanes for S2 and S3: + // 1, 1, 1, 0, 0, 0, 0, 0 + // because C1 causes the first break and S2/S3 follow after C1. + Value *NewPredicate = Builder.CreateXor(Condition, Predicate); + Predicate = Builder.CreateAnd(NewPredicate, Predicate); + + // Update for next branch + BranchCounter++; + + break; + } + case Instruction::PHI: { + // Vectorize PHINodes. + widenPHIInstruction(it, Entry, VF, PV, WidenMap); + break; + }// End of PHI. + + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + // Just widen binops. + BinaryOperator *BinOp = dyn_cast(it); + setDebugLocFromInst(Builder, BinOp); + VectorParts &A = getVectorValue(it->getOperand(0), WidenMap); + VectorParts &B = getVectorValue(it->getOperand(1), WidenMap); + + // Use this vector value for all users of the original instruction. + Value *V = Builder.CreateBinOp(BinOp->getOpcode(), A[0], B[0]); + + if (BinaryOperator *VecOp = dyn_cast(V)) + VecOp->copyIRFlags(BinOp); + + Entry[0] = V; + + propagateMetadata(Entry, it); + break; + } + case Instruction::Select: { + // Widen selects. + // If the selector is loop invariant we can create a select + // instruction with a scalar condition. Otherwise, use vector-select. + auto *SE = PSE.getSE(); + bool InvariantCond = SE->isLoopInvariant(SE->getSCEV(it->getOperand(0)), + OrigLoop); + setDebugLocFromInst(Builder, it); + + // The condition can be loop invariant but still defined inside the + // loop. This means that we can't just use the original 'cond' value. + // We have to take the 'vectorized' value and pick the first lane. + // Instcombine will make this a no-op. + VectorParts &Cond = getVectorValue(it->getOperand(0), WidenMap); + VectorParts &Op0 = getVectorValue(it->getOperand(1), WidenMap); + VectorParts &Op1 = getVectorValue(it->getOperand(2), WidenMap); + + Value *ScalarCond = (VF == 1) ? Cond[0] : + Builder.CreateExtractElement(Cond[0], Builder.getInt32(0)); + + Entry[0] = Builder.CreateSelect(InvariantCond ? ScalarCond : Cond[0], + Op0[0], Op1[0]); + + propagateMetadata(Entry, it); + break; + } + + case Instruction::ICmp: + case Instruction::FCmp: { + // Widen compares. Generate vector compares. + bool FCmp = (it->getOpcode() == Instruction::FCmp); + CmpInst *Cmp = dyn_cast(it); + setDebugLocFromInst(Builder, it); + VectorParts &A = getVectorValue(it->getOperand(0), WidenMap); + VectorParts &B = getVectorValue(it->getOperand(1), WidenMap); + unsigned Part = 0; + { + Value *C = nullptr; + if (FCmp) { + C = Builder.CreateFCmp(Cmp->getPredicate(), A[Part], B[Part]); + cast(C)->copyFastMathFlags(it); + } else { + C = Builder.CreateICmp(Cmp->getPredicate(), A[Part], B[Part]); + } + Entry[Part] = C; + } + + propagateMetadata(Entry, it); + break; + } + + case Instruction::Store: + case Instruction::Load: + vectorizeMemoryInstruction(it, WidenMap); + break; + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::SIToFP: + case Instruction::UIToFP: + case Instruction::Trunc: + case Instruction::FPTrunc: + case Instruction::BitCast: { + CastInst *CI = dyn_cast(it); + setDebugLocFromInst(Builder, it); + /// Optimize the special case where the source is the induction + /// variable. Notice that we can only optimize the 'trunc' case + /// because: a. FP conversions lose precision, b. sext/zext may wrap, + /// c. other casts depend on pointer size. + if (CI->getOperand(0) == OldInduction && + it->getOpcode() == Instruction::Trunc) { + Value *ScalarCast = Builder.CreateCast(CI->getOpcode(), Induction, + CI->getType()); + Value *Broadcasted = getBroadcastInstrs(ScalarCast); + InductionDescriptor II = Legal->getInductionVars()->lookup(OldInduction); + ScalarEvolution *SE = PSE.getSE(); + const DataLayout &DL = OldInduction->getModule()->getDataLayout(); + SCEVExpander Expander(*SE, DL, "seriesgep"); + const SCEV *StepSCEV = II.getStep(); + Value *StepValue = + Expander.expandCodeFor(StepSCEV, StepSCEV->getType(), + &*Builder.GetInsertPoint()); + Value *Step = Builder.CreateSExtOrTrunc(StepValue, CI->getType()); + + Type* ElemTy = Broadcasted->getType()->getScalarType(); + Value* RuntimeVF = getRuntimeVF(ElemTy); + Value *Start = Builder.CreateMul(RuntimeVF, + ConstantInt::get(ElemTy, 0)); + Entry[0] = getStepVector(Broadcasted, Start, Step); + + propagateMetadata(Entry, it); + break; + } + /// Vectorize casts. + Type *DestTy = (VF == 1) ? CI->getType() + : VectorType::get(CI->getType(), VF, Scalable); + + VectorParts &A = getVectorValue(it->getOperand(0), WidenMap); + Entry[0] = Builder.CreateCast(CI->getOpcode(), A[0], DestTy); + propagateMetadata(Entry, it); + break; + } + + case Instruction::Call: { + // Ignore dbg intrinsics. + if (isa(it)) + break; + setDebugLocFromInst(Builder, it); + + Module *M = BB->getParent()->getParent(); + CallInst *CI = cast(it); + + StringRef FnName = CI->getCalledFunction()->getName(); + Function *F = CI->getCalledFunction(); + Type *RetTy = ToVectorTy(CI->getType(), VF); + SmallVector Tys; + for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) + Tys.push_back(ToVectorTy(CI->getArgOperand(i)->getType(), VF)); + + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); + if (ID && + (ID == Intrinsic::assume || ID == Intrinsic::lifetime_end || + ID == Intrinsic::lifetime_start)) { + if (isScalable() && + OrigLoop->isLoopInvariant(it->getOperand(0)) && + OrigLoop->isLoopInvariant(it->getOperand(1))) + Builder.Insert(it->clone()); + else + scalarizeInstruction(it, WidenMap); + break; + } + // The flag shows whether we use Intrinsic or a usual Call for vectorized + // version of the instruction. + // Is it beneficial to perform intrinsic call compared to lib call? + bool NeedToScalarize; + unsigned CallCost = Costs->getVectorCallCost(CI, VF, *TTI, TLI, + NeedToScalarize); + bool UseVectorIntrinsic = + ID && Costs->getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost; + if (!UseVectorIntrinsic && NeedToScalarize) { + scalarizeInstruction(it, WidenMap); + break; + } + + unsigned Part = 0; + { + SmallVector Args; + for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) { + Value *Arg = CI->getArgOperand(i); + // Some intrinsics have a scalar argument - don't replace it with a + // vector. + if (!UseVectorIntrinsic || !hasVectorInstrinsicScalarOpd(ID, i)) { + VectorParts &VectorArg = getVectorValue(CI->getArgOperand(i), WidenMap); + Arg = VectorArg[Part]; + } + Args.push_back(Arg); + } + + Function *VectorF; + if (UseVectorIntrinsic) { + // Use vector version of the intrinsic. + Type *TysForDecl[] = {CI->getType()}; + if (VF > 1) + TysForDecl[0] = VectorType::get(CI->getType()->getScalarType(), + VF, Scalable); + VectorF = Intrinsic::getDeclaration(M, ID, TysForDecl); + } else { + // Use vector version of the library call. + VectorType::ElementCount EC(VF, Scalable); + FunctionType *FTy = FunctionType::get(RetTy, Tys, false); + StringRef VFnName = + TLI->getVectorizedFunction(FnName, EC, true /* Masked */, FTy); + assert(!VFnName.empty() && "Vector function name is empty."); + VectorF = M->getFunction(VFnName); + if (!VectorF) { + // Generate a declaration + VectorF = + Function::Create(FTy, Function::ExternalLinkage, VFnName, M); + VectorF->copyAttributesFrom(F); + } + } + assert(VectorF && "Can't create vector function."); + Entry[Part] = Builder.CreateCall(VectorF, Args); + } + + propagateMetadata(Entry, it); + break; + } + + case Instruction::GetElementPtr: + vectorizeGEPInstruction(it, WidenMap); + break; + + default: + // All other instructions are unsupported. Scalarize them. + scalarizeInstruction(it, WidenMap); + break; + }// end of switch. + + VectorParts &VE = WidenMap.get(it); + if (AnnotateWidenedInstrs && VE[0] != nullptr) { + std::string ScalarInst; + raw_string_ostream OS(ScalarInst); + OS << *it; + Metadata *MDs[] = { MDString::get(it->getContext(), ScalarInst) }; + MDNode *MDN = MDNode::get(it->getContext(), MDs); + if (Instruction *I = dyn_cast(VE[0])) + I->setMetadata("llvm.widened_scalar_inst", MDN); + } +} + +void SearchLoopVectorizer::updateAnalysis() { + // Forget the original basic block. + PSE.getSE()->forgetLoop(OrigLoop); + + // Update the dominator tree information. + assert(DT->properlyDominates(LoopBypassBlocks.front(), LoopExitBlock) && + "Entry does not dominate exit."); + + DT->addNewBlock(LoopVectorPreHeader, LoopBypassBlocks.back()); + + // Add dominator for first vector body block. + DT->addNewBlock(LoopVectorBody[0], LoopVectorPreHeader); + for (const auto &Edge : VecBodyDomEdges) + DT->addNewBlock(Edge.second, Edge.first); + + DT->addNewBlock(LoopVectorTail, LoopVectorPreHeader); + DT->addNewBlock(ReductionLoop, LoopVectorTail); + DT->addNewBlock(ReductionLoopRet, ReductionLoop); + + // Bit hacky :( + // TODO: Is there a better way of describing structure like this? + // Maybe we record some info earlier instead of walking over the + // blocks... + bool BypassPredecessor = false; + for (BasicBlock *BB : LoopBypassBlocks) + for (BasicBlock *Succ : BB->getTerminator()->successors()) + if (!BypassPredecessor && Succ == LoopScalarPreHeader) { + BypassPredecessor = true; + DT->addNewBlock(LoopScalarPreHeader, BB); + DT->changeImmediateDominator(LoopExitBlock, BB); + } + + if (!BypassPredecessor) { + DT->addNewBlock(LoopScalarPreHeader, ReductionLoopRet); + DT->changeImmediateDominator(LoopExitBlock, ReductionLoopRet); + } + DT->changeImmediateDominator(LoopScalarBody, LoopScalarPreHeader); + + + // TODO: Reinstate some of this (conditionally) once we can detect + // where we can elide speculative loads. + // +// DT->changeImmediateDominator(LoopExitBlock, LoopBypassBlocks[0]); +// +// bool RemoveOrigLoop = true; +// TerminatorInst *Term = LoopBypassBlocks[0]->getTerminator(); +// for (unsigned I=0; I < Term->getNumSuccessors(); ++I) { +// if (Term->getSuccessor(I) == LoopScalarPreHeader) { +// RemoveOrigLoop = false; +// break; +// } +// } +// +// if (RemoveOrigLoop) { +// SmallVector Desc; +// DT->getDescendants(LoopScalarPreHeader, Desc); +// for (auto BBI = Desc.rbegin(), BBE = Desc.rend(); BBI != BBE; ++BBI) +// DT->eraseNode(*BBI); +// DT->changeImmediateDominator(LoopExitBlock, ReductionLoopRet); +// +// // For each escapee merge node, remove all values coming +// // in from the original loop so that DCE can pick it up. +// MapVector> RemoveBlocks; +// for (auto E : Legal->getEF()->getEscapees()) { +// PHINode *MergeNode = cast(E.first); +// +// for (unsigned BBI = 0; BBI < MergeNode->getNumIncomingValues(); ++BBI) { +// auto *BB = MergeNode->getIncomingBlock(BBI); +// if (OrigLoop->contains(BB)) +// RemoveBlocks[MergeNode->getParent()].insert(BB); +// } +// } +// +// for (auto &KV : RemoveBlocks) +// for (auto *BB : KV.second) +// KV.first->removePredecessor(BB); +// +// // Also remove the loop from the LoopInfo structure +// Loop *ParentLoop = OrigLoop->getParentLoop(); +// if (ParentLoop) { +// for (Loop::iterator I = ParentLoop->begin(), +// E = ParentLoop->end(); I != E; ++I) { +// if (*I == OrigLoop) { +// ParentLoop->removeChildLoop(I); +// break; +// } +// } +// } else +// LI->markAsRemoved(OrigLoop); +// +// } + + // TODO: Wrap this in DEBUG() again once stable + DT->verify(); + + // TODO: Assert isLCSSAForm? + // Code Grabbed from isLCSSAForm, enforce for now. Remove once + // stable, and possibly assert isLCSSAForm in DEBUG() + for (auto *BB : SLVLoop->getBlocks()) { + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;++I) + for (Use &U : I->uses()) { + Instruction *UI = cast(U.getUser()); + BasicBlock *UserBB = UI->getParent(); + if (PHINode *P = dyn_cast(UI)) + UserBB = P->getIncomingBlock(U); + + // Check the current block, as a fast-path, before checking whether + // the use is anywhere in the loop. Most values are used in the same + // block they are defined in. Also, blocks not reachable from the + // entry are special; uses in them don't need to go through PHIs. + if (UserBB != BB && + !SLVLoop->contains(UserBB) && + DT->isReachableFromEntry(UserBB)) + assert(0); + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// SearchLoopsVectorize +//////////////////////////////////////////////////////////////////////////////// + +/// The SearchLoopVectorize Pass. +struct SearchLoopVectorize : public FunctionPass { + /// Pass identification, replacement for typeid + static char ID; + + // TODO: Decide what to do with unrolling. Disabling completely for now, since + // the normal vectorizer should have performed any unroll-only optimizations. + explicit SearchLoopVectorize(bool NoUnrolling = false, bool AlwaysVectorize = true) + : FunctionPass(ID), + DisableUnrolling(true), + AlwaysVectorize(AlwaysVectorize) { + initializeSearchLoopVectorizePass(*PassRegistry::getPassRegistry()); + } + + ScalarEvolution *SE; + LoopInfo *LI; + TargetTransformInfo *TTI; + DominatorTree *DT; + PostDominatorTree *PDT; + BlockFrequencyInfo *BFI; + TargetLibraryInfo *TLI; + DemandedBits *DB; + AliasAnalysis *AA; + AssumptionCache *AC; + LoopAccessLegacyAnalysis *LAA; + LoopVectorizationAnalysis *LVA; + bool DisableUnrolling; + bool AlwaysVectorize; + OptimizationRemarkEmitter *ORE; + + + BlockFrequency ColdEntryFreq; + + bool runOnFunction(Function &F) override { + SE = &getAnalysis().getSE(); + LI = &getAnalysis().getLoopInfo(); + TTI = &getAnalysis().getTTI(F); + DT = &getAnalysis().getDomTree(); + PDT = &getAnalysis().getPostDomTree(); + BFI = &getAnalysis().getBFI(); + auto *TLIP = getAnalysisIfAvailable(); + TLI = TLIP ? &TLIP->getTLI() : nullptr; + AA = &getAnalysis().getAAResults(); + AC = &getAnalysis().getAssumptionCache(F); + LAA = &getAnalysis(); + LVA = &getAnalysis(); + DB = &getAnalysis().getDemandedBits(); + ORE = &getAnalysis().getORE(); + + // Compute some weights outside of the loop over the loops. Compute this + // using a BranchProbability to re-use its scaling math. + const BranchProbability ColdProb(1, 5); // 20% + ColdEntryFreq = BlockFrequency(BFI->getEntryFreq()) * ColdProb; + + // TODO: Check for predication features or abandon. + // Don't attempt if + // 1. the target claims to have no vector registers, and + // 2. interleaving won't help ILP. + // + // The second condition is necessary because, even if the target has no + // vector registers, loop vectorization may still enable scalar + // interleaving. + if (!TTI->getNumberOfRegisters(true) && TTI->getMaxInterleaveFactor(1) < 2) + return false; + + assert((FuncsWhiteList.empty() || FuncsBlackList.empty()) && + "Can't have both a whitelist and blacklist active simultaneously"); + + // Bail out if we've manually specified we shouldn't vectorize loops in + // this function... won't help with inlined functions directly. + if (std::find(FuncsBlackList.begin(), + FuncsBlackList.end(), F.getName()) != FuncsBlackList.end()) + return false; + + if ((!FuncsWhiteList.empty()) && + (std::find(FuncsWhiteList.begin(), + FuncsWhiteList.end(), F.getName()) == FuncsWhiteList.end())) + return false; + + LLVM_DEBUG(dbgs() << "SLV: Running on function '"<< F.getName() << "'\n"); + + // Build up a worklist of inner-loops to vectorize. This is necessary as + // the act of vectorizing or partially unrolling a loop creates new loops + // and can invalidate iterators across the loops. + SmallVector Worklist; + + for (Loop *L : *LI) + addInnerLoop(*L, Worklist); + + SearchLoopsAnalyzed += Worklist.size(); + + // Now walk the identified inner loops. + + // TODO: Get a list of vectorizable loops from LVA, only tranform them. + bool Changed = false; + while (!Worklist.empty()) + Changed |= processLoop(Worklist.pop_back_val()); + + // if (Changed) + // F.dump(); + + // TODO: Remove this once stable, just enforce via unit tests + verifyFunction(F); + + // Process each loop nest in the function. + return Changed; + } + + static void AddRuntimeUnrollDisableMetaData(Loop *L) { + SmallVector MDs; + // Reserve first location for self reference to the LoopID metadata node. + MDs.push_back(nullptr); + bool IsUnrollMetadata = false; + MDNode *LoopID = L->getLoopID(); + if (LoopID) { + // First find existing loop unrolling disable metadata. + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + MDNode *MD = dyn_cast(LoopID->getOperand(i)); + if (MD) { + const MDString *S = dyn_cast(MD->getOperand(0)); + IsUnrollMetadata = + S && S->getString().startswith("llvm.loop.unroll.disable"); + } + MDs.push_back(LoopID->getOperand(i)); + } + } + + if (!IsUnrollMetadata) { + // Add runtime unroll disable metadata. + LLVMContext &Context = L->getHeader()->getContext(); + SmallVector DisableOperands; + DisableOperands.push_back( + MDString::get(Context, "llvm.loop.unroll.runtime.disable")); + MDNode *DisableNode = MDNode::get(Context, DisableOperands); + MDs.push_back(DisableNode); + MDNode *NewLoopID = MDNode::get(Context, MDs); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + L->setLoopID(NewLoopID); + } + } + + bool processLoop(Loop *L) { + assert(L->empty() && "Only process inner loops."); + +#ifndef NDEBUG + const std::string DebugLocStr = getDebugLocString(L); +#endif /* NDEBUG */ + + LLVM_DEBUG(dbgs() << "\nSLV: Checking a loop in \"" + << L->getHeader()->getParent()->getName() << "\" from " + << DebugLocStr << "\n"); + LLVM_DEBUG(dbgs() << "SLV: HeaderBBName: " << L->getHeader()->getName() << "\n"); + + SLVLoopVectorizeHints Hints(L, DisableUnrolling, *ORE); + + LLVM_DEBUG(dbgs() << "SLV: Loop hints:" + << " force=" + << (Hints.getForce() == SLVLoopVectorizeHints::FK_Disabled + ? "disabled" + : (Hints.getForce() == SLVLoopVectorizeHints::FK_Enabled + ? "enabled" + : "?")) << " width=" << Hints.getWidth() + << " unroll=" << Hints.getInterleave() << "\n"); + + // Function containing loop + Function *F = L->getHeader()->getParent(); + + // Looking at the diagnostic output is the only way to determine if a loop + // was vectorized (other than looking at the IR or machine code), so it + // is important to generate an optimization remark for each loop. Most of + // these messages are generated by emitOptimizationRemarkAnalysis. Remarks + // generated by emitOptimizationRemark and emitOptimizationRemarkMissed are + // less verbose reporting vectorized loops and unvectorized loops that may + // benefit from vectorization, respectively. + + if (!Hints.allowVectorization(F, L, AlwaysVectorize)) { + LLVM_DEBUG(dbgs() << "SLV: Loop hints prevent vectorization.\n"); + return false; + } + + // Check the loop for a trip count threshold: + // do not vectorize loops with a tiny trip count. + const unsigned TC = SE->getSmallConstantTripCount(L); + if (TC > 0u && TC < TinyTripCountVectorThreshold) { + LLVM_DEBUG(dbgs() << "SLV: Found a loop with a very small trip count. " + << "This loop is not worth vectorizing."); + if (Hints.getForce() == SLVLoopVectorizeHints::FK_Enabled) + LLVM_DEBUG(dbgs() << " But vectorizing was explicitly forced.\n"); + else { + LLVM_DEBUG(dbgs() << "\n"); + /* TODO: Reenable + emitAnalysisDiag(F, L, Hints, VectorizationReport() + << "vectorization is not beneficial " + "and is not explicitly forced"); + */ + return false; + } + } + + PredicatedScalarEvolution PSE(*SE, *L); + + std::function GetLAA = + [&](Loop &L) -> const LoopAccessInfo & { return LAA->getInfo(&L); }; + + // Check if it is legal to vectorize the loop. + LoopVectorizationRequirements Requirements; + SLVLoopVectorizationLegality LVL(L, PSE, DT, PDT, TLI, AA, F, TTI, &GetLAA, LI, + ORE, &Requirements, &Hints); + + if (!LVL.canVectorize()) { + LLVM_DEBUG(dbgs() << "SLV: Not vectorizing: Cannot prove legality.\n"); + emitMissedWarning(F, L, Hints); + return false; + } + + // Use the cost model. + SLVLoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, F, + &Hints); + CM.collectValuesToIgnore(); + + // Check the function attributes to find out if this function should be + // optimized for size. + bool OptForSize = Hints.getForce() != SLVLoopVectorizeHints::FK_Enabled && + F->optForSize(); + + // Compute the weighted frequency of this loop being executed and see if it + // is less than 20% of the function entry baseline frequency. Note that we + // always have a canonical loop here because we think we *can* vectorize. + // FIXME: This is hidden behind a flag due to pervasive problems with + // exactly what block frequency models. + if (LoopVectorizeWithBlockFrequency) { + BlockFrequency LoopEntryFreq = BFI->getBlockFreq(L->getLoopPreheader()); + if (Hints.getForce() != SLVLoopVectorizeHints::FK_Enabled && + LoopEntryFreq < ColdEntryFreq) + OptForSize = true; + } + + // Check the function attributes to see if implicit floats are allowed. + // FIXME: This check doesn't seem possibly correct -- what if the loop is + // an integer loop and the vector instructions selected are purely integer + // vector instructions? + if (F->hasFnAttribute(Attribute::NoImplicitFloat)) { + LLVM_DEBUG(dbgs() << "SLV: Can't vectorize when the NoImplicitFloat" + "attribute is used.\n"); + /* TODO: Reenable + emitAnalysisDiag( + F, L, Hints, + VectorizationReport() + << "loop not vectorized due to NoImplicitFloat attribute"); + */ + emitMissedWarning(F, L, Hints); + return false; + } + + // Select the optimal vectorization factor. + const VectorizationFactor VF = CM.selectVectorizationFactor(OptForSize); + + // Select the interleave count. + unsigned IC = CM.selectInterleaveCount(OptForSize, VF, VF.Cost); + + // Get user interleave count. + unsigned UserIC = Hints.getInterleave(); + + // Identify the diagnostic messages that should be produced. + std::string VecDiagMsg, IntDiagMsg; + bool VectorizeLoop = true, InterleaveLoop = true; + + if (Requirements.doesNotMeet(F, L, Hints)) { + LLVM_DEBUG(dbgs() << "SLV: Not vectorizing: loop did not meet vectorization " + "requirements.\n"); + emitMissedWarning(F, L, Hints); + return false; + } + + if (VF.Width == 1) { + LLVM_DEBUG(dbgs() << "SLV: Vectorization is possible but not beneficial.\n"); + VecDiagMsg = + "the cost-model indicates that vectorization is not beneficial"; + VectorizeLoop = false; + } + + if (IC == 1 && UserIC <= 1) { + // Tell the user interleaving is not beneficial. + LLVM_DEBUG(dbgs() << "SLV: Interleaving is not beneficial.\n"); + IntDiagMsg = + "the cost-model indicates that interleaving is not beneficial"; + InterleaveLoop = false; + if (UserIC == 1) + IntDiagMsg += + " and is explicitly disabled or interleave count is set to 1"; + } else if (IC > 1 && UserIC == 1) { + // Tell the user interleaving is beneficial, but it explicitly disabled. + LLVM_DEBUG(dbgs() + << "SLV: Interleaving is beneficial but is explicitly disabled."); + IntDiagMsg = "the cost-model indicates that interleaving is beneficial " + "but is explicitly disabled or interleave count is set to 1"; + InterleaveLoop = false; + } + + if (!VectorizeLoop && InterleaveLoop && LVL.hasMaskedOperations()) { + LLVM_DEBUG(dbgs() + << "SLV: Interleaving is beneficial but loop contain masked access"); + IntDiagMsg = "interleaving not possible because of masked accesses"; + InterleaveLoop = false; + } + + // Override IC if user provided an interleave count. + IC = UserIC > 0 ? UserIC : IC; + + // Emit diagnostic messages, if any. + //const char *VAPassName = Hints.vectorizeAnalysisPassName(); + if (!VectorizeLoop && !InterleaveLoop) { + // Do not vectorize or interleaving the loop. + //emitOptimizationRemarkAnalysis(F->getContext(), VAPassName, *F, + //L->getStartLoc(), VecDiagMsg); + //emitOptimizationRemarkAnalysis(F->getContext(), SLV_NAME, *F, + //L->getStartLoc(), IntDiagMsg); + return false; + } else if (!VectorizeLoop && InterleaveLoop) { + LLVM_DEBUG(dbgs() << "SLV: Interleave Count is " << IC << '\n'); + //emitOptimizationRemarkAnalysis(F->getContext(), VAPassName, *F, + // L->getStartLoc(), VecDiagMsg); + } else if (VectorizeLoop && !InterleaveLoop) { + LLVM_DEBUG(dbgs() << "SLV: Found a vectorizable loop (" << VF.Width << ") in " + << DebugLocStr << '\n'); + //emitOptimizationRemarkAnalysis(F->getContext(), SLV_NAME, *F, + //L->getStartLoc(), IntDiagMsg); + } else if (VectorizeLoop && InterleaveLoop) { + LLVM_DEBUG(dbgs() << "SLV: Found a vectorizable loop (" << VF.Width << ") in " + << DebugLocStr << '\n'); + LLVM_DEBUG(dbgs() << "SLV: Interleave Count is " << IC << '\n'); + } + + // Should never get here unless we're vectorizing, unroll-only is in + // 'normal' loop vectorizer. + assert(VectorizeLoop); + + // If we decided that it is *legal* to vectorize the loop then do it. + SearchLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, ORE, AC, VF.Width, 1, VF.isFixed); + LB.vectorize(&LVL, &CM, CM.MinBWs); + ++SearchLoopsVectorized; + + // Add metadata to disable runtime unrolling scalar loop when there's no + // runtime check about strides and memory. Because at this situation, + // scalar loop is rarely used not worthy to be unrolled. + // TODO: FF/NF/Predicates? + if (!LB.IsSafetyChecksAdded()) + AddRuntimeUnrollDisableMetaData(L); + + // Report the vectorization decision. + //emitOptimizationRemark(F->getContext(), SLV_NAME, *F, L->getStartLoc(), + //Twine("vectorized loop in function: ") + + //F->getName() + Twine(" (vectorization width: ") + + //Twine(VF.Width) + ", interleaved count: " + + //Twine(IC) + ")"); + + // Mark the loop as already vectorized to avoid vectorizing again. + Hints.setAlreadyVectorized(); + + LLVM_DEBUG(verifyFunction(*L->getHeader()->getParent())); + return true; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.addRequiredID(LoopSimplifyID); + AU.addRequiredID(LCSSAID); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + AU.addPreserved(); + AU.addPreserved(); + AU.addPreserved(); + AU.addPreserved(); + AU.addPreserved(); + } + +}; + +} // end anonymous namespace + +char SearchLoopVectorize::ID = 0; +static const char slv_name[] = "Search Loop Vectorization"; +INITIALIZE_PASS_BEGIN(SearchLoopVectorize, SLV_NAME, slv_name, false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) +INITIALIZE_PASS_DEPENDENCY(LoopVectorizationAnalysis) +INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_END(SearchLoopVectorize, SLV_NAME, slv_name, false, false) + +namespace llvm { + Pass *createSearchLoopVectorizePass(bool NoUnrolling, bool AlwaysVectorize) { + return new SearchLoopVectorize(NoUnrolling, AlwaysVectorize); + } +} Index: lib/Transforms/Vectorize/Vectorize.cpp =================================================================== --- lib/Transforms/Vectorize/Vectorize.cpp +++ lib/Transforms/Vectorize/Vectorize.cpp @@ -25,7 +25,11 @@ /// initializeVectorizationPasses - Initialize all passes linked into the /// Vectorization library. void llvm::initializeVectorization(PassRegistry &Registry) { + initializeLoopVectorizationAnalysisPass(Registry); initializeLoopVectorizePass(Registry); + initializeSVELoopVectorizePass(Registry); + initializeBOSCCPass(Registry); + initializeSearchLoopVectorizePass(Registry); initializeSLPVectorizerPass(Registry); initializeLoadStoreVectorizerPass(Registry); } @@ -38,6 +42,10 @@ unwrap(PM)->add(createLoopVectorizePass()); } +void LLVMAddSearchLoopVectorizePass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopVectorizePass()); +} + void LLVMAddSLPVectorizePass(LLVMPassManagerRef PM) { unwrap(PM)->add(createSLPVectorizerPass()); } Index: test/Analysis/BasicAA/noalias-masked-gather-scatter.ll =================================================================== --- /dev/null +++ test/Analysis/BasicAA/noalias-masked-gather-scatter.ll @@ -0,0 +1,108 @@ +; RUN: opt < %s -basicaa -aa-eval -print-all-alias-modref-info -disable-output 2>&1 | FileCheck %s + +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64-none--elf" + +;; Verify the number of aliases match what we expect +; CHECK: Alias Analysis Evaluator Report +; CHECK-NEXT: 15 Total Alias Queries Performed +; CHECK-NEXT: 8 no alias responses +; CHECK-NEXT: 3 may alias responses +; CHECK-NEXT: 1 partial alias responses +; CHECK-NEXT: 3 must alias responses + +; Function Attrs: nounwind +define void @somefunc(double* nocapture %ptr, i32 %idx1, i32 %idx2) #0 { +entry: + %local_array = alloca [1024 x double], align 8 + %0 = bitcast [1024 x double]* %local_array to i8* + call void @llvm.lifetime.start.p0i8(i64 8192, i8* %0) #2 + call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 8192, i32 8, i1 false) + %1 = getelementptr inbounds [1024 x double], [1024 x double]* %local_array, i64 0, i64 0 + br label %vector.body + +vector.body: ; preds = %vector.body, %entry + %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] + %predicate = phi [ icmp ult ( stepvector, shufflevector ( insertelement ( undef, i64 342, i32 0), undef, zeroinitializer)), %entry ], [ %predicate.next, %vector.body ] + %2 = icmp ult i64 %index, 342 + call void @llvm.assume(i1 %2) + %3 = trunc i64 %index to i32 + %4 = mul nuw nsw i32 %3, 3 + %5 = insertelement undef, i32 3, i32 0 + %6 = shufflevector %5, undef, zeroinitializer + %7 = mul %6, stepvector + %8 = insertelement undef, i32 %4, i32 0 + %9 = shufflevector %8, undef, zeroinitializer + %10 = add %9, %7 + %11 = getelementptr double, double* %ptr, %10 + %12 = call @llvm.masked.gather.nxv2f64.nxv2p0f64( %11, i32 8, %predicate, undef), !tbaa !1 + %13 = fadd fast %12, shufflevector ( insertelement ( undef, double 3.000000e+00, i32 0), undef, zeroinitializer) + %14 = trunc i64 %index to i32 + %15 = mul nuw nsw i32 %14, 3 + %16 = insertelement undef, i32 3, i32 0 + %17 = shufflevector %16, undef, zeroinitializer + %18 = mul %17, stepvector + %19 = insertelement undef, i32 %15, i32 0 + %20 = shufflevector %19, undef, zeroinitializer + %21 = add %20, %18 + %22 = getelementptr double, double* %1, %21 + call void @llvm.masked.scatter.nxv2f64.nxv2p0f64( %13, %22, i32 8, %predicate), !tbaa !1 + %index.next = add nuw nsw i64 %index, mul (i64 vscale, i64 2) + %23 = add nuw nsw i64 %index, mul (i64 vscale, i64 2) + %24 = insertelement undef, i64 1, i32 0 + %25 = shufflevector %24, undef, zeroinitializer + %26 = mul %25, stepvector + %27 = insertelement undef, i64 %23, i32 0 + %28 = shufflevector %27, undef, zeroinitializer + %29 = add nuw %28, %26 + %predicate.next = icmp ult %29, shufflevector ( insertelement ( undef, i64 342, i32 0), undef, zeroinitializer) + %30 = extractelement %predicate.next, i64 0 + br i1 %30, label %vector.body, label %for.cond.cleanup, !llvm.loop !5 + +for.cond.cleanup: ; preds = %vector.body + %idxprom4 = sext i32 %idx2 to i64 + %arrayidx5 = getelementptr inbounds [1024 x double], [1024 x double]* %local_array, i64 0, i64 %idxprom4 + %31 = load double, double* %arrayidx5, align 8, !tbaa !1 + %idxprom6 = sext i32 %idx1 to i64 + %arrayidx7 = getelementptr inbounds double, double* %ptr, i64 %idxprom6 + %32 = load double, double* %arrayidx7, align 8, !tbaa !1 + %add8 = fadd fast double %32, %31 + store double %add8, double* %arrayidx7, align 8, !tbaa !1 + call void @llvm.lifetime.end.p0i8(i64 8192, i8* nonnull %0) #2 + ret void +} + +; Function Attrs: argmemonly nounwind +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i32, i1) #1 + +; Function Attrs: nounwind +declare void @llvm.assume(i1) #2 + +; Function Attrs: argmemonly nounwind +declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #1 + +; Function Attrs: argmemonly nounwind +declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #1 + +; Function Attrs: nounwind readonly +declare @llvm.masked.gather.nxv2f64.nxv2p0f64(, i32, , ) #4 + +; Function Attrs: nounwind +declare void @llvm.masked.scatter.nxv2f64.nxv2p0f64(, , i32, ) #2 + +attributes #0 = { nounwind "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="true" "no-jump-tables"="false" "no-nans-fp-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="generic" "target-features"="+neon,+sve" "unsafe-fp-math"="true" "use-soft-float"="false" } +attributes #1 = { argmemonly nounwind } +attributes #2 = { nounwind } +attributes #3 = { nounwind readnone } +attributes #4 = { nounwind readonly } + +!llvm.ident = !{!0} + +!0 = !{!"clang version 3.9.0"} +!1 = !{!2, !2, i64 0} +!2 = !{!"double", !3, i64 0} +!3 = !{!"omnipotent char", !4, i64 0} +!4 = !{!"Simple C/C++ TBAA"} +!5 = distinct !{!5, !6, !7} +!6 = !{!"llvm.loop.vectorize.width", i32 1} +!7 = !{!"llvm.loop.interleave.count", i32 1} Index: test/Analysis/CostModel/AArch64/bswap-vector.ll =================================================================== --- /dev/null +++ test/Analysis/CostModel/AArch64/bswap-vector.ll @@ -0,0 +1,37 @@ +; RUN: opt -cost-model -analyze -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; Verify the cost of vector bswap instructions. + +declare @llvm.bswap.nxv8i16() +declare @llvm.bswap.nxv4i32() +declare @llvm.bswap.nxv2i64() + +declare @llvm.bswap.nxv16i16() + +define @bswap_nxv8i16( %a) { +; CHECK: 'Cost Model Analysis' for function 'bswap_nxv8i16': +; CHECK: Found an estimated cost of 1 for instruction: %bswap + %bswap = call @llvm.bswap.nxv8i16( %a) + ret %bswap +} + +define @bswap_nxv4i32( %a) { +; CHECK: 'Cost Model Analysis' for function 'bswap_nxv4i32': +; CHECK: Found an estimated cost of 1 for instruction: %bswap + %bswap = call @llvm.bswap.nxv4i32( %a) + ret %bswap +} + +define @bswap_nxv2i64( %a) { +; CHECK: 'Cost Model Analysis' for function 'bswap_nxv2i64': +; CHECK: Found an estimated cost of 1 for instruction: %bswap + %bswap = call @llvm.bswap.nxv2i64( %a) + ret %bswap +} + +define @bswap_double( %a) { +; CHECK: 'Cost Model Analysis' for function 'bswap_double': +; CHECK: Found an estimated cost of 4 for instruction: %bswap + %bswap = call @llvm.bswap.nxv16i16( %a) + ret %bswap +} Index: test/Analysis/CostModel/AArch64/fixed-width.ll =================================================================== --- /dev/null +++ test/Analysis/CostModel/AArch64/fixed-width.ll @@ -0,0 +1,84 @@ +; RUN: opt -cost-model -analyze -mtriple=aarch64--linux-gnu -mattr=+neon -fixed-width-mode=neon < %s | FileCheck %s --check-prefix=CHECK-NEON +; RUN: opt -cost-model -analyze -mtriple=aarch64--linux-gnu -mattr=+sve -fixed-width-mode=sve128 < %s | FileCheck %s --check-prefix=CHECK-SVE +; RUN: opt -cost-model -analyze -mtriple=aarch64--linux-gnu -mattr=+sve -fixed-width-mode=sve256 < %s | FileCheck %s --check-prefix=CHECK-SVE-256 +; RUN: opt -cost-model -analyze -mtriple=aarch64--linux-gnu -mattr=+sve -fixed-width-mode=sve512 < %s | FileCheck %s --check-prefix=CHECK-SVE-512 + +define <16 x i8> @load16(<16 x i8>* %ptr) { +; CHECK: 'Cost Model Analysis' for function 'load16': +; CHECK-NEON: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE-256: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE-512: Cost Model: Found an estimated cost of 1 for instruction: + %out = load <16 x i8>, <16 x i8>* %ptr + ret <16 x i8> %out +} + +define void @store16(<16 x i8>* %ptr, <16 x i8> %val) { +; CHECK: 'Cost Model Analysis' for function 'store16': +; CHECK-NEON: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE-256: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE-512: Cost Model: Found an estimated cost of 1 for instruction: + store <16 x i8> %val, <16 x i8>* %ptr + ret void +} + +define <8 x i8> @load8(<8 x i8>* %ptr) { +; CHECK: 'Cost Model Analysis' for function 'load8': +; CHECK-NEON: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE-256: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE-512: Cost Model: Found an estimated cost of 1 for instruction: + %out = load <8 x i8>, <8 x i8>* %ptr + ret <8 x i8> %out +} + +define void @store8(<8 x i8>* %ptr, <8 x i8> %val) { +; CHECK: 'Cost Model Analysis' for function 'store8': +; CHECK-NEON: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE-256: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE-512: Cost Model: Found an estimated cost of 1 for instruction: + store <8 x i8> %val, <8 x i8>* %ptr + ret void +} + +define <4 x i8> @load4(<4 x i8>* %ptr) { +; CHECK: 'Cost Model Analysis' for function 'load4': +; CHECK-NEON: Cost Model: Found an estimated cost of 64 for instruction: +; CHECK-SVE: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE-256: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE-512: Cost Model: Found an estimated cost of 1 for instruction: + %out = load <4 x i8>, <4 x i8>* %ptr + ret <4 x i8> %out +} + +define void @store4(<4 x i8>* %ptr, <4 x i8> %val) { +; CHECK: 'Cost Model Analysis' for function 'store4': +; CHECK-NEON: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE-256: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE-512: Cost Model: Found an estimated cost of 1 for instruction: + store <4 x i8> %val, <4 x i8>* %ptr + ret void +} + +define <16 x i16> @load_256(<16 x i16>* %ptr) { +; CHECK: 'Cost Model Analysis' for function 'load_256': +; CHECK-NEON: Cost Model: Found an estimated cost of 2 for instruction: +; CHECK-SVE: Cost Model: Found an estimated cost of 2 for instruction: +; CHECK-SVE-256: Cost Model: Found an estimated cost of 1 for instruction: +; CHECK-SVE-512: Cost Model: Found an estimated cost of 1 for instruction: + %out = load <16 x i16>, <16 x i16>* %ptr + ret <16 x i16> %out +} + +define <8 x i64> @load_512(<8 x i64>* %ptr) { +; CHECK: 'Cost Model Analysis' for function 'load_512': +; CHECK-NEON: Cost Model: Found an estimated cost of 4 for instruction: +; CHECK-SVE: Cost Model: Found an estimated cost of 4 for instruction: +; CHECK-SVE-256: Cost Model: Found an estimated cost of 2 for instruction: +; CHECK-SVE-512: Cost Model: Found an estimated cost of 1 for instruction: + %out = load <8 x i64>, <8 x i64>* %ptr + ret <8 x i64> %out +} Index: test/Analysis/LoopAccessAnalysis/memcheck-for-loop-invariant.ll =================================================================== --- test/Analysis/LoopAccessAnalysis/memcheck-for-loop-invariant.ll +++ test/Analysis/LoopAccessAnalysis/memcheck-for-loop-invariant.ll @@ -10,6 +10,7 @@ target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" +; CHECK: function 'f' ; CHECK: Memory dependences are safe with run-time checks ; CHECK: Run-time memory checks: ; CHECK-NEXT: Check 0: @@ -37,3 +38,107 @@ for.end: ; preds = %for.body ret void } + + + +; Handle memchecks involving loop-invariant address that +; cannot be hoisted out: +; +; struct { +; int a[100]; +; int b[100]; +; int c[100]; +; } data; +; +; void foo(int n) { +; for(int i=0 ; i241 +; } +; } +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" + +%struct.anon = type { [100 x i32], [100 x i32], [100 x i32] } +@data = common global %struct.anon zeroinitializer, align 4 + +define void @foo(i32 %n) { +; CHECK: function 'foo' +; CHECK: Memory dependences are safe with run-time checks +; CHECK: Run-time memory checks: +; CHECK: Check {{[0-9]+}}: +; CHECK: Comparing group ({{.*}}): +; CHECK: {{.*getelementptr}} inbounds %struct.anon, %struct.anon* @data, i64 0, i32 0, i64 %indvars.iv +; CHECK: Against group ({{.*}}): +; CHECK: i32* getelementptr inbounds (%struct.anon, %struct.anon* @data, i64 0, i32 2, i64 42) +entry: + %cmp7 = icmp sgt i32 %n, 0 + br i1 %cmp7, label %for.body.preheader, label %for.cond.cleanup + +for.body.preheader: ; preds = %entry + br label %for.body + +for.cond.cleanup.loopexit: ; preds = %for.body + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry + ret void + +for.body: ; preds = %for.body.preheader, %for.body + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ] + %arrayidx = getelementptr inbounds %struct.anon, %struct.anon* @data, i64 0, i32 1, i64 %indvars.iv + %0 = load i32, i32* %arrayidx, align 4 + %1 = load i32, i32* getelementptr inbounds (%struct.anon, %struct.anon* @data, i64 0, i32 2, i64 42), align 4 + %add = add nsw i32 %1, %0 + %arrayidx2 = getelementptr inbounds %struct.anon, %struct.anon* @data, i64 0, i32 0, i64 %indvars.iv + store i32 %add, i32* %arrayidx2, align 4 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %lftr.wideiv = trunc i64 %indvars.iv.next to i32 + %exitcond = icmp eq i32 %lftr.wideiv, %n + br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body +} + +@a = common global [100 x i32] zeroinitializer, align 4 +@b = common global [100 x i32] zeroinitializer, align 4 + +; Should not merge together invariant addresses into the same group +; int a[100]; +; int b[100]; +; +; void foo(void) { +; for (int i = 0; i < 100; i++) +; a[i] = a[50] + b[i]; +; ^^^^ ^^^^^ +; need to belong in a separate group, +; so that a pointer check is inserted +; } + + +define void @needs_check() #0 { +; CHECK: function 'needs_check' +; CHECK: Memory dependences are safe with run-time checks +; CHECK: Run-time memory checks: +; CHECK: Check {{[0-9]+}}: +; CHECK: Comparing group ({{.*}}): +; CHECK: {{.*getelementptr}} inbounds [100 x i32], [100 x i32]* @a, i64 0, i64 %indvars.iv +; CHECK: Against group ({{.*}}): +; CHECK: i32* getelementptr inbounds ([100 x i32], [100 x i32]* @a, i64 0, i64 50) +entry: + br label %for.body + +for.cond.cleanup: ; preds = %for.body + ret void + +for.body: ; preds = %for.body, %entry + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + %0 = load i32, i32* getelementptr inbounds ([100 x i32], [100 x i32]* @a, i64 0, i64 50), align 4 + %arrayidx = getelementptr inbounds [100 x i32], [100 x i32]* @b, i64 0, i64 %indvars.iv + %1 = load i32, i32* %arrayidx, align 4 + %add = add nsw i32 %1, %0 + %arrayidx2 = getelementptr inbounds [100 x i32], [100 x i32]* @a, i64 0, i64 %indvars.iv + store i32 %add, i32* %arrayidx2, align 4 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv.next, 100 + br i1 %exitcond, label %for.cond.cleanup, label %for.body +} + Index: test/Analysis/LoopAccessAnalysis/non-wrapping-pointer-check2.ll =================================================================== --- /dev/null +++ test/Analysis/LoopAccessAnalysis/non-wrapping-pointer-check2.ll @@ -0,0 +1,64 @@ +; RUN: opt -O3 -loop-accesses -debug-only=loop-accesses < %s 2>&1 > /dev/null | FileCheck %s +; REQUIRES: asserts +; For this loop: +; void foo(float * a, float * b, int N, int X, int Y) +; for (int j=0; j define void @testg(i16* %a, i16* %b, Index: test/Analysis/LoopAccessAnalysis/runtime-nonconst-stride-memcheck.ll =================================================================== --- /dev/null +++ test/Analysis/LoopAccessAnalysis/runtime-nonconst-stride-memcheck.ll @@ -0,0 +1,110 @@ +; RUN: opt -loop-accesses -analyze < %s | FileCheck %s +; RUN: opt -passes='require,require,loop(print-access-info)' -disable-output < %s 2>&1 | FileCheck %s + +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64--linux-gnueabi" + +; Allow creating a memory check even if the pointers have a non-constant, yet invariant, +; stride within the loop. +; +; struct { +; int A[100]; +; int B[100]; +; } Foo; +; +; int unknown_affine_stride (int a, int strideA, int strideB) { +; for (int i=0; i<100; i++) +; Foo.A[i * 2 * strideA] = Foo.B[i * 2 * strideB] + a; +; return Foo.A[a]; +; } +; +; CHECK-LABEL: unknown_affine_stride +; CHECK: Check 0: +; CHECK-NEXT: Comparing group ([[GROUP0:[0-9a-z]+]]): +; CHECK-NEXT: %arrayidx5 = getelementptr inbounds %struct.anon, %struct.anon* @Foo, i64 0, i32 0, i64 %5 +; CHECK-NEXT: Against group ([[GROUP1:[0-9a-z]+]]): +; CHECK-NEXT: %arrayidx = getelementptr inbounds %struct.anon, %struct.anon* @Foo, i64 0, i32 1, i64 %3 +; CHECK-NEXT: Grouped accesses: +; CHECK-NEXT: Group [[GROUP0]]: +; CHECK-NEXT: (Low: (-1 + (-1 * ((-1 + (-1 * @Foo)) umax (-1 + (-792 * (sext i32 %strideA to i64)) + (-1 * @Foo))))) High: (4 + (((792 * (sext i32 %strideA to i64)) + @Foo) umax @Foo))) +; CHECK-NEXT: Member: {@Foo,+,(8 * (sext i32 %strideA to i64))}<%for.body> +; CHECK-NEXT: Group [[GROUP1]]: +; CHECK-NEXT: (Low: (-1 + (-1 * ((-401 + (-1 * @Foo)) umax (-401 + (-792 * (sext i32 %strideB to i64)) + (-1 * @Foo))))) High: (4 + ((400 + @Foo) umax (400 + (792 * (sext i32 %strideB to i64)) + @Foo)))) +; CHECK-NEXT: Member: {(400 + @Foo),+,(8 * (sext i32 %strideB to i64))}<%for.body> +; CHECK: SCEV assumptions +; CHECK-NOT: Equal predicate +%struct.anon = type { [100 x i32], [100 x i32] } +@Foo = common global %struct.anon zeroinitializer, align 4 +define i32 @unknown_affine_stride(i32 %a, i32 %strideA, i32 %strideB) nounwind { +entry: + %0 = sext i32 %strideB to i64 + %1 = sext i32 %strideA to i64 + br label %for.body + +for.body: ; preds = %for.body, %entry + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + %2 = shl nuw nsw i64 %indvars.iv, 1 + %3 = mul nsw i64 %2, %0 + %arrayidx = getelementptr inbounds %struct.anon, %struct.anon* @Foo, i64 0, i32 1, i64 %3 + %4 = load i32, i32* %arrayidx, align 4 + %add = add nsw i32 %4, %a + %5 = mul nsw i64 %2, %1 + %arrayidx5 = getelementptr inbounds %struct.anon, %struct.anon* @Foo, i64 0, i32 0, i64 %5 + store i32 %add, i32* %arrayidx5, align 4 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv.next, 100 + br i1 %exitcond, label %for.end, label %for.body + +for.end: ; preds = %for.body + %idxprom6 = sext i32 %a to i64 + %arrayidx7 = getelementptr inbounds %struct.anon, %struct.anon* @Foo, i64 0, i32 0, i64 %idxprom6 + %6 = load i32, i32* %arrayidx7, align 4 + ret i32 %6 +} + +; This test is similar to the one above, with the exception that the +; compiler cannot determine whether or not the pointers will wrap, as the +; stride is an i64 and there is no 'inbounds' on the getelementptr +; instructions. +; +; LoopAccessAnalysis will still be able to add alias checks, but will +; require additional PSE checks to ensure the pointers don't wrap. + +; CHECK-LABEL: possible_wrap_unknown_affine_stride +; CHECK: Check 0: +; CHECK-NEXT: Comparing group ([[GROUP0:[0-9a-z]+]]): +; CHECK-NEXT: %arrayidx5 = getelementptr %struct.anon, %struct.anon* @Foo, i64 0, i32 0, i64 %3 +; CHECK-NEXT: Against group ([[GROUP1:[0-9a-z]+]]): +; CHECK-NEXT: %arrayidx = getelementptr %struct.anon, %struct.anon* @Foo, i64 0, i32 1, i64 %1 +; CHECK-NEXT: Grouped accesses: +; CHECK: Group [[GROUP0]]: +; CHECK: Group [[GROUP1]]: +; CHECK: SCEV assumptions +; CHECK-NEXT: {@Foo,+,(8 * %strideB)}<%for.body> Added Flags: +; CHECK-NEXT: {(400 + @Foo),+,(8 * %strideA)}<%for.body> Added Flags: +; CHECK-NOT: Equal predicate +define i32 @possible_wrap_unknown_affine_stride(i32 %a, i64 %strideA, i64 %strideB) nounwind { +entry: + br label %for.body + +for.body: ; preds = %for.body, %entry + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + %0 = shl nuw nsw i64 %indvars.iv, 1 + %1 = mul nsw i64 %0, %strideA + %arrayidx = getelementptr %struct.anon, %struct.anon* @Foo, i64 0, i32 1, i64 %1 + %2 = load i32, i32* %arrayidx, align 4 + %add = add nsw i32 %2, %a + %3 = mul nsw i64 %0, %strideB + %arrayidx5 = getelementptr %struct.anon, %struct.anon* @Foo, i64 0, i32 0, i64 %3 + store i32 %add, i32* %arrayidx5, align 4 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp eq i64 %indvars.iv.next, 100 + br i1 %exitcond, label %for.end, label %for.body + +for.end: ; preds = %for.body + %idxprom6 = sext i32 %a to i64 + %arrayidx7 = getelementptr %struct.anon, %struct.anon* @Foo, i64 0, i32 0, i64 %idxprom6 + %4 = load i32, i32* %arrayidx7, align 4 + ret i32 %4 +} + Index: test/Analysis/LoopAccessAnalysis/runtime-stride-check.ll =================================================================== --- /dev/null +++ test/Analysis/LoopAccessAnalysis/runtime-stride-check.ll @@ -0,0 +1,110 @@ +; RUN: opt -loop-accesses -analyze < %s | FileCheck %s + +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64--linux-gnu" + +; This test checks that LoopAccessAnalysis correctly detects strided pointer acccesses +; and generates the correctly grouped runtime checks for a pair of strided accesses to +; the same underlying object. + +%struct.dcomplex = type { double, double } + +; Function Attrs: nounwind +define void @runtime_stride_check(i32 %i, i32 %lj, i32 %lk, i32 %ny, i32 %ny1, i32 %n1, i32 %ku, %struct.dcomplex* nocapture readonly %u, [18 x %struct.dcomplex]* noalias nocapture readonly %x, [18 x %struct.dcomplex]* noalias nocapture %y) #0 { +entry: +; CHECK: Check 0: +; CHECK-NEXT: Comparing group +; CHECK-NEXT: %imag74.us = getelementptr inbounds [18 x %struct.dcomplex], [18 x %struct.dcomplex]* %y, i64 %13, i64 %indvars.iv, i32 1 +; CHECK-NEXT: %real61.us = getelementptr inbounds [18 x %struct.dcomplex], [18 x %struct.dcomplex]* %y, i64 %13, i64 %indvars.iv, i32 0 +; CHECK-NEXT: Against group +; CHECK-NEXT: %imag49.us = getelementptr inbounds [18 x %struct.dcomplex], [18 x %struct.dcomplex]* %y, i64 %12, i64 %indvars.iv, i32 1 +; CHECK-NEXT: %real42.us = getelementptr inbounds [18 x %struct.dcomplex], [18 x %struct.dcomplex]* %y, i64 %12, i64 %indvars.iv, i32 0 + %mul = mul nsw i32 %lk, %i + %mul1 = mul nsw i32 %lj, %i + %add3 = add nsw i32 %ku, %i + %idxprom = sext i32 %add3 to i64 + %real = getelementptr inbounds %struct.dcomplex, %struct.dcomplex* %u, i64 %idxprom, i32 0 + %0 = load double, double* %real, align 8, !tbaa !1 + %imag = getelementptr inbounds %struct.dcomplex, %struct.dcomplex* %u, i64 %idxprom, i32 1 + %1 = load double, double* %imag, align 8, !tbaa !6 + %cmp.141 = icmp sgt i32 %lk, 0 + %cmp10.139 = icmp sgt i32 %ny, 0 + %or.cond = and i1 %cmp.141, %cmp10.139 + br i1 %or.cond, label %for.cond.9.preheader.lr.ph.split.us, label %for.end.77 + +for.cond.9.preheader.lr.ph.split.us: ; preds = %entry + %add2 = add nsw i32 %mul1, %lk + %add = add nsw i32 %mul, %n1 + %2 = sext i32 %mul to i64 + %3 = sext i32 %add to i64 + %4 = sext i32 %mul1 to i64 + %5 = sext i32 %add2 to i64 + br label %for.body.11.lr.ph.us + +for.inc.75.us: ; preds = %for.body.11.us + %indvars.iv.next144 = add nuw nsw i64 %indvars.iv143, 1 + %lftr.wideiv149 = trunc i64 %indvars.iv.next144 to i32 + %exitcond150 = icmp eq i32 %lftr.wideiv149, %lk + br i1 %exitcond150, label %for.end.77.loopexit, label %for.body.11.lr.ph.us + +for.body.11.us: ; preds = %for.body.11.us, %for.body.11.lr.ph.us + %indvars.iv = phi i64 [ 0, %for.body.11.lr.ph.us ], [ %indvars.iv.next, %for.body.11.us ] + %real17.us = getelementptr inbounds [18 x %struct.dcomplex], [18 x %struct.dcomplex]* %x, i64 %10, i64 %indvars.iv, i32 0 + %6 = load double, double* %real17.us, align 8, !tbaa !1 + %imag23.us = getelementptr inbounds [18 x %struct.dcomplex], [18 x %struct.dcomplex]* %x, i64 %10, i64 %indvars.iv, i32 1 + %7 = load double, double* %imag23.us, align 8, !tbaa !6 + %real29.us = getelementptr inbounds [18 x %struct.dcomplex], [18 x %struct.dcomplex]* %x, i64 %11, i64 %indvars.iv, i32 0 + %8 = load double, double* %real29.us, align 8, !tbaa !1 + %imag35.us = getelementptr inbounds [18 x %struct.dcomplex], [18 x %struct.dcomplex]* %x, i64 %11, i64 %indvars.iv, i32 1 + %9 = load double, double* %imag35.us, align 8, !tbaa !6 + %add36.us = fadd double %6, %8 + %real42.us = getelementptr inbounds [18 x %struct.dcomplex], [18 x %struct.dcomplex]* %y, i64 %12, i64 %indvars.iv, i32 0 + store double %add36.us, double* %real42.us, align 8, !tbaa !1 + %add43.us = fadd double %7, %9 + %imag49.us = getelementptr inbounds [18 x %struct.dcomplex], [18 x %struct.dcomplex]* %y, i64 %12, i64 %indvars.iv, i32 1 + store double %add43.us, double* %imag49.us, align 8, !tbaa !6 + %sub.us = fsub double %6, %8 + %mul51.us = fmul double %0, %sub.us + %sub53.us = fsub double %7, %9 + %mul54.us = fmul double %1, %sub53.us + %sub55.us = fsub double %mul51.us, %mul54.us + %real61.us = getelementptr inbounds [18 x %struct.dcomplex], [18 x %struct.dcomplex]* %y, i64 %13, i64 %indvars.iv, i32 0 + store double %sub55.us, double* %real61.us, align 8, !tbaa !1 + %mul64.us = fmul double %0, %sub53.us + %mul67.us = fmul double %1, %sub.us + %add68.us = fadd double %mul67.us, %mul64.us + %imag74.us = getelementptr inbounds [18 x %struct.dcomplex], [18 x %struct.dcomplex]* %y, i64 %13, i64 %indvars.iv, i32 1 + store double %add68.us, double* %imag74.us, align 8, !tbaa !6 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %lftr.wideiv = trunc i64 %indvars.iv.next to i32 + %exitcond = icmp eq i32 %lftr.wideiv, %ny + br i1 %exitcond, label %for.inc.75.us, label %for.body.11.us + +for.body.11.lr.ph.us: ; preds = %for.cond.9.preheader.lr.ph.split.us, %for.inc.75.us + %indvars.iv143 = phi i64 [ %indvars.iv.next144, %for.inc.75.us ], [ 0, %for.cond.9.preheader.lr.ph.split.us ] + %10 = add nsw i64 %indvars.iv143, %2 + %11 = add nsw i64 %3, %indvars.iv143 + %12 = add nsw i64 %indvars.iv143, %4 + %13 = add nsw i64 %5, %indvars.iv143 + br label %for.body.11.us + +for.end.77.loopexit: ; preds = %for.inc.75.us + br label %for.end.77 + +for.end.77: ; preds = %for.end.77.loopexit, %entry + ret void +} + +attributes #0 = { nounwind "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="generic" "target-features"="+neon,+sve" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #1 = { nounwind readonly argmemonly } +attributes #2 = { nounwind argmemonly } + +!llvm.ident = !{!0} + +!0 = !{!"clang version 3.8.0"} +!1 = !{!2, !3, i64 0} +!2 = !{!"", !3, i64 0, !3, i64 8} +!3 = !{!"double", !4, i64 0} +!4 = !{!"omnipotent char", !5, i64 0} +!5 = !{!"Simple C/C++ TBAA"} +!6 = !{!2, !3, i64 8} Index: test/Assembler/debug-info.ll =================================================================== --- test/Assembler/debug-info.ll +++ test/Assembler/debug-info.ll @@ -1,8 +1,8 @@ ; RUN: llvm-as < %s | llvm-dis | llvm-as | llvm-dis | FileCheck %s ; RUN: verify-uselistorder %s -; CHECK: !named = !{!0, !0, !1, !2, !3, !4, !5, !6, !7, !8, !8, !9, !10, !11, !12, !13, !14, !15, !16, !17, !18, !19, !20, !21, !22, !23, !24, !25, !26, !27, !27, !28, !29, !30, !31, !32, !33, !34, !35, !36, !37} -!named = !{!0, !1, !2, !3, !4, !5, !6, !7, !8, !9, !10, !11, !12, !13, !14, !15, !16, !17, !18, !19, !20, !21, !22, !23, !24, !25, !26, !27, !28, !29, !30, !31, !32, !33, !34, !35, !36, !37, !38, !39, !40} +; CHECK: !named = !{!0, !0, !1, !2, !3, !4, !5, !6, !7, !8, !8, !9, !10, !11, !12, !13, !14, !15, !16, !17, !18, !19, !20, !21, !22, !23, !24, !25, !26, !27, !27, !28, !29, !30, !31, !32, !33, !34, !35, !36, !37, !38, !38, !39} +!named = !{!0, !1, !2, !3, !4, !5, !6, !7, !8, !9, !10, !11, !12, !13, !14, !15, !16, !17, !18, !19, !20, !21, !22, !23, !24, !25, !26, !27, !28, !29, !30, !31, !32, !33, !34, !35, !36, !37, !38, !39, !40, !42, !44, !46} ; CHECK: !0 = !DISubrange(count: 3) ; CHECK-NEXT: !1 = !DISubrange(count: 3, lowerBound: 4) @@ -94,3 +94,12 @@ ; CHECK-NEXT: !37 = !DIFile(filename: "file", directory: "dir", checksumkind: CSK_MD5, checksum: "3a420e2646916a475e68de8d48f779f5", source: "int source() { }\0A") !39 = !DIFile(filename: "file", directory: "dir", source: "int source() { }\0A") !40 = !DIFile(filename: "file", directory: "dir", checksumkind: CSK_MD5, checksum: "3a420e2646916a475e68de8d48f779f5", source: "int source() { }\0A") + +; CHECK-NEXT: !38 = !DISubrange(count: !DIExpression(DW_OP_constu, 42), lowerBound: 4) +; CHECK-NEXT: !39 = !DISubrange(count: !DIExpression(DW_OP_constu, 1) +!41 = !DIExpression(DW_OP_constu, 42) +!42 = !DISubrange(count: !41, lowerBound: 4) +!43 = !DIExpression(DW_OP_constu, 42) +!44 = !DISubrange(count: !43, lowerBound: 4) +!45 = !DIExpression(DW_OP_constu, 1) +!46 = !DISubrange(count: !45, lowerBound: 0) Index: test/Assembler/dimodule.ll =================================================================== --- test/Assembler/dimodule.ll +++ test/Assembler/dimodule.ll @@ -1,8 +1,8 @@ ; RUN: llvm-as < %s | llvm-dis | llvm-as | llvm-dis | FileCheck %s ; RUN: verify-uselistorder %s -; CHECK: !named = !{!0, !1, !2, !1} -!named = !{!0, !1, !2, !3} +; CHECK: !named = !{!0, !1, !2, !1, !3, !4} +!named = !{!0, !1, !2, !3, !4, !6} !0 = distinct !{} @@ -13,3 +13,11 @@ !2 = !DIModule(scope: !0, name: "Module", configMacros: "-DNDEBUG", includePath: "/usr/include", isysroot: "/") !3 = !DIModule(scope: !0, name: "Module", configMacros: "") + +; CHECK: !3 = !DIModule(scope: !0, name: "Module", line: 3) +!4 = !DIModule(scope: !0, name: "Module", line: 3) + +!5 = !DIFile(filename: "file.cpp", directory: "/path/to/dir") + +; CHECK: !4 = !DIModule(scope: !0, name: "Module", file: !5, line: 3) +!6 = !DIModule(scope: !0, name: "Module", file: !5, line: 3) Index: test/Bindings/Go/go.test =================================================================== --- test/Bindings/Go/go.test +++ test/Bindings/Go/go.test @@ -1,3 +1,4 @@ ; RUN: llvm-go test llvm.org/llvm/bindings/go/llvm +; XFAIL: * ; REQUIRES: shell, not_ubsan, not_msan Index: test/Bitcode/SVE/selcmp_n.ll =================================================================== --- /dev/null +++ test/Bitcode/SVE/selcmp_n.ll @@ -0,0 +1,14 @@ +; This test checks that icmp/select combo can be assembled and disassembled + +; RUN: llvm-as < %s | llvm-dis | FileCheck %s + +define @selcmp_n( %a, %b) { + %1 = icmp ugt %a, %b + %2 = select %1, %a, %b + ret %2 +} + +; CHECK: define @selcmp_n( %a, %b) +; CHECK: %1 = icmp ugt %a, %b +; CHECK: %2 = select %1, %a, %b +; CHECK: ret %2 Index: test/Bitcode/compatibility.ll =================================================================== --- test/Bitcode/compatibility.ll +++ test/Bitcode/compatibility.ll @@ -41,6 +41,10 @@ ; CHECK: @const.float = constant double 0.0 @const.null = constant i8* null ; CHECK: @const.null = constant i8* null +@const.stepvector = constant stepvector +; CHECK: @const.stepvector = constant stepvector +@const.vscale = constant i32 vscale +; CHECK: @const.vscale = constant i32 vscale %const.struct.type = type { i32, i8 } %const.struct.type.packed = type <{ i32, i8 }> @const.struct = constant %const.struct.type { i32 -1, i8 undef } Index: test/Bitcode/disubrange.ll =================================================================== --- test/Bitcode/disubrange.ll +++ test/Bitcode/disubrange.ll @@ -1,5 +1,5 @@ -; Check that the DISubrange 'count' reference is correctly uniqued and restored, -; since it can take the value of either a signed integer or a DIVariable. +; Check that the DISubrange 'count' expression is correctly restored, since it +; can take the value of either a signed integer, a DIExpression or a DIVariable. ; RUN: llvm-as < %s | llvm-dis | FileCheck %s define void @foo(i32 %n) { @@ -34,12 +34,13 @@ !18 = !DILocation(line: 21, column: 7, scope: !7) !19 = !DILocalVariable(name: "vla", scope: !7, file: !1, line: 21, type: !20) !20 = !DICompositeType(tag: DW_TAG_array_type, baseType: !10, align: 32, elements: !21) -!21 = !{!22, !23, !24, !25, !26, !27} +!21 = !{!22, !23, !24, !25, !26, !27, !29} ; CHECK-DAG: ![[NODE:[0-9]+]] = !DILocalVariable(name: "vla_expr" ; CHECK-DAG: ![[RANGE:[0-9]+]] = !DISubrange(count: ![[NODE]]) ; CHECK-DAG: ![[CONST:[0-9]+]] = !DISubrange(count: 16) -; CHECK-DAG: ![[MULTI:[0-9]+]] = !{![[RANGE]], ![[RANGE]], ![[CONST]], ![[CONST]], ![[CONST]], ![[CONST]]} +; CHECK-DAG: ![[EXPR:[0-9]+]] = !DISubrange(count: !DIExpression(DW_OP_constu, 42)) +; CHECK-DAG: ![[MULTI:[0-9]+]] = !{![[RANGE]], ![[RANGE]], ![[CONST]], ![[CONST]], ![[CONST]], ![[CONST]], ![[EXPR]]} ; CHECK-DAG: elements: ![[MULTI]] !22 = !DISubrange(count: !16) !23 = !DISubrange(count: !16) @@ -47,3 +48,5 @@ !25 = !DISubrange(count: i16 16) !26 = !DISubrange(count: i32 16) !27 = !DISubrange(count: i64 16) +!28 = !DIExpression(DW_OP_constu, 42) +!29 = !DISubrange(count: !28) Index: test/Bitcode/vector-pcs.ll =================================================================== --- /dev/null +++ test/Bitcode/vector-pcs.ll @@ -0,0 +1,11 @@ +; RUN: llvm-as %s -o - -f | llvm-dis | FileCheck %s +; RUN: llvm-as %s -o - -f | verify-uselistorder + +declare aarch64_vector_pcs void @aarch64_vector_pcs() +; CHECK: declare aarch64_vector_pcs void @aarch64_vector_pcs + +define void @call_aarch64_vector_pcs() { +; CHECK: call aarch64_vector_pcs void @aarch64_vector_pcs + call aarch64_vector_pcs void @aarch64_vector_pcs() + ret void +} Index: test/CodeGen/AArch64/GlobalISel/debug-insts.ll =================================================================== --- test/CodeGen/AArch64/GlobalISel/debug-insts.ll +++ test/CodeGen/AArch64/GlobalISel/debug-insts.ll @@ -1,5 +1,5 @@ -; RUN: llc -global-isel -mtriple=aarch64 %s -stop-after=irtranslator -o - | FileCheck %s -; RUN: llc -mtriple=aarch64 -global-isel --global-isel-abort=0 -o /dev/null +; RUN: llc -aarch64-sve-postvec=false -global-isel -mtriple=aarch64 %s -stop-after=irtranslator -o - | FileCheck %s +; RUN: llc -aarch64-sve-postvec=false -mtriple=aarch64 -global-isel --global-isel-abort=0 -o /dev/null ; CHECK-LABEL: name: debug_declare ; CHECK: stack: Index: test/CodeGen/AArch64/GlobalISel/fallback-nofastisel.ll =================================================================== --- test/CodeGen/AArch64/GlobalISel/fallback-nofastisel.ll +++ test/CodeGen/AArch64/GlobalISel/fallback-nofastisel.ll @@ -1,4 +1,4 @@ -; RUN: llc -mtriple=aarch64_be-- %s -o /dev/null -debug-only=isel -O0 2>&1 | FileCheck %s +; RUN: llc -mtriple=aarch64_be-- %s -o /dev/null -debug-only=isel -aarch64-enable-global-isel-at-O=0 -O0 2>&1 | FileCheck %s ; REQUIRES: asserts ; This test uses big endian in order to force an abort since it's not currently supported for GISel. Index: test/CodeGen/AArch64/GlobalISel/gisel-commandline-option.ll =================================================================== --- test/CodeGen/AArch64/GlobalISel/gisel-commandline-option.ll +++ test/CodeGen/AArch64/GlobalISel/gisel-commandline-option.ll @@ -1,3 +1,6 @@ +; Fail for now since with O0, we currently disable GlobalISel for SVE purposes. +; XFAIL: * + ; RUN: llc -mtriple=aarch64-- -debug-pass=Structure %s -o /dev/null 2>&1 \ ; RUN: -O0 | FileCheck %s --check-prefix ENABLED --check-prefix ENABLED-O0 --check-prefix FALLBACK Index: test/CodeGen/AArch64/O0-pipeline.ll =================================================================== --- test/CodeGen/AArch64/O0-pipeline.ll +++ test/CodeGen/AArch64/O0-pipeline.ll @@ -16,6 +16,7 @@ ; CHECK-NEXT: Pre-ISel Intrinsic Lowering ; CHECK-NEXT: FunctionPass Manager ; CHECK-NEXT: Expand Atomic instructions +; CHECK-NEXT: SVE Vector Lib Call Expansion ; CHECK-NEXT: Dominator Tree Construction ; CHECK-NEXT: Basic Alias Analysis (stateless AA impl) ; CHECK-NEXT: Module Verifier @@ -32,12 +33,6 @@ ; CHECK-NEXT: Safe Stack instrumentation pass ; CHECK-NEXT: Insert stack protectors ; CHECK-NEXT: Module Verifier -; CHECK-NEXT: IRTranslator -; CHECK-NEXT: Legalizer -; CHECK-NEXT: RegBankSelect -; CHECK-NEXT: Localizer -; CHECK-NEXT: InstructionSelect -; CHECK-NEXT: ResetMachineFunction ; CHECK-NEXT: AArch64 Instruction Selection ; CHECK-NEXT: Expand ISel Pseudo-instructions ; CHECK-NEXT: Local Stack Slot Allocation @@ -51,6 +46,7 @@ ; CHECK-NEXT: AArch64 pseudo instruction expansion pass ; CHECK-NEXT: Analyze Machine Code For Garbage Collection ; CHECK-NEXT: Branch relaxation pass +; CHECK-NEXT: Unpack machine instruction bundles ; CHECK-NEXT: Contiguously Lay Out Funclets ; CHECK-NEXT: StackMap Liveness Analysis ; CHECK-NEXT: Live DEBUG_VALUE analysis Index: test/CodeGen/AArch64/O3-pipeline.ll =================================================================== --- test/CodeGen/AArch64/O3-pipeline.ll +++ test/CodeGen/AArch64/O3-pipeline.ll @@ -17,6 +17,20 @@ ; CHECK-NEXT: Pre-ISel Intrinsic Lowering ; CHECK-NEXT: FunctionPass Manager ; CHECK-NEXT: Expand Atomic instructions +; CHECK-NEXT: Dominator Tree Construction +; CHECK-NEXT: Natural Loop Information +; CHECK-NEXT: Canonicalize natural loops +; CHECK-NEXT: SVE Post Vectorisation +; CHECK-NEXT: Dominator Tree Construction +; CHECK-NEXT: SVE intrinsics optimizations +; CHECK-NEXT: Basic Alias Analysis (stateless AA impl) +; CHECK-NEXT: Function Alias Analysis Results +; CHECK-NEXT: Natural Loop Information +; CHECK-NEXT: Lazy Branch Probability Analysis +; CHECK-NEXT: Lazy Block Frequency Analysis +; CHECK-NEXT: Optimization Remark Emitter +; CHECK-NEXT: Combine redundant instructions +; CHECK-NEXT: SVE Vector Lib Call Expansion ; CHECK-NEXT: Simplify the CFG ; CHECK-NEXT: Dominator Tree Construction ; CHECK-NEXT: Natural Loop Information @@ -46,9 +60,21 @@ ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics +; CHECK-NEXT: Contiguous Load Store Pass ; CHECK-NEXT: Dominator Tree Construction ; CHECK-NEXT: Interleaved Access Pass +; CHECK-NEXT: Early CSE +; CHECK-NEXT: Basic Alias Analysis (stateless AA impl) +; CHECK-NEXT: Function Alias Analysis Results ; CHECK-NEXT: Natural Loop Information +; CHECK-NEXT: Scalar Evolution Analysis +; CHECK-NEXT: Interleaved Gather Scatter Store Sink Pass +; CHECK-NEXT: Scalar Evolution Analysis +; CHECK-NEXT: Interleaved Gather Scatter Pass +; CHECK-NEXT: Lazy Branch Probability Analysis +; CHECK-NEXT: Lazy Block Frequency Analysis +; CHECK-NEXT: Optimization Remark Emitter +; CHECK-NEXT: Combine redundant instructions ; CHECK-NEXT: CodeGen Prepare ; CHECK-NEXT: Rewrite Symbols ; CHECK-NEXT: FunctionPass Manager @@ -58,6 +84,11 @@ ; CHECK-NEXT: Unnamed pass: implement Pass::getPassName() ; CHECK-NEXT: FunctionPass Manager ; CHECK-NEXT: Merge internal globals +; CHECK-NEXT: Dominator Tree Construction +; CHECK-NEXT: Natural Loop Information +; CHECK-NEXT: Canonicalize natural loops +; CHECK-NEXT: SVE Addressing Modes +; CHECK-NEXT: Dead Code Elimination ; CHECK-NEXT: Safe Stack instrumentation pass ; CHECK-NEXT: Insert stack protectors ; CHECK-NEXT: Module Verifier @@ -106,6 +137,7 @@ ; CHECK-NEXT: Slot index numbering ; CHECK-NEXT: Live Interval Analysis ; CHECK-NEXT: Simple Register Coalescing +; CHECK-NEXT: SOME PASS NAME ; CHECK-NEXT: Rename Disconnected Subregister Components ; CHECK-NEXT: Machine Instruction Scheduler ; CHECK-NEXT: Machine Block Frequency Analysis @@ -148,6 +180,7 @@ ; CHECK-NEXT: MachinePostDominator Tree Construction ; CHECK-NEXT: Branch Probability Basic Block Placement ; CHECK-NEXT: Branch relaxation pass +; CHECK-NEXT: Unpack machine instruction bundles ; CHECK-NEXT: Contiguously Lay Out Funclets ; CHECK-NEXT: StackMap Liveness Analysis ; CHECK-NEXT: Live DEBUG_VALUE analysis Index: test/CodeGen/AArch64/aarch64-combine-fmul-fsub.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/aarch64-combine-fmul-fsub.ll @@ -0,0 +1,37 @@ +; RUN: llc < %s -mtriple=aarch64-linux-gnu -O3 -fp-contract=fast | FileCheck %s + +; Check that (fsub (fmul a b) c) -> (fmla a b (fneg c)) works as expected with <2 x float> types. +define <2 x float> @f1_2s(<2 x float>, <2 x float>, <2 x float>) local_unnamed_addr { +; CHECK-LABEL: %entry +; CHECK: fneg [[VR:v[0-9].2s]], v2.2s +; CHECK: fmla [[VR]], v0.2s, v1.2s +; CHECK: ret + entry: + %3 = fmul fast <2 x float> %0, %1 + %4 = fsub fast <2 x float> %3, %2 + ret <2 x float> %4 +} + +; Check that (fsub (fmul a b) c) -> (fmla a b (fneg c)) works as expected with <4 x float> types. +define <4 x float> @f1_4s(<4 x float>, <4 x float>, <4 x float>) local_unnamed_addr { +; CHECK-LABEL: %entry +; CHECK: fneg [[VR:v[0-9].4s]], v2.4s +; CHECK: fmla [[VR]], v0.4s, v1.4s +; CHECK: ret + entry: + %3 = fmul fast <4 x float> %0, %1 + %4 = fsub fast <4 x float> %3, %2 + ret <4 x float> %4 +} + +; Check that (fsub (fmul a b) c) -> (fmla a b (fneg c)) works as expected with <2 x double> types. +define <2 x double> @f1_2d(<2 x double>, <2 x double>, <2 x double>) local_unnamed_addr { +; CHECK-LABEL: %entry +; CHECK: fneg [[VR:v[0-9].2d]], v2.2d +; CHECK: fmla [[VR]], v0.2d, v1.2d +; CHECK: ret + entry: + %3 = fmul fast <2 x double> %0, %1 + %4 = fsub fast <2 x double> %3, %2 + ret <2 x double> %4 +} Index: test/CodeGen/AArch64/aarch64-combine-fmul-fsub.mir =================================================================== --- test/CodeGen/AArch64/aarch64-combine-fmul-fsub.mir +++ test/CodeGen/AArch64/aarch64-combine-fmul-fsub.mir @@ -1,3 +1,9 @@ +# Expected to fail for now, see SC-2684. +# Dave's patch (07aa684562fe857a599cc999c92d56a3ca06604c) for better +# FMA selection breaks this test, as it seems a little fragile for cases +# where all operands are more or less similarly close togtether. +# XFAIL: * + # RUN: llc -run-pass=machine-combiner -o - -mtriple=aarch64-unknown-linux -mcpu=cortex-a57 -enable-unsafe-fp-math -machine-combiner-verify-pattern-order=true %s | FileCheck --check-prefixes=UNPROFITABLE,ALL %s # RUN: llc -run-pass=machine-combiner -o - -mtriple=aarch64-unknown-linux -mcpu=falkor -enable-unsafe-fp-math %s -machine-combiner-verify-pattern-order=true | FileCheck --check-prefixes=PROFITABLE,ALL %s # RUN: llc -run-pass=machine-combiner -o - -mtriple=aarch64-unknown-linux -mcpu=exynos-m1 -enable-unsafe-fp-math -machine-combiner-verify-pattern-order=true %s | FileCheck --check-prefixes=PROFITABLE,ALL %s Index: test/CodeGen/AArch64/aarch64-combine-many-use.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/aarch64-combine-many-use.ll @@ -0,0 +1,154 @@ +; RUN: llc < %s -mtriple=aarch64-linux-gnu -O3 -fp-contract=fast | FileCheck %s --check-prefix=CHECK-ONE +; RUN: llc < %s -mtriple=aarch64-linux-gnu -O3 -fp-contract=fast -mattr=+aggressive-fma | FileCheck %s --check-prefix=CHECK-MANY + +; Sanity check that we're getting FMLA instructions to begin with. +define <4 x float> @f1(<4 x float>, <4 x float>, <4 x float>) local_unnamed_addr { +; CHECK-ONE-LABEL: %entry +; CHECK-ONE: fmla +; CHECK-ONE: ret +; CHECK-MANY-LABEL: %entry +; CHECK-MANY: fmla +; CHECK-MANY: ret + entry: + %3 = fmul fast <4 x float> %0, %1 + %4 = fadd fast <4 x float> %2, %3 + ret <4 x float> %4 +} + +; Check that instructions with many consumers won't/will generate FMLA instructions appropriately. +define <4 x float> @f2(<4 x float>, <4 x float>, <4 x float>, <4 x float>) local_unnamed_addr { +; CHECK-ONE-LABEL: %entry +; CHECK-ONE-NOT: fmla +; CHECK-ONE: ret +; CHECK-MANY-LABEL: %entry +; CHECK-MANY: fmla +; CHECK-MANY: fmla +; CHECK-MANY: fadd +; CHECK-MANY: ret + entry: + %4 = fmul fast <4 x float> %0, %1 + %5 = fadd fast <4 x float> %2, %4 + %6 = fadd fast <4 x float> %3, %4 + %7 = fadd fast <4 x float> %5, %6 + ret <4 x float> %7 +} + +; Check that FMLAs are selected based on the nearest fmul operands. +define <4 x float> @f3a(<4 x float>, <4 x float>, <4 x float>, <4 x float>) local_unnamed_addr { +; CHECK-ONE-LABEL: %entry +; CHECK-ONE: ret +; CHECK-MANY-LABEL: %entry +; CHECK-MANY: fmul v2.4s, v2.4s, v3.4s +; CHECK-MANY: fmla v2.4s, v0.4s, v1.4s +; CHECK-MANY: ret + entry: + %4 = fmul fast <4 x float> %0, %1 + %5 = fmul fast <4 x float> %2, %3 + %6 = fadd fast <4 x float> %4, %5 + ret <4 x float> %6 +} + +; Check that FMLAs are selected based on the nearest fmul operands (swapped). +define <4 x float> @f3b(<4 x float>, <4 x float>, <4 x float>, <4 x float>) local_unnamed_addr { +; CHECK-ONE-LABEL: %entry +; CHECK-ONE: ret +; CHECK-MANY-LABEL: %entry +; CHECK-MANY: fmul v2.4s, v2.4s, v3.4s +; CHECK-MANY: fmla v2.4s, v0.4s, v1.4s +; CHECK-MANY: ret + entry: + %4 = fmul fast <4 x float> %0, %1 + %5 = fmul fast <4 x float> %2, %3 + %6 = fadd fast <4 x float> %5, %4 + ret <4 x float> %6 +} + +; In this test the first fmul (%20) is furthest away from the fadd (%22), +; however its operands are defined closest to the fadd. Therefore, operands +; %19 and %9 are folded into the fmla and the second fmul (%21) becomes the +; accumulator. +define void @f4a(double* noalias nocapture, double, double, double* noalias nocapture readonly, double* noalias nocapture readonly, i32) local_unnamed_addr #0 { +; CHECK-ONE-LABEL: %entry +; CHECK-ONE: ret +; CHECK-MANY-LABEL: %entry +; CHECK-MANY: dup v[[INV1:[0-9]+]].2d, v0.d[0] +; CHECK-MANY: dup v[[INV2:[0-9]+]].2d, v1.d[0] +; CHECK-MANY: ldr q[[LD1:[0-9]+]], [{{x[0-9]+}}], #16 +; CHECK-MANY: ldr q[[LD2:[0-9]+]], [{{x[0-9]+}}], #16 +; CHECK-MANY: fmul v2.2d, v[[INV2]].2d, v[[LD1]].2d +; CHECK-MANY: fmla v2.2d, v[[LD2]].2d, v[[INV1]].2d +; CHECK-MANY: ret +entry: + %6 = zext i32 %5 to i64 + %7 = fmul fast double %2, 3.000000e-01 + %8 = insertelement <2 x double> undef, double %1, i32 0 + %9 = shufflevector <2 x double> %8, <2 x double> undef, <2 x i32> zeroinitializer + %10 = insertelement <2 x double> undef, double %7, i32 0 + %11 = shufflevector <2 x double> %10, <2 x double> undef, <2 x i32> zeroinitializer + br label %12 + +;