diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -609,10 +609,6 @@ /// the block that was created for it. void sinkScalarOperands(Instruction *PredInst); - /// Shrinks vector element sizes to the smallest bitwidth they can be legally - /// represented as. - void truncateToMinimalBitwidths(VPTransformState &State); - /// Returns (and creates if needed) the trip count of the widened loop. Value *getOrCreateVectorTripCount(BasicBlock *InsertBlock); @@ -3554,146 +3550,8 @@ return I1->getBitWidth() > I2->getBitWidth() ? T1 : T2; } -void InnerLoopVectorizer::truncateToMinimalBitwidths(VPTransformState &State) { - // For every instruction `I` in MinBWs, truncate the operands, create a - // truncated version of `I` and reextend its result. InstCombine runs - // later and will remove any ext/trunc pairs. - SmallPtrSet Erased; - for (const auto &KV : Cost->getMinimalBitwidths()) { - // If the value wasn't vectorized, we must maintain the original scalar - // type. The absence of the value from State indicates that it - // wasn't vectorized. - // FIXME: Should not rely on getVPValue at this point. - VPValue *Def = State.Plan->getVPValue(KV.first, true); - if (!State.hasAnyVectorValue(Def)) - continue; - for (unsigned Part = 0; Part < UF; ++Part) { - Value *I = State.get(Def, Part); - if (Erased.count(I) || I->use_empty() || !isa(I)) - continue; - Type *OriginalTy = I->getType(); - Type *ScalarTruncatedTy = - IntegerType::get(OriginalTy->getContext(), KV.second); - auto *TruncatedTy = VectorType::get( - ScalarTruncatedTy, cast(OriginalTy)->getElementCount()); - if (TruncatedTy == OriginalTy) - continue; - - IRBuilder<> B(cast(I)); - auto ShrinkOperand = [&](Value *V) -> Value * { - if (auto *ZI = dyn_cast(V)) - if (ZI->getSrcTy() == TruncatedTy) - return ZI->getOperand(0); - return B.CreateZExtOrTrunc(V, TruncatedTy); - }; - - // The actual instruction modification depends on the instruction type, - // unfortunately. - Value *NewI = nullptr; - if (auto *BO = dyn_cast(I)) { - NewI = B.CreateBinOp(BO->getOpcode(), ShrinkOperand(BO->getOperand(0)), - ShrinkOperand(BO->getOperand(1))); - - // Any wrapping introduced by shrinking this operation shouldn't be - // considered undefined behavior. So, we can't unconditionally copy - // arithmetic wrapping flags to NewI. - cast(NewI)->copyIRFlags(I, /*IncludeWrapFlags=*/false); - } else if (auto *CI = dyn_cast(I)) { - NewI = - B.CreateICmp(CI->getPredicate(), ShrinkOperand(CI->getOperand(0)), - ShrinkOperand(CI->getOperand(1))); - } else if (auto *SI = dyn_cast(I)) { - NewI = B.CreateSelect(SI->getCondition(), - ShrinkOperand(SI->getTrueValue()), - ShrinkOperand(SI->getFalseValue())); - } else if (auto *CI = dyn_cast(I)) { - switch (CI->getOpcode()) { - default: - llvm_unreachable("Unhandled cast!"); - case Instruction::Trunc: - NewI = ShrinkOperand(CI->getOperand(0)); - break; - case Instruction::SExt: - NewI = B.CreateSExtOrTrunc( - CI->getOperand(0), - smallestIntegerVectorType(OriginalTy, TruncatedTy)); - break; - case Instruction::ZExt: - NewI = B.CreateZExtOrTrunc( - CI->getOperand(0), - smallestIntegerVectorType(OriginalTy, TruncatedTy)); - break; - } - } else if (auto *SI = dyn_cast(I)) { - auto Elements0 = - cast(SI->getOperand(0)->getType())->getElementCount(); - auto *O0 = B.CreateZExtOrTrunc( - SI->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements0)); - auto Elements1 = - cast(SI->getOperand(1)->getType())->getElementCount(); - auto *O1 = B.CreateZExtOrTrunc( - SI->getOperand(1), VectorType::get(ScalarTruncatedTy, Elements1)); - - NewI = B.CreateShuffleVector(O0, O1, SI->getShuffleMask()); - } else if (isa(I) || isa(I)) { - // Don't do anything with the operands, just extend the result. - continue; - } else if (auto *IE = dyn_cast(I)) { - auto Elements = - cast(IE->getOperand(0)->getType())->getElementCount(); - auto *O0 = B.CreateZExtOrTrunc( - IE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements)); - auto *O1 = B.CreateZExtOrTrunc(IE->getOperand(1), ScalarTruncatedTy); - NewI = B.CreateInsertElement(O0, O1, IE->getOperand(2)); - } else if (auto *EE = dyn_cast(I)) { - auto Elements = - cast(EE->getOperand(0)->getType())->getElementCount(); - auto *O0 = B.CreateZExtOrTrunc( - EE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements)); - NewI = B.CreateExtractElement(O0, EE->getOperand(2)); - } else { - // If we don't know what to do, be conservative and don't do anything. - continue; - } - - // Lastly, extend the result. - NewI->takeName(cast(I)); - Value *Res = B.CreateZExtOrTrunc(NewI, OriginalTy); - I->replaceAllUsesWith(Res); - cast(I)->eraseFromParent(); - Erased.insert(I); - State.reset(Def, Res, Part); - } - } - - // We'll have created a bunch of ZExts that are now parentless. Clean up. - for (const auto &KV : Cost->getMinimalBitwidths()) { - // If the value wasn't vectorized, we must maintain the original scalar - // type. The absence of the value from State indicates that it - // wasn't vectorized. - // FIXME: Should not rely on getVPValue at this point. - VPValue *Def = State.Plan->getVPValue(KV.first, true); - if (!State.hasAnyVectorValue(Def)) - continue; - for (unsigned Part = 0; Part < UF; ++Part) { - Value *I = State.get(Def, Part); - ZExtInst *Inst = dyn_cast(I); - if (Inst && Inst->use_empty()) { - Value *NewI = Inst->getOperand(0); - Inst->eraseFromParent(); - State.reset(Def, NewI, Part); - } - } - } -} - void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, VPlan &Plan) { - // Insert truncates and extends for any truncated instructions as hints to - // InstCombine. - if (VF.isVector()) - truncateToMinimalBitwidths(State); - // Fix widened non-induction PHIs by setting up the PHI operands. if (EnableVPlanNativePath) fixNonInductionPHIs(Plan, State); @@ -9019,6 +8877,7 @@ VPlanTransforms::removeRedundantCanonicalIVs(*Plan); VPlanTransforms::removeRedundantInductionCasts(*Plan); + VPlanTransforms::truncateToMinimalBitwidths(*Plan, CM.getMinimalBitwidths()); // Adjust the recipes for any inloop reductions. adjustRecipesForReductions(cast(TopRegion->getExiting()), Plan, diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -86,6 +86,11 @@ static void optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, unsigned BestUF, PredicatedScalarEvolution &PSE); + /// Insert truncates and extends for any truncated instructions as hints to + /// InstCombine. + static void + truncateToMinimalBitwidths(VPlan &Plan, + const MapVector &MinBWs); }; } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -754,3 +754,82 @@ } return true; } + +void VPlanTransforms::truncateToMinimalBitwidths( + VPlan &Plan, const MapVector &MinBWs) { + auto GetType = [](VPValue *Op) { + auto *UV = Op->getUnderlyingValue(); + if (UV) + return UV->getType(); + if (auto *VPC = dyn_cast(Op)) { + return VPC->getResultType(); + } + llvm_unreachable("trying to get type of a VPValue without type info"); + }; + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly( + vp_depth_first_deep(Plan.getEntry()))) { + for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { + if (R.getNumDefinedValues() != 1) + continue; + + auto *UV = + cast_or_null(R.getVPSingleValue()->getUnderlyingValue()); + auto I = MinBWs.find(UV); + if (!UV || I == MinBWs.end()) + continue; + + Type *ResTy = UV->getType(); + if (!ResTy->isIntegerTy() || ResTy->getScalarSizeInBits() == I->second) + continue; + + if (!isa(&R)) + continue; + + LLVMContext &Ctx = ResTy->getContext(); + + // Try to replace wider SExt/ZExts with narrower ones if possible. + if (auto *VPW = dyn_cast(&R)) { + Instruction *UI = VPW->getUnderlyingInstr(); + switch (UI->getOpcode()) { + default: + break; + case Instruction::SExt: + case Instruction::ZExt: { + if (UI->getType()->getScalarSizeInBits() > I->second) { + if (GetType(VPW->getOperand(0))->getScalarSizeInBits() >= I->second) + break; + auto *C = new VPWidenCastRecipe(cast(UI)->getOpcode(), + VPW->getOperand(0), + IntegerType::get(Ctx, I->second)); + C->insertBefore(VPW); + VPW->replaceAllUsesWith(C); + continue; + } + } + } + } + + // Shrink operands by introducing truncates as needed. + for (unsigned Idx = 0; Idx != R.getNumOperands(); ++Idx) { + auto *Op = R.getOperand(Idx); + if (GetType(Op)->getScalarSizeInBits() == I->second) + continue; + if (auto *VPW = dyn_cast(&R)) + VPW->dropPoisonGeneratingFlags(); + + auto *Shrunk = + new VPWidenCastRecipe(Instruction::Trunc, R.getOperand(Idx), + IntegerType::get(Ctx, I->second)); + R.setOperand(Idx, Shrunk); + Shrunk->insertBefore(&R); + } + + // Extend result to original width. + auto *Ext = + new VPWidenCastRecipe(Instruction::ZExt, R.getVPSingleValue(), ResTy); + R.getVPSingleValue()->replaceAllUsesWith(Ext); + Ext->setOperand(0, R.getVPSingleValue()); + Ext->insertAfter(&R); + } + } +} diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/deterministic-type-shrinkage.ll b/llvm/test/Transforms/LoopVectorize/AArch64/deterministic-type-shrinkage.ll --- a/llvm/test/Transforms/LoopVectorize/AArch64/deterministic-type-shrinkage.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/deterministic-type-shrinkage.ll @@ -9,11 +9,11 @@ ;; CHECK-LABEL: @test( ; CHECK: load <16 x i8> +; CHECK-NEXT: zext <16 x i8> ; CHECK-NEXT: getelementptr ; CHECK-NEXT: bitcast ; CHECK-NEXT: load <16 x i8> ; CHECK-NEXT: zext <16 x i8> -; CHECK-NEXT: zext <16 x i8> define void @test(i32 %n, i8* nocapture %a, i8* nocapture %b, i8* nocapture readonly %c) { entry: %cmp.28 = icmp eq i32 %n, 0 diff --git a/llvm/test/Transforms/LoopVectorize/ARM/pointer_iv.ll b/llvm/test/Transforms/LoopVectorize/ARM/pointer_iv.ll --- a/llvm/test/Transforms/LoopVectorize/ARM/pointer_iv.ll +++ b/llvm/test/Transforms/LoopVectorize/ARM/pointer_iv.ll @@ -62,8 +62,8 @@ ; CHECK-NEXT: [[TMP2:%.*]] = add nsw <4 x i32> [[STRIDED_VEC]], [[BROADCAST_SPLAT]] ; CHECK-NEXT: store <4 x i32> [[TMP2]], ptr [[NEXT_GEP4]], align 4 ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4 -; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i32 [[INDEX_NEXT]], 996 -; CHECK-NEXT: br i1 [[TMP4]], label [[FOR_BODY:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i32 [[INDEX_NEXT]], 996 +; CHECK-NEXT: br i1 [[TMP3]], label [[FOR_BODY:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]] ; CHECK: for.body: ; CHECK-NEXT: [[A_ADDR_09:%.*]] = phi ptr [ [[ADD_PTR:%.*]], [[FOR_BODY]] ], [ [[IND_END]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[I_08:%.*]] = phi i32 [ [[INC:%.*]], [[FOR_BODY]] ], [ 996, [[VECTOR_BODY]] ] @@ -881,10 +881,10 @@ define hidden void @mult_ptr_iv(ptr noalias nocapture readonly %x, ptr noalias nocapture %z) { ; CHECK-LABEL: @mult_ptr_iv( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[UGLYGEP:%.*]] = getelementptr i8, ptr [[Z:%.*]], i32 3000 -; CHECK-NEXT: [[UGLYGEP1:%.*]] = getelementptr i8, ptr [[X:%.*]], i32 3000 -; CHECK-NEXT: [[BOUND0:%.*]] = icmp ugt ptr [[UGLYGEP1]], [[Z]] -; CHECK-NEXT: [[BOUND1:%.*]] = icmp ugt ptr [[UGLYGEP]], [[X]] +; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[Z:%.*]], i32 3000 +; CHECK-NEXT: [[SCEVGEP1:%.*]] = getelementptr i8, ptr [[X:%.*]], i32 3000 +; CHECK-NEXT: [[BOUND0:%.*]] = icmp ugt ptr [[SCEVGEP1]], [[Z]] +; CHECK-NEXT: [[BOUND1:%.*]] = icmp ugt ptr [[SCEVGEP]], [[X]] ; CHECK-NEXT: [[FOUND_CONFLICT:%.*]] = and i1 [[BOUND0]], [[BOUND1]] ; CHECK-NEXT: br i1 [[FOUND_CONFLICT]], label [[FOR_BODY:%.*]], label [[VECTOR_PH:%.*]] ; CHECK: vector.ph: