Index: lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- lib/Transforms/InstCombine/InstCombineInternal.h +++ lib/Transforms/InstCombine/InstCombineInternal.h @@ -812,7 +812,8 @@ int DmaskIdx = -1); Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, - APInt &UndefElts, unsigned Depth = 0); + APInt &UndefElts, unsigned Depth = 0, + bool AllowMultipleUsers = false); /// Canonicalize the position of binops relative to shufflevector. Instruction *foldVectorBinop(BinaryOperator &Inst); Index: lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -1067,16 +1067,22 @@ } /// The specified value produces a vector with any number of elements. +/// This method analyzes which elements of the operand are undef and returns +/// that information in UndefElts. +/// /// DemandedElts contains the set of elements that are actually used by the -/// caller. This method analyzes which elements of the operand are undef and -/// returns that information in UndefElts. +/// caller, and by default (AllowMultipleUsers equals false) the value is +/// simplified only if it has a single caller. If AllowMultipleUsers is set +/// to true, DemandedElts refers to the union of sets of elements that are +/// used by all callers. /// /// If the information about demanded elements can be used to simplify the /// operation, the operation is simplified, then the resultant value is /// returned. This returns null if no change was made. Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, - unsigned Depth) { + unsigned Depth, + bool AllowMultipleUsers) { unsigned VWidth = V->getType()->getVectorNumElements(); APInt EltMask(APInt::getAllOnesValue(VWidth)); assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); @@ -1130,19 +1136,21 @@ if (Depth == 10) return nullptr; - // If multiple users are using the root value, proceed with - // simplification conservatively assuming that all elements - // are needed. - if (!V->hasOneUse()) { - // Quit if we find multiple users of a non-root value though. - // They'll be handled when it's their turn to be visited by - // the main instcombine process. - if (Depth != 0) - // TODO: Just compute the UndefElts information recursively. - return nullptr; + if (!AllowMultipleUsers) { + // If multiple users are using the root value, proceed with + // simplification conservatively assuming that all elements + // are needed. + if (!V->hasOneUse()) { + // Quit if we find multiple users of a non-root value though. + // They'll be handled when it's their turn to be visited by + // the main instcombine process. + if (Depth != 0) + // TODO: Just compute the UndefElts information recursively. + return nullptr; - // Conservatively assume that all elements are needed. - DemandedElts = EltMask; + // Conservatively assume that all elements are needed. + DemandedElts = EltMask; + } } Instruction *I = dyn_cast(V); Index: lib/Transforms/InstCombine/InstCombineVectorOps.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -270,20 +270,46 @@ if (!IndexC->getValue().ule(NumElts)) return nullptr; - // This instruction only demands the single element from the input vector. - // If the input vector has a single use, simplify it based on this use - // property. - if (SrcVec->hasOneUse() && NumElts != 1) { - APInt UndefElts(NumElts, 0); - APInt DemandedElts(NumElts, 0); - DemandedElts.setBit(IndexC->getZExtValue()); - if (Value *V = SimplifyDemandedVectorElts(SrcVec, DemandedElts, - UndefElts)) { - EI.setOperand(0, V); - return &EI; + if (NumElts != 1) { + // This instruction only demands the single element from the input vector. + // If the input vector has a single use, simplify it based on this use + // property. + if (SrcVec->hasOneUse()) { + APInt UndefElts(NumElts, 0); + APInt DemandedElts(NumElts, 0); + DemandedElts.setBit(IndexC->getZExtValue()); + if (Value *V = SimplifyDemandedVectorElts(SrcVec, DemandedElts, + UndefElts)) { + EI.setOperand(0, V); + return &EI; + } + } else { + // If the input vector has multiple uses, and all uses come from + // extractelement instructions with constant indices, simplify it + // based on the union of all elements used. + APInt DemandedElts(NumElts, 0); + bool AllUsesAreValidConstEEI = true; + for (const Use &U : SrcVec->uses()) { + ExtractElementInst *EEI = dyn_cast(U.getUser()); + ConstantInt *EEIIndexC = EEI ? dyn_cast(EEI->getIndexOperand()) : nullptr; + + if (!EEIIndexC || !EEIIndexC->getValue().ule(NumElts)) { + AllUsesAreValidConstEEI = false; + break; + } + DemandedElts.setBit(EEIIndexC->getZExtValue()); + } + if (AllUsesAreValidConstEEI) { + APInt UndefElts(NumElts, 0); + if (Value *V = SimplifyDemandedVectorElts(SrcVec, DemandedElts, + UndefElts, 0 /* Depth */, true /* AllowMultipleUsers */)) { + if (V != SrcVec) + SrcVec->replaceAllUsesWith(V); + return &EI; + } + } } } - if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian())) return I; Index: test/Transforms/InstCombine/AMDGPU/amdgcn-demanded-vector-elts.ll =================================================================== --- test/Transforms/InstCombine/AMDGPU/amdgcn-demanded-vector-elts.ll +++ test/Transforms/InstCombine/AMDGPU/amdgcn-demanded-vector-elts.ll @@ -152,11 +152,10 @@ ret <3 x float> %shuf } -; FIXME: Not handled even though only 2 elts used ; CHECK-LABEL: @extract_elt0_elt1_buffer_load_v4f32_2( -; CHECK-NEXT: %data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false) -; CHECK-NEXT: %elt0 = extractelement <4 x float> %data, i32 0 -; CHECK-NEXT: %elt1 = extractelement <4 x float> %data, i32 1 +; CHECK-NEXT: %data = call <2 x float> @llvm.amdgcn.buffer.load.v2f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false) +; CHECK-NEXT: %elt0 = extractelement <2 x float> %data, i32 0 +; CHECK-NEXT: %elt1 = extractelement <2 x float> %data, i32 1 ; CHECK-NEXT: %ins0 = insertvalue { float, float } undef, float %elt0, 0 ; CHECK-NEXT: %ins1 = insertvalue { float, float } %ins0, float %elt1, 1 ; CHECK-NEXT: ret { float, float } %ins1