Index: llvm/lib/Transforms/Vectorize/VectorCombine.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Vectorize/VectorCombine.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" @@ -103,11 +104,12 @@ bool foldSingleElementStore(Instruction &I); bool scalarizeLoadExtract(Instruction &I); bool foldShuffleOfBinops(Instruction &I); + bool foldShuffleFromReductions(Instruction &I); void replaceValue(Value &Old, Value &New) { Old.replaceAllUsesWith(&New); - New.takeName(&Old); if (auto *NewI = dyn_cast(&New)) { + New.takeName(&Old); Worklist.pushUsersToWorkList(*NewI); Worklist.pushValue(NewI); } @@ -1113,6 +1115,137 @@ return true; } +/// Given a commutative reduction, the order of the input lanes does not alter +/// the results. We can use this to remove certain shuffles feeding the +/// reduction, removing the need to shuffle at all. +bool VectorCombine::foldShuffleFromReductions(Instruction &I) { + auto *II = dyn_cast(&I); + if (!II) + return false; + switch (II->getIntrinsicID()) { + case Intrinsic::vector_reduce_add: + case Intrinsic::vector_reduce_mul: + case Intrinsic::vector_reduce_and: + case Intrinsic::vector_reduce_or: + case Intrinsic::vector_reduce_xor: + case Intrinsic::vector_reduce_smin: + case Intrinsic::vector_reduce_smax: + case Intrinsic::vector_reduce_umin: + case Intrinsic::vector_reduce_umax: + break; + default: + return false; + } + + // Find all the inputs when looking through operations that do not alter the + // lane order (binops, for example). Currently we look for a single shuffle, + // and can ignore splat values. + std::queue Worklist; + SmallPtrSet Visited; + ShuffleVectorInst *Shuffle = nullptr; + if (auto *Op = dyn_cast(I.getOperand(0))) + Worklist.push(Op); + + while (!Worklist.empty()) { + Value *CV = Worklist.front(); + Worklist.pop(); + if (Visited.contains(CV)) + continue; + + // Splats don't change the order, so can be safely ignored. + if (isSplatValue(CV)) + continue; + + Visited.insert(CV); + + if (auto *CI = dyn_cast(CV)) { + if (CI->isBinaryOp()) { + for (auto *Op : CI->operand_values()) + Worklist.push(Op); + continue; + } else if (auto *SV = dyn_cast(CI)) { + if (Shuffle && Shuffle != SV) + return false; + Shuffle = SV; + continue; + } + } + + // Anything else is currently an unknown node. + return false; + } + + if (!Shuffle) + return false; + + // Check all uses of the binary ops and shuffles are also included in the + // lane-invariant operations (Visited should be the list of lanewise + // instructions, including the shuffle that we found). + for (auto *V : Visited) + for (auto *U : V->users()) + if (!Visited.contains(U) && U != &I) + return false; + + FixedVectorType *VecType = + dyn_cast(II->getOperand(0)->getType()); + if (!VecType) + return false; + int NumVecElts = VecType->getNumElements(); + FixedVectorType *ShuffleInputType = + dyn_cast(Shuffle->getOperand(0)->getType()); + if (!ShuffleInputType) + return false; + int NumInputElts = ShuffleInputType->getNumElements(); + + SmallBitVector UsedLanes1(NumInputElts); + SmallBitVector UsedLanes2(NumInputElts); + for (int M : Shuffle->getShuffleMask()) { + if (M < 0) + continue; + else if (M < NumInputElts) + UsedLanes1.set(M); + else + UsedLanes2.set(M - NumInputElts); + } + + // Check if only the low lanes from each vector input are used. The + // simplest case of this is an identity mask from the first vector. + int NumUsed1 = UsedLanes1.find_first_unset(); + if (NumUsed1 == -1) + NumUsed1 = NumInputElts; + int NumUsed2 = UsedLanes2.find_first_unset(); + if (NumUsed2 == -1) + NumUsed2 = NumInputElts; + if (NumUsed1 + NumUsed2 == NumVecElts) { + // Create a shuffle mask in-order and see if the cost is really cheaper. + SmallVector ConcatMask; + for (int i = 0; i < NumUsed1; i++) + ConcatMask.push_back(i); + for (int i = 0; i < NumUsed2; i++) + ConcatMask.push_back(NumInputElts + i); + InstructionCost OldCost = TTI.getShuffleCost( + UsedLanes2.any() ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, + VecType, Shuffle->getShuffleMask()); + InstructionCost NewCost = TTI.getShuffleCost( + UsedLanes2.any() ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, + VecType, ConcatMask); + + LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " + << *Shuffle << "\n"); + LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost + << "\n"); + if (NewCost < OldCost) { + Builder.SetInsertPoint(Shuffle); + Value *NewShuffle = Builder.CreateShuffleVector( + Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask); + LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n"); + replaceValue(*Shuffle, *NewShuffle); + } + } + + return false; +} + /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. bool VectorCombine::run() { @@ -1132,6 +1265,7 @@ MadeChange |= foldBitcastShuf(I); MadeChange |= foldExtractedCmps(I); MadeChange |= foldShuffleOfBinops(I); + MadeChange |= foldShuffleFromReductions(I); } MadeChange |= scalarizeBinopOrCmp(I); MadeChange |= scalarizeLoadExtract(I); Index: llvm/test/Transforms/VectorCombine/AArch64/vecreduce-shuffle.ll =================================================================== --- llvm/test/Transforms/VectorCombine/AArch64/vecreduce-shuffle.ll +++ llvm/test/Transforms/VectorCombine/AArch64/vecreduce-shuffle.ll @@ -16,7 +16,7 @@ define i32 @reduceshuffle_onein_v4i32(<4 x i32> %a) { ; CHECK-LABEL: @reduceshuffle_onein_v4i32( -; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <4 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <4 x i32> ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] ; @@ -27,7 +27,7 @@ define i32 @reduceshuffle_onein_const_v4i32(<4 x i32> %a) { ; CHECK-LABEL: @reduceshuffle_onein_const_v4i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <4 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <4 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <4 x i32> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] @@ -66,7 +66,7 @@ define i32 @reduceshuffle_twoin_concat_v4i32(<2 x i32> %a, <2 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_concat_v4i32( -; CHECK-NEXT: [[X:%.*]] = shufflevector <2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]], <4 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]], <4 x i32> ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] ; @@ -77,7 +77,7 @@ define i32 @reduceshuffle_twoin_lowelts_v4i32(<4 x i32> %a, <4 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_lowelts_v4i32( -; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] ; @@ -110,7 +110,7 @@ define i32 @reduceshuffle_twoin_uneven_v4i32(<4 x i32> %a, <4 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_uneven_v4i32( -; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] ; @@ -160,7 +160,7 @@ define i32 @reduceshuffle_twoin_splat_v4i32(<4 x i32> %a, <4 x i32> %b, i32 %c) { ; CHECK-LABEL: @reduceshuffle_twoin_splat_v4i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> ; CHECK-NEXT: [[INSERT:%.*]] = insertelement <4 x i32> poison, i32 [[C:%.*]], i32 0 ; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <4 x i32> [[INSERT]], <4 x i32> poison, <4 x i32> zeroinitializer ; CHECK-NEXT: [[X:%.*]] = xor <4 x i32> [[S]], [[SPLAT]] @@ -189,7 +189,7 @@ define i32 @reduceshuffle_onein_v16i32(<16 x i32> %a) { ; CHECK-LABEL: @reduceshuffle_onein_v16i32( -; CHECK-NEXT: [[X:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> undef, <16 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> undef, <16 x i32> ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] ; @@ -200,7 +200,7 @@ define i32 @reduceshuffle_onein_ext_v16i32(<16 x i32> %a) { ; CHECK-LABEL: @reduceshuffle_onein_ext_v16i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> undef, <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> undef, <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i32> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] @@ -213,7 +213,7 @@ define i32 @reduceshuffle_twoin_concat_v16i32(<8 x i32> %a, <8 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_concat_v16i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <8 x i32> [[A:%.*]], <8 x i32> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i32> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] @@ -226,7 +226,7 @@ define i32 @reduceshuffle_twoin_lowelt_v16i32(<16 x i32> %a, <16 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_lowelt_v16i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i32> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] @@ -252,7 +252,7 @@ define i32 @reduceshuffle_twoin_uneven_v16i32(<16 x i32> %a, <16 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_uneven_v16i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i32> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[X]]) ; CHECK-NEXT: ret i32 [[R]] @@ -265,7 +265,7 @@ define i32 @reduceshuffle_twoin_ext_v16i32(<16 x i32> %a, <16 x i32> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_ext_v16i32( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i32> [[A:%.*]], <16 x i32> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[A1:%.*]] = lshr <16 x i32> [[S]], ; CHECK-NEXT: [[A2:%.*]] = and <16 x i32> [[A1]], ; CHECK-NEXT: [[A3:%.*]] = mul nuw <16 x i32> [[A2]], @@ -299,7 +299,7 @@ define i16 @reduceshuffle_onein_v16i16(<16 x i16> %a) { ; CHECK-LABEL: @reduceshuffle_onein_v16i16( -; CHECK-NEXT: [[X:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> undef, <16 x i32> +; CHECK-NEXT: [[X:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> undef, <16 x i32> ; CHECK-NEXT: [[R:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[X]]) ; CHECK-NEXT: ret i16 [[R]] ; @@ -310,7 +310,7 @@ define i16 @reduceshuffle_onein_ext_v16i16(<16 x i16> %a) { ; CHECK-LABEL: @reduceshuffle_onein_ext_v16i16( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> undef, <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> undef, <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i16> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[X]]) ; CHECK-NEXT: ret i16 [[R]] @@ -323,7 +323,7 @@ define i16 @reduceshuffle_twoin_concat_v16i16(<8 x i16> %a, <8 x i16> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_concat_v16i16( -; CHECK-NEXT: [[S:%.*]] = shufflevector <8 x i16> [[A:%.*]], <8 x i16> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <8 x i16> [[A:%.*]], <8 x i16> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i16> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[X]]) ; CHECK-NEXT: ret i16 [[R]] @@ -336,7 +336,7 @@ define i16 @reduceshuffle_twoin_lowelt_v16i16(<16 x i16> %a, <16 x i16> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_lowelt_v16i16( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[X:%.*]] = xor <16 x i16> [[S]], ; CHECK-NEXT: [[R:%.*]] = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> [[X]]) ; CHECK-NEXT: ret i16 [[R]] @@ -375,7 +375,7 @@ define i16 @reduceshuffle_twoin_ext_v16i16(<16 x i16> %a, <16 x i16> %b) { ; CHECK-LABEL: @reduceshuffle_twoin_ext_v16i16( -; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> [[B:%.*]], <16 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <16 x i16> [[A:%.*]], <16 x i16> [[B:%.*]], <16 x i32> ; CHECK-NEXT: [[A1:%.*]] = lshr <16 x i16> [[S]], ; CHECK-NEXT: [[A2:%.*]] = and <16 x i16> [[A1]], ; CHECK-NEXT: [[A3:%.*]] = mul nuw <16 x i16> [[A2]],