diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -366,7 +366,13 @@ APInt UnionUsedElts(VWidth, 0); for (const Use &U : V->uses()) { if (Instruction *I = dyn_cast(U.getUser())) { - UnionUsedElts |= findDemandedEltsBySingleUser(V, I); + // Bitcast between vectors with the same element count does not change + // the demanded elements, so we are safe to look through them. + if (isa(I) && isa(I->getType()) && + VWidth == cast(I->getType())->getNumElements()) + UnionUsedElts |= findDemandedEltsByAllUsers(I); + else + UnionUsedElts |= findDemandedEltsBySingleUser(V, I); } else { UnionUsedElts = APInt::getAllOnes(VWidth); break; @@ -562,24 +568,35 @@ if (!EC.isScalable() && NumElts != 1) { // If the input vector has a single use, simplify it based on this use // property. - if (SrcVec->hasOneUse()) { + auto *SimplifyVec = SrcVec; + Value *BitCastSrc = nullptr; + // Look one-step further through simple bitcast between vectors with the + // same element number. + if (match(SrcVec, m_BitCast(m_Value(BitCastSrc))) && + isa(BitCastSrc->getType()) && + cast(BitCastSrc->getType())->getNumElements() == + NumElts) { + SimplifyVec = BitCastSrc; + } + + if (!BitCastSrc && SimplifyVec->hasOneUse()) { APInt UndefElts(NumElts, 0); APInt DemandedElts(NumElts, 0); DemandedElts.setBit(IndexC->getZExtValue()); - if (Value *V = - SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) + if (Value *V = SimplifyDemandedVectorElts(SimplifyVec, DemandedElts, + UndefElts)) return replaceOperand(EI, 0, V); } else { // If the input vector has multiple uses, simplify it based on a union // of all elements used. - APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec); + APInt DemandedElts = findDemandedEltsByAllUsers(SimplifyVec); if (!DemandedElts.isAllOnes()) { APInt UndefElts(NumElts, 0); if (Value *V = SimplifyDemandedVectorElts( - SrcVec, DemandedElts, UndefElts, 0 /* Depth */, + SimplifyVec, DemandedElts, UndefElts, 0 /* Depth */, true /* AllowMultipleUsers */)) { - if (V != SrcVec) { - SrcVec->replaceAllUsesWith(V); + if (V != SimplifyVec) { + SimplifyVec->replaceAllUsesWith(V); return &EI; } } diff --git a/llvm/test/Transforms/InstCombine/AMDGPU/demanded-vector-elts-multi-user.ll b/llvm/test/Transforms/InstCombine/AMDGPU/demanded-vector-elts-multi-user.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/AMDGPU/demanded-vector-elts-multi-user.ll @@ -0,0 +1,85 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -passes=instcombine -mtriple=amdgcn-amd-amdhsa %s | FileCheck %s + +declare <4 x i32> @llvm.amdgcn.raw.buffer.load.v4i32(<4 x i32>, i32, i32, i32) #1 + +; extractelem (bitcast x) + extractelem +define float @extract_bitcast_and_extract(<4 x i32> inreg %rsrc, i32 %ofs, i32 %sofs) #0 { +; CHECK-LABEL: @extract_bitcast_and_extract( +; CHECK-NEXT: [[VAR:%.*]] = call <2 x i32> @llvm.amdgcn.raw.buffer.load.v2i32(<4 x i32> [[RSRC:%.*]], i32 [[OFS:%.*]], i32 [[SOFS:%.*]], i32 0) +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x i32> [[VAR]], <2 x i32> poison, <4 x i32> +; CHECK-NEXT: [[VAR1:%.*]] = bitcast <4 x i32> [[TMP1]] to <4 x float> +; CHECK-NEXT: [[VAR2:%.*]] = extractelement <4 x float> [[VAR1]], i64 0 +; CHECK-NEXT: [[VAR3:%.*]] = extractelement <2 x i32> [[VAR]], i64 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[VAR3]], -1 +; CHECK-NEXT: [[VAR4:%.*]] = select i1 [[CMP]], float [[VAR2]], float 1.000000e+00 +; CHECK-NEXT: ret float [[VAR4]] +; + %var = call <4 x i32> @llvm.amdgcn.raw.buffer.load.v4i32(<4 x i32> %rsrc, i32 %ofs, i32 %sofs, i32 0) + %var1 = bitcast <4 x i32> %var to <4 x float> + %var2 = extractelement <4 x float> %var1, i32 0 + %var3 = extractelement <4 x i32> %var, i32 1 + %cmp = icmp eq i32 %var3, -1 + %var4 = select i1 %cmp, float %var2, float 1.0 + ret float %var4 +} + +; multiple extractelem (bitcast x) +define float @multi_extract_bitcast(<4 x i32> inreg %rsrc, i32 %ofs, i32 %sofs, i32 %flag) #0 { +; CHECK-LABEL: @multi_extract_bitcast( +; CHECK-NEXT: bb0: +; CHECK-NEXT: [[VAR:%.*]] = call <2 x i32> @llvm.amdgcn.raw.buffer.load.v2i32(<4 x i32> [[RSRC:%.*]], i32 [[OFS:%.*]], i32 [[SOFS:%.*]], i32 0) +; CHECK-NEXT: [[TMP0:%.*]] = shufflevector <2 x i32> [[VAR]], <2 x i32> poison, <4 x i32> +; CHECK-NEXT: [[VAR1:%.*]] = bitcast <4 x i32> [[TMP0]] to <4 x float> +; CHECK-NEXT: [[VAR2:%.*]] = extractelement <4 x float> [[VAR1]], i64 0 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[FLAG:%.*]], -1 +; CHECK-NEXT: br i1 [[CMP]], label [[BB2:%.*]], label [[BB1:%.*]] +; CHECK: bb1: +; CHECK-NEXT: [[VAR10:%.*]] = bitcast <4 x i32> [[TMP0]] to <4 x float> +; CHECK-NEXT: [[VAR11:%.*]] = extractelement <4 x float> [[VAR10]], i64 1 +; CHECK-NEXT: br label [[BB2]] +; CHECK: bb2: +; CHECK-NEXT: [[VAR20:%.*]] = phi float [ [[VAR2]], [[BB0:%.*]] ], [ [[VAR11]], [[BB1]] ] +; CHECK-NEXT: ret float [[VAR20]] +; +bb0: + %var = call <4 x i32> @llvm.amdgcn.raw.buffer.load.v4i32(<4 x i32> %rsrc, i32 %ofs, i32 %sofs, i32 0) + %var1 = bitcast <4 x i32> %var to <4 x float> + %var2 = extractelement <4 x float> %var1, i32 0 + %cmp = icmp eq i32 %flag, -1 + br i1 %cmp, label %bb2, label %bb1 + +bb1: + %var10 = bitcast <4 x i32> %var to <4 x float> + %var11 = extractelement <4 x float> %var10, i32 1 + br label %bb2 + +bb2: + %var20 = phi float [ %var2, %bb0 ], [ %var11, %bb1 ] + ret float %var20 +} + +; extractelem (bitcast x) + shufflevector +define float @extract_bitcast_and_shufflevector(<4 x i32> inreg %rsrc, i32 %ofs, i32 %sofs) #0 { +; CHECK-LABEL: @extract_bitcast_and_shufflevector( +; CHECK-NEXT: [[VAR:%.*]] = call <2 x i32> @llvm.amdgcn.raw.buffer.load.v2i32(<4 x i32> [[RSRC:%.*]], i32 [[OFS:%.*]], i32 [[SOFS:%.*]], i32 0) +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x i32> [[VAR]], <2 x i32> poison, <4 x i32> +; CHECK-NEXT: [[VAR1:%.*]] = bitcast <4 x i32> [[TMP1]] to <4 x float> +; CHECK-NEXT: [[VAR2:%.*]] = extractelement <4 x float> [[VAR1]], i64 0 +; CHECK-NEXT: [[VAR4:%.*]] = extractelement <2 x i32> [[VAR]], i64 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[VAR4]], -1 +; CHECK-NEXT: [[VAR5:%.*]] = select i1 [[CMP]], float [[VAR2]], float 1.000000e+00 +; CHECK-NEXT: ret float [[VAR5]] +; + %var = call <4 x i32> @llvm.amdgcn.raw.buffer.load.v4i32(<4 x i32> %rsrc, i32 %ofs, i32 %sofs, i32 0) + %var1 = bitcast <4 x i32> %var to <4 x float> + %var2 = extractelement <4 x float> %var1, i32 0 + %var3 = shufflevector <4 x i32> %var, <4 x i32> poison, <4 x i32> + %var4 = extractelement <4 x i32> %var3, i32 0 + %cmp = icmp eq i32 %var4, -1 + %var5 = select i1 %cmp, float %var2, float 1.0 + ret float %var5 +} + +attributes #0 = { nounwind } +attributes #1 = { nounwind readonly }