diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1265,13 +1265,21 @@ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const; /// Calculate the cost of an extended reduction pattern, similar to - /// getArithmeticReductionCost of an Add reduction with an extension and - /// optional multiply. This is the cost of as: - /// ResTy vecreduce.add(ext(Ty A)), or if IsMLA flag is set then: - /// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)). The reduction happens - /// on a VectorType with ResTy elements and Ty lanes. - InstructionCost getExtendedAddReductionCost( - bool IsMLA, bool IsUnsigned, Type *ResTy, VectorType *Ty, + /// getArithmeticReductionCost of an Add reduction with multiply and optional + /// extensions. This is the cost of as: + /// ResTy vecreduce.add(mul (A, B)). + /// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)). + InstructionCost getMulAccReductionCost( + bool IsUnsigned, Type *ResTy, VectorType *Ty, + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const; + + /// Calculate the cost of an extended reduction pattern, similar to + /// getArithmeticReductionCost of a reduction with an extension. + /// This is the cost of as: + /// ResTy vecreduce(ext(Ty A)). + InstructionCost getExtendedReductionCost( + unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty, + Optional FMF, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const; /// \returns The cost of Intrinsic instructions. Analyses the real arguments. @@ -1775,8 +1783,12 @@ virtual InstructionCost getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy, bool IsUnsigned, TTI::TargetCostKind CostKind) = 0; - virtual InstructionCost getExtendedAddReductionCost( - bool IsMLA, bool IsUnsigned, Type *ResTy, VectorType *Ty, + virtual InstructionCost getExtendedReductionCost( + unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty, + Optional FMF, + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0; + virtual InstructionCost getMulAccReductionCost( + bool IsUnsigned, Type *ResTy, VectorType *Ty, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0; virtual InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, @@ -2345,11 +2357,17 @@ TTI::TargetCostKind CostKind) override { return Impl.getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind); } - InstructionCost getExtendedAddReductionCost( - bool IsMLA, bool IsUnsigned, Type *ResTy, VectorType *Ty, + InstructionCost getExtendedReductionCost( + unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty, + Optional FMF, + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) override { + return Impl.getExtendedReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF, + CostKind); + } + InstructionCost getMulAccReductionCost( + bool IsUnsigned, Type *ResTy, VectorType *Ty, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) override { - return Impl.getExtendedAddReductionCost(IsMLA, IsUnsigned, ResTy, Ty, - CostKind); + return Impl.getMulAccReductionCost(IsUnsigned, ResTy, Ty, CostKind); } InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, TTI::TargetCostKind CostKind) override { diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -681,10 +681,16 @@ return 1; } - InstructionCost - getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned, Type *ResTy, - VectorType *Ty, - TTI::TargetCostKind CostKind) const { + InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned, + Type *ResTy, VectorType *Ty, + Optional FMF, + TTI::TargetCostKind CostKind) const { + return 1; + } + + InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy, + VectorType *Ty, + TTI::TargetCostKind CostKind) const { return 1; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -2318,23 +2318,37 @@ thisT()->getVectorInstrCost(Instruction::ExtractElement, Ty, 0); } - InstructionCost getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned, - Type *ResTy, VectorType *Ty, - TTI::TargetCostKind CostKind) { + InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned, + Type *ResTy, VectorType *Ty, + Optional FMF, + TTI::TargetCostKind CostKind) { // Without any native support, this is equivalent to the cost of - // vecreduce.add(ext) or if IsMLA vecreduce.add(mul(ext, ext)) + // vecreduce.op(ext). + VectorType *ExtTy = VectorType::get(ResTy, Ty); + InstructionCost RedCost = + thisT()->getArithmeticReductionCost(Opcode, ExtTy, FMF, CostKind); + InstructionCost ExtCost = thisT()->getCastInstrCost( + IsUnsigned ? Instruction::ZExt : Instruction::SExt, ExtTy, Ty, + TTI::CastContextHint::None, CostKind); + + return RedCost + ExtCost; + } + + InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy, + VectorType *Ty, + TTI::TargetCostKind CostKind) { + // Without any native support, this is equivalent to the cost of + // vecreduce.add(mul(ext, ext)). VectorType *ExtTy = VectorType::get(ResTy, Ty); InstructionCost RedCost = thisT()->getArithmeticReductionCost( Instruction::Add, ExtTy, None, CostKind); - InstructionCost MulCost = 0; InstructionCost ExtCost = thisT()->getCastInstrCost( IsUnsigned ? Instruction::ZExt : Instruction::SExt, ExtTy, Ty, TTI::CastContextHint::None, CostKind); - if (IsMLA) { - MulCost = - thisT()->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind); - ExtCost *= 2; - } + + InstructionCost MulCost = + thisT()->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind); + ExtCost *= 2; return RedCost + MulCost + ExtCost; } diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -976,11 +976,17 @@ return Cost; } -InstructionCost TargetTransformInfo::getExtendedAddReductionCost( - bool IsMLA, bool IsUnsigned, Type *ResTy, VectorType *Ty, +InstructionCost TargetTransformInfo::getExtendedReductionCost( + unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty, + Optional FMF, TTI::TargetCostKind CostKind) const { + return TTIImpl->getExtendedReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF, + CostKind); +} + +InstructionCost TargetTransformInfo::getMulAccReductionCost( + bool IsUnsigned, Type *ResTy, VectorType *Ty, TTI::TargetCostKind CostKind) const { - return TTIImpl->getExtendedAddReductionCost(IsMLA, IsUnsigned, ResTy, Ty, - CostKind); + return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, CostKind); } InstructionCost diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h @@ -275,9 +275,13 @@ InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy, Optional FMF, TTI::TargetCostKind CostKind); - InstructionCost getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned, - Type *ResTy, VectorType *ValTy, - TTI::TargetCostKind CostKind); + InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned, + Type *ResTy, VectorType *ValTy, + Optional FMF, + TTI::TargetCostKind CostKind); + InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy, + VectorType *ValTy, + TTI::TargetCostKind CostKind); InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, TTI::TargetCostKind CostKind); diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -1677,10 +1677,46 @@ return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind); } +InstructionCost ARMTTIImpl::getExtendedReductionCost( + unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *ValTy, + Optional FMF, TTI::TargetCostKind CostKind) { + EVT ValVT = TLI->getValueType(DL, ValTy); + EVT ResVT = TLI->getValueType(DL, ResTy); + + int ISD = TLI->InstructionOpcodeToISD(Opcode); + + switch (ISD) { + case ISD::ADD: + if (ST->hasMVEIntegerOps() && ValVT.isSimple() && ResVT.isSimple()) { + std::pair LT = + TLI->getTypeLegalizationCost(DL, ValTy); + + // The legal cases are: + // VADDV u/s 8/16/32 + // VADDLV u/s 32 + // Codegen currently cannot always handle larger than legal vectors very + // well, especially for predicated reductions where the mask needs to be + // split, so restrict to 128bit or smaller input types. + unsigned RevVTSize = ResVT.getSizeInBits(); + if (ValVT.getSizeInBits() <= 128 && + ((LT.second == MVT::v16i8 && RevVTSize <= 32) || + (LT.second == MVT::v8i16 && RevVTSize <= 32) || + (LT.second == MVT::v4i32 && RevVTSize <= 64))) + return ST->getMVEVectorCostFactor(CostKind) * LT.first; + } + break; + // TODO: Add other Op, for example: FADD? + default: + break; + } + return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy, FMF, + CostKind); +} + InstructionCost -ARMTTIImpl::getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned, - Type *ResTy, VectorType *ValTy, - TTI::TargetCostKind CostKind) { +ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy, + VectorType *ValTy, + TTI::TargetCostKind CostKind) { EVT ValVT = TLI->getValueType(DL, ValTy); EVT ResVT = TLI->getValueType(DL, ResTy); @@ -1689,9 +1725,7 @@ TLI->getTypeLegalizationCost(DL, ValTy); // The legal cases are: - // VADDV u/s 8/16/32 // VMLAV u/s 8/16/32 - // VADDLV u/s 32 // VMLALV u/s 16/32 // Codegen currently cannot always handle larger than legal vectors very // well, especially for predicated reductions where the mask needs to be @@ -1699,13 +1733,12 @@ unsigned RevVTSize = ResVT.getSizeInBits(); if (ValVT.getSizeInBits() <= 128 && ((LT.second == MVT::v16i8 && RevVTSize <= 32) || - (LT.second == MVT::v8i16 && RevVTSize <= (IsMLA ? 64u : 32u)) || + (LT.second == MVT::v8i16 && RevVTSize <= 64) || (LT.second == MVT::v4i32 && RevVTSize <= 64))) return ST->getMVEVectorCostFactor(CostKind) * LT.first; } - return BaseT::getExtendedAddReductionCost(IsMLA, IsUnsigned, ResTy, ValTy, - CostKind); + return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, CostKind); } InstructionCost diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -6547,7 +6547,7 @@ VectorTy = VectorType::get(I->getOperand(0)->getType(), VectorTy); Instruction *Op0, *Op1; - if (RedOp && + if (RedOp && RdxDesc.getOpcode() == Instruction::Add && match(RedOp, m_ZExtOrSExt(m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) && match(Op0, m_ZExtOrSExt(m_Value())) && @@ -6556,7 +6556,7 @@ !TheLoop->isLoopInvariant(Op0) && !TheLoop->isLoopInvariant(Op1) && (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) { - // Matched reduce(ext(mul(ext(A), ext(B))) + // Matched reduce.add(ext(mul(ext(A), ext(B))) // Note that the extend opcodes need to all match, or if A==B they will have // been converted to zext(mul(sext(A), sext(A))) as it is known positive, // which is equally fine. @@ -6573,9 +6573,8 @@ TTI.getCastInstrCost(RedOp->getOpcode(), VectorTy, MulType, TTI::CastContextHint::None, CostKind, RedOp); - InstructionCost RedCost = TTI.getExtendedAddReductionCost( - /*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, - CostKind); + InstructionCost RedCost = TTI.getMulAccReductionCost( + IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind); if (RedCost.isValid() && RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost) @@ -6585,16 +6584,16 @@ // Matched reduce(ext(A)) bool IsUnsigned = isa(RedOp); auto *ExtType = VectorType::get(RedOp->getOperand(0)->getType(), VectorTy); - InstructionCost RedCost = TTI.getExtendedAddReductionCost( - /*IsMLA=*/false, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, - CostKind); + InstructionCost RedCost = TTI.getExtendedReductionCost( + RdxDesc.getOpcode(), IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, + RdxDesc.getFastMathFlags(), CostKind); InstructionCost ExtCost = TTI.getCastInstrCost(RedOp->getOpcode(), VectorTy, ExtType, TTI::CastContextHint::None, CostKind, RedOp); if (RedCost.isValid() && RedCost < BaseCost + ExtCost) return I == RetI ? RedCost : 0; - } else if (RedOp && + } else if (RedOp && RdxDesc.getOpcode() == Instruction::Add && match(RedOp, m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) { if (match(Op0, m_ZExtOrSExt(m_Value())) && Op0->getOpcode() == Op1->getOpcode() && @@ -6607,7 +6606,7 @@ : Op0Ty; auto *ExtType = VectorType::get(LargestOpTy, VectorTy); - // Matched reduce(mul(ext(A), ext(B))), where the two ext may be of + // Matched reduce.add(mul(ext(A), ext(B))), where the two ext may be of // different sizes. We take the largest type as the ext to reduce, and add // the remaining cost as, for example reduce(mul(ext(ext(A)), ext(B))). InstructionCost ExtCost0 = TTI.getCastInstrCost( @@ -6619,9 +6618,8 @@ InstructionCost MulCost = TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind); - InstructionCost RedCost = TTI.getExtendedAddReductionCost( - /*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, - CostKind); + InstructionCost RedCost = TTI.getMulAccReductionCost( + IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind); InstructionCost ExtraExtCost = 0; if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) { Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1; @@ -6635,13 +6633,12 @@ (RedCost + ExtraExtCost) < (ExtCost0 + ExtCost1 + MulCost + BaseCost)) return I == RetI ? RedCost : 0; } else if (!match(I, m_ZExtOrSExt(m_Value()))) { - // Matched reduce(mul()) + // Matched reduce.add(mul()) InstructionCost MulCost = TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind); - InstructionCost RedCost = TTI.getExtendedAddReductionCost( - /*IsMLA=*/true, true, RdxDesc.getRecurrenceType(), VectorTy, - CostKind); + InstructionCost RedCost = TTI.getMulAccReductionCost( + true, RdxDesc.getRecurrenceType(), VectorTy, CostKind); if (RedCost.isValid() && RedCost < MulCost + BaseCost) return I == RetI ? RedCost : 0;