diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -21,6 +21,7 @@ #ifndef LLVM_ANALYSIS_TARGETTRANSFORMINFO_H #define LLVM_ANALYSIS_TARGETTRANSFORMINFO_H +#include "llvm/ADT/SmallBitVector.h" #include "llvm/IR/FMF.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/PassManager.h" @@ -678,6 +679,16 @@ /// Return true if the target supports masked expand load. bool isLegalMaskedExpandLoad(Type *DataType) const; + /// Return true if this is an alternating opcode pattern that can be lowered + /// to a single instruction on the target. In X86 this is for the addsub + /// instruction which corrsponds to a Shuffle + Fadd + FSub pattern in IR. + /// This function expectes two opcodes: \p Opcode1 and \p Opcode2 being + /// selected by \p OpcodeMask. The mask contains one bit per lane and is a `0` + /// when \p Opcode0 is selected and `1` when Opcode1 is selected. + /// \p VecTy is the vector type of the instruction to be generated. + bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, + const SmallBitVector &OpcodeMask) const; + /// Return true if we should be enabling ordered reductions for the target. bool enableOrderedReductions() const; @@ -1581,6 +1592,9 @@ Align Alignment) = 0; virtual bool isLegalMaskedCompressStore(Type *DataType) = 0; virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0; + virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, + unsigned Opcode1, + const SmallBitVector &OpcodeMask) const = 0; virtual bool enableOrderedReductions() = 0; virtual bool hasDivRemOp(Type *DataType, bool IsSigned) = 0; virtual bool hasVolatileVariant(Instruction *I, unsigned AddrSpace) = 0; @@ -2006,6 +2020,10 @@ bool isLegalMaskedExpandLoad(Type *DataType) override { return Impl.isLegalMaskedExpandLoad(DataType); } + bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, + const SmallBitVector &OpcodeMask) const override { + return Impl.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask); + } bool enableOrderedReductions() override { return Impl.enableOrderedReductions(); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -279,6 +279,11 @@ bool isLegalMaskedCompressStore(Type *DataType) const { return false; } + bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, + const SmallBitVector &OpcodeMask) const { + return false; + } + bool isLegalMaskedExpandLoad(Type *DataType) const { return false; } bool enableOrderedReductions() const { return false; } diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -412,6 +412,12 @@ return TTIImpl->isLegalMaskedGather(DataType, Alignment); } +bool TargetTransformInfo::isLegalAltInstr( + VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, + const SmallBitVector &OpcodeMask) const { + return TTIImpl->isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask); +} + bool TargetTransformInfo::isLegalMaskedScatter(Type *DataType, Align Alignment) const { return TTIImpl->isLegalMaskedScatter(DataType, Alignment); diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h --- a/llvm/lib/Target/X86/X86TargetTransformInfo.h +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h @@ -241,6 +241,8 @@ bool isLegalMaskedScatter(Type *DataType, Align Alignment); bool isLegalMaskedExpandLoad(Type *DataType); bool isLegalMaskedCompressStore(Type *DataType); + bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, + const SmallBitVector &OpcodeMask) const; bool hasDivRemOp(Type *DataType, bool IsSigned); bool isFCmpOrdCheaperThanFCmpZero(Type *Ty); bool areInlineCompatible(const Function *Caller, diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -5341,6 +5341,39 @@ return IntWidth == 32 || IntWidth == 64; } +bool X86TTIImpl::isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, + unsigned Opcode1, + const SmallBitVector &OpcodeMask) const { + // ADDSUBPS 4xf32 SSE3 + // VADDSUBPS 4xf32 AVX + // VADDSUBPS 8xf32 AVX2 + // ADDSUBPD 2xf64 SSE3 + // VADDSUBPD 2xf64 AVX + // VADDSUBPD 4xf64 AVX2 + + unsigned NumElements = cast(VecTy)->getNumElements(); + assert(OpcodeMask.size() == NumElements && "Mask and VecTy are incompatible"); + if (!isPowerOf2_32(NumElements)) + return false; + // Check the opcode pattern. We apply the mask on the opcode arguments and + // then check if it is what we expect. + for (int Lane : seq(0, NumElements)) { + unsigned Opc = OpcodeMask.test(Lane) ? Opcode1 : Opcode0; + // We expect FSub for even lanes and FAdd for odd lanes. + if (Lane % 2 == 0 && Opc != Instruction::FSub) + return false; + if (Lane % 2 == 1 && Opc != Instruction::FAdd) + return false; + } + // Now check that the pattern is supported by the target ISA. + Type *ElemTy = cast(VecTy)->getElementType(); + if (ElemTy->isFloatTy()) + return ST->hasSSE3() && NumElements % 4 == 0; + if (ElemTy->isDoubleTy()) + return ST->hasSSE3() && NumElements % 2 == 0; + return false; +} + bool X86TTIImpl::isLegalMaskedScatter(Type *DataType, Align Alignment) { // AVX2 doesn't support scatter if (!ST->hasAVX512()) diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -3710,6 +3710,10 @@ // their ordering. DenseMap GathersToOrders; + // AltShuffles can also have a preferred ordering that leads to fewer + // instructions, e.g., the addsub instruction in x86. + DenseMap AltShufflesToOrders; + // Maps a TreeEntry to the reorder indices of external users. DenseMap> ExternalUserReorderMap; @@ -3717,7 +3721,7 @@ // Currently the are vectorized stores,loads,extracts + some gathering of // extracts. for_each(VectorizableTree, [this, &VFToOrderedEntries, &GathersToOrders, - &ExternalUserReorderMap]( + &ExternalUserReorderMap, &AltShufflesToOrders]( const std::unique_ptr &TE) { // Look for external users that will probably be vectorized. SmallVector ExternalUserReorderIndices = @@ -3728,6 +3732,27 @@ std::move(ExternalUserReorderIndices)); } + // Patterns like [fadd,fsub] can be combined into a single instruction in + // x86. Reordering them into [fsub,fadd] blocks this pattern. So we need + // to take into account their order when looking for the most used order. + if (TE->isAltShuffle()) { + VectorType *VecTy = + FixedVectorType::get(TE->Scalars[0]->getType(), TE->Scalars.size()); + unsigned Opcode0 = TE->getOpcode(); + unsigned Opcode1 = TE->getAltOpcode(); + // The opcode mask selects between the two opcodes. + SmallBitVector OpcodeMask(TE->Scalars.size(), 0); + for (unsigned Lane : seq(0, TE->Scalars.size())) + if (cast(TE->Scalars[Lane])->getOpcode() == Opcode1) + OpcodeMask.set(Lane); + // If this pattern is supported by the target then we consider the order. + if (TTI->isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) { + VFToOrderedEntries[TE->Scalars.size()].insert(TE.get()); + AltShufflesToOrders.try_emplace(TE.get(), OrdersType()); + } + // TODO: Check the reverse order too. + } + if (Optional CurrentOrder = getReorderingData(*TE, /*TopToBottom=*/true)) { // Do not include ordering for nodes used in the alt opcode vectorization, @@ -3778,12 +3803,18 @@ if (!OpTE->ReuseShuffleIndices.empty()) continue; // Count number of orders uses. - const auto &Order = [OpTE, &GathersToOrders]() -> const OrdersType & { + const auto &Order = [OpTE, &GathersToOrders, + &AltShufflesToOrders]() -> const OrdersType & { if (OpTE->State == TreeEntry::NeedToGather) { auto It = GathersToOrders.find(OpTE); if (It != GathersToOrders.end()) return It->second; } + if (OpTE->isAltShuffle()) { + auto It = AltShufflesToOrders.find(OpTE); + if (It != AltShufflesToOrders.end()) + return It->second; + } return OpTE->ReorderIndices; }(); // First consider the order of the external scalar users. diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder_with_external_users.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder_with_external_users.ll --- a/llvm/test/Transforms/SLPVectorizer/X86/reorder_with_external_users.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder_with_external_users.ll @@ -118,20 +118,21 @@ ; CHECK-NEXT: [[LD:%.*]] = load double, double* undef, align 8 ; CHECK-NEXT: [[TMP0:%.*]] = insertelement <2 x double> poison, double [[LD]], i32 0 ; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x double> [[TMP0]], double [[LD]], i32 1 -; CHECK-NEXT: [[TMP2:%.*]] = fsub <2 x double> [[TMP1]], -; CHECK-NEXT: [[TMP3:%.*]] = fadd <2 x double> [[TMP1]], -; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <2 x double> [[TMP2]], <2 x double> [[TMP3]], <2 x i32> -; CHECK-NEXT: [[TMP5:%.*]] = fdiv <2 x double> [[TMP4]], -; CHECK-NEXT: [[TMP6:%.*]] = fmul <2 x double> [[TMP5]], +; CHECK-NEXT: [[TMP2:%.*]] = fsub <2 x double> [[TMP1]], +; CHECK-NEXT: [[TMP3:%.*]] = fadd <2 x double> [[TMP1]], +; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <2 x double> [[TMP2]], <2 x double> [[TMP3]], <2 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = fdiv <2 x double> [[TMP4]], +; CHECK-NEXT: [[TMP6:%.*]] = fmul <2 x double> [[TMP5]], ; CHECK-NEXT: [[PTRA0:%.*]] = getelementptr inbounds double, double* [[A:%.*]], i64 0 +; CHECK-NEXT: [[SHUFFLE:%.*]] = shufflevector <2 x double> [[TMP6]], <2 x double> poison, <2 x i32> ; CHECK-NEXT: [[TMP7:%.*]] = bitcast double* [[PTRA0]] to <2 x double>* -; CHECK-NEXT: store <2 x double> [[TMP6]], <2 x double>* [[TMP7]], align 8 +; CHECK-NEXT: store <2 x double> [[SHUFFLE]], <2 x double>* [[TMP7]], align 8 ; CHECK-NEXT: br label [[BB2:%.*]] ; CHECK: bb2: -; CHECK-NEXT: [[TMP8:%.*]] = fadd <2 x double> [[TMP6]], +; CHECK-NEXT: [[TMP8:%.*]] = fadd <2 x double> [[TMP6]], ; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x double> [[TMP8]], i32 0 ; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[TMP8]], i32 1 -; CHECK-NEXT: [[SEED:%.*]] = fcmp ogt double [[TMP10]], [[TMP9]] +; CHECK-NEXT: [[SEED:%.*]] = fcmp ogt double [[TMP9]], [[TMP10]] ; CHECK-NEXT: ret void ; bb1: