Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -931,6 +931,33 @@ return false; } +static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) { + if (!ScalarTy.isSimple()) + return false; + + uint64_t MaskForTy = 0ULL; + switch (ScalarTy.getSimpleVT().SimpleTy) { + case MVT::i8: + MaskForTy = 0xFFULL; + break; + case MVT::i16: + MaskForTy = 0xFFFFULL; + break; + case MVT::i32: + MaskForTy = 0xFFFFFFFFULL; + break; + default: + return false; + break; + } + + APInt Val; + if (ISD::isConstantSplatVector(N, Val)) + return Val.getLimitedValue() == MaskForTy; + + return false; +} + // Returns the SDNode if it is a constant float BuildVector // or constant float. static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) { @@ -5647,6 +5674,28 @@ } } + // fold (and (masked_gather x)) -> (zext_masked_gather x) + if (auto *GN0 = dyn_cast(N0)) { + EVT MemVT = GN0->getMemoryVT(); + EVT ScalarVT = MemVT.getScalarType(); + + if (SDValue(GN0, 0).hasOneUse() && + isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) && + TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) { + SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(), + GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()}; + + SDValue ZExtLoad = DAG.getMaskedGather( + DAG.getVTList(VT, MVT::Other), MemVT, SDLoc(N), Ops, + GN0->getMemOperand(), GN0->getIndexType(), ISD::ZEXTLOAD); + + CombineTo(N, ZExtLoad); + AddToWorklist(ZExtLoad.getNode()); + // Avoid recheck of N. + return SDValue(N, 0); + } + } + // fold (and (load x), 255) -> (zextload x, i8) // fold (and (extload x, i16), 255) -> (zextload x, i8) // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8) @@ -11504,6 +11553,25 @@ } } + // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x) + if (auto *GN0 = dyn_cast(N0)) { + if (SDValue(GN0, 0).hasOneUse() && + ExtVT == GN0->getMemoryVT() && + TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) { + SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(), + GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()}; + + SDValue ExtLoad = DAG.getMaskedGather( + DAG.getVTList(VT, MVT::Other), ExtVT, SDLoc(N), Ops, + GN0->getMemOperand(), GN0->getIndexType(), ISD::SEXTLOAD); + + CombineTo(N, ExtLoad); + CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); + AddToWorklist(ExtLoad.getNode()); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } + } + // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16)) if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) { if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0), Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3820,6 +3820,26 @@ return AddrModes.find(Key)->second; } +unsigned getExtendedGatherOpcode(unsigned Opcode) { + switch (Opcode) { + default: + llvm_unreachable("unimplemented opcode"); + return Opcode; + case AArch64ISD::GLD1_MERGE_ZERO: + return AArch64ISD::GLD1S_MERGE_ZERO; + case AArch64ISD::GLD1_UXTW_MERGE_ZERO: + return AArch64ISD::GLD1S_UXTW_MERGE_ZERO; + case AArch64ISD::GLD1_SXTW_MERGE_ZERO: + return AArch64ISD::GLD1S_SXTW_MERGE_ZERO; + case AArch64ISD::GLD1_SCALED_MERGE_ZERO: + return AArch64ISD::GLD1S_SCALED_MERGE_ZERO; + case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO: + return AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO; + case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO: + return AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO; + } +} + bool getGatherScatterIndexIsExtended(SDValue Index) { unsigned Opcode = Index.getOpcode(); if (Opcode == ISD::SIGN_EXTEND_INREG) @@ -3849,6 +3869,7 @@ SDValue PassThru = MGT->getPassThru(); SDValue Mask = MGT->getMask(); SDValue BasePtr = MGT->getBasePtr(); + ISD::LoadExtType ExtTy = MGT->getExtensionType(); ISD::MemIndexType IndexType = MGT->getIndexType(); bool IsScaled = @@ -3858,6 +3879,7 @@ bool IdxNeedsExtend = getGatherScatterIndexIsExtended(Index) || Index.getSimpleValueType().getVectorElementType() == MVT::i32; + bool ResNeedsExtend = ExtTy == ISD::EXTLOAD || ExtTy == ISD::SEXTLOAD; EVT VT = PassThru.getSimpleValueType(); EVT MemVT = MGT->getMemoryVT(); @@ -3884,9 +3906,12 @@ if (getGatherScatterIndexIsExtended(Index)) Index = Index.getOperand(0); + unsigned Opcode = getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend); + if (ResNeedsExtend) + Opcode = getExtendedGatherOpcode(Opcode); + SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT, PassThru}; - return DAG.getNode(getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend), DL, - VTs, Ops); + return DAG.getNode(Opcode, DL, VTs, Ops); } SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op, Index: llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll +++ llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll @@ -7,7 +7,7 @@ define @masked_gather_nxv2i32( %ptrs, %mask) { ; CHECK-LABEL: masked_gather_nxv2i32: ; CHECK-DAG: mov x8, xzr -; CHECK-DAG: ld1w { z0.d }, p0/z, [x8, z0.d] +; CHECK-DAG: ld1sw { z0.d }, p0/z, [x8, z0.d] ; CHECK: ret %data = call @llvm.masked.gather.nxv2i32( %ptrs, i32 4, %mask, undef) ret %data @@ -41,8 +41,8 @@ ; CHECK-NEXT: mov x8, xzr ; CHECK-NEXT: zip2 p2.s, p0.s, p1.s ; CHECK-NEXT: zip1 p0.s, p0.s, p1.s -; CHECK-NEXT: ld1b { z1.d }, p2/z, [x8, z1.d] -; CHECK-NEXT: ld1b { z0.d }, p0/z, [x8, z0.d] +; CHECK-NEXT: ld1sb { z1.d }, p2/z, [x8, z1.d] +; CHECK-NEXT: ld1sb { z0.d }, p0/z, [x8, z0.d] ; CHECK-NEXT: ptrue p0.s ; CHECK-NEXT: uzp1 z0.s, z0.s, z1.s ; CHECK-NEXT: sxtb z0.s, p0/m, z0.s