diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -2117,6 +2117,14 @@ return m_Intrinsic(Op0, Op1, Op2, Op3); } +/// Matches MaskedGather Intrinsic. +template +inline typename m_Intrinsic_Ty::Ty +m_MaskedGather(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2, + const Opnd3 &Op3) { + return m_Intrinsic(Op0, Op1, Op2, Op3); +} + template inline typename m_Intrinsic_Ty::Ty m_Intrinsic(const T0 &Op0) { return m_CombineAnd(m_Intrinsic(), m_Argument<0>(Op0)); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3031,18 +3031,22 @@ // select(mask, mload(,,mask,0), 0) -> mload(,,mask,0) // Load inst is intentionally not checked for hasOneUse() if (match(FalseVal, m_Zero()) && - match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal), - m_CombineOr(m_Undef(), m_Zero())))) { - auto *MaskedLoad = cast(TrueVal); - if (isa(MaskedLoad->getArgOperand(3))) - MaskedLoad->setArgOperand(3, FalseVal /* Zero */); - return replaceInstUsesWith(SI, MaskedLoad); + (match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal), + m_CombineOr(m_Undef(), m_Zero()))) || + match(TrueVal, m_MaskedGather(m_Value(), m_Value(), m_Specific(CondVal), + m_CombineOr(m_Undef(), m_Zero()))))) { + auto *MaskedInst = cast(TrueVal); + if (isa(MaskedInst->getArgOperand(3))) + MaskedInst->setArgOperand(3, FalseVal /* Zero */); + return replaceInstUsesWith(SI, MaskedInst); } Value *Mask; if (match(TrueVal, m_Zero()) && - match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask), - m_CombineOr(m_Undef(), m_Zero()))) && + (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask), + m_CombineOr(m_Undef(), m_Zero()))) || + match(FalseVal, m_MaskedGather(m_Value(), m_Value(), m_Value(Mask), + m_CombineOr(m_Undef(), m_Zero())))) && (CondVal->getType() == Mask->getType())) { // We can remove the select by ensuring the load zeros all lanes the // select would have. We determine this by proving there is no overlap @@ -3053,10 +3057,10 @@ CanMergeSelectIntoLoad = match(V, m_Zero()); if (CanMergeSelectIntoLoad) { - auto *MaskedLoad = cast(FalseVal); - if (isa(MaskedLoad->getArgOperand(3))) - MaskedLoad->setArgOperand(3, TrueVal /* Zero */); - return replaceInstUsesWith(SI, MaskedLoad); + auto *MaskedInst = cast(FalseVal); + if (isa(MaskedInst->getArgOperand(3))) + MaskedInst->setArgOperand(3, TrueVal /* Zero */); + return replaceInstUsesWith(SI, MaskedInst); } } diff --git a/llvm/test/Transforms/InstCombine/select-masked_gather.ll b/llvm/test/Transforms/InstCombine/select-masked_gather.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/select-masked_gather.ll @@ -0,0 +1,125 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +; Fold zeroing of inactive lanes into the load's passthrough parameter. +define @masked_load_and_zero_inactive_1( %ptr, %mask) { +; CHECK-LABEL: @masked_load_and_zero_inactive_1( +; CHECK-NEXT: [[LOAD:%.*]] = call @llvm.masked.gather.nxv2f32.nxv2p0f32( [[PTR:%.*]], i32 4, [[MASK:%.*]], zeroinitializer) +; CHECK-NEXT: ret [[LOAD]] +; + %load = call @llvm.masked.gather.nxv2f32.p0nxv2f32( %ptr, i32 4, %mask, undef) + %masked = select %mask, %load, zeroinitializer + ret %masked +} + +; As above but reuse the load's existing passthrough. +define @masked_load_and_zero_inactive_2( %ptr, %mask) { +; CHECK-LABEL: @masked_load_and_zero_inactive_2( +; CHECK-NEXT: [[LOAD:%.*]] = call @llvm.masked.gather.nxv2i32.nxv2p0i32( [[PTR:%.*]], i32 4, [[MASK:%.*]], zeroinitializer) +; CHECK-NEXT: ret [[LOAD]] +; + %load = call @llvm.masked.gather.nxv2i32.nxv2p0i32( %ptr, i32 4, %mask, zeroinitializer) + %masked = select %mask, %load, zeroinitializer + ret %masked +} + +; No transform when the load's passthrough cannot be reused or altered. +define @masked_load_and_zero_inactive_3( %ptr, %mask, %passthrough) { +; CHECK-LABEL: @masked_load_and_zero_inactive_3( +; CHECK-NEXT: [[LOAD:%.*]] = call @llvm.masked.gather.nxv2i32.nxv2p0i32( [[PTR:%.*]], i32 4, [[MASK:%.*]], [[PASSTHROUGH:%.*]]) +; CHECK-NEXT: [[MASKED:%.*]] = select [[MASK]], [[LOAD]], zeroinitializer +; CHECK-NEXT: ret [[MASKED]] +; + %load = call @llvm.masked.gather.nxv2i32.nxv2p0i32( %ptr, i32 4, %mask, %passthrough) + %masked = select %mask, %load, zeroinitializer + ret %masked +} + +; Remove redundant select when its mask doesn't overlap with the load mask. +define @masked_load_and_zero_inactive_4( %ptr, %inv_mask) { +; CHECK-LABEL: @masked_load_and_zero_inactive_4( +; CHECK-NEXT: [[MASK:%.*]] = xor [[INV_MASK:%.*]], shufflevector ( insertelement ( undef, i1 true, i32 0), undef, zeroinitializer) +; CHECK-NEXT: [[LOAD:%.*]] = call @llvm.masked.gather.nxv2i32.nxv2p0i32( [[PTR:%.*]], i32 4, [[MASK]], zeroinitializer) +; CHECK-NEXT: ret [[LOAD]] +; + %splat = shufflevector insertelement ( undef, i1 true, i32 0), undef, zeroinitializer + %mask = xor %inv_mask, %splat + %load = call @llvm.masked.gather.nxv2i32.nxv2p0i32( %ptr, i32 4, %mask, undef) + %masked = select %inv_mask, zeroinitializer, %load + ret %masked +} + +; As above but reuse the load's existing passthrough. +define @masked_load_and_zero_inactive_5( %ptr, %inv_mask) { +; CHECK-LABEL: @masked_load_and_zero_inactive_5( +; CHECK-NEXT: [[MASK:%.*]] = xor [[INV_MASK:%.*]], shufflevector ( insertelement ( undef, i1 true, i32 0), undef, zeroinitializer) +; CHECK-NEXT: [[LOAD:%.*]] = call @llvm.masked.gather.nxv2i32.nxv2p0i32( [[PTR:%.*]], i32 4, [[MASK]], zeroinitializer) +; CHECK-NEXT: ret [[LOAD]] +; + %splat = shufflevector insertelement ( undef, i1 true, i32 0), undef, zeroinitializer + %mask = xor %inv_mask, %splat + %load = call @llvm.masked.gather.nxv2i32.nxv2p0i32( %ptr, i32 4, %mask, zeroinitializer) + %masked = select %inv_mask, zeroinitializer, %load + ret %masked +} + +; No transform when the load's passthrough cannot be reused or altered. +define @masked_load_and_zero_inactive_6( %ptr, %inv_mask, %passthrough) { +; CHECK-LABEL: @masked_load_and_zero_inactive_6( +; CHECK-NEXT: [[MASK:%.*]] = xor [[INV_MASK:%.*]], shufflevector ( insertelement ( undef, i1 true, i32 0), undef, zeroinitializer) +; CHECK-NEXT: [[LOAD:%.*]] = call @llvm.masked.gather.nxv2i32.nxv2p0i32( [[PTR:%.*]], i32 4, [[MASK]], [[PASSTHROUGH:%.*]]) +; CHECK-NEXT: [[MASKED:%.*]] = select [[INV_MASK]], zeroinitializer, [[LOAD]] +; CHECK-NEXT: ret [[MASKED]] +; + %splat = shufflevector insertelement ( undef, i1 true, i32 0), undef, zeroinitializer + %mask = xor %inv_mask, %splat + %load = call @llvm.masked.gather.nxv2i32.nxv2p0i32( %ptr, i32 4, %mask, %passthrough) + %masked = select %inv_mask, zeroinitializer, %load + ret %masked +} + +; No transform when select and load masks have no relation. +define @masked_load_and_zero_inactive_7( %ptr, %mask1, %mask2) { +; CHECK-LABEL: @masked_load_and_zero_inactive_7( +; CHECK-NEXT: [[LOAD:%.*]] = call @llvm.masked.gather.nxv2i32.nxv2p0i32( [[PTR:%.*]], i32 4, [[MASK1:%.*]], zeroinitializer) +; CHECK-NEXT: [[MASKED:%.*]] = select [[MASK2:%.*]], zeroinitializer, [[LOAD]] +; CHECK-NEXT: ret [[MASKED]] +; + %load = call @llvm.masked.gather.nxv2i32.nxv2p0i32( %ptr, i32 4, %mask1, zeroinitializer) + %masked = select %mask2, zeroinitializer, %load + ret %masked +} + +; A more complex case where we can prove the select mask is a subset of the +; load's inactive lanes and thus the load's passthrough takes effect. +define @masked_load_and_zero_inactive_8( %ptr, %inv_mask, %cond) { +; CHECK-LABEL: @masked_load_and_zero_inactive_8( +; CHECK-NEXT: [[MASK:%.*]] = xor [[INV_MASK:%.*]], shufflevector ( insertelement ( undef, i1 true, i32 0), undef, zeroinitializer) +; CHECK-NEXT: [[PG:%.*]] = and [[MASK]], [[COND:%.*]] +; CHECK-NEXT: [[LOAD:%.*]] = call @llvm.masked.gather.nxv2f32.nxv2p0f32( [[PTR:%.*]], i32 4, [[PG]], zeroinitializer) +; CHECK-NEXT: ret [[LOAD]] +; + %splat = shufflevector insertelement ( undef, i1 true, i32 0), undef, zeroinitializer + %mask = xor %inv_mask, %splat + %pg = and %mask, %cond + %load = call @llvm.masked.gather.v4f32.p0v4f32( %ptr, i32 4, %pg, undef) + %masked = select %inv_mask, zeroinitializer, %load + ret %masked +} + +define @masked_load_and_scalar_select_cond( %ptr, %mask, i1 %cond) { +; CHECK-LABEL: @masked_load_and_scalar_select_cond( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = call @llvm.masked.gather.nxv2f32.nxv2p0f32( [[PTR:%.*]], i32 32, [[MASK:%.*]], undef) +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], zeroinitializer, [[TMP0]] +; CHECK-NEXT: ret [[TMP1]] +; +entry: + %0 = call @llvm.masked.gather.nxv2f32.p0nxv2f32( %ptr, i32 32, %mask, undef) + %1 = select i1 %cond, zeroinitializer, %0 + ret %1 +} + +declare @llvm.masked.gather.v4f32.p0v4f32(, i32 immarg, , ) +declare @llvm.masked.gather.nxv2i32.nxv2p0i32(, i32, , ) +declare @llvm.masked.gather.nxv2f32.p0nxv2f32(, i32, , )