Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -438,6 +438,35 @@ /// zext, etc. int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) const; + /// A struct to hold information for contextual cost estimates. + struct InstrContext { + InstrContext() : InstrTy(nullptr), Opcode(0), User(nullptr) { + Operands = makeArrayRef((InstrContext *)nullptr, 0); + } + InstrContext(Type *Ty, unsigned Opc = 0, InstrContext *Usr = nullptr) + : InstrTy(Ty), Opcode(Opc), User(Usr) { + Operands = makeArrayRef((InstrContext *)nullptr, 0); + } + InstrContext(Type *Ty, unsigned Opc, ArrayRef Ops, + InstrContext *Usr = nullptr) + : InstrTy(Ty), Opcode(Opc), Operands(Ops), User(Usr) {} + + /// The type of the instruction. + Type *InstrTy; + + /// The opcode of the instruction. + unsigned Opcode; + + /// A list of InstrContexts describing the operands of the instruction. + ArrayRef Operands; + + /// An InstrContext describing a single user of the instruction. + InstrContext *User; + }; + + /// \return The expected cost of a cast instruction in the given context. + int getCastInstrCostInCtx(InstrContext &Ctx) const; + /// \return The expected cost of control-flow related instructions such as /// Phi, Ret, Br. int getCFInstrCost(unsigned Opcode) const; @@ -618,6 +647,7 @@ virtual int getShuffleCost(ShuffleKind Kind, Type *Tp, int Index, Type *SubTp) = 0; virtual int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) = 0; + virtual int getCastInstrCostInCtx(InstrContext &Ctx) = 0; virtual int getCFInstrCost(unsigned Opcode) = 0; virtual int getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy) = 0; @@ -793,6 +823,9 @@ int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) override { return Impl.getCastInstrCost(Opcode, Dst, Src); } + int getCastInstrCostInCtx(InstrContext &Ctx) override { + return Impl.getCastInstrCostInCtx(Ctx); + } int getCFInstrCost(unsigned Opcode) override { return Impl.getCFInstrCost(Opcode); } Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -281,6 +281,8 @@ unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) { return 1; } + unsigned getCastInstrCostInCtx(TTI::InstrContext &Ctx) { return 1; } + unsigned getCFInstrCost(unsigned Opcode) { return 1; } unsigned getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy) { Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -428,6 +428,12 @@ llvm_unreachable("Unhandled cast"); } + unsigned getCastInstrCostInCtx(TTI::InstrContext &Ctx) { + assert(!Ctx.Operands.empty() && "Empty context"); + return static_cast(this)->getCastInstrCost(Ctx.Opcode, Ctx.InstrTy, + Ctx.Operands[0].InstrTy); + } + unsigned getCFInstrCost(unsigned Opcode) { // Branches are assumed to be predicted. return 0; Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -243,6 +243,12 @@ return Cost; } +int TargetTransformInfo::getCastInstrCostInCtx(InstrContext &Ctx) const { + int Cost = TTIImpl->getCastInstrCostInCtx(Ctx); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + int TargetTransformInfo::getCFInstrCost(unsigned Opcode) const { int Cost = TTIImpl->getCFInstrCost(Opcode); assert(Cost >= 0 && "TTI should not produce negative costs!"); Index: lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- lib/Target/AArch64/AArch64TargetTransformInfo.h +++ lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -99,6 +99,8 @@ int getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src); + int getCastInstrCostInCtx(TTI::InstrContext &Ctx); + int getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index); int getArithmeticInstrCost( Index: lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -291,6 +291,139 @@ return BaseT::getCastInstrCost(Opcode, Dst, Src); } +int AArch64TTIImpl::getCastInstrCostInCtx(TTI::InstrContext &Ctx) { + + assert(!Ctx.Operands.empty() && "Invalid context"); + + // Contextual cost estimate. We can base the estimate on information about + // the value being cast. + switch (Ctx.Operands[0].Opcode) { + default: + break; + + // If the cast value has been extracted from a vector, we can usually perform + // sign and zero extensions automatically with smov and umov. We check for + // the free cases below, and fall back to the context-free estimate + // otherwise. + case Instruction::ExtractElement: + + // We only contextualize sext and zext casts. + if (Ctx.Opcode != Instruction::SExt && Ctx.Opcode != Instruction::ZExt) + break; + + // See if we've been given enough context to do anything useful. We need to + // know the first operand of the extractelement. + auto &Operands = Ctx.Operands[0].Operands; + if (Operands.empty()) + break; + + // Make sure the first operand of the extractelement is a vector type. This + // must be the case if we have a valid context. + auto *Vec = dyn_cast_or_null(Operands[0].InstrTy); + assert(Vec && "Invalid context"); + + // Because the cast is either zext or sext, we have to be working with all + // integer types. Again, this must be the case if we have a valid context. + auto *Dst = dyn_cast_or_null(Ctx.InstrTy); + auto *Src = dyn_cast_or_null(Ctx.Operands[0].InstrTy); + assert(Dst && Src == Vec->getElementType() && "Invalid context"); + + // Now we legalize the types. If the resulting type is still a vector and + // the destination type is legal, we may get the extension for free. + auto VecLT = TLI->getTypeLegalizationCost(DL, Vec); + auto DstVT = TLI->getValueType(DL, Dst); + if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT)) + break; + + // Ensure we have a valid extract-extend combination. The source and + // element types should be the same, and the destination type should be + // larger than the element type. + auto SrcVT = TLI->getValueType(DL, Src); + auto ElmVT = VecLT.second.getVectorElementType(); + if (SrcVT != ElmVT || DstVT.getSizeInBits() < ElmVT.getSizeInBits()) + break; + + switch (Ctx.Opcode) { + default: + llvm_unreachable("Ctx.Opcode should be either SExt or ZExt"); + + // For sign extensions, we only need a smov, which performs the extension + // automatically. + case Instruction::SExt: + return 0; + + // For zero extensions, the extend is performed automatically by a umov + // unless the destination type is i64 and the element type is i8 or i16. + case Instruction::ZExt: + if (DstVT.getSizeInBits() != 64u || ElmVT.getSizeInBits() == 32u) + return 0; + } + } + + // At this point, we've looked at the operand of the cast if interesting to + // do so. If we don't yet know that the cast will be free, we can try to look + // at the user of the cast. If there is no information about what the user of + // the cast looks like, there is nothing else to do. + if (!Ctx.User) + return getCastInstrCost(Ctx.Opcode, Ctx.InstrTy, Ctx.Operands[0].InstrTy); + + // Otherwise, we can can base the estimate on information about the user. + switch (Ctx.User->Opcode) { + default: + break; + + // AArch64 can perform widening and narrowing vector arithmetic. If the cast + // value is a vector extend that is used by an arithmetic operation, it may + // happen automatically. We check for the free cases below, and fall back to + // the context-free estimate otherwise. + case Instruction::Add: + case Instruction::Sub: + + // We only contextualize sext and zext casts. + if (Ctx.Opcode != Instruction::SExt && Ctx.Opcode != Instruction::ZExt) + break; + + // We have to look at the operands of widening and narrowing instructions. + // If we weren't given this information, there's nothing to do. + if (Ctx.User->Operands.size() != 2) + break; + + // Type check the extend and it's user to ensure we have a valid context. + // Because we have an extend, it's source and destination types have to be + // integers or vectors of integers. + auto *Dst = Ctx.InstrTy; + auto *Src = Ctx.Operands[0].InstrTy; + auto &Operands = Ctx.User->Operands; + assert(Src && Src->isIntOrIntVectorTy() && Dst && + Dst->isIntOrIntVectorTy() && Ctx.User->InstrTy == Dst && + Operands[0].InstrTy == Dst && Operands[1].InstrTy == Dst && + "Invalid context"); + + // Check if this operation can be matched to a lengthening or widening + // vector operation. If not, give up. + auto IsWidening = Operands[1].Opcode == Ctx.Opcode; + auto IsLengthening = IsWidening && Operands[0].Opcode == Ctx.Opcode; + if (!IsWidening && !IsLengthening) + break; + + // Now we legalize the types. If the resulting types are not vectors, this + // won't match the lengthening and widening operations. + auto SrcVT = TLI->getTypeLegalizationCost(DL, Src).second; + auto DstVT = TLI->getTypeLegalizationCost(DL, Dst).second; + if (!SrcVT.isVector() || !DstVT.isVector()) + break; + + // If we have legal vectors and the destination element size is twice the + // source element size, the extend can be performed automatically by the + // lengthening or widening instruction. + if (DstVT.getScalarSizeInBits() == 2 * SrcVT.getScalarSizeInBits()) + return 0; + } + + // If we can't do anything useful, fall back to the default implementation. + return getCastInstrCost(Ctx.Opcode, Ctx.InstrTy, Ctx.Operands[0].InstrTy); +} + int AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index) { assert(Val->isVectorTy() && "This must be a vector type"); Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -361,6 +361,8 @@ typedef SmallPtrSet ValueSet; typedef SmallVector StoreList; + typedef TargetTransformInfo::InstrContext InstrCtx; + BoUpSLP(Function *Func, ScalarEvolution *Se, TargetTransformInfo *Tti, TargetLibraryInfo *TLi, AliasAnalysis *Aa, LoopInfo *Li, DominatorTree *Dt, AssumptionCache *AC) @@ -1539,15 +1541,67 @@ case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: { - Type *SrcTy = VL0->getOperand(0)->getType(); - // Calculate the cost of this instruction. - int ScalarCost = VL.size() * TTI->getCastInstrCost(VL0->getOpcode(), - VL0->getType(), SrcTy); + // Simplifying definitions to avoid confusion in the code below. + auto *SrcScalarTy = VL0->getOperand(0)->getType(); + auto *SrcVectorTy = VectorType::get(SrcScalarTy, VL.size()); + auto *DstScalarTy = VL0->getType(); + auto *DstVectorTy = VecTy; + auto DstOpcode = VL0->getOpcode(); + + // The scalar and vector cost estimates. We will define these below. + int ScalarCost; + int VectorCost; + + // If there is only one user of the scalar value, we can attempt to use + // the more precise, contextual cost estimates. If this is the case, we + // build up the context. Otherwise, we use the context-free estimates. + auto *TheUser = dyn_cast_or_null(*VL0->user_begin()); + if (!TheUser || !TheUser->getType()->isIntegerTy() || !VL0->hasOneUse()) { + ScalarCost = TTI->getCastInstrCost(DstOpcode, DstScalarTy, SrcScalarTy); + VectorCost = TTI->getCastInstrCost(DstOpcode, DstVectorTy, SrcVectorTy); + ScalarCost *= VL.size(); + return VectorCost - ScalarCost; + } - VectorType *SrcVecTy = VectorType::get(SrcTy, VL.size()); - int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy); - return VecCost - ScalarCost; + // We're going to use the contextual estimates. First define the contexts + // for both teh vector and scalar estimates. + InstrCtx ScalarCast, ScalarOper(SrcScalarTy), ScalarUser; + InstrCtx VectorCast, VectorOper(SrcVectorTy), VectorUser; + SmallVector ScalarOperands; + SmallVector VectorOperands; + + // Next, gather information about the user. + auto UsrOpcode = TheUser->getOpcode(); + auto *UsrScalarTy = TheUser->getType(); + auto *UsrVectorTy = MaxRequiredIntegerTy + ? DstVectorTy + : VectorType::get(UsrScalarTy, VL.size()); + + // Finally, build the contexts. We first visit each operand of the user, + // and then connect the user to the cast. + for (Use &U : TheUser->operands()) { + auto *ScalarTy = U.get()->getType(); + auto *VectorTy = VectorType::get(ScalarTy, VL.size()); + auto *I = dyn_cast(U.get()); + auto Opc = I ? I->getOpcode() : 0; + ScalarOperands.push_back(InstrCtx(ScalarTy, Opc, &ScalarUser)); + if (MaxRequiredIntegerTy) + VectorOperands.push_back(InstrCtx(DstVectorTy, Opc, &VectorUser)); + else + VectorOperands.push_back(InstrCtx(VectorTy, Opc, &VectorUser)); + } + ScalarUser = InstrCtx(UsrScalarTy, UsrOpcode, ScalarOperands); + VectorUser = InstrCtx(UsrVectorTy, UsrOpcode, VectorOperands); + ScalarCast = InstrCtx(DstScalarTy, DstOpcode, ScalarOper, &ScalarUser); + VectorCast = InstrCtx(DstVectorTy, DstOpcode, VectorOper, &VectorUser); + + // The contextual cost estimateis. + ScalarCost = TTI->getCastInstrCostInCtx(ScalarCast); + VectorCost = TTI->getCastInstrCostInCtx(VectorCast); + + ScalarCost *= VL.size(); + return VectorCost - ScalarCost; } case Instruction::FCmp: case Instruction::ICmp: @@ -1787,7 +1841,7 @@ PrevInst = Inst; } - DEBUG(dbgs() << "SLP: SpillCost=" << Cost << "\n"); + DEBUG(dbgs() << "SLP: Spill Cost = " << Cost << ".\n"); return Cost; } @@ -1829,12 +1883,17 @@ // If we plan to rewrite the tree in a smaller type, we will need to sign // extend the extracted value back to the original type. Here, we account - // for the extract and the added cost of the sign extend if needed. + // for the extract and the added cost of the sign extend if needed. We + // build a context for the extend that includes the extract to improve the + // accuracy of the estimate. auto *VecTy = VectorType::get(I->Scalar->getType(), BundleWidth); if (MaxRequiredIntegerTy) { VecTy = VectorType::get(MaxRequiredIntegerTy, BundleWidth); - ExtractCost += TTI->getCastInstrCost( - Instruction::SExt, I->Scalar->getType(), MaxRequiredIntegerTy); + InstrCtx Vector(VecTy), Extract, Extend; + Extract = InstrCtx(MaxRequiredIntegerTy, Instruction::ExtractElement, + Vector, &Extend); + Extend = InstrCtx(I->Scalar->getType(), Instruction::SExt, Extract); + ExtractCost += TTI->getCastInstrCostInCtx(Extend); } ExtractCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, I->Lane); @@ -1842,7 +1901,9 @@ Cost += getSpillCost(); - DEBUG(dbgs() << "SLP: Total Cost " << Cost + ExtractCost<< ".\n"); + DEBUG(dbgs() << "SLP: Extract Cost = " << ExtractCost << ".\n"); + DEBUG(dbgs() << "SLP: Total Cost = " << Cost + ExtractCost << ".\n"); + return Cost + ExtractCost; } Index: test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll =================================================================== --- test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll +++ test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll @@ -1,4 +1,4 @@ -; RUN: opt -S -slp-vectorizer -slp-threshold=-12 -dce -instcombine < %s | FileCheck %s +; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" target triple = "aarch64--linux-gnu"