diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Vectorize/VectorCombine.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" @@ -103,11 +104,12 @@ bool foldSingleElementStore(Instruction &I); bool scalarizeLoadExtract(Instruction &I); bool foldShuffleOfBinops(Instruction &I); + bool foldShuffleFromReductions(Instruction &I); void replaceValue(Value &Old, Value &New) { Old.replaceAllUsesWith(&New); - New.takeName(&Old); if (auto *NewI = dyn_cast(&New)) { + New.takeName(&Old); Worklist.pushUsersToWorkList(*NewI); Worklist.pushValue(NewI); } @@ -1113,6 +1115,118 @@ return true; } +/// Given a commutative reduction, the order of the input lanes does not alter +/// the results. We can use this to remove certain shuffles feeding the +/// reduction, removing the need to shuffle at all. +bool VectorCombine::foldShuffleFromReductions(Instruction &I) { + auto *II = dyn_cast(&I); + if (!II) + return false; + switch (II->getIntrinsicID()) { + case Intrinsic::vector_reduce_add: + case Intrinsic::vector_reduce_mul: + case Intrinsic::vector_reduce_and: + case Intrinsic::vector_reduce_or: + case Intrinsic::vector_reduce_xor: + case Intrinsic::vector_reduce_smin: + case Intrinsic::vector_reduce_smax: + case Intrinsic::vector_reduce_umin: + case Intrinsic::vector_reduce_umax: + break; + default: + return false; + } + + // Find all the inputs when looking through operations that do not alter the + // lane order (binops, for example). Currently we look for a single shuffle, + // and can ignore splat values. + std::queue Worklist; + SmallPtrSet Visited; + ShuffleVectorInst *Shuffle = nullptr; + if (auto *Op = dyn_cast(I.getOperand(0))) + Worklist.push(Op); + + while (!Worklist.empty()) { + Value *CV = Worklist.front(); + Worklist.pop(); + if (Visited.contains(CV)) + continue; + + // Splats don't change the order, so can be safely ignored. + if (isSplatValue(CV)) + continue; + + Visited.insert(CV); + + if (auto *CI = dyn_cast(CV)) { + if (CI->isBinaryOp()) { + for (auto *Op : CI->operand_values()) + Worklist.push(Op); + continue; + } else if (auto *SV = dyn_cast(CI)) { + if (Shuffle && Shuffle != SV) + return false; + Shuffle = SV; + continue; + } + } + + // Anything else is currently an unknown node. + return false; + } + + if (!Shuffle) + return false; + + // Check all uses of the binary ops and shuffles are also included in the + // lane-invariant operations (Visited should be the list of lanewise + // instructions, including the shuffle that we found). + for (auto *V : Visited) + for (auto *U : V->users()) + if (!Visited.contains(U) && U != &I) + return false; + + FixedVectorType *VecType = + dyn_cast(II->getOperand(0)->getType()); + if (!VecType) + return false; + FixedVectorType *ShuffleInputType = + dyn_cast(Shuffle->getOperand(0)->getType()); + if (!ShuffleInputType) + return false; + int NumInputElts = ShuffleInputType->getNumElements(); + + // Find the mask from sorting the lanes into order. This is most likely to + // become a identity or concat mask. Undef elements are pushed to the end. + SmallVector ConcatMask; + Shuffle->getShuffleMask(ConcatMask); + sort(ConcatMask, [](int X, int Y) { + return Y == UndefMaskElem ? true : (X == UndefMaskElem ? false : X < Y); + }); + bool UsesSecondVec = + any_of(ConcatMask, [&](int M) { return M >= NumInputElts; }); + InstructionCost OldCost = TTI.getShuffleCost( + UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType, + Shuffle->getShuffleMask()); + InstructionCost NewCost = TTI.getShuffleCost( + UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType, + ConcatMask); + + LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle + << "\n"); + LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost + << "\n"); + if (NewCost < OldCost) { + Builder.SetInsertPoint(Shuffle); + Value *NewShuffle = Builder.CreateShuffleVector( + Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask); + LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n"); + replaceValue(*Shuffle, *NewShuffle); + } + + return false; +} + /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. bool VectorCombine::run() { @@ -1132,6 +1246,7 @@ MadeChange |= foldBitcastShuf(I); MadeChange |= foldExtractedCmps(I); MadeChange |= foldShuffleOfBinops(I); + MadeChange |= foldShuffleFromReductions(I); } MadeChange |= scalarizeBinopOrCmp(I); MadeChange |= scalarizeLoadExtract(I); diff --git a/llvm/test/Transforms/VectorCombine/AArch64/vecreduce-shuffle.ll b/llvm/test/Transforms/VectorCombine/AArch64/vecreduce-shuffle.ll --- a/llvm/test/Transforms/VectorCombine/AArch64/vecreduce-shuffle.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/vecreduce-shuffle.ll @@ -16,7 +16,7 @@ define i32 @reduceshuffle_onein_v4i32(<4 x i32> %a) { ; CHECK-LABEL: @reduceshuffle_onein_v4i32( -; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <4 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <4 x i32> ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] ; @@ -27,7 +27,7 @@ define i32 @reduceshuffle_onein_const_v4i32(<4 x i32> %a) { ; CHECK-LABEL: @reduceshuffle_onein_const_v4i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <4 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <4 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <4 x i32> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] @@ -77,7 +77,7 @@ define i32 @reduceshuffle_twoin_concat_v4i32(<2 x i32> %a, <2 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_concat_v4i32( -; CHECK-NEXT: [[X:%.*]] = shufflevector <2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]], <4 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]], <4 x i32> ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] ; @@ -99,7 +99,7 @@ define i32 @reduceshuffle_twoin_notlowelts_v4i32(<4 x i32> %a, <4 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_notlowelts_v4i32( -; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] ; @@ -121,7 +121,7 @@ define i32 @reduceshuffle_twoin_uneven_v4i32(<4 x i32> %a, <4 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_uneven_v4i32( -; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] ; @@ -143,7 +143,7 @@ define i32 @reduceshuffle_twoin_undef2_v4i32(<4 x i32> %a, <4 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_undef2_v4i32( -; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] ; @@ -154,7 +154,7 @@ define i32 @reduceshuffle_twoin_multiundef_v4i32(<4 x i32> %a, <4 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_multiundef_v4i32( -; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] ; @@ -222,7 +222,7 @@ define i32 @reduceshuffle_onein_v16i32(<16 x i32> %a) { ; CHECK-LABEL: @reduceshuffle_onein_v16i32( -; CHECK-NEXT: [[X:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> undef, <16 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> undef, <16 x i32> ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] ; @@ -233,7 +233,7 @@ define i32 @reduceshuffle_onein_ext_v16i32(<16 x i32> %a) { ; CHECK-LABEL: @reduceshuffle_onein_ext_v16i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> undef, <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> undef, <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i32> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] @@ -246,7 +246,7 @@ define i32 @reduceshuffle_twoin_concat_v16i32(<8 x i32> %a, <8 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_concat_v16i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i32> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] @@ -259,7 +259,7 @@ define i32 @reduceshuffle_twoin_lowelt_v16i32(<16 x i32> %a, <16 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_lowelt_v16i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i32> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] @@ -285,7 +285,7 @@ define i32 @reduceshuffle_twoin_uneven_v16i32(<16 x i32> %a, <16 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_uneven_v16i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i32> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] @@ -298,7 +298,7 @@ define i32 @reduceshuffle_twoin_shr1_v16i32(<16 x i32> %a, <16 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_shr1_v16i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[A1:%.*]] = lshr <16 x i32> [[S]], ; CHECK-NEXT: [[A2:%.*]] = and <16 x i32> [[A1]], ; CHECK-NEXT: [[A3:%.*]] = mul nuw <16 x i32> [[A2]], @@ -319,7 +319,7 @@ define i32 @reduceshuffle_twoin_shr2_v16i32(<16 x i32> %a, <16 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_shr2_v16i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[A1:%.*]] = lshr <16 x i32> , [[S]] ; CHECK-NEXT: [[A2:%.*]] = and <16 x i32> [[A1]], ; CHECK-NEXT: [[A3:%.*]] = mul nuw <16 x i32> [[A2]], @@ -353,7 +353,7 @@ define i16 @reduceshuffle_onein_v16i16(<16 x i16> %a) { ; CHECK-LABEL: @reduceshuffle_onein_v16i16( -; CHECK-NEXT: [[X:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> undef, <16 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> undef, <16 x i32> ; CHECK-NEXT: [[R:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[X]]) ; CHECK-NEXT: ret i16 [[R]] ; @@ -364,7 +364,7 @@ define i16 @reduceshuffle_onein_ext_v16i16(<16 x i16> %a) { ; CHECK-LABEL: @reduceshuffle_onein_ext_v16i16( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> undef, <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> undef, <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i16> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[X]]) ; CHECK-NEXT: ret i16 [[R]] @@ -377,7 +377,7 @@ define i16 @reduceshuffle_twoin_concat_v16i16(<8 x i16> %a, <8 x i16> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_concat_v16i16( -; CHECK-NEXT: [[S:%.*]] = shufflevector <8 x i16> [[A:%.*]], <8 x i16> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <8 x i16> [[A:%.*]], <8 x i16> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i16> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[X]]) ; CHECK-NEXT: ret i16 [[R]] @@ -390,7 +390,7 @@ define i16 @reduceshuffle_twoin_lowelt_v16i16(<16 x i16> %a, <16 x i16> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_lowelt_v16i16( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i16> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[X]]) ; CHECK-NEXT: ret i16 [[R]] @@ -429,7 +429,7 @@ define i16 @reduceshuffle_twoin_ext_v16i16(<16 x i16> %a, <16 x i16> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_ext_v16i16( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[A1:%.*]] = lshr <16 x i16> [[S]], ; CHECK-NEXT: [[A2:%.*]] = and <16 x i16> [[A1]], ; CHECK-NEXT: [[A3:%.*]] = mul nuw <16 x i16> [[A2]],