diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4662,6 +4662,17 @@ ISD::LoadExtType ExtType = MGT->getExtensionType(); ISD::MemIndexType IndexType = MGT->getIndexType(); + // SVE supports zero (and so undef) passthrough values only, everything else + // must be handled manually by an explicit select on the load's output. + if (!PassThru->isUndef() && !isZerosVector(PassThru.getNode())) { + SDValue Ops[] = {Chain, DAG.getUNDEF(VT), Mask, BasePtr, Index, Scale}; + SDValue Load = + DAG.getMaskedGather(MGT->getVTList(), MemVT, DL, Ops, + MGT->getMemOperand(), IndexType, ExtType); + SDValue Select = DAG.getSelect(DL, VT, Mask, Load, PassThru); + return DAG.getMergeValues({Select, Load.getValue(1)}, DL); + } + bool IsScaled = IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED; bool IsSigned = @@ -4708,17 +4719,9 @@ VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask); } - if (PassThru->isUndef() || isZerosVector(PassThru.getNode())) - PassThru = SDValue(); - - if (VT.isFloatingPoint() && !IsFixedLength) { - // Handle FP data by using an integer gather and casting the result. - if (PassThru) { - EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount()); - PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG); - } + // Handle FP data by using an integer gather and casting the result. + if (VT.isFloatingPoint() && !IsFixedLength) InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger()); - } SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other); @@ -4750,16 +4753,8 @@ Result); Result = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Result); Result = DAG.getNode(ISD::BITCAST, DL, VT, Result); - - if (PassThru) - Result = DAG.getSelect(DL, VT, MGT->getMask(), Result, PassThru); - } else { - if (PassThru) - Result = DAG.getSelect(DL, IndexVT, Mask, Result, PassThru); - - if (VT.isFloatingPoint()) - Result = getSVESafeBitCast(VT, Result, DAG); - } + } else if (VT.isFloatingPoint()) + Result = getSVESafeBitCast(VT, Result, DAG); return DAG.getMergeValues({Result, Chain}, DL); }