diff --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h --- a/llvm/include/llvm/CodeGen/ValueTypes.h +++ b/llvm/include/llvm/CodeGen/ValueTypes.h @@ -364,6 +364,12 @@ return {(BaseSize.getKnownMinSize() + 7) / 8, BaseSize.isScalable()}; } + // Return the number of bytes overwritten by a store of this value type or + // this value type's element type in the case of a vector. + uint64_t getScalarStoreSize() const { + return getScalarType().getStoreSize().getFixedSize(); + } + /// Return the number of bits overwritten by a store of the specified value /// type. /// diff --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h --- a/llvm/include/llvm/Support/MachineValueType.h +++ b/llvm/include/llvm/Support/MachineValueType.h @@ -1078,6 +1078,12 @@ return {(BaseSize.getKnownMinSize() + 7) / 8, BaseSize.isScalable()}; } + // Return the number of bytes overwritten by a store of this value type or + // this value type's element type in the case of a vector. + uint64_t getScalarStoreSize() const { + return getScalarType().getStoreSize().getFixedSize(); + } + /// Return the number of bits overwritten by a store of the specified value /// type. /// diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4400,9 +4400,14 @@ Base = SDB->getValue(BasePtr); Index = SDB->getValue(IndexVal); IndexType = ISD::SIGNED_SCALED; - Scale = DAG.getTargetConstant( - DL.getTypeAllocSize(GEP->getResultElementType()), - SDB->getCurSDLoc(), TLI.getPointerTy(DL)); + + // MGATHER/MSCATTER only support scaling by a power-of-two. + uint64_t ScaleVal = DL.getTypeAllocSize(GEP->getResultElementType()); + if (!isPowerOf2_64(ScaleVal)) + return false; + + Scale = + DAG.getTargetConstant(ScaleVal, SDB->getCurSDLoc(), TLI.getPointerTy(DL)); return true; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4650,33 +4650,50 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op, SelectionDAG &DAG) const { - SDLoc DL(Op); MaskedGatherSDNode *MGT = cast(Op); - assert(MGT && "Can only custom lower gather load nodes"); - - bool IsFixedLength = MGT->getMemoryVT().isFixedLengthVector(); - SDValue Index = MGT->getIndex(); + SDLoc DL(Op); SDValue Chain = MGT->getChain(); SDValue PassThru = MGT->getPassThru(); SDValue Mask = MGT->getMask(); SDValue BasePtr = MGT->getBasePtr(); - ISD::LoadExtType ExtTy = MGT->getExtensionType(); - + SDValue Index = MGT->getIndex(); + SDValue Scale = MGT->getScale(); + EVT VT = Op.getValueType(); + EVT MemVT = MGT->getMemoryVT(); + ISD::LoadExtType ExtType = MGT->getExtensionType(); ISD::MemIndexType IndexType = MGT->getIndexType(); + bool IsScaled = IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED; bool IsSigned = IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED; + + // SVE supports an index scaled by sizeof(MemVT.elt) only, everything else + // must be calculated before hand. + uint64_t ScaleVal = cast(Scale)->getZExtValue(); + if (IsScaled && ScaleVal != MemVT.getScalarStoreSize()) { + assert(isPowerOf2_64(ScaleVal) && "Expecting power-of-two types"); + EVT IndexVT = Index.getValueType(); + Index = DAG.getNode(ISD::SHL, DL, IndexVT, Index, + DAG.getConstant(Log2_32(ScaleVal), DL, IndexVT)); + Scale = DAG.getTargetConstant(1, DL, Scale.getValueType()); + + SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; + IndexType = IsSigned ? ISD::SIGNED_UNSCALED : ISD::UNSIGNED_UNSCALED; + return DAG.getMaskedGather(MGT->getVTList(), MemVT, DL, Ops, + MGT->getMemOperand(), IndexType, ExtType); + } + bool IdxNeedsExtend = getGatherScatterIndexIsExtended(Index) || Index.getSimpleValueType().getVectorElementType() == MVT::i32; - EVT VT = PassThru.getSimpleValueType(); EVT IndexVT = Index.getSimpleValueType(); - EVT MemVT = MGT->getMemoryVT(); SDValue InputVT = DAG.getValueType(MemVT); + bool IsFixedLength = MGT->getMemoryVT().isFixedLengthVector(); + if (IsFixedLength) { assert(Subtarget->useSVEForFixedLengthVectors() && "Cannot lower when not using SVE for fixed vectors"); @@ -4714,7 +4731,7 @@ selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode, /*isGather=*/true, DAG); - if (ExtTy == ISD::SEXTLOAD) + if (ExtType == ISD::SEXTLOAD) Opcode = getSignExtendedGatherOpcode(Opcode); if (IsFixedLength) { @@ -4751,33 +4768,51 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const { - SDLoc DL(Op); MaskedScatterSDNode *MSC = cast(Op); - assert(MSC && "Can only custom lower scatter store nodes"); - bool IsFixedLength = MSC->getMemoryVT().isFixedLengthVector(); - - SDValue Index = MSC->getIndex(); + SDLoc DL(Op); SDValue Chain = MSC->getChain(); SDValue StoreVal = MSC->getValue(); SDValue Mask = MSC->getMask(); SDValue BasePtr = MSC->getBasePtr(); - + SDValue Index = MSC->getIndex(); + SDValue Scale = MSC->getScale(); + EVT VT = StoreVal.getValueType(); + EVT MemVT = MSC->getMemoryVT(); ISD::MemIndexType IndexType = MSC->getIndexType(); + bool IsScaled = IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED; bool IsSigned = IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED; + + // SVE supports an index scaled by sizeof(MemVT.elt) only, everything else + // must be calculated before hand. + uint64_t ScaleVal = cast(Scale)->getZExtValue(); + if (IsScaled && ScaleVal != MemVT.getScalarStoreSize()) { + assert(isPowerOf2_64(ScaleVal) && "Expecting power-of-two types"); + EVT IndexVT = Index.getValueType(); + Index = DAG.getNode(ISD::SHL, DL, IndexVT, Index, + DAG.getConstant(Log2_32(ScaleVal), DL, IndexVT)); + Scale = DAG.getTargetConstant(1, DL, Scale.getValueType()); + + SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; + IndexType = IsSigned ? ISD::SIGNED_UNSCALED : ISD::UNSIGNED_UNSCALED; + return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops, + MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); + } + bool NeedsExtend = getGatherScatterIndexIsExtended(Index) || Index.getSimpleValueType().getVectorElementType() == MVT::i32; - EVT VT = StoreVal.getSimpleValueType(); EVT IndexVT = Index.getSimpleValueType(); SDVTList VTs = DAG.getVTList(MVT::Other); - EVT MemVT = MSC->getMemoryVT(); SDValue InputVT = DAG.getValueType(MemVT); + bool IsFixedLength = MSC->getMemoryVT().isFixedLengthVector(); + if (IsFixedLength) { assert(Subtarget->useSVEForFixedLengthVectors() && "Cannot lower when not using SVE for fixed vectors"); diff --git a/llvm/test/CodeGen/AArch64/sve-masked-gather.ll b/llvm/test/CodeGen/AArch64/sve-masked-gather.ll --- a/llvm/test/CodeGen/AArch64/sve-masked-gather.ll +++ b/llvm/test/CodeGen/AArch64/sve-masked-gather.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve -opaque-pointers < %s | FileCheck %s define @masked_gather_nxv2i8( %ptrs, %mask) { ; CHECK-LABEL: masked_gather_nxv2i8: @@ -127,6 +127,30 @@ ret %vals.sext } +%i64_x3 = type { i64, i64, i64} +define @masked_gather_non_power_of_two_based_scaling(ptr %base, %offsets, %mask) { +; CHECK-LABEL: masked_gather_non_power_of_two_based_scaling: +; CHECK: // %bb.0: +; CHECK-NEXT: mul z0.d, z0.d, #24 +; CHECK-NEXT: ld1d { z0.d }, p0/z, [x0, z0.d] +; CHECK-NEXT: ret + %ptrs = getelementptr inbounds %i64_x3, ptr %base, %offsets + %vals = call @llvm.masked.gather.nxv2i64( %ptrs, i32 8, %mask, undef) + ret %vals +} + +%i64_x4 = type { i64, i64, i64, i64} +define @masked_gather_non_element_type_based_scaling(ptr %base, %offsets, %mask) { +; CHECK-LABEL: masked_gather_non_element_type_based_scaling: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.d, z0.d, #5 +; CHECK-NEXT: ld1d { z0.d }, p0/z, [x0, z0.d] +; CHECK-NEXT: ret + %ptrs = getelementptr inbounds %i64_x4, ptr %base, %offsets + %vals = call @llvm.masked.gather.nxv2i64( %ptrs, i32 8, %mask, undef) + ret %vals +} + declare @llvm.masked.gather.nxv2i8(, i32, , ) declare @llvm.masked.gather.nxv2i16(, i32, , ) declare @llvm.masked.gather.nxv2i32(, i32, , ) diff --git a/llvm/test/CodeGen/AArch64/sve-masked-scatter.ll b/llvm/test/CodeGen/AArch64/sve-masked-scatter.ll --- a/llvm/test/CodeGen/AArch64/sve-masked-scatter.ll +++ b/llvm/test/CodeGen/AArch64/sve-masked-scatter.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -opaque-pointers < %s | FileCheck %s define void @masked_scatter_nxv2i8( %data, %ptrs, %masks) nounwind { ; CHECK-LABEL: masked_scatter_nxv2i8: @@ -79,17 +79,41 @@ ; CHECK-NEXT: mov z0.d, #0 // =0x0 ; CHECK-NEXT: punpklo p1.h, p0.b ; CHECK-NEXT: punpkhi p0.h, p0.b -; CHECK-NEXT: st1w { z0.d }, p1, [x8, z0.d, lsl #2] -; CHECK-NEXT: st1w { z0.d }, p0, [x8, z0.d, lsl #2] +; CHECK-NEXT: st1w { z0.d }, p1, [z0.d] +; CHECK-NEXT: st1w { z0.d }, p0, [z0.d] ; CHECK-NEXT: ret vector.body: call void @llvm.masked.scatter.nxv4i32.nxv4p0i32( undef, - shufflevector ( insertelement ( poison, i32* undef, i32 0), poison, zeroinitializer), + shufflevector ( insertelement ( poison, i32* null, i32 0), poison, zeroinitializer), i32 4, %pg) ret void } +%i64_x3 = type { i64, i64, i64 } +define void @masked_scatter_non_power_of_two_based_scaling( %data, ptr %base, %offsets, %mask) { +; CHECK-LABEL: masked_scatter_non_power_of_two_based_scaling: +; CHECK: // %bb.0: +; CHECK-NEXT: mul z1.d, z1.d, #24 +; CHECK-NEXT: st1d { z0.d }, p0, [x0, z1.d] +; CHECK-NEXT: ret + %ptrs = getelementptr inbounds %i64_x3, ptr %base, %offsets + call void @llvm.masked.scatter.nxv2f64( %data, %ptrs, i32 8, %mask) + ret void +} + +%i64_x4 = type { i64, i64, i64, i64} +define void @masked_scatter_non_element_type_based_scaling( %data, ptr %base, %offsets, %mask) { +; CHECK-LABEL: masked_scatter_non_element_type_based_scaling: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z1.d, z1.d, #5 +; CHECK-NEXT: st1d { z0.d }, p0, [x0, z1.d] +; CHECK-NEXT: ret + %ptrs = getelementptr inbounds %i64_x4, ptr %base, %offsets + call void @llvm.masked.scatter.nxv2f64( %data, %ptrs, i32 8, %mask) + ret void +} + declare void @llvm.masked.scatter.nxv2f16(, , i32, ) declare void @llvm.masked.scatter.nxv2bf16(, , i32, ) declare void @llvm.masked.scatter.nxv2f32(, , i32, )