diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -835,6 +835,8 @@ if (Subtarget->supportsAddressTopByteIgnored()) setTargetDAGCombine(ISD::LOAD); + setTargetDAGCombine(ISD::MSCATTER); + setTargetDAGCombine(ISD::MUL); setTargetDAGCombine(ISD::SELECT); @@ -13944,6 +13946,44 @@ return SDValue(); } +static SDValue performMSCATTERCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + MaskedScatterSDNode *MSC = cast(N); + assert(MSC && "Can only combine scatter store nodes"); + + SDLoc DL(MSC); + SDValue Chain = MSC->getChain(); + SDValue Scale = MSC->getScale(); + SDValue Index = MSC->getIndex(); + SDValue Data = MSC->getValue(); + SDValue Mask = MSC->getMask(); + SDValue BasePtr = MSC->getBasePtr(); + ISD::MemIndexType IndexType = MSC->getIndexType(); + + EVT IdxVT = Index.getValueType(); + + if (DCI.isBeforeLegalize()) { + // SVE gather/scatter requires indices of i32/i64. Promote anything smaller + // prior to legalisation so the result can be split if required. + if ((IdxVT.getVectorElementType() == MVT::i8) || + (IdxVT.getVectorElementType() == MVT::i16)) { + EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32); + if (MSC->isIndexSigned()) + Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index); + else + Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index); + + SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), + MSC->getMemoryVT(), DL, Ops, + MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); + } + } + + return SDValue(); +} /// Target-specific DAG combine function for NEON load/store intrinsics /// to merge base address updates. @@ -15136,6 +15176,8 @@ break; case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); + case ISD::MSCATTER: + return performMSCATTERCombine(N, DCI, DAG); case AArch64ISD::BRCOND: return performBRCONDCombine(N, DCI, DAG); case AArch64ISD::TBNZ: diff --git a/llvm/test/CodeGen/AArch64/sve-masked-scatter-legalise.ll b/llvm/test/CodeGen/AArch64/sve-masked-scatter-legalise.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-masked-scatter-legalise.ll @@ -0,0 +1,59 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; Tests that exercise various type legalisation scenarios for ISD::MSCATTER. + +; Code generate the scenario where the offset vector type is illegal. +define void @masked_scatter_nxv16i8( %data, i8* %base, %offsets, %mask) { +; CHECK-LABEL: masked_scatter_nxv16i8: +; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: st1b { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK: ret + %ptrs = getelementptr i8, i8* %base, %offsets + call void @llvm.masked.scatter.nxv16i8( %data, %ptrs, i32 1, %mask) + ret void +} + +define void @masked_scatter_nxv8i16( %data, i16* %base, %offsets, %mask) { +; CHECK-LABEL: masked_scatter_nxv8i16 +; CHECK-DAG: st1h { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #1] +; CHECK-DAG: st1h { {{z[0-9]+}}.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #1] +; CHECK: ret + %ptrs = getelementptr i16, i16* %base, %offsets + call void @llvm.masked.scatter.nxv8i16( %data, %ptrs, i32 1, %mask) + ret void +} + +define void @masked_scatter_nxv8f32( %data, float* %base, %indexes, %masks) { +; CHECK-LABEL: masked_scatter_nxv8f32 +; CHECK-DAG: st1w { z0.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, uxtw #2] +; CHECK-DAG: st1w { z1.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, uxtw #2] + %ext = zext %indexes to + %ptrs = getelementptr float, float* %base, %ext + call void @llvm.masked.scatter.nxv8f32( %data, %ptrs, i32 0, %masks) + ret void +} + +; Code generate the worst case scenario when all vector types are illegal. +define void @masked_scatter_nxv32i32( %data, i32* %base, %offsets, %mask) { +; CHECK-LABEL: masked_scatter_nxv32i32: +; CHECK-NOT: unpkhi +; CHECK-DAG: st1w { z0.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z1.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z2.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z3.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z4.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z5.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z6.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK-DAG: st1w { z7.s }, {{p[0-9]+}}, [x0, {{z[0-9]+}}.s, sxtw #2] +; CHECK: ret + %ptrs = getelementptr i32, i32* %base, %offsets + call void @llvm.masked.scatter.nxv32i32( %data, %ptrs, i32 4, %mask) + ret void +} + +declare void @llvm.masked.scatter.nxv16i8(, , i32, ) +declare void @llvm.masked.scatter.nxv8i16(, , i32, ) +declare void @llvm.masked.scatter.nxv8f32(, , i32, ) +declare void @llvm.masked.scatter.nxv32i32(, , i32, )