diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -2109,6 +2109,14 @@ return IntrinsicID_match(IntrID); } +/// Matches MaskedLoad Intrinsic. +template +inline typename m_Intrinsic_Ty::Ty +m_MaskedLoad(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2, + const Opnd3 &Op3) { + return m_Intrinsic(Op0, Op1, Op2, Op3); +} + template inline typename m_Intrinsic_Ty::Ty m_Intrinsic(const T0 &Op0) { return m_CombineAnd(m_Intrinsic(), m_Argument<0>(Op0)); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3218,5 +3218,36 @@ if (Value *Fr = foldSelectWithFrozenICmp(SI, Builder)) return replaceInstUsesWith(SI, Fr); + // select(mask, mload(,,mask,0), 0) -> mload(,,mask,0) + // Load inst is intentionally not checked for hasOneUse() + if (match(FalseVal, m_Zero()) && + match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal), + m_CombineOr(m_Undef(), m_Zero())))) { + auto *MaskedLoad = cast(TrueVal); + if (isa(MaskedLoad->getArgOperand(3))) + MaskedLoad->setArgOperand(3, FalseVal /* Zero */); + return replaceInstUsesWith(SI, MaskedLoad); + } + + Value *Mask; + if (match(TrueVal, m_Zero()) && + match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask), + m_CombineOr(m_Undef(), m_Zero())))) { + // We can remove the select by ensuring the load zeros all lanes the + // select would have. We determine this by proving there is no overlap + // between the load and select masks. + // (i.e (load_mask & select_mask) == 0 == no overlap) + bool CanMergeSelectIntoLoad = false; + if (Value *V = SimplifyAndInst(CondVal, Mask, SQ.getWithInstruction(&SI))) + CanMergeSelectIntoLoad = match(V, m_Zero()); + + if (CanMergeSelectIntoLoad) { + auto *MaskedLoad = cast(FalseVal); + if (isa(MaskedLoad->getArgOperand(3))) + MaskedLoad->setArgOperand(3, TrueVal /* Zero */); + return replaceInstUsesWith(SI, MaskedLoad); + } + } + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/select-masked_load.ll b/llvm/test/Transforms/InstCombine/select-masked_load.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/select-masked_load.ll @@ -0,0 +1,98 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +; Fold zeroing of inactive lanes into the load's passthrough parameter. +define <4 x float> @masked_load_and_zero_inactive_1(<4 x float>* %ptr, <4 x i1> %mask) { +; CHECK-LABEL: @masked_load_and_zero_inactive_1( +; CHECK: %load = call <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>* %ptr, i32 4, <4 x i1> %mask, <4 x float> zeroinitializer) +; CHECK-NEXT: ret <4 x float> %load + %load = call <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>* %ptr, i32 4, <4 x i1> %mask, <4 x float> undef) + %masked = select <4 x i1> %mask, <4 x float> %load, <4 x float> zeroinitializer + ret <4 x float> %masked +} + +; As above but reuse the load's existing passthrough. +define <4 x i32> @masked_load_and_zero_inactive_2(<4 x i32>* %ptr, <4 x i1> %mask) { +; CHECK-LABEL: @masked_load_and_zero_inactive_2( +; CHECK: %load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %ptr, i32 4, <4 x i1> %mask, <4 x i32> zeroinitializer) +; CHECK-NEXT: ret <4 x i32> %load + %load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %ptr, i32 4, <4 x i1> %mask, <4 x i32> zeroinitializer) + %masked = select <4 x i1> %mask, <4 x i32> %load, <4 x i32> zeroinitializer + ret <4 x i32> %masked +} + +; No transform when the load's passthrough cannot be reused or altered. +define <4 x i32> @masked_load_and_zero_inactive_3(<4 x i32>* %ptr, <4 x i1> %mask, <4 x i32> %passthrough) { +; CHECK-LABEL: @masked_load_and_zero_inactive_3( +; CHECK: %load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %ptr, i32 4, <4 x i1> %mask, <4 x i32> %passthrough) +; CHECK-NEXT: %masked = select <4 x i1> %mask, <4 x i32> %load, <4 x i32> zeroinitializer +; CHECK-NEXT: ret <4 x i32> %masked + %load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %ptr, i32 4, <4 x i1> %mask, <4 x i32> %passthrough) + %masked = select <4 x i1> %mask, <4 x i32> %load, <4 x i32> zeroinitializer + ret <4 x i32> %masked +} + +; Remove redundant select when its mask doesn't overlap with the load mask. +define <4 x i32> @masked_load_and_zero_inactive_4(<4 x i32>* %ptr, <4 x i1> %inv_mask) { +; CHECK-LABEL: @masked_load_and_zero_inactive_4( +; CHECK: %mask = xor <4 x i1> %inv_mask, +; CHECK-NEXT: %load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %ptr, i32 4, <4 x i1> %mask, <4 x i32> zeroinitializer) +; CHECK-NEXT: ret <4 x i32> %load + %mask = xor <4 x i1> %inv_mask, + %load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %ptr, i32 4, <4 x i1> %mask, <4 x i32> undef) + %masked = select <4 x i1> %inv_mask, <4 x i32> zeroinitializer, <4 x i32> %load + ret <4 x i32> %masked +} + +; As above but reuse the load's existing passthrough. +define <4 x i32> @masked_load_and_zero_inactive_5(<4 x i32>* %ptr, <4 x i1> %inv_mask) { +; CHECK-LABEL: @masked_load_and_zero_inactive_5( +; CHECK: %mask = xor <4 x i1> %inv_mask, +; CHECK-NEXT: %load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %ptr, i32 4, <4 x i1> %mask, <4 x i32> zeroinitializer) +; CHECK-NEXT: ret <4 x i32> %load + %mask = xor <4 x i1> %inv_mask, + %load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %ptr, i32 4, <4 x i1> %mask, <4 x i32> zeroinitializer) + %masked = select <4 x i1> %inv_mask, <4 x i32> zeroinitializer, <4 x i32> %load + ret <4 x i32> %masked +} + +; No transform when the load's passthrough cannot be reused or altered. +define <4 x i32> @masked_load_and_zero_inactive_6(<4 x i32>* %ptr, <4 x i1> %inv_mask, <4 x i32> %passthrough) { +; CHECK-LABEL: @masked_load_and_zero_inactive_6( +; CHECK: %mask = xor <4 x i1> %inv_mask, +; CHECK-NEXT: %load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %ptr, i32 4, <4 x i1> %mask, <4 x i32> %passthrough) +; CHECK-NEXT: %masked = select <4 x i1> %inv_mask, <4 x i32> zeroinitializer, <4 x i32> %load +; CHECK-NEXT: ret <4 x i32> %masked + %mask = xor <4 x i1> %inv_mask, + %load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %ptr, i32 4, <4 x i1> %mask, <4 x i32> %passthrough) + %masked = select <4 x i1> %inv_mask, <4 x i32> zeroinitializer, <4 x i32> %load + ret <4 x i32> %masked +} + +; No transform when select and load masks have no relation. +define <4 x i32> @masked_load_and_zero_inactive_7(<4 x i32>* %ptr, <4 x i1> %mask1, <4 x i1> %mask2) { +; CHECK-LABEL: @masked_load_and_zero_inactive_7( +; CHECK: %load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %ptr, i32 4, <4 x i1> %mask1, <4 x i32> zeroinitializer) +; CHECK-NEXT: %masked = select <4 x i1> %mask2, <4 x i32> zeroinitializer, <4 x i32> %load +; CHECK-NEXT: ret <4 x i32> %masked + %load = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %ptr, i32 4, <4 x i1> %mask1, <4 x i32> zeroinitializer) + %masked = select <4 x i1> %mask2, <4 x i32> zeroinitializer, <4 x i32> %load + ret <4 x i32> %masked +} + +; A more complex case where we can prove the select mask is a subset of the +; load's inactive lanes and thus the load's passthrough takes effect. +define <4 x float> @masked_load_and_zero_inactive_8(<4 x float>* %ptr, <4 x i1> %inv_mask, <4 x i1> %cond) { +; CHECK-LABEL: @masked_load_and_zero_inactive_8( +; CHECK: %mask = xor <4 x i1> %inv_mask, +; CHECK-NEXT: %pg = and <4 x i1> %mask, %cond +; CHECK-NEXT: %load = call <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>* %ptr, i32 4, <4 x i1> %pg, <4 x float> zeroinitializer) +; CHECK-NEXT: ret <4 x float> %load + %mask = xor <4 x i1> %inv_mask, + %pg = and <4 x i1> %mask, %cond + %load = call <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>* %ptr, i32 4, <4 x i1> %pg, <4 x float> undef) + %masked = select <4 x i1> %inv_mask, <4 x float> zeroinitializer, <4 x float> %load + ret <4 x float> %masked +} + +declare <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>*, i32 immarg, <4 x i1>, <4 x i32>) +declare <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>*, i32 immarg, <4 x i1>, <4 x float>)