Index: llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/OverflowInstAnalysis.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" @@ -2638,6 +2639,21 @@ return R; } +bool isEffectivelyZero(Function *F, Value *Val) { + if (match(Val, m_Zero())) + return true; + + if (F) { + Value *SplatVal = llvm::getSplatValue(Val); + ConstantFP *SplatFPVal = dyn_cast_or_null(SplatVal); + if (SplatFPVal && SplatFPVal->isExactlyValue(-0.0) && + F->getFnAttribute("no-signed-zeros-fp-math").getValueAsBool()) + return true; + } + + return false; +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -3167,14 +3183,15 @@ // select(mask, mload(,,mask,0), 0) -> mload(,,mask,0) // Load inst is intentionally not checked for hasOneUse() - if (match(FalseVal, m_Zero()) && + if (isEffectivelyZero(SI.getParent()->getParent(), FalseVal) && (match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal), m_CombineOr(m_Undef(), m_Zero()))) || match(TrueVal, m_MaskedGather(m_Value(), m_Value(), m_Specific(CondVal), m_CombineOr(m_Undef(), m_Zero()))))) { auto *MaskedInst = cast(TrueVal); if (isa(MaskedInst->getArgOperand(3))) - MaskedInst->setArgOperand(3, FalseVal /* Zero */); + MaskedInst->setArgOperand( + 3, ConstantAggregateZero::get(FalseVal->getType())); return replaceInstUsesWith(SI, MaskedInst); } Index: llvm/test/Transforms/InstCombine/select-masked_load.ll =================================================================== --- llvm/test/Transforms/InstCombine/select-masked_load.ll +++ llvm/test/Transforms/InstCombine/select-masked_load.ll @@ -106,6 +106,27 @@ ret <8 x float> %1 } +define <4 x float> @masked_load_and_minus_zero_inactive_1(<4 x float>* %ptr, <4 x i1> %mask) #0 { +; CHECK-LABEL: @masked_load_and_minus_zero_inactive_1( +; CHECK: %load = call <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>* %ptr, i32 4, <4 x i1> %mask, <4 x float> zeroinitializer) +; CHECK-NEXT: ret <4 x float> %load + %load = call <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>* %ptr, i32 4, <4 x i1> %mask, <4 x float> poison) + %masked = select <4 x i1> %mask, <4 x float> %load, <4 x float> + ret <4 x float> %masked +} + +define @masked_load_and_minus_zero_inactive_2(* %ptr, %mask) #0 { +; CHECK-LABEL: @masked_load_and_minus_zero_inactive_2( +; CHECK: %load = call @llvm.masked.load.nxv4f32.p0nxv4f32(* %ptr, i32 4, %mask, zeroinitializer) +; CHECK-NEXT: ret %load + %load = call @llvm.masked.load.nxv4f32.p0nxv4f32(* %ptr, i32 4, %mask, poison) + %masked = select %mask, %load, shufflevector ( insertelement ( poison, float -0.000000e+00, i32 0), poison, zeroinitializer) + ret %masked +} + declare <8 x float> @llvm.masked.load.v8f32.p0v8f32(<8 x float>*, i32 immarg, <8 x i1>, <8 x float>) declare <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>*, i32 immarg, <4 x i1>, <4 x i32>) declare <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>*, i32 immarg, <4 x i1>, <4 x float>) +declare @llvm.masked.load.nxv4f32.p0nxv4f32(*, i32 immarg, , ) + +attributes #0 = { "no-signed-zeros-fp-math"="true" }