diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -8084,13 +8084,17 @@ } SDValue DAGCombiner::visitMGATHER(SDNode *N) { - if (Level >= AfterLegalizeTypes) - return SDValue(); - MaskedGatherSDNode *MGT = cast(N); SDValue Mask = MGT->getMask(); SDLoc DL(N); + // Zap gathers with a zero mask. + if (isNullOrNullSplat(Mask)) + return CombineTo(N, MGT->getPassThru(), MGT->getChain()); + + if (Level >= AfterLegalizeTypes) + return SDValue(); + // If the MGATHER result requires splitting and the mask is provided by a // SETCC, then split both nodes and its operands before legalization. This // prevents the type legalizer from unrolling SETCC into scalar comparisons diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -8668,7 +8668,8 @@ return C && C->isAllOnesValue(); } -ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs) { +static ConstantSDNode *isConstOrConstSplat(SDValue N, bool AllowUndefs, + bool IgnoreBitWidth) { if (ConstantSDNode *CN = dyn_cast(N)) return CN; @@ -8678,13 +8679,18 @@ // BuildVectors can truncate their operands. Ignore that case here. if (CN && (UndefElements.none() || AllowUndefs) && - CN->getValueType(0) == N.getValueType().getScalarType()) + (IgnoreBitWidth || + CN->getValueType(0) == N.getValueType().getScalarType())) return CN; } return nullptr; } +ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs) { + return ::isConstOrConstSplat(N, AllowUndefs, /*IgnoreBitWidth=*/false); +} + ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, const APInt &DemandedElts, bool AllowUndefs) { if (ConstantSDNode *CN = dyn_cast(N)) @@ -8736,7 +8742,8 @@ bool llvm::isNullOrNullSplat(SDValue N, bool AllowUndefs) { // TODO: may want to use peekThroughBitcast() here. - ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs); + ConstantSDNode *C = + ::isConstOrConstSplat(N, AllowUndefs, /*IgnoreBitWidth=*/true); return C && C->isNullValue(); } diff --git a/llvm/test/CodeGen/AArch64/vecreduce-bool.ll b/llvm/test/CodeGen/AArch64/vecreduce-bool.ll --- a/llvm/test/CodeGen/AArch64/vecreduce-bool.ll +++ b/llvm/test/CodeGen/AArch64/vecreduce-bool.ll @@ -96,9 +96,8 @@ define i32 @reduce_and_v32(<32 x i8> %a0, i32 %a1, i32 %a2) nounwind { ; CHECK-LABEL: reduce_and_v32: ; CHECK: // %bb.0: -; CHECK-NEXT: cmlt v1.16b, v1.16b, #0 -; CHECK-NEXT: cmlt v0.16b, v0.16b, #0 ; CHECK-NEXT: and v0.16b, v0.16b, v1.16b +; CHECK-NEXT: cmlt v0.16b, v0.16b, #0 ; CHECK-NEXT: uminv b0, v0.16b ; CHECK-NEXT: fmov w8, s0 ; CHECK-NEXT: tst w8, #0x1 @@ -191,9 +190,8 @@ define i32 @reduce_or_v32(<32 x i8> %a0, i32 %a1, i32 %a2) nounwind { ; CHECK-LABEL: reduce_or_v32: ; CHECK: // %bb.0: -; CHECK-NEXT: cmlt v1.16b, v1.16b, #0 -; CHECK-NEXT: cmlt v0.16b, v0.16b, #0 ; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: cmlt v0.16b, v0.16b, #0 ; CHECK-NEXT: umaxv b0, v0.16b ; CHECK-NEXT: fmov w8, s0 ; CHECK-NEXT: tst w8, #0x1 diff --git a/llvm/test/CodeGen/X86/avx2-masked-gather.ll b/llvm/test/CodeGen/X86/avx2-masked-gather.ll --- a/llvm/test/CodeGen/X86/avx2-masked-gather.ll +++ b/llvm/test/CodeGen/X86/avx2-masked-gather.ll @@ -769,3 +769,24 @@ ret <2 x double> %res } + +define <2 x double> @masked_gather_zeromask(<2 x double*>* %ptr, <2 x double> %dummy, <2 x double> %passthru) { +; X86-LABEL: masked_gather_zeromask: +; X86: # %bb.0: # %entry +; X86-NEXT: vmovaps %xmm1, %xmm0 +; X86-NEXT: retl +; +; X64-LABEL: masked_gather_zeromask: +; X64: # %bb.0: # %entry +; X64-NEXT: vmovaps %xmm1, %xmm0 +; X64-NEXT: retq +; +; NOGATHER-LABEL: masked_gather_zeromask: +; NOGATHER: # %bb.0: # %entry +; NOGATHER-NEXT: vmovaps %xmm1, %xmm0 +; NOGATHER-NEXT: retq +entry: + %ld = load <2 x double*>, <2 x double*>* %ptr + %res = call <2 x double> @llvm.masked.gather.v2double(<2 x double*> %ld, i32 0, <2 x i1> zeroinitializer, <2 x double> %passthru) + ret <2 x double> %res +}