Index: lib/Target/X86/X86ISelDAGToDAG.cpp =================================================================== --- lib/Target/X86/X86ISelDAGToDAG.cpp +++ lib/Target/X86/X86ISelDAGToDAG.cpp @@ -1538,6 +1538,19 @@ return true; } +// Make sure this node has one path back to Root. Otherwise its not legal +// to fold the load. +static bool hasSingleUsesFromRoot(SDNode *Root, SDNode *N) { + SDNode *User = *N->use_begin(); + while (User != Root) { + if (!User->hasOneUse()) + return false; + User = *User->use_begin(); + } + + return true; +} + /// Match a scalar SSE load. In particular, we want to match a load whose top /// elements are either undef or zeros. The load flavor is derived from the /// type of N, which is either v4f32 or v2f64. @@ -1554,7 +1567,8 @@ if (ISD::isNON_EXTLoad(N.getNode())) { PatternNodeWithChain = N; if (IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) && - IsLegalToFold(PatternNodeWithChain, *N->use_begin(), Root, OptLevel)) { + IsLegalToFold(PatternNodeWithChain, *N->use_begin(), Root, OptLevel) && + hasSingleUsesFromRoot(Root, N.getNode())) { LoadSDNode *LD = cast(PatternNodeWithChain); return selectAddr(LD, LD->getBasePtr(), Base, Scale, Index, Disp, Segment); @@ -1565,7 +1579,8 @@ if (N.getOpcode() == X86ISD::VZEXT_LOAD) { PatternNodeWithChain = N; if (IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) && - IsLegalToFold(PatternNodeWithChain, *N->use_begin(), Root, OptLevel)) { + IsLegalToFold(PatternNodeWithChain, *N->use_begin(), Root, OptLevel) && + hasSingleUsesFromRoot(Root, N.getNode())) { auto *MI = cast(PatternNodeWithChain); return selectAddr(MI, MI->getBasePtr(), Base, Scale, Index, Disp, Segment); @@ -1579,7 +1594,8 @@ PatternNodeWithChain = N.getOperand(0); if (ISD::isNON_EXTLoad(PatternNodeWithChain.getNode()) && IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) && - IsLegalToFold(PatternNodeWithChain, N.getNode(), Root, OptLevel)) { + IsLegalToFold(PatternNodeWithChain, N.getNode(), Root, OptLevel) && + hasSingleUsesFromRoot(Root, N.getNode())) { LoadSDNode *LD = cast(PatternNodeWithChain); return selectAddr(LD, LD->getBasePtr(), Base, Scale, Index, Disp, Segment); @@ -1595,7 +1611,8 @@ PatternNodeWithChain = N.getOperand(0).getOperand(0); if (ISD::isNON_EXTLoad(PatternNodeWithChain.getNode()) && IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) && - IsLegalToFold(PatternNodeWithChain, N.getNode(), Root, OptLevel)) { + IsLegalToFold(PatternNodeWithChain, N.getNode(), Root, OptLevel) && + hasSingleUsesFromRoot(Root, N.getNode())) { // Okay, this is a zero extending load. Fold it. LoadSDNode *LD = cast(PatternNodeWithChain); return selectAddr(LD, LD->getBasePtr(), Base, Scale, Index, Disp, Index: test/CodeGen/X86/avx512-memfold.ll =================================================================== --- test/CodeGen/X86/avx512-memfold.ll +++ test/CodeGen/X86/avx512-memfold.ll @@ -72,9 +72,10 @@ define <4 x float> @test_mask_add_ss_double_use(<4 x float> %a, float* %b, i8 %mask, <4 x float> %c) { ; CHECK-LABEL: test_mask_add_ss_double_use: ; CHECK: ## BB#0: +; CHECK-NEXT: vmovss {{.*#+}} xmm2 = mem[0],zero,zero,zero ; CHECK-NEXT: kmovw %esi, %k1 -; CHECK-NEXT: vaddss (%rdi), %xmm0, %xmm1 {%k1} -; CHECK-NEXT: vaddss (%rdi), %xmm0, %xmm0 {%k1} {z} +; CHECK-NEXT: vaddss %xmm2, %xmm0, %xmm1 {%k1} +; CHECK-NEXT: vaddss %xmm2, %xmm0, %xmm0 {%k1} {z} ; CHECK-NEXT: vmulps %xmm0, %xmm1, %xmm0 ; CHECK-NEXT: retq %b.val = load float, float* %b