diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -614,10 +614,6 @@ /// the block that was created for it. void sinkScalarOperands(Instruction *PredInst); - /// Shrinks vector element sizes to the smallest bitwidth they can be legally - /// represented as. - void truncateToMinimalBitwidths(VPTransformState &State); - /// Returns (and creates if needed) the trip count of the widened loop. Value *getOrCreateVectorTripCount(BasicBlock *InsertBlock); @@ -3524,146 +3520,8 @@ return I1->getBitWidth() > I2->getBitWidth() ? T1 : T2; } -void InnerLoopVectorizer::truncateToMinimalBitwidths(VPTransformState &State) { - // For every instruction `I` in MinBWs, truncate the operands, create a - // truncated version of `I` and reextend its result. InstCombine runs - // later and will remove any ext/trunc pairs. - SmallPtrSet Erased; - for (const auto &KV : Cost->getMinimalBitwidths()) { - // If the value wasn't vectorized, we must maintain the original scalar - // type. The absence of the value from State indicates that it - // wasn't vectorized. - // FIXME: Should not rely on getVPValue at this point. - VPValue *Def = State.Plan->getVPValue(KV.first, true); - if (!State.hasAnyVectorValue(Def)) - continue; - for (unsigned Part = 0; Part < UF; ++Part) { - Value *I = State.get(Def, Part); - if (Erased.count(I) || I->use_empty() || !isa(I)) - continue; - Type *OriginalTy = I->getType(); - Type *ScalarTruncatedTy = - IntegerType::get(OriginalTy->getContext(), KV.second); - auto *TruncatedTy = VectorType::get( - ScalarTruncatedTy, cast(OriginalTy)->getElementCount()); - if (TruncatedTy == OriginalTy) - continue; - - IRBuilder<> B(cast(I)); - auto ShrinkOperand = [&](Value *V) -> Value * { - if (auto *ZI = dyn_cast(V)) - if (ZI->getSrcTy() == TruncatedTy) - return ZI->getOperand(0); - return B.CreateZExtOrTrunc(V, TruncatedTy); - }; - - // The actual instruction modification depends on the instruction type, - // unfortunately. - Value *NewI = nullptr; - if (auto *BO = dyn_cast(I)) { - NewI = B.CreateBinOp(BO->getOpcode(), ShrinkOperand(BO->getOperand(0)), - ShrinkOperand(BO->getOperand(1))); - - // Any wrapping introduced by shrinking this operation shouldn't be - // considered undefined behavior. So, we can't unconditionally copy - // arithmetic wrapping flags to NewI. - cast(NewI)->copyIRFlags(I, /*IncludeWrapFlags=*/false); - } else if (auto *CI = dyn_cast(I)) { - NewI = - B.CreateICmp(CI->getPredicate(), ShrinkOperand(CI->getOperand(0)), - ShrinkOperand(CI->getOperand(1))); - } else if (auto *SI = dyn_cast(I)) { - NewI = B.CreateSelect(SI->getCondition(), - ShrinkOperand(SI->getTrueValue()), - ShrinkOperand(SI->getFalseValue())); - } else if (auto *CI = dyn_cast(I)) { - switch (CI->getOpcode()) { - default: - llvm_unreachable("Unhandled cast!"); - case Instruction::Trunc: - NewI = ShrinkOperand(CI->getOperand(0)); - break; - case Instruction::SExt: - NewI = B.CreateSExtOrTrunc( - CI->getOperand(0), - smallestIntegerVectorType(OriginalTy, TruncatedTy)); - break; - case Instruction::ZExt: - NewI = B.CreateZExtOrTrunc( - CI->getOperand(0), - smallestIntegerVectorType(OriginalTy, TruncatedTy)); - break; - } - } else if (auto *SI = dyn_cast(I)) { - auto Elements0 = - cast(SI->getOperand(0)->getType())->getElementCount(); - auto *O0 = B.CreateZExtOrTrunc( - SI->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements0)); - auto Elements1 = - cast(SI->getOperand(1)->getType())->getElementCount(); - auto *O1 = B.CreateZExtOrTrunc( - SI->getOperand(1), VectorType::get(ScalarTruncatedTy, Elements1)); - - NewI = B.CreateShuffleVector(O0, O1, SI->getShuffleMask()); - } else if (isa(I) || isa(I)) { - // Don't do anything with the operands, just extend the result. - continue; - } else if (auto *IE = dyn_cast(I)) { - auto Elements = - cast(IE->getOperand(0)->getType())->getElementCount(); - auto *O0 = B.CreateZExtOrTrunc( - IE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements)); - auto *O1 = B.CreateZExtOrTrunc(IE->getOperand(1), ScalarTruncatedTy); - NewI = B.CreateInsertElement(O0, O1, IE->getOperand(2)); - } else if (auto *EE = dyn_cast(I)) { - auto Elements = - cast(EE->getOperand(0)->getType())->getElementCount(); - auto *O0 = B.CreateZExtOrTrunc( - EE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements)); - NewI = B.CreateExtractElement(O0, EE->getOperand(2)); - } else { - // If we don't know what to do, be conservative and don't do anything. - continue; - } - - // Lastly, extend the result. - NewI->takeName(cast(I)); - Value *Res = B.CreateZExtOrTrunc(NewI, OriginalTy); - I->replaceAllUsesWith(Res); - cast(I)->eraseFromParent(); - Erased.insert(I); - State.reset(Def, Res, Part); - } - } - - // We'll have created a bunch of ZExts that are now parentless. Clean up. - for (const auto &KV : Cost->getMinimalBitwidths()) { - // If the value wasn't vectorized, we must maintain the original scalar - // type. The absence of the value from State indicates that it - // wasn't vectorized. - // FIXME: Should not rely on getVPValue at this point. - VPValue *Def = State.Plan->getVPValue(KV.first, true); - if (!State.hasAnyVectorValue(Def)) - continue; - for (unsigned Part = 0; Part < UF; ++Part) { - Value *I = State.get(Def, Part); - ZExtInst *Inst = dyn_cast(I); - if (Inst && Inst->use_empty()) { - Value *NewI = Inst->getOperand(0); - Inst->eraseFromParent(); - State.reset(Def, NewI, Part); - } - } - } -} - void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, VPlan &Plan) { - // Insert truncates and extends for any truncated instructions as hints to - // InstCombine. - if (VF.isVector()) - truncateToMinimalBitwidths(State); - // Fix widened non-induction PHIs by setting up the PHI operands. if (EnableVPlanNativePath) fixNonInductionPHIs(Plan, State); @@ -9003,6 +8861,7 @@ VPlanTransforms::removeRedundantCanonicalIVs(*Plan); VPlanTransforms::removeRedundantInductionCasts(*Plan); + VPlanTransforms::truncateToMinimalBitwidths(*Plan, CM.getMinimalBitwidths()); // Adjust the recipes for any inloop reductions. adjustRecipesForReductions(cast(TopRegion->getExiting()), Plan, diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -86,6 +86,11 @@ static void optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, unsigned BestUF, PredicatedScalarEvolution &PSE); + /// Insert truncates and extends for any truncated instructions as hints to + /// InstCombine. + static void + truncateToMinimalBitwidths(VPlan &Plan, + const MapVector &MinBWs); }; } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -743,3 +743,78 @@ } return true; } + +void VPlanTransforms::truncateToMinimalBitwidths( + VPlan &Plan, const MapVector &MinBWs) { + auto GetSizeInBits = [](VPValue *VPV) { + auto *UV = VPV->getUnderlyingValue(); + if (UV) + return UV->getType()->getScalarSizeInBits(); + if (auto *VPC = dyn_cast(VPV)) { + return VPC->getResultType()->getScalarSizeInBits(); + } + llvm_unreachable("trying to get type of a VPValue without type info"); + }; + + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly( + vp_depth_first_deep(Plan.getEntry()))) { + for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { + if (!isa(&R)) + continue; + + VPValue *ResultVPV = R.getVPSingleValue(); + auto *UI = cast_or_null(ResultVPV->getUnderlyingValue()); + auto I = MinBWs.find(UI); + if (!UI || I == MinBWs.end()) + continue; + + unsigned ResSizeInBits = GetSizeInBits(ResultVPV); + unsigned NewResSizeInBits = I->second; + Type *ResTy = UI->getType(); + if (!ResTy->isIntegerTy() || ResSizeInBits == NewResSizeInBits) + continue; + + LLVMContext &Ctx = ResTy->getContext(); + auto *NewResTy = IntegerType::get(Ctx, NewResSizeInBits); + + // Try to replace wider SExt/ZExts with narrower ones if possible. + if (auto *VPC = dyn_cast(&R)) { + switch (VPC->getOpcode()) { + default: + break; + case Instruction::SExt: + case Instruction::ZExt: { + auto *Op = R.getOperand(0); + assert(ResSizeInBits > NewResSizeInBits && "Nothing to shrink?"); + if (GetSizeInBits(Op) >= NewResSizeInBits) + break; + auto *C = new VPWidenCastRecipe(VPC->getOpcode(), Op, NewResTy); + C->insertBefore(&R); + ResultVPV->replaceAllUsesWith(C); + continue; + } + } + } + + // Shrink operands by introducing truncates as needed. + for (unsigned Idx = 0; Idx != R.getNumOperands(); ++Idx) { + auto *Op = R.getOperand(Idx); + unsigned OpSizeInBits = GetSizeInBits(Op); + if (OpSizeInBits == NewResSizeInBits) + continue; + assert(OpSizeInBits > NewResSizeInBits && "nothing to truncate"); + auto *Shrunk = new VPWidenCastRecipe(Instruction::Trunc, Op, NewResTy); + R.setOperand(Idx, Shrunk); + Shrunk->insertBefore(&R); + if (auto *VPW = dyn_cast(&R)) + VPW->dropPoisonGeneratingFlags(); + } + + // Extend result to original width. + auto *Ext = new VPWidenCastRecipe(Instruction::ZExt, ResultVPV, ResTy); + ResultVPV->replaceAllUsesWith(Ext); + Ext->setOperand(0, ResultVPV); + Ext->insertAfter(&R); + } + } +} diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/deterministic-type-shrinkage.ll b/llvm/test/Transforms/LoopVectorize/AArch64/deterministic-type-shrinkage.ll --- a/llvm/test/Transforms/LoopVectorize/AArch64/deterministic-type-shrinkage.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/deterministic-type-shrinkage.ll @@ -9,11 +9,11 @@ ;; CHECK-LABEL: @test( ; CHECK: load <16 x i8> +; CHECK-NEXT: zext <16 x i8> ; CHECK-NEXT: getelementptr ; CHECK-NEXT: bitcast ; CHECK-NEXT: load <16 x i8> ; CHECK-NEXT: zext <16 x i8> -; CHECK-NEXT: zext <16 x i8> define void @test(i32 %n, i8* nocapture %a, i8* nocapture %b, i8* nocapture readonly %c) { entry: %cmp.28 = icmp eq i32 %n, 0