diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -858,7 +858,8 @@ int DmaskIdx = -1); Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, - APInt &UndefElts, unsigned Depth = 0); + APInt &UndefElts, unsigned Depth = 0, + bool AllowMultipleUsers = false); /// Canonicalize the position of binops relative to shufflevector. Instruction *foldVectorBinop(BinaryOperator &Inst); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -1074,16 +1074,22 @@ } /// The specified value produces a vector with any number of elements. +/// This method analyzes which elements of the operand are undef and returns +/// that information in UndefElts. +/// /// DemandedElts contains the set of elements that are actually used by the -/// caller. This method analyzes which elements of the operand are undef and -/// returns that information in UndefElts. +/// caller, and by default (AllowMultipleUsers equals false) the value is +/// simplified only if it has a single caller. If AllowMultipleUsers is set +/// to true, DemandedElts refers to the union of sets of elements that are +/// used by all callers. /// /// If the information about demanded elements can be used to simplify the /// operation, the operation is simplified, then the resultant value is /// returned. This returns null if no change was made. Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, - unsigned Depth) { + unsigned Depth, + bool AllowMultipleUsers) { unsigned VWidth = V->getType()->getVectorNumElements(); APInt EltMask(APInt::getAllOnesValue(VWidth)); assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); @@ -1137,19 +1143,21 @@ if (Depth == 10) return nullptr; - // If multiple users are using the root value, proceed with - // simplification conservatively assuming that all elements - // are needed. - if (!V->hasOneUse()) { - // Quit if we find multiple users of a non-root value though. - // They'll be handled when it's their turn to be visited by - // the main instcombine process. - if (Depth != 0) - // TODO: Just compute the UndefElts information recursively. - return nullptr; + if (!AllowMultipleUsers) { + // If multiple users are using the root value, proceed with + // simplification conservatively assuming that all elements + // are needed. + if (!V->hasOneUse()) { + // Quit if we find multiple users of a non-root value though. + // They'll be handled when it's their turn to be visited by + // the main instcombine process. + if (Depth != 0) + // TODO: Just compute the UndefElts information recursively. + return nullptr; - // Conservatively assume that all elements are needed. - DemandedElts = EltMask; + // Conservatively assume that all elements are needed. + DemandedElts = EltMask; + } } Instruction *I = dyn_cast(V); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -253,6 +253,69 @@ return nullptr; } +/// Find elements of V demanded by UserInstr. +static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { + unsigned VWidth = V->getType()->getVectorNumElements(); + + // Conservatively assume that all elements are needed. + APInt UsedElts(APInt::getAllOnesValue(VWidth)); + + switch (UserInstr->getOpcode()) { + case Instruction::ExtractElement: { + ExtractElementInst *EEI = cast(UserInstr); + assert(EEI->getVectorOperand() == V); + ConstantInt *EEIIndexC = dyn_cast(EEI->getIndexOperand()); + if (EEIIndexC && EEIIndexC->getValue().ult(VWidth)) { + UsedElts = APInt::getOneBitSet(VWidth, EEIIndexC->getZExtValue()); + } + break; + } + case Instruction::ShuffleVector: { + ShuffleVectorInst *Shuffle = cast(UserInstr); + unsigned MaskNumElts = UserInstr->getType()->getVectorNumElements(); + + UsedElts = APInt(VWidth, 0); + for (unsigned i = 0; i < MaskNumElts; i++) { + unsigned MaskVal = Shuffle->getMaskValue(i); + if (MaskVal == -1u || MaskVal >= 2 * VWidth) + continue; + if (Shuffle->getOperand(0) == V && (MaskVal < VWidth)) + UsedElts.setBit(MaskVal); + if (Shuffle->getOperand(1) == V && + ((MaskVal >= VWidth) && (MaskVal < 2 * VWidth))) + UsedElts.setBit(MaskVal - VWidth); + } + break; + } + default: + break; + } + return UsedElts; +} + +/// Find union of elements of V demanded by all its users. +/// If it is known by querying findDemandedEltsBySingleUser that +/// no user demands an element of V, then the corresponding bit +/// remains unset in the returned value. +static APInt findDemandedEltsByAllUsers(Value *V) { + unsigned VWidth = V->getType()->getVectorNumElements(); + + APInt UnionUsedElts(VWidth, 0); + for (const Use &U : V->uses()) { + if (Instruction *I = dyn_cast(U.getUser())) { + UnionUsedElts |= findDemandedEltsBySingleUser(V, I); + } else { + UnionUsedElts = APInt::getAllOnesValue(VWidth); + break; + } + + if (UnionUsedElts.isAllOnesValue()) + break; + } + + return UnionUsedElts; +} + Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { Value *SrcVec = EI.getVectorOperand(); Value *Index = EI.getIndexOperand(); @@ -271,19 +334,35 @@ return nullptr; // This instruction only demands the single element from the input vector. - // If the input vector has a single use, simplify it based on this use - // property. - if (SrcVec->hasOneUse() && NumElts != 1) { - APInt UndefElts(NumElts, 0); - APInt DemandedElts(NumElts, 0); - DemandedElts.setBit(IndexC->getZExtValue()); - if (Value *V = SimplifyDemandedVectorElts(SrcVec, DemandedElts, - UndefElts)) { - EI.setOperand(0, V); - return &EI; + if (NumElts != 1) { + // If the input vector has a single use, simplify it based on this use + // property. + if (SrcVec->hasOneUse()) { + APInt UndefElts(NumElts, 0); + APInt DemandedElts(NumElts, 0); + DemandedElts.setBit(IndexC->getZExtValue()); + if (Value *V = + SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) { + EI.setOperand(0, V); + return &EI; + } + } else { + // If the input vector has multiple uses, simplify it based on a union + // of all elements used. + APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec); + if (!DemandedElts.isAllOnesValue()) { + APInt UndefElts(NumElts, 0); + if (Value *V = SimplifyDemandedVectorElts( + SrcVec, DemandedElts, UndefElts, 0 /* Depth */, + true /* AllowMultipleUsers */)) { + if (V != SrcVec) { + SrcVec->replaceAllUsesWith(V); + return &EI; + } + } + } } } - if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian())) return I; diff --git a/llvm/test/Transforms/InstCombine/AMDGPU/amdgcn-demanded-vector-elts.ll b/llvm/test/Transforms/InstCombine/AMDGPU/amdgcn-demanded-vector-elts.ll --- a/llvm/test/Transforms/InstCombine/AMDGPU/amdgcn-demanded-vector-elts.ll +++ b/llvm/test/Transforms/InstCombine/AMDGPU/amdgcn-demanded-vector-elts.ll @@ -152,11 +152,10 @@ ret <3 x float> %shuf } -; FIXME: Not handled even though only 2 elts used ; CHECK-LABEL: @extract_elt0_elt1_buffer_load_v4f32_2( -; CHECK-NEXT: %data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false) -; CHECK-NEXT: %elt0 = extractelement <4 x float> %data, i32 0 -; CHECK-NEXT: %elt1 = extractelement <4 x float> %data, i32 1 +; CHECK-NEXT: %data = call <2 x float> @llvm.amdgcn.buffer.load.v2f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false) +; CHECK-NEXT: %elt0 = extractelement <2 x float> %data, i32 0 +; CHECK-NEXT: %elt1 = extractelement <2 x float> %data, i32 1 ; CHECK-NEXT: %ins0 = insertvalue { float, float } undef, float %elt0, 0 ; CHECK-NEXT: %ins1 = insertvalue { float, float } %ins0, float %elt1, 1 ; CHECK-NEXT: ret { float, float } %ins1 @@ -169,6 +168,74 @@ ret { float, float } %ins1 } +; CHECK-LABEL: @extract_elt0_elt1_elt2_buffer_load_v4f32_2( +; CHECK-NEXT: %data = call <3 x float> @llvm.amdgcn.buffer.load.v3f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false) +; CHECK-NEXT: %elt0 = extractelement <3 x float> %data, i32 0 +; CHECK-NEXT: %elt1 = extractelement <3 x float> %data, i32 1 +; CHECK-NEXT: %elt2 = extractelement <3 x float> %data, i32 2 +; CHECK-NEXT: %ins0 = insertvalue { float, float, float } undef, float %elt0, 0 +; CHECK-NEXT: %ins1 = insertvalue { float, float, float } %ins0, float %elt1, 1 +; CHECK-NEXT: %ins2 = insertvalue { float, float, float } %ins1, float %elt2, 2 +; CHECK-NEXT: ret { float, float, float } %ins2 +define amdgpu_ps { float, float, float } @extract_elt0_elt1_elt2_buffer_load_v4f32_2(<4 x i32> inreg %rsrc, i32 %idx, i32 %ofs) #0 { + %data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false) + %elt0 = extractelement <4 x float> %data, i32 0 + %elt1 = extractelement <4 x float> %data, i32 1 + %elt2 = extractelement <4 x float> %data, i32 2 + %ins0 = insertvalue { float, float, float } undef, float %elt0, 0 + %ins1 = insertvalue { float, float, float } %ins0, float %elt1, 1 + %ins2 = insertvalue { float, float, float } %ins1, float %elt2, 2 + ret { float, float, float } %ins2 +} + +; CHECK-LABEL: @extract_elt0_elt1_elt2_buffer_load_v4f32_3( +; CHECK-NEXT: %data = call <3 x float> @llvm.amdgcn.buffer.load.v3f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false) +; CHECK-NEXT: %ins1 = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> +; CHECK-NEXT: %shuf = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> +; CHECK-NEXT: %ret = fadd <2 x float> %ins1, %shuf +define amdgpu_ps <2 x float> @extract_elt0_elt1_elt2_buffer_load_v4f32_3(<4 x i32> inreg %rsrc, i32 %idx, i32 %ofs) #0 { + %data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false) + %elt0 = extractelement <4 x float> %data, i32 0 + %elt2 = extractelement <4 x float> %data, i32 2 + %ins0 = insertelement <2 x float> undef, float %elt0, i32 0 + %ins1 = insertelement <2 x float> %ins0, float %elt2, i32 1 + %shuf = shufflevector <4 x float> %data, <4 x float> undef, <2 x i32> + %ret = fadd <2 x float> %ins1, %shuf + ret <2 x float> %ret +} + +; CHECK-LABEL: @extract_elt0_elt1_elt2_buffer_load_v4f32_4( +; CHECK-NEXT: %data = call <3 x float> @llvm.amdgcn.buffer.load.v3f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false) +; CHECK-NEXT: %ins1 = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> +; CHECK-NEXT: %shuf = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> +; CHECK-NEXT: %ret = fadd <2 x float> %ins1, %shuf +; CHECK-NEXT: ret <2 x float> %ret +define amdgpu_ps <2 x float> @extract_elt0_elt1_elt2_buffer_load_v4f32_4(<4 x i32> inreg %rsrc, i32 %idx, i32 %ofs) #0 { + %data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false) + %elt0 = extractelement <4 x float> %data, i32 0 + %elt2 = extractelement <4 x float> %data, i32 2 + %ins0 = insertelement <2 x float> undef, float %elt0, i32 0 + %ins1 = insertelement <2 x float> %ins0, float %elt2, i32 1 + %shuf = shufflevector <4 x float> undef, <4 x float> %data, <2 x i32> + %ret = fadd <2 x float> %ins1, %shuf + ret <2 x float> %ret +} + +; CHECK-LABEL: @extract_elt0_elt1_elt2_buffer_load_v4f32_5( +; CHECK-NEXT: %data = call <3 x float> @llvm.amdgcn.buffer.load.v3f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false) +; CHECK-NEXT: %ins1 = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> +; CHECK-NEXT: %shuf = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> +; CHECK-NEXT: %ret = fadd <2 x float> %ins1, %shuf +define amdgpu_ps <2 x float> @extract_elt0_elt1_elt2_buffer_load_v4f32_5(<4 x i32> inreg %rsrc, i32 %idx, i32 %ofs) #0 { + %data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false) + %elt2 = extractelement <4 x float> %data, i32 2 + %ins0 = insertelement <2 x float> undef, float %elt2, i32 0 + %ins1 = insertelement <2 x float> %ins0, float %elt2, i32 1 + %shuf = shufflevector <4 x float> %data, <4 x float> %data, <2 x i32> + %ret = fadd <2 x float> %ins1, %shuf + ret <2 x float> %ret +} + ; CHECK-LABEL: @extract_elt0_buffer_load_v3f32( ; CHECK-NEXT: %data = call float @llvm.amdgcn.buffer.load.f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false) ; CHECK-NEXT: ret float %data