Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -455,6 +455,11 @@ /// zext, etc. int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) const; + /// \return The expected cost of a sign- or zero-extended vector extract. Use + /// -1 to indicate that there is no information about the index value. + int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, + unsigned Index = -1) const; + /// \return The expected cost of control-flow related instructions such as /// Phi, Ret, Br. int getCFInstrCost(unsigned Opcode) const; @@ -639,6 +644,8 @@ virtual int getShuffleCost(ShuffleKind Kind, Type *Tp, int Index, Type *SubTp) = 0; virtual int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) = 0; + virtual int getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, unsigned Index) = 0; virtual int getCFInstrCost(unsigned Opcode) = 0; virtual int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy) = 0; @@ -824,6 +831,10 @@ int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) override { return Impl.getCastInstrCost(Opcode, Dst, Src); } + int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, + unsigned Index) override { + return Impl.getExtractWithExtendCost(Opcode, Dst, VecTy, Index); + } int getCFInstrCost(unsigned Opcode) override { return Impl.getCFInstrCost(Opcode); } Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -289,6 +289,11 @@ unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) { return 1; } + unsigned getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, unsigned Index) { + return 1; + } + unsigned getCFInstrCost(unsigned Opcode) { return 1; } unsigned getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy) { Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -433,6 +433,14 @@ llvm_unreachable("Unhandled cast"); } + unsigned getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, unsigned Index) { + return static_cast(this)->getVectorInstrCost( + Instruction::ExtractElement, VecTy, Index) + + static_cast(this)->getCastInstrCost(Opcode, Dst, + VecTy->getElementType()); + } + unsigned getCFInstrCost(unsigned Opcode) { // Branches are assumed to be predicted. return 0; Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -259,6 +259,14 @@ return Cost; } +int TargetTransformInfo::getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, + unsigned Index) const { + int Cost = TTIImpl->getExtractWithExtendCost(Opcode, Dst, VecTy, Index); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + int TargetTransformInfo::getCFInstrCost(unsigned Opcode) const { int Cost = TTIImpl->getCFInstrCost(Opcode); assert(Cost >= 0 && "TTI should not produce negative costs!"); Index: lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- lib/Target/AArch64/AArch64TargetTransformInfo.h +++ lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -99,6 +99,9 @@ int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src); + int getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, + unsigned Index); + int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); int getArithmeticInstrCost( Index: lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -309,6 +309,66 @@ return BaseT::getCastInstrCost(Opcode, Dst, Src); } +int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, + unsigned Index) { + + // Make sure we were given a valid extend opcode. + assert(Opcode == Instruction::SExt || + Opcode == Instruction::ZExt && "Invalid opcode"); + + // We are extending an element we extract from a vector, so the source type + // of the extend is the element type of the vector. + auto *Src = VecTy->getElementType(); + + // Sign- and zero-extends are for integer types only. + assert(isa(Dst) && isa(Src) && "Invalid type"); + + // Get the cost for the extract. We compute the cost (if any) for the extend + // below. + auto Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy, Index); + + // Legalize the types. + auto VecLT = TLI->getTypeLegalizationCost(DL, VecTy); + auto DstVT = TLI->getValueType(DL, Dst); + auto SrcVT = TLI->getValueType(DL, Src); + + // If the resulting type is still a vector and the destination type is legal, + // we may get the extension for free. If not, get the default cost for the + // extend. + if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT)) { + Cost += getCastInstrCost(Opcode, Dst, Src); + return Cost; + } + + // The destination type should be larger than the element type. If not, get + // the default cost for the extend. + if (DstVT.getSizeInBits() < SrcVT.getSizeInBits()) { + Cost += getCastInstrCost(Opcode, Dst, Src); + return Cost; + } + + switch (Opcode) { + default: + llvm_unreachable("Opcode should be either SExt or ZExt"); + + // For sign-extends, we only need a smov, which performs the extension + // automatically. + case Instruction::SExt: + return 0; + + // For zero-extends, the extend is performed automatically by a umov unless + // the destination type is i64 and the element type is i8 or i16. + case Instruction::ZExt: + if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u) + return 0; + } + + // If we are unable to perform the extend for free, get the default cost. + Cost += getCastInstrCost(Opcode, Dst, Src); + return Cost; +} + int AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index) { assert(Val->isVectorTy() && "This must be a vector type"); Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1827,11 +1827,12 @@ if (MinBWs.count(ScalarRoot)) { auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot]); VecTy = VectorType::get(MinTy, BundleWidth); + ExtractCost += TTI->getExtractWithExtendCost( + Instruction::SExt, EU.Scalar->getType(), VecTy, EU.Lane); + } else { ExtractCost += - TTI->getCastInstrCost(Instruction::SExt, EU.Scalar->getType(), MinTy); + TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane); } - ExtractCost += - TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane); } int SpillCost = getSpillCost(); Index: test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll =================================================================== --- test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll +++ test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll @@ -1,5 +1,5 @@ -; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s --check-prefix=PROFITABLE -; RUN: opt -S -slp-vectorizer -slp-threshold=-12 -dce -instcombine < %s | FileCheck %s --check-prefix=UNPROFITABLE +; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s --check-prefix=GENERIC +; RUN: opt -S -mcpu=kryo -slp-vectorizer -dce -instcombine < %s | FileCheck %s --check-prefix=KRYO target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" target triple = "aarch64--linux-gnu" @@ -19,13 +19,13 @@ ; return sum; ; } -; PROFITABLE-LABEL: @gather_reduce_8x16_i32 +; GENERIC-LABEL: @gather_reduce_8x16_i32 ; -; PROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> -; PROFITABLE: zext <8 x i16> [[L]] to <8 x i32> -; PROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> -; PROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] -; PROFITABLE: sext i32 [[X]] to i64 +; GENERIC: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> +; GENERIC: zext <8 x i16> [[L]] to <8 x i32> +; GENERIC: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> +; GENERIC: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] +; GENERIC: sext i32 [[X]] to i64 ; define i32 @gather_reduce_8x16_i32(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) { entry: @@ -138,18 +138,13 @@ br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body } -; UNPROFITABLE-LABEL: @gather_reduce_8x16_i64 +; KRYO-LABEL: @gather_reduce_8x16_i64 ; -; UNPROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> -; UNPROFITABLE: zext <8 x i16> [[L]] to <8 x i32> -; UNPROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> -; UNPROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] -; UNPROFITABLE: sext i32 [[X]] to i64 -; -; TODO: Although we can now vectorize this case while converting the i64 -; subtractions to i32, the cost model currently finds vectorization to be -; unprofitable. The cost model is penalizing the sign and zero -; extensions in the vectorized version, but they are actually free. +; KRYO: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> +; KRYO: zext <8 x i16> [[L]] to <8 x i32> +; KRYO: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> +; KRYO: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] +; KRYO: sext i32 [[X]] to i64 ; define i32 @gather_reduce_8x16_i64(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) { entry: