Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -620,9 +620,9 @@ /// (StartIdx * Step, (StartIdx + 1) * Step, (StartIdx + 2) * Step, ...) /// to each vector element of Val. The sequence starts at StartIndex. /// \p Opcode is relevant for FP induction variable. - virtual Value *getStepVector(Value *Val, int StartIdx, Value *Step, - Instruction::BinaryOps Opcode = - Instruction::BinaryOpsEnd); + virtual Value * + getStepVector(Value *Val, Value *StartIdx, Value *Step, + Instruction::BinaryOps Opcode = Instruction::BinaryOpsEnd); /// Compute scalar induction steps. \p ScalarIV is the scalar induction /// variable on which to base the steps, \p Step is the size of the step, and @@ -889,9 +889,9 @@ private: Value *getBroadcastInstrs(Value *V) override; - Value *getStepVector(Value *Val, int StartIdx, Value *Step, - Instruction::BinaryOps Opcode = - Instruction::BinaryOpsEnd) override; + Value *getStepVector( + Value *Val, Value *StartIdx, Value *Step, + Instruction::BinaryOps Opcode = Instruction::BinaryOpsEnd) override; Value *reverseVector(Value *Vec) override; }; @@ -1119,6 +1119,13 @@ return VF.isScalable() ? B.CreateVScale(EC) : EC; } +Value *getRuntimeVFAsFloat(IRBuilder<> &B, Type *FTy, ElementCount VF) { + assert(FTy->isFloatingPointTy() && "Expected floating point type!"); + Type *IntTy = IntegerType::get(FTy->getContext(), FTy->getScalarSizeInBits()); + Value *RuntimeVF = getRuntimeVF(B, IntTy, VF); + return B.CreateSIToFP(RuntimeVF, FTy); +} + void reportVectorizationFailure(const StringRef DebugMsg, const StringRef OREMsg, const StringRef ORETag, OptimizationRemarkEmitter *ORE, Loop *TheLoop, @@ -2286,9 +2293,16 @@ Step = Builder.CreateTrunc(Step, TruncType); Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType); } + + Value *Zero; + if (Start->getType()->isFloatingPointTy()) + Zero = ConstantFP::get(Start->getType(), 0); + else + Zero = ConstantInt::get(Start->getType(), 0); + Value *SplatStart = Builder.CreateVectorSplat(VF, Start); Value *SteppedStart = - getStepVector(SplatStart, 0, Step, II.getInductionOpcode()); + getStepVector(SplatStart, Zero, Step, II.getInductionOpcode()); // We create vector phi nodes for both integer and floating-point induction // variables. Here, we determine the kind of arithmetic we will perform. @@ -2305,12 +2319,11 @@ // Multiply the vectorization factor by the step using integer or // floating-point arithmetic as appropriate. Type *StepType = Step->getType(); + Value *RuntimeVF; if (Step->getType()->isFloatingPointTy()) - StepType = IntegerType::get(StepType->getContext(), - StepType->getScalarSizeInBits()); - Value *RuntimeVF = getRuntimeVF(Builder, StepType, VF); - if (Step->getType()->isFloatingPointTy()) - RuntimeVF = Builder.CreateSIToFP(RuntimeVF, Step->getType()); + RuntimeVF = getRuntimeVFAsFloat(Builder, StepType, VF); + else + RuntimeVF = getRuntimeVF(Builder, StepType, VF); Value *Mul = Builder.CreateBinOp(MulOp, Step, RuntimeVF); // Create a vector splat to use in the induction update. @@ -2459,9 +2472,14 @@ Value *Broadcasted = getBroadcastInstrs(ScalarIV); for (unsigned Part = 0; Part < UF; ++Part) { assert(!VF.isScalable() && "scalable vectors not yet supported."); + Value *StartIdx; + if (Step->getType()->isFloatingPointTy()) + StartIdx = getRuntimeVFAsFloat(Builder, Step->getType(), VF * Part); + else + StartIdx = getRuntimeVF(Builder, Step->getType(), VF * Part); + Value *EntryPart = - getStepVector(Broadcasted, VF.getKnownMinValue() * Part, Step, - ID.getInductionOpcode()); + getStepVector(Broadcasted, StartIdx, Step, ID.getInductionOpcode()); State.set(Def, EntryPart, Part); if (Trunc) addMetadata(EntryPart, Trunc); @@ -2517,7 +2535,8 @@ buildScalarSteps(ScalarIV, Step, EntryVal, ID, Def, CastDef, State); } -Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step, +Value *InnerLoopVectorizer::getStepVector(Value *Val, Value *StartIdx, + Value *Step, Instruction::BinaryOps BinOp) { // Create and check the types. auto *ValVTy = cast(Val->getType()); @@ -2540,12 +2559,11 @@ } Value *InitVec = Builder.CreateStepVector(InitVecValVTy); - // Add on StartIdx - Value *StartIdxSplat = Builder.CreateVectorSplat( - VLen, ConstantInt::get(InitVecValSTy, StartIdx)); - InitVec = Builder.CreateAdd(InitVec, StartIdxSplat); + // Splat the StartIdx + Value *StartIdxSplat = Builder.CreateVectorSplat(VLen, StartIdx); if (STy->isIntegerTy()) { + InitVec = Builder.CreateAdd(InitVec, StartIdxSplat); Step = Builder.CreateVectorSplat(VLen, Step); assert(Step->getType() == Val->getType() && "Invalid step vec"); // FIXME: The newly created binary instructions should contain nsw/nuw flags, @@ -2558,6 +2576,8 @@ assert((BinOp == Instruction::FAdd || BinOp == Instruction::FSub) && "Binary Opcode should be specified for FP induction"); InitVec = Builder.CreateUIToFP(InitVec, ValVTy); + InitVec = Builder.CreateFAdd(InitVec, StartIdxSplat); + Step = Builder.CreateVectorSplat(VLen, Step); Value *MulOp = Builder.CreateFMul(InitVec, Step); return Builder.CreateBinOp(BinOp, Val, MulOp, "induction"); @@ -8306,21 +8326,19 @@ Value *InnerLoopUnroller::getBroadcastInstrs(Value *V) { return V; } -Value *InnerLoopUnroller::getStepVector(Value *Val, int StartIdx, Value *Step, +Value *InnerLoopUnroller::getStepVector(Value *Val, Value *StartIdx, + Value *Step, Instruction::BinaryOps BinOp) { // When unrolling and the VF is 1, we only need to add a simple scalar. Type *Ty = Val->getType(); assert(!Ty->isVectorTy() && "Val must be a scalar"); if (Ty->isFloatingPointTy()) { - Constant *C = ConstantFP::get(Ty, (double)StartIdx); - // Floating-point operations inherit FMF via the builder's flags. - Value *MulOp = Builder.CreateFMul(C, Step); + Value *MulOp = Builder.CreateFMul(StartIdx, Step); return Builder.CreateBinOp(BinOp, Val, MulOp); } - Constant *C = ConstantInt::get(Ty, StartIdx); - return Builder.CreateAdd(Val, Builder.CreateMul(C, Step), "induction"); + return Builder.CreateAdd(Val, Builder.CreateMul(StartIdx, Step), "induction"); } static void AddRuntimeUnrollDisableMetaData(Loop *L) {