diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12092,13 +12092,9 @@ return SDValue(N, 0); // Return N so it doesn't get rechecked! } -static SDValue tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, - const TargetLowering &TLI, EVT VT, - SDNode *N, SDValue N0, - ISD::LoadExtType ExtLoadType, - ISD::NodeType ExtOpc) { - if (!N0.hasOneUse()) - return SDValue(); +static SDValue tryToFoldExtOfMaskedLoad( + SelectionDAG &DAG, DAGCombiner &Combiner, const TargetLowering &TLI, EVT VT, + SDNode *N, SDValue N0, ISD::LoadExtType ExtLoadType, ISD::NodeType ExtOpc) { MaskedLoadSDNode *Ld = dyn_cast(N0); if (!Ld || Ld->getExtensionType() != ISD::NON_EXTLOAD) @@ -12110,14 +12106,51 @@ if (!TLI.isVectorLoadExtDesirable(SDValue(N, 0))) return SDValue(); + auto AllUsesCanBeReplaced = [&](SDValue V) { + bool isTruncFree = TLI.isTruncateFree(VT, V.getValueType()); + for (SDNode::use_iterator UI = V->use_begin(), UE = V->use_end(); UI != UE; + ++UI) { + SDNode *User = *UI; + // Skip chain uses and the extension dag node N + if (UI.getUse().getResNo() != V.getResNo() || User == N) + continue; + // FIXME: May be possible to handle these cases: + if (User->getOpcode() == ISD::SETCC || + User->getOpcode() == ISD::CopyToReg) + return false; + // Replacing a non-free truncate with another non-free truncate should + // not generate extra code. + if (User->getOpcode() == ISD::TRUNCATE && + !TLI.isTruncateFree(V.getValueType(), User->getValueType(0))) + continue; + if (!isTruncFree) + return false; + } + return true; + }; + + if (!AllUsesCanBeReplaced(N0)) + return SDValue(); + SDLoc dl(Ld); SDValue PassThru = DAG.getNode(ExtOpc, dl, VT, Ld->getPassThru()); SDValue NewLoad = DAG.getMaskedLoad( VT, dl, Ld->getChain(), Ld->getBasePtr(), Ld->getOffset(), Ld->getMask(), PassThru, Ld->getMemoryVT(), Ld->getMemOperand(), Ld->getAddressingMode(), ExtLoadType, Ld->isExpandingLoad()); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), SDValue(NewLoad.getNode(), 1)); - return NewLoad; + + bool OnlyReplaceChainUses = SDValue(Ld, 0).hasOneUse(); + Combiner.CombineTo(N, NewLoad); + if (OnlyReplaceChainUses) { + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), NewLoad.getValue(1)); + Combiner.recursivelyDeleteUnusedNodes(Ld); + } else { + // Replace all old uses with a truncate of the new load. + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, dl, N0.getValueType(), NewLoad); + Combiner.CombineTo(Ld, Trunc, NewLoad.getValue(1)); + } + + return SDValue(N, 0); } static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG, @@ -12354,9 +12387,8 @@ ISD::SEXTLOAD, ISD::SIGN_EXTEND)) return foldedExt; - if (SDValue foldedExt = - tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::SEXTLOAD, - ISD::SIGN_EXTEND)) + if (SDValue foldedExt = tryToFoldExtOfMaskedLoad( + DAG, *this, TLI, VT, N, N0, ISD::SEXTLOAD, ISD::SIGN_EXTEND)) return foldedExt; // fold (sext (load x)) to multiple smaller sextloads. @@ -12630,9 +12662,8 @@ ISD::ZEXTLOAD, ISD::ZERO_EXTEND)) return foldedExt; - if (SDValue foldedExt = - tryToFoldExtOfMaskedLoad(DAG, TLI, VT, N, N0, ISD::ZEXTLOAD, - ISD::ZERO_EXTEND)) + if (SDValue foldedExt = tryToFoldExtOfMaskedLoad( + DAG, *this, TLI, VT, N, N0, ISD::ZEXTLOAD, ISD::ZERO_EXTEND)) return foldedExt; // fold (zext (load x)) to multiple smaller zextloads. diff --git a/llvm/test/CodeGen/AArch64/sve-load-compare-store.ll b/llvm/test/CodeGen/AArch64/sve-load-compare-store.ll --- a/llvm/test/CodeGen/AArch64/sve-load-compare-store.ll +++ b/llvm/test/CodeGen/AArch64/sve-load-compare-store.ll @@ -6,9 +6,7 @@ ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: ptrue p0.b ; CHECK-NEXT: ld1h { z0.s }, p0/z, [x0] -; CHECK-NEXT: mov z1.d, z0.d -; CHECK-NEXT: and z1.s, z1.s, #0xffff -; CHECK-NEXT: cmphs p0.s, p0/z, z1.s, #0 +; CHECK-NEXT: cmphs p0.s, p0/z, z0.s, #0 ; CHECK-NEXT: st1b { z0.s }, p0, [x1] ; CHECK-NEXT: ret entry: