Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -434,9 +434,36 @@ int getShuffleCost(ShuffleKind Kind, Type *Tp, int Index = 0, Type *SubTp = nullptr) const; + /// A struct to hold information for contextual cost estimates. + struct InstrContext { + InstrContext() : InstrTy(nullptr), Opcode(0), User(nullptr) { + Operands = makeArrayRef((InstrContext *)nullptr, 0); + } + InstrContext(Type *Ty, unsigned Opc = 0, InstrContext *Usr = nullptr) + : InstrTy(Ty), Opcode(Opc), User(Usr) { + Operands = makeArrayRef((InstrContext *)nullptr, 0); + } + InstrContext(Type *Ty, unsigned Opc, ArrayRef Ops, + InstrContext *Usr = nullptr) + : InstrTy(Ty), Opcode(Opc), Operands(Ops), User(Usr) {} + + /// The type of the instruction. + Type *InstrTy; + + /// The opcode of the instruction. + unsigned Opcode; + + /// A list of InstrContexts describing the operands of the instruction. + ArrayRef Operands; + + /// An InstrContext describing a single user of the instruction. + InstrContext *User; + }; + /// \return The expected cost of cast instructions, such as bitcast, trunc, /// zext, etc. - int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) const; + int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + InstrContext *Ctx = nullptr) const; /// \return The expected cost of control-flow related instructions such as /// Phi, Ret, Br. @@ -617,7 +644,8 @@ OperandValueProperties Opd2PropInfo) = 0; virtual int getShuffleCost(ShuffleKind Kind, Type *Tp, int Index, Type *SubTp) = 0; - virtual int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) = 0; + virtual int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + InstrContext *Ctx) = 0; virtual int getCFInstrCost(unsigned Opcode) = 0; virtual int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy) = 0; @@ -790,8 +818,9 @@ Type *SubTp) override { return Impl.getShuffleCost(Kind, Tp, Index, SubTp); } - int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) override { - return Impl.getCastInstrCost(Opcode, Dst, Src); + int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + InstrContext *Ctx) override { + return Impl.getCastInstrCost(Opcode, Dst, Src, Ctx); } int getCFInstrCost(unsigned Opcode) override { return Impl.getCFInstrCost(Opcode); Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -279,7 +279,10 @@ return 1; } - unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) { return 1; } + unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::InstrContext *Ctx) { + return 1; + } unsigned getCFInstrCost(unsigned Opcode) { return 1; } Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -335,7 +335,8 @@ return 1; } - unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) { + unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::InstrContext *Ctx) { const TargetLoweringBase *TLI = getTLI(); int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); @@ -407,7 +408,7 @@ // scalarization costs. unsigned Num = Dst->getVectorNumElements(); unsigned Cost = static_cast(this)->getCastInstrCost( - Opcode, Dst->getScalarType(), Src->getScalarType()); + Opcode, Dst->getScalarType(), Src->getScalarType(), Ctx); // Return the cost of multiple scalar invocation plus the cost of // inserting and extracting the values. Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -236,9 +236,9 @@ return Cost; } -int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst, - Type *Src) const { - int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src); +int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + InstrContext *Ctx) const { + int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, Ctx); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } Index: lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- lib/Target/AArch64/AArch64TargetTransformInfo.h +++ lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -97,7 +97,11 @@ unsigned getMaxInterleaveFactor(unsigned VF); - int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src); + int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::InstrContext *Ctx); + + int getCastInstrCostInCtx(unsigned Opcode, Type *Dst, Type *Src, + TTI::InstrContext *Ctx); int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); Index: lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -176,7 +176,15 @@ return TTI::PSK_Software; } -int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) { +int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::InstrContext *Ctx) { + + // Try and use the contextual cost estimate. Note that the contextual version + // calls the implementation here if nothing useful can be done, so we can + // directly return the cost. + if (Ctx) + return getCastInstrCostInCtx(Opcode, Dst, Src, Ctx); + int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); @@ -184,7 +192,7 @@ EVT DstTy = TLI->getValueType(DL, Dst); if (!SrcTy.isSimple() || !DstTy.isSimple()) - return BaseT::getCastInstrCost(Opcode, Dst, Src); + return BaseT::getCastInstrCost(Opcode, Dst, Src, Ctx); static const TypeConversionCostTblEntry ConversionTbl[] = { @@ -288,7 +296,138 @@ SrcTy.getSimpleVT())) return Entry->Cost; - return BaseT::getCastInstrCost(Opcode, Dst, Src); + return BaseT::getCastInstrCost(Opcode, Dst, Src, Ctx); +} + +int AArch64TTIImpl::getCastInstrCostInCtx(unsigned Opcode, Type *Dst, Type *Src, + TTI::InstrContext *Ctx) { + + assert(!Ctx->Operands.empty() && "Invalid context"); + + // Contextual cost estimate. We can base the estimate on information about + // the value being cast. + switch (Ctx->Operands[0].Opcode) { + default: + break; + + // If the cast value has been extracted from a vector, we can usually perform + // sign and zero extensions automatically with smov and umov. We check for + // the free cases below, and fall back to the context-free estimate + // otherwise. + case Instruction::ExtractElement: + + // We only contextualize sext and zext casts. + if (Opcode != Instruction::SExt && Opcode != Instruction::ZExt) + break; + + // See if we've been given enough context to do anything useful. We need to + // know the first operand of the extractelement. + auto &Operands = Ctx->Operands[0].Operands; + if (Operands.empty()) + break; + + // Make sure the first operand of the extractelement is a vector type. This + // must be the case if we have a valid context. + assert(isa(Operands[0].InstrTy) && "Invalid context"); + auto *Vec = cast(Operands[0].InstrTy); + + // Because the cast is either zext or sext, we have to be working with all + // integer types. Again, this must be the case if we have a valid context. + assert(isa(Dst) && Src == Vec->getElementType() && + "Invalid context"); + + // Now we legalize the types. If the resulting type is still a vector and + // the destination type is legal, we may get the extension for free. + auto VecLT = TLI->getTypeLegalizationCost(DL, Vec); + auto DstVT = TLI->getValueType(DL, Dst); + if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT)) + break; + + // Ensure we have a valid extract-extend combination. The source and + // element types should be the same, and the destination type should be + // larger than the element type. + auto SrcVT = TLI->getValueType(DL, Src); + auto ElmVT = VecLT.second.getVectorElementType(); + if (SrcVT != ElmVT || DstVT.getSizeInBits() < ElmVT.getSizeInBits()) + break; + + switch (Opcode) { + default: + llvm_unreachable("Ctx->Opcode should be either SExt or ZExt"); + + // For sign extensions, we only need a smov, which performs the extension + // automatically. + case Instruction::SExt: + return 0; + + // For zero extensions, the extend is performed automatically by a umov + // unless the destination type is i64 and the element type is i8 or i16. + case Instruction::ZExt: + if (DstVT.getSizeInBits() != 64u || ElmVT.getSizeInBits() == 32u) + return 0; + } + } + + // At this point, we've looked at the operand of the cast if interesting to + // do so. If we don't yet know that the cast will be free, we can try to look + // at the user of the cast. If there is no information about what the user of + // the cast looks like, there is nothing else to do. + if (!Ctx->User) + return getCastInstrCost(Opcode, Dst, Src, nullptr); + + // Otherwise, we can can base the estimate on information about the user. + switch (Ctx->User->Opcode) { + default: + break; + + // AArch64 can perform widening and narrowing vector arithmetic. If the cast + // value is a vector extend that is used by an arithmetic operation, it may + // happen automatically. We check for the free cases below, and fall back to + // the context-free estimate otherwise. + case Instruction::Add: + case Instruction::Sub: + + // We only contextualize sext and zext casts. + if (Ctx->Opcode != Instruction::SExt && Ctx->Opcode != Instruction::ZExt) + break; + + // We have to look at the operands of widening and narrowing instructions. + // If we weren't given this information, there's nothing to do. + if (Ctx->User->Operands.size() != 2) + break; + + // Type check the extend and it's user to ensure we have a valid context. + // Because we have an extend, it's source and destination types have to be + // integers or vectors of integers. + auto &Operands = Ctx->User->Operands; + assert(Src && Src->isIntOrIntVectorTy() && Dst && + Dst->isIntOrIntVectorTy() && Ctx->User->InstrTy == Dst && + Operands[0].InstrTy == Dst && Operands[1].InstrTy == Dst && + "Invalid context"); + + // Check if this operation can be matched to a lengthening or widening + // vector operation. If not, give up. + auto IsWidening = Operands[1].Opcode == Ctx->Opcode; + auto IsLengthening = IsWidening && Operands[0].Opcode == Ctx->Opcode; + if (!IsWidening && !IsLengthening) + break; + + // Now we legalize the types. If the resulting types are not vectors, this + // won't match the lengthening and widening operations. + auto SrcVT = TLI->getTypeLegalizationCost(DL, Src).second; + auto DstVT = TLI->getTypeLegalizationCost(DL, Dst).second; + if (!SrcVT.isVector() || !DstVT.isVector()) + break; + + // If we have legal vectors and the destination element size is twice the + // source element size, the extend can be performed automatically by the + // lengthening or widening instruction. + if (DstVT.getScalarSizeInBits() == 2 * SrcVT.getScalarSizeInBits()) + return 0; + } + + // If we can't do anything useful, fall back to the default implementation. + return getCastInstrCost(Opcode, Dst, Src, nullptr); } int AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, Index: lib/Target/ARM/ARMTargetTransformInfo.h =================================================================== --- lib/Target/ARM/ARMTargetTransformInfo.h +++ lib/Target/ARM/ARMTargetTransformInfo.h @@ -96,7 +96,8 @@ int getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index, Type *SubTp); - int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src); + int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::InstrContext *Ctx); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy); Index: lib/Target/ARM/ARMTargetTransformInfo.cpp =================================================================== --- lib/Target/ARM/ARMTargetTransformInfo.cpp +++ lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -47,7 +47,8 @@ return 3; } -int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) { +int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::InstrContext *Ctx) { int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); @@ -70,7 +71,7 @@ EVT DstTy = TLI->getValueType(DL, Dst); if (!SrcTy.isSimple() || !DstTy.isSimple()) - return BaseT::getCastInstrCost(Opcode, Dst, Src); + return BaseT::getCastInstrCost(Opcode, Dst, Src, Ctx); // Some arithmetic, load and store operations have specific instructions // to cast up/down their types automatically at no extra cost. @@ -237,7 +238,7 @@ return Entry->Cost; } - return BaseT::getCastInstrCost(Opcode, Dst, Src); + return BaseT::getCastInstrCost(Opcode, Dst, Src, Ctx); } int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy, Index: lib/Target/PowerPC/PPCTargetTransformInfo.h =================================================================== --- lib/Target/PowerPC/PPCTargetTransformInfo.h +++ lib/Target/PowerPC/PPCTargetTransformInfo.h @@ -78,7 +78,8 @@ TTI::OperandValueProperties Opd1PropInfo = TTI::OP_None, TTI::OperandValueProperties Opd2PropInfo = TTI::OP_None); int getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index, Type *SubTp); - int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src); + int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::InstrContext *Ctx); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy); int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); int getMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, Index: lib/Target/PowerPC/PPCTargetTransformInfo.cpp =================================================================== --- lib/Target/PowerPC/PPCTargetTransformInfo.cpp +++ lib/Target/PowerPC/PPCTargetTransformInfo.cpp @@ -281,10 +281,11 @@ return LT.first; } -int PPCTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) { +int PPCTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::InstrContext *Ctx) { assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode"); - return BaseT::getCastInstrCost(Opcode, Dst, Src); + return BaseT::getCastInstrCost(Opcode, Dst, Src, Ctx); } int PPCTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy) { Index: lib/Target/X86/X86TargetTransformInfo.h =================================================================== --- lib/Target/X86/X86TargetTransformInfo.h +++ lib/Target/X86/X86TargetTransformInfo.h @@ -69,7 +69,8 @@ TTI::OperandValueProperties Opd1PropInfo = TTI::OP_None, TTI::OperandValueProperties Opd2PropInfo = TTI::OP_None); int getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index, Type *SubTp); - int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src); + int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::InstrContext *Ctx); int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy); int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); int getMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, Index: lib/Target/X86/X86TargetTransformInfo.cpp =================================================================== --- lib/Target/X86/X86TargetTransformInfo.cpp +++ lib/Target/X86/X86TargetTransformInfo.cpp @@ -524,7 +524,8 @@ return BaseT::getShuffleCost(Kind, Tp, Index, SubTp); } -int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) { +int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, + TTI::InstrContext *Ctx) { int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); @@ -805,7 +806,7 @@ // The function getSimpleVT only handles simple value types. if (!SrcTy.isSimple() || !DstTy.isSimple()) - return BaseT::getCastInstrCost(Opcode, Dst, Src); + return BaseT::getCastInstrCost(Opcode, Dst, Src, Ctx); if (ST->hasDQI()) if (const auto *Entry = ConvertCostTableLookup(AVX512DQConversionTbl, ISD, @@ -847,7 +848,7 @@ return Entry->Cost; } - return BaseT::getCastInstrCost(Opcode, Dst, Src); + return BaseT::getCastInstrCost(Opcode, Dst, Src, Ctx); } int X86TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy) { Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -362,6 +362,8 @@ typedef SmallPtrSet ValueSet; typedef SmallVector StoreList; + typedef TargetTransformInfo::InstrContext InstrCtx; + BoUpSLP(Function *Func, ScalarEvolution *Se, TargetTransformInfo *Tti, TargetLibraryInfo *TLi, AliasAnalysis *Aa, LoopInfo *Li, DominatorTree *Dt, AssumptionCache *AC, DemandedBits *DB) @@ -1543,15 +1545,69 @@ case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: { - Type *SrcTy = VL0->getOperand(0)->getType(); - // Calculate the cost of this instruction. - int ScalarCost = VL.size() * TTI->getCastInstrCost(VL0->getOpcode(), - VL0->getType(), SrcTy); + // Simplifying definitions to avoid confusion in the code below. + auto *SrcScalarTy = VL0->getOperand(0)->getType(); + auto *SrcVectorTy = VectorType::get(SrcScalarTy, VL.size()); + auto *DstScalarTy = VL0->getType(); + auto *DstVectorTy = VecTy; + auto DstOpcode = VL0->getOpcode(); + + // The scalar and vector cost estimates. We will define these below. + int ScalarCost; + int VectorCost; + + // If there is only one user of the scalar value, we can attempt to use + // the more precise, contextual cost estimates. If this is the case, we + // build up the context. Otherwise, we use the context-free estimates. + auto *TheUser = dyn_cast_or_null(*VL0->user_begin()); + if (!TheUser || !TheUser->getType()->isIntegerTy() || !VL0->hasOneUse()) { + ScalarCost = TTI->getCastInstrCost(DstOpcode, DstScalarTy, SrcScalarTy); + VectorCost = TTI->getCastInstrCost(DstOpcode, DstVectorTy, SrcVectorTy); + ScalarCost *= VL.size(); + return VectorCost - ScalarCost; + } - VectorType *SrcVecTy = VectorType::get(SrcTy, VL.size()); - int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy); - return VecCost - ScalarCost; + // We're going to use the contextual estimates. First define the contexts + // for both teh vector and scalar estimates. + InstrCtx ScalarCast, ScalarOper(SrcScalarTy), ScalarUser; + InstrCtx VectorCast, VectorOper(SrcVectorTy), VectorUser; + SmallVector ScalarOperands; + SmallVector VectorOperands; + + // Next, gather information about the user. + auto UsrOpcode = TheUser->getOpcode(); + auto *UsrScalarTy = TheUser->getType(); + auto *UsrVectorTy = MaxRequiredIntegerTy + ? DstVectorTy + : VectorType::get(UsrScalarTy, VL.size()); + + // Finally, build the contexts. We first visit each operand of the user, + // and then connect the user to the cast. + for (Use &U : TheUser->operands()) { + auto *ScalarTy = U.get()->getType(); + auto *VectorTy = VectorType::get(ScalarTy, VL.size()); + auto *I = dyn_cast(U.get()); + auto Opc = I ? I->getOpcode() : 0; + ScalarOperands.push_back(InstrCtx(ScalarTy, Opc, &ScalarUser)); + if (MaxRequiredIntegerTy) + VectorOperands.push_back(InstrCtx(DstVectorTy, Opc, &VectorUser)); + else + VectorOperands.push_back(InstrCtx(VectorTy, Opc, &VectorUser)); + } + ScalarUser = InstrCtx(UsrScalarTy, UsrOpcode, ScalarOperands); + VectorUser = InstrCtx(UsrVectorTy, UsrOpcode, VectorOperands); + ScalarCast = InstrCtx(DstScalarTy, DstOpcode, ScalarOper, &ScalarUser); + VectorCast = InstrCtx(DstVectorTy, DstOpcode, VectorOper, &VectorUser); + + // The contextual cost estimateis. + ScalarCost = TTI->getCastInstrCost(DstOpcode, DstScalarTy, SrcScalarTy, + &ScalarCast); + VectorCost = TTI->getCastInstrCost(DstOpcode, DstVectorTy, SrcVectorTy, + &VectorCast); + + ScalarCost *= VL.size(); + return VectorCost - ScalarCost; } case Instruction::FCmp: case Instruction::ICmp: @@ -1791,7 +1847,7 @@ PrevInst = Inst; } - DEBUG(dbgs() << "SLP: SpillCost=" << Cost << "\n"); + DEBUG(dbgs() << "SLP: Spill Cost = " << Cost << ".\n"); return Cost; } @@ -1833,12 +1889,19 @@ // If we plan to rewrite the tree in a smaller type, we will need to sign // extend the extracted value back to the original type. Here, we account - // for the extract and the added cost of the sign extend if needed. + // for the extract and the added cost of the sign extend if needed. We + // build a context for the extend that includes the extract to improve the + // accuracy of the estimate. auto *VecTy = VectorType::get(I->Scalar->getType(), BundleWidth); if (MaxRequiredIntegerTy) { VecTy = VectorType::get(MaxRequiredIntegerTy, BundleWidth); - ExtractCost += TTI->getCastInstrCost( - Instruction::SExt, I->Scalar->getType(), MaxRequiredIntegerTy); + InstrCtx Vector(VecTy), Extract, Extend; + Extract = InstrCtx(MaxRequiredIntegerTy, Instruction::ExtractElement, + Vector, &Extend); + Extend = InstrCtx(I->Scalar->getType(), Instruction::SExt, Extract); + ExtractCost += + TTI->getCastInstrCost(Instruction::SExt, I->Scalar->getType(), + MaxRequiredIntegerTy, &Extend); } ExtractCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, I->Lane); @@ -1846,7 +1909,9 @@ Cost += getSpillCost(); - DEBUG(dbgs() << "SLP: Total Cost " << Cost + ExtractCost<< ".\n"); + DEBUG(dbgs() << "SLP: Extract Cost = " << ExtractCost << ".\n"); + DEBUG(dbgs() << "SLP: Total Cost = " << Cost + ExtractCost << ".\n"); + return Cost + ExtractCost; } Index: test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll =================================================================== --- test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll +++ test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll @@ -1,5 +1,4 @@ -; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s --check-prefix=PROFITABLE -; RUN: opt -S -slp-vectorizer -slp-threshold=-12 -dce -instcombine < %s | FileCheck %s --check-prefix=UNPROFITABLE +; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" target triple = "aarch64--linux-gnu" @@ -19,13 +18,13 @@ ; return sum; ; } -; PROFITABLE-LABEL: @gather_reduce_8x16_i32 +; CHECK-LABEL: @gather_reduce_8x16_i32 ; -; PROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> -; PROFITABLE: zext <8 x i16> [[L]] to <8 x i32> -; PROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> -; PROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] -; PROFITABLE: sext i32 [[X]] to i64 +; CHECK: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> +; CHECK: zext <8 x i16> [[L]] to <8 x i32> +; CHECK: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> +; CHECK: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] +; CHECK: sext i32 [[X]] to i64 ; define i32 @gather_reduce_8x16_i32(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) { entry: @@ -138,18 +137,13 @@ br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body } -; UNPROFITABLE-LABEL: @gather_reduce_8x16_i64 +; CHECK-LABEL: @gather_reduce_8x16_i64 ; -; UNPROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> -; UNPROFITABLE: zext <8 x i16> [[L]] to <8 x i32> -; UNPROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> -; UNPROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] -; UNPROFITABLE: sext i32 [[X]] to i64 -; -; TODO: Although we can now vectorize this case while converting the i64 -; subtractions to i32, the cost model currently finds vectorization to be -; unprofitable. The cost model is penalizing the sign and zero -; extensions in the vectorized version, but they are actually free. +; CHECK: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> +; CHECK: zext <8 x i16> [[L]] to <8 x i32> +; CHECK: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> +; CHECK: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] +; CHECK: sext i32 [[X]] to i64 ; define i32 @gather_reduce_8x16_i64(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) { entry: