Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1539,6 +1539,8 @@ : EVLParamStrategy(EVLParamStrategy), OpStrategy(OpStrategy) {} }; + bool isFMAFasterThanFMulAndFAdd(const Function &F, Type *Ty) const; + /// \returns How the target needs this vector-predicated operation to be /// transformed. VPLegalization getVPLegalizationStrategy(const VPIntrinsic &PI) const; @@ -1895,6 +1897,7 @@ virtual bool supportsScalableVectors() const = 0; virtual bool hasActiveVectorLength(unsigned Opcode, Type *DataType, Align Alignment) const = 0; + virtual bool isFMAFasterThanFMulAndFAdd(const Function &F, Type *Ty) const = 0; virtual VPLegalization getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0; }; @@ -2562,6 +2565,10 @@ return Impl.hasActiveVectorLength(Opcode, DataType, Alignment); } + bool isFMAFasterThanFMulAndFAdd(const Function &F, Type *Ty) const override { + return Impl.isFMAFasterThanFMulAndFAdd(F, Ty); + } + VPLegalization getVPLegalizationStrategy(const VPIntrinsic &PI) const override { return Impl.getVPLegalizationStrategy(PI); Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -847,6 +847,10 @@ return false; } + bool isFMAFasterThanFMulAndFAdd(const Function &F, Type *Ty) const { + return false; + } + TargetTransformInfo::VPLegalization getVPLegalizationStrategy(const VPIntrinsic &PI) const { return TargetTransformInfo::VPLegalization( Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -368,6 +368,10 @@ return ST->hasSVE(); } + bool isFMAFasterThanFMulAndFAdd(const Function &F, Type *Ty) const { + return TLI->isFMAFasterThanFMulAndFAdd(F, Ty); + } + InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, Optional FMF, TTI::TargetCostKind CostKind); Index: llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -2708,6 +2708,21 @@ "Bundle and VL out of sync"); if (BundleMember) { for (Value *V : VL) { + if (isa(V)) { + Instruction *Op = dyn_cast(V); + if ((Op->getOpcode() == Instruction::FAdd || + Op->getOpcode() == Instruction::FSub) && Op->hasOneUse()) { + if (isa(V)) + for (int i = 0; i<2; ++i) { + Value *VOp = Op->getOperand(i); + Instruction *Op1 = dyn_cast(VOp); + if (Op1 && Op1->getOpcode() == Instruction::FMul && Op1->hasOneUse()) { + LLVM_DEBUG(dbgs() << "MyDebug0: " << *Op << " Op1" << *Op1 << "\n"); + break; + } + } + } + } if (doesNotNeedToBeScheduled(V)) continue; assert(BundleMember && "Unexpected end of bundle."); @@ -6847,6 +6862,9 @@ ? static_cast(Instruction::Add) : ShuffleOrOp; auto GetScalarCost = [=](unsigned Idx) { + //LLVM_DEBUG(dbgs() << "MyDebug0: " << VL.size() << " Idx: " << Idx << "\n"); + //for (Value *V : VL) + // LLVM_DEBUG(dbgs() << "MyDebug0: " << *V << "\n"); auto *VI = dyn_cast(VL[Idx]); // GEPs may contain just addresses without instructions, consider // their cost 0. @@ -6857,6 +6875,26 @@ TTI::OperandValueInfo Op2Info = TTI::getOperandInfo(VI->getOperand(OpIdx)); SmallVector Operands(VI->operand_values()); + if ((Opcode == Instruction::FAdd || Opcode == Instruction::FSub) && VI->hasOneUse()) { + bool HasFMul = false; + //LLVM_DEBUG(dbgs() << "MyDebug1: " << *VI << "\n"); + for (int i = 0; i<2; ++i) { + Value *V = VI->getOperand(i); + if (isa(V)) { + Instruction *Op = dyn_cast(V); + if (Op->getOpcode() == Instruction::FMul && Op->hasOneUse()) { + HasFMul = true; + break; + } + } + } + if (HasFMul) { + LLVM_DEBUG(dbgs() << "MyDebug2: " << *VI << "\n"); + InstructionCost C = 1; + //assert(false); + return C; + } + } return TTI->getArithmeticInstrCost(Opcode, ScalarTy, CostKind, Op1Info, Op2Info, Operands, VI); }; @@ -6976,10 +7014,35 @@ } return false; }; - auto GetScalarCost = [=](unsigned Idx) { + //const Function *F1 = F; + auto GetScalarCost = [=](unsigned Idx/*, const Function *F1*/) { auto *VI = cast(VL[Idx]); + InstructionCost C = 1; assert(E->isOpcodeOrAlt(VI) && "Unexpected main/alternate opcode"); (void)E; + unsigned Opcode = VI->getOpcode(); + if ((Opcode == Instruction::FAdd || Opcode == Instruction::FSub) && VI->hasOneUse()) { + bool HasFMul = false; + //LLVM_DEBUG(dbgs() << "MyDebug1: " << *VI << "\n"); + for (int i = 0; i<2; ++i) { + Value *V = VI->getOperand(i); + if (isa(V)) { + Instruction *Op = dyn_cast(V); + if (Op->getOpcode() == Instruction::FMul && Op->hasOneUse()) { + HasFMul = true; + break; + } + } + } + if (HasFMul) { + LLVM_DEBUG(dbgs() << "MyDebug2: " << *VI << "\n"); + LLVM_DEBUG(dbgs() << "MyDebug3: From: " << TTI->getInstructionCost(VI, CostKind) + << " To: " << C << "\n"); + //TTI->isFMAFasterThanFMulAndFAdd(&(*F1), VI->getType()); + //assert(false); + return C; + } + } return TTI->getInstructionCost(VI, CostKind); }; // Need to clear CommonCost since the final shuffle cost is included into