Index: llvm/lib/Transforms/Vectorize/VectorCombine.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -244,6 +245,51 @@ return true; } +static bool foldBitcastShuf(Instruction &I, const TargetTransformInfo &TTI) { + Value *V; + Constant *MaskC; + if (!match(&I, m_BitCast(m_OneUse(m_ShuffleVector(m_Value(V), m_Undef(), + m_Constant(MaskC)))))) + return false; + + Type *DestTy = I.getType(); + Type *SrcTy = V->getType(); + if (!DestTy->isVectorTy() || I.getOperand(0)->getType() != SrcTy) + return false; + + // TODO: Handle bitcast from narrow element type to wide element type. + assert(SrcTy->isVectorTy() && "Shuffle of non-vector type?"); + unsigned DestNumElts = DestTy->getVectorNumElements(); + unsigned SrcNumElts = SrcTy->getVectorNumElements(); + if (SrcNumElts > DestNumElts) + return false; + + if (TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy) > + TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, DestTy)) + return false; + + // Bitcast the source vector and expand the shuffle mask to the equivalent for + // narrow elements. + // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC' + IRBuilder<> Builder(&I); + Value *CastV = Builder.CreateBitCast(V, DestTy); + SmallVector OldMask, NewMask; + ShuffleVectorInst::getShuffleMask(MaskC, OldMask); + assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask"); + unsigned ScaleFactor = DestNumElts / SrcNumElts; + scaleShuffleMask(ScaleFactor, makeArrayRef(OldMask), NewMask); + + // TODO: IRBuilder has a CreateShuffleVector that takes an array of uint32_t, + // but does not recognize plain 'int'. + SmallVector NewMaskC; + for (unsigned i = 0; i != DestNumElts; ++i) + NewMaskC.push_back(Builder.getInt32(NewMask[i])); + Value *Shuf = Builder.CreateShuffleVector(CastV, UndefValue::get(DestTy), + ConstantVector::get(NewMaskC)); + I.replaceAllUsesWith(Shuf); + return true; +} + /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. static bool runImpl(Function &F, const TargetTransformInfo &TTI, @@ -261,8 +307,10 @@ // use->defs, so we're more likely to succeed by starting from the bottom. // TODO: It could be more efficient to remove dead instructions // iteratively in this loop rather than waiting until the end. - for (Instruction &I : make_range(BB.rbegin(), BB.rend())) + for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { MadeChange |= foldExtractExtract(I, TTI); + MadeChange |= foldBitcastShuf(I, TTI); + } } // We're done with transforms, so remove dead instructions. Index: llvm/test/Transforms/VectorCombine/X86/shuffle.ll =================================================================== --- llvm/test/Transforms/VectorCombine/X86/shuffle.ll +++ llvm/test/Transforms/VectorCombine/X86/shuffle.ll @@ -4,9 +4,9 @@ define <16 x i8> @bitcast_shuf_narrow_element(<4 x i32> %v) { ; CHECK-LABEL: @bitcast_shuf_narrow_element( -; CHECK-NEXT: [[SHUF:%.*]] = shufflevector <4 x i32> [[V:%.*]], <4 x i32> undef, <4 x i32> -; CHECK-NEXT: [[R:%.*]] = bitcast <4 x i32> [[SHUF]] to <16 x i8> -; CHECK-NEXT: ret <16 x i8> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i32> [[V:%.*]] to <16 x i8> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <16 x i8> [[TMP1]], <16 x i8> undef, <16 x i32> +; CHECK-NEXT: ret <16 x i8> [[TMP2]] ; %shuf = shufflevector <4 x i32> %v, <4 x i32> undef, <4 x i32> %r = bitcast <4 x i32> %shuf to <16 x i8> @@ -15,9 +15,9 @@ define <4 x float> @bitcast_shuf_same_size(<4 x i32> %v) { ; CHECK-LABEL: @bitcast_shuf_same_size( -; CHECK-NEXT: [[SHUF:%.*]] = shufflevector <4 x i32> [[V:%.*]], <4 x i32> undef, <4 x i32> -; CHECK-NEXT: [[R:%.*]] = bitcast <4 x i32> [[SHUF]] to <4 x float> -; CHECK-NEXT: ret <4 x float> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i32> [[V:%.*]] to <4 x float> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x float> [[TMP1]], <4 x float> undef, <4 x i32> +; CHECK-NEXT: ret <4 x float> [[TMP2]] ; %shuf = shufflevector <4 x i32> %v, <4 x i32> undef, <4 x i32> %r = bitcast <4 x i32> %shuf to <4 x float> @@ -53,9 +53,9 @@ define <2 x i64> @PR35454_1(<2 x i64> %v) { ; CHECK-LABEL: @PR35454_1( ; CHECK-NEXT: [[BC:%.*]] = bitcast <2 x i64> [[V:%.*]] to <4 x i32> -; CHECK-NEXT: [[PERMIL:%.*]] = shufflevector <4 x i32> [[BC]], <4 x i32> undef, <4 x i32> -; CHECK-NEXT: [[BC1:%.*]] = bitcast <4 x i32> [[PERMIL]] to <16 x i8> -; CHECK-NEXT: [[ADD:%.*]] = shl <16 x i8> [[BC1]], +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i32> [[BC]] to <16 x i8> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <16 x i8> [[TMP1]], <16 x i8> undef, <16 x i32> +; CHECK-NEXT: [[ADD:%.*]] = shl <16 x i8> [[TMP2]], ; CHECK-NEXT: [[BC2:%.*]] = bitcast <16 x i8> [[ADD]] to <4 x i32> ; CHECK-NEXT: [[PERMIL1:%.*]] = shufflevector <4 x i32> [[BC2]], <4 x i32> undef, <4 x i32> ; CHECK-NEXT: [[BC3:%.*]] = bitcast <4 x i32> [[PERMIL1]] to <2 x i64> @@ -74,9 +74,9 @@ define <2 x i64> @PR35454_2(<2 x i64> %v) { ; CHECK-LABEL: @PR35454_2( ; CHECK-NEXT: [[BC:%.*]] = bitcast <2 x i64> [[V:%.*]] to <4 x i32> -; CHECK-NEXT: [[PERMIL:%.*]] = shufflevector <4 x i32> [[BC]], <4 x i32> undef, <4 x i32> -; CHECK-NEXT: [[BC1:%.*]] = bitcast <4 x i32> [[PERMIL]] to <8 x i16> -; CHECK-NEXT: [[ADD:%.*]] = shl <8 x i16> [[BC1]], +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i32> [[BC]] to <8 x i16> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i16> [[TMP1]], <8 x i16> undef, <8 x i32> +; CHECK-NEXT: [[ADD:%.*]] = shl <8 x i16> [[TMP2]], ; CHECK-NEXT: [[BC2:%.*]] = bitcast <8 x i16> [[ADD]] to <4 x i32> ; CHECK-NEXT: [[PERMIL1:%.*]] = shufflevector <4 x i32> [[BC2]], <4 x i32> undef, <4 x i32> ; CHECK-NEXT: [[BC3:%.*]] = bitcast <4 x i32> [[PERMIL1]] to <2 x i64>