diff --git a/llvm/include/llvm/Analysis/MemoryLocation.h b/llvm/include/llvm/Analysis/MemoryLocation.h --- a/llvm/include/llvm/Analysis/MemoryLocation.h +++ b/llvm/include/llvm/Analysis/MemoryLocation.h @@ -19,6 +19,7 @@ #include "llvm/ADT/Optional.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Metadata.h" +#include "llvm/Support/TypeSize.h" namespace llvm { @@ -240,6 +241,12 @@ return getForArgument(Call, ArgIdx, &TLI); } + // Return the exact size if the exact size is known at compiletime, + // otherwise return MemoryLocation::UnknownSize. + static uint64_t getSizeOrUnknown(const TypeSize &T) { + return T.isScalable() ? UnknownSize : T.getFixedSize(); + } + explicit MemoryLocation(const Value *Ptr = nullptr, LocationSize Size = LocationSize::unknown(), const AAMDNodes &AATags = AAMDNodes()) diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp --- a/llvm/lib/Analysis/Loads.cpp +++ b/llvm/lib/Analysis/Loads.cpp @@ -140,7 +140,9 @@ const DataLayout &DL, const Instruction *CtxI, const DominatorTree *DT) { - if (!Ty->isSized()) + // For unsized types or scalable vectors we don't know exactly how many bytes + // are dereferenced, so bail out. + if (!Ty->isSized() || (Ty->isVectorTy() && Ty->getVectorIsScalable())) return false; // When dereferenceability information is provided by a dereferenceable diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -6809,6 +6809,14 @@ const TargetLowering &TLI) { // Handle simple but common cases only. Type *StoreType = SI.getValueOperand()->getType(); + + // The code below assumes shifting a value by , + // whereas scalable vectors would have to be shifted by + // <2log(vscale) + number of bits> in order to store the + // low/high parts. Bailing out for now. + if (StoreType->isVectorTy() && StoreType->getVectorIsScalable()) + return false; + if (!DL.typeSizeEqualsStoreSize(StoreType) || DL.getTypeSizeInBits(StoreType) == 0) return false; diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -15738,7 +15738,14 @@ if (OptLevel == CodeGenOpt::None || !EnableStoreMerging) return false; + // TODO: Extend this function to merge stores of scalable vectors. + // (i.e. two stores can be merged to one + // store since we know is exactly twice as large as + // ). Until then, bail out for scalable vectors. EVT MemVT = St->getMemoryVT(); + if (MemVT.isScalableVector()) + return false; + int64_t ElementSizeBytes = MemVT.getStoreSize(); unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1; @@ -20842,9 +20849,11 @@ : (LSN->getAddressingMode() == ISD::PRE_DEC) ? -1 * C->getSExtValue() : 0; + uint64_t Size = + MemoryLocation::getSizeOrUnknown(LSN->getMemoryVT().getStoreSize()); return {LSN->isVolatile(), LSN->isAtomic(), LSN->getBasePtr(), Offset /*base offset*/, - Optional(LSN->getMemoryVT().getStoreSize()), + Optional(Size), LSN->getMemOperand()}; } if (const auto *LN = cast(N)) @@ -21124,6 +21133,12 @@ if (BasePtr.getBase().isUndef()) return false; + // BaseIndexOffset assumes that offsets are fixed-size, which + // is not valid for scalable vectors where the offsets are + // scaled by `vscale`, so bail out early. + if (St->getMemoryVT().isScalableVector()) + return false; + // Add ST's interval. Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8, Unit); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -6824,9 +6824,10 @@ if (PtrInfo.V.isNull()) PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr, Offset); + uint64_t Size = MemoryLocation::getSizeOrUnknown(MemVT.getStoreSize()); MachineFunction &MF = getMachineFunction(); MachineMemOperand *MMO = MF.getMachineMemOperand( - PtrInfo, MMOFlags, MemVT.getStoreSize(), Alignment, AAInfo, Ranges); + PtrInfo, MMOFlags, Size, Alignment, AAInfo, Ranges); return getLoad(AM, ExtType, VT, dl, Chain, Ptr, Offset, MemVT, MMO); } @@ -6946,8 +6947,10 @@ PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr); MachineFunction &MF = getMachineFunction(); - MachineMemOperand *MMO = MF.getMachineMemOperand( - PtrInfo, MMOFlags, Val.getValueType().getStoreSize(), Alignment, AAInfo); + uint64_t Size = + MemoryLocation::getSizeOrUnknown(Val.getValueType().getStoreSize()); + MachineMemOperand *MMO = + MF.getMachineMemOperand(PtrInfo, MMOFlags, Size, Alignment, AAInfo); return getStore(Chain, dl, Val, Ptr, MMO); } diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -220,6 +220,8 @@ void SelectLoadLane(SDNode *N, unsigned NumVecs, unsigned Opc); void SelectPostLoadLane(SDNode *N, unsigned NumVecs, unsigned Opc); + bool SelectAddrModeFrameIndexSVE(SDValue N, SDValue &Base, SDValue &OffImm); + void SelectStore(SDNode *N, unsigned NumVecs, unsigned Opc); void SelectPostStore(SDNode *N, unsigned NumVecs, unsigned Opc); void SelectStoreLane(SDNode *N, unsigned NumVecs, unsigned Opc); @@ -1374,6 +1376,23 @@ ReplaceNode(N, St); } +bool AArch64DAGToDAGISel::SelectAddrModeFrameIndexSVE(SDValue N, SDValue &Base, + SDValue &OffImm) { + SDLoc dl(N); + const DataLayout &DL = CurDAG->getDataLayout(); + const TargetLowering *TLI = getTargetLowering(); + + // Try to match it for the frame address + if (auto FINode = dyn_cast(N)) { + int FI = FINode->getIndex(); + Base = CurDAG->getTargetFrameIndex(FI, TLI->getPointerTy(DL)); + OffImm = CurDAG->getTargetConstant(0, dl, MVT::i64); + return true; + } + + return false; +} + void AArch64DAGToDAGISel::SelectPostStore(SDNode *N, unsigned NumVecs, unsigned Opc) { SDLoc dl(N); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -9458,6 +9458,10 @@ if (AM.HasBaseReg && AM.BaseOffs && AM.Scale) return false; + // FIXME: Update this method to support scalable addressing modes. + if (Ty->isVectorTy() && Ty->getVectorIsScalable()) + return AM.HasBaseReg && !AM.BaseOffs && !AM.Scale; + // check reg + imm case: // i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12 uint64_t NumBytes = 0; diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -349,6 +349,8 @@ let PrintMethod = "printImmScale<16>"; } +def am_sve_fi : ComplexPattern; + def am_indexed7s8 : ComplexPattern; def am_indexed7s16 : ComplexPattern; def am_indexed7s32 : ComplexPattern; diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -1691,6 +1691,8 @@ case AArch64::STRSui: case AArch64::STRDui: case AArch64::STRQui: + case AArch64::LDR_PXI: + case AArch64::STR_PXI: if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() && MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) { FrameIndex = MI.getOperand(1).getIndex(); @@ -1803,9 +1805,19 @@ case AArch64::STNPSi: case AArch64::LDG: case AArch64::STGPi: + case AArch64::LD1B_IMM: + case AArch64::LD1H_IMM: + case AArch64::LD1W_IMM: + case AArch64::LD1D_IMM: + case AArch64::ST1B_IMM: + case AArch64::ST1H_IMM: + case AArch64::ST1W_IMM: + case AArch64::ST1D_IMM: return 3; case AArch64::ADDG: case AArch64::STGOffset: + case AArch64::LDR_PXI: + case AArch64::STR_PXI: return 2; } } @@ -2056,6 +2068,7 @@ bool AArch64InstrInfo::getMemOpInfo(unsigned Opcode, unsigned &Scale, unsigned &Width, int64_t &MinOffset, int64_t &MaxOffset) { + const unsigned SVEMaxBytesPerVector = AArch64::SVEMaxBitsPerVector / 8; switch (Opcode) { // Not a memory operation or something we want to handle. default: @@ -2220,16 +2233,33 @@ break; case AArch64::LDR_PXI: case AArch64::STR_PXI: - Scale = Width = 2; + Scale = 2; + Width = SVEMaxBytesPerVector / 8; MinOffset = -256; MaxOffset = 255; break; case AArch64::LDR_ZXI: case AArch64::STR_ZXI: - Scale = Width = 16; + Scale = 16; + Width = SVEMaxBytesPerVector; MinOffset = -256; MaxOffset = 255; break; + case AArch64::LD1B_IMM: + case AArch64::LD1H_IMM: + case AArch64::LD1W_IMM: + case AArch64::LD1D_IMM: + case AArch64::ST1B_IMM: + case AArch64::ST1H_IMM: + case AArch64::ST1W_IMM: + case AArch64::ST1D_IMM: + // A full vectors worth of data + // Width = mbytes * elements + Scale = 16; + Width = SVEMaxBytesPerVector; + MinOffset = -8; + MaxOffset = 7; + break; case AArch64::ST2GOffset: case AArch64::STZ2GOffset: Scale = 16; @@ -3433,6 +3463,14 @@ case AArch64::STR_ZXI: case AArch64::LDR_PXI: case AArch64::STR_PXI: + case AArch64::LD1B_IMM: + case AArch64::LD1H_IMM: + case AArch64::LD1W_IMM: + case AArch64::LD1D_IMM: + case AArch64::ST1B_IMM: + case AArch64::ST1H_IMM: + case AArch64::ST1W_IMM: + case AArch64::ST1D_IMM: return true; default: return false; diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1261,6 +1261,52 @@ defm : pred_store; defm : pred_store; + multiclass unpred_store { + def _fi : Pat<(store (Ty ZPR:$val), (am_sve_fi GPR64sp:$base, simm4s1:$offset)), + (RegImmInst ZPR:$val, (PTrue 31), GPR64sp:$base, simm4s1:$offset)>; + } + + defm Pat_ST1B : unpred_store; + defm Pat_ST1H : unpred_store; + defm Pat_ST1W : unpred_store; + defm Pat_ST1D : unpred_store; + defm Pat_ST1H_float16: unpred_store; + defm Pat_ST1W_float : unpred_store; + defm Pat_ST1D_double : unpred_store; + + multiclass unpred_load { + def _fi : Pat<(Ty (load (am_sve_fi GPR64sp:$base, simm4s1:$offset))), + (RegImmInst (PTrue 31), GPR64sp:$base, simm4s1:$offset)>; + } + + defm Pat_LD1B : unpred_load; + defm Pat_LD1H : unpred_load; + defm Pat_LD1W : unpred_load; + defm Pat_LD1D : unpred_load; + defm Pat_LD1H_float16: unpred_load; + defm Pat_LD1W_float : unpred_load; + defm Pat_LD1D_double : unpred_load; + + multiclass unpred_store_predicate { + def _fi : Pat<(store (Ty PPR:$val), (am_sve_fi GPR64sp:$base, simm9:$offset)), + (Store PPR:$val, GPR64sp:$base, simm9:$offset)>; + } + + defm Pat_Store_P16 : unpred_store_predicate; + defm Pat_Store_P8 : unpred_store_predicate; + defm Pat_Store_P4 : unpred_store_predicate; + defm Pat_Store_P2 : unpred_store_predicate; + + multiclass unpred_load_predicate { + def _fi : Pat<(Ty (load (am_sve_fi GPR64sp:$base, simm9:$offset))), + (Load GPR64sp:$base, simm9:$offset)>; + } + + defm Pat_Load_P16 : unpred_load_predicate; + defm Pat_Load_P8 : unpred_load_predicate; + defm Pat_Load_P4 : unpred_load_predicate; + defm Pat_Load_P2 : unpred_load_predicate; + multiclass ldnf1 { // base def : Pat<(Ty (Load (PredTy PPR:$gp), GPR64:$base, MemVT)), diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h --- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h @@ -658,6 +658,7 @@ // in index i*P of a vector. The other elements of the // vector (such as index 1) are undefined. static constexpr unsigned SVEBitsPerBlock = 128; +static constexpr unsigned SVEMaxBitsPerVector = 2048; const unsigned NeonBitsPerVector = 128; } // end namespace AArch64 } // end namespace llvm diff --git a/llvm/test/CodeGen/AArch64/spillfill-sve.ll b/llvm/test/CodeGen/AArch64/spillfill-sve.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/spillfill-sve.ll @@ -0,0 +1,189 @@ +; RUN: llc -mtriple=aarch64-none-linux-gnu -mattr=+sve < %s | FileCheck %s + +; This file checks that unpredicated load/store instructions to locals +; use the right instructions and offsets. + +; Data fills + +define void @fill_nxv16i8() { +; CHECK-LABEL: fill_nxv16i8 +; CHECK-DAG: ld1b { z{{[01]}}.b }, p0/z, [sp] +; CHECK-DAG: ld1b { z{{[01]}}.b }, p0/z, [sp, #1, mul vl] + %local0 = alloca + %local1 = alloca + load volatile , * %local0 + load volatile , * %local1 + ret void +} + +define void @fill_nxv8i16() { +; CHECK-LABEL: fill_nxv8i16 +; CHECK-DAG: ld1h { z{{[01]}}.h }, p0/z, [sp] +; CHECK-DAG: ld1h { z{{[01]}}.h }, p0/z, [sp, #1, mul vl] + %local0 = alloca + %local1 = alloca + load volatile , * %local0 + load volatile , * %local1 + ret void +} + +define void @fill_nxv4i32() { +; CHECK-LABEL: fill_nxv4i32 +; CHECK-DAG: ld1w { z{{[01]}}.s }, p0/z, [sp] +; CHECK-DAG: ld1w { z{{[01]}}.s }, p0/z, [sp, #1, mul vl] + %local0 = alloca + %local1 = alloca + load volatile , * %local0 + load volatile , * %local1 + ret void +} + +define void @fill_nxv2i64() { +; CHECK-LABEL: fill_nxv2i64 +; CHECK-DAG: ld1d { z{{[01]}}.d }, p0/z, [sp] +; CHECK-DAG: ld1d { z{{[01]}}.d }, p0/z, [sp, #1, mul vl] + %local0 = alloca + %local1 = alloca + load volatile , * %local0 + load volatile , * %local1 + ret void +} + + +; Data spills + +define void @spill_nxv16i8( %v0, %v1) { +; CHECK-LABEL: spill_nxv16i8 +; CHECK-DAG: st1b { z{{[01]}}.b }, p0, [sp] +; CHECK-DAG: st1b { z{{[01]}}.b }, p0, [sp, #1, mul vl] + %local0 = alloca + %local1 = alloca + store volatile %v0, * %local0 + store volatile %v1, * %local1 + ret void +} + +define void @spill_nxv8i16( %v0, %v1) { +; CHECK-LABEL: spill_nxv8i16 +; CHECK-DAG: st1h { z{{[01]}}.h }, p0, [sp] +; CHECK-DAG: st1h { z{{[01]}}.h }, p0, [sp, #1, mul vl] + %local0 = alloca + %local1 = alloca + store volatile %v0, * %local0 + store volatile %v1, * %local1 + ret void +} + +define void @spill_nxv4i32( %v0, %v1) { +; CHECK-LABEL: spill_nxv4i32 +; CHECK-DAG: st1w { z{{[01]}}.s }, p0, [sp] +; CHECK-DAG: st1w { z{{[01]}}.s }, p0, [sp, #1, mul vl] + %local0 = alloca + %local1 = alloca + store volatile %v0, * %local0 + store volatile %v1, * %local1 + ret void +} + +define void @spill_nxv2i64( %v0, %v1) { +; CHECK-LABEL: spill_nxv2i64 +; CHECK-DAG: st1d { z{{[01]}}.d }, p0, [sp] +; CHECK-DAG: st1d { z{{[01]}}.d }, p0, [sp, #1, mul vl] + %local0 = alloca + %local1 = alloca + store volatile %v0, * %local0 + store volatile %v1, * %local1 + ret void +} + +; Predicate fills + +define void @fill_nxv16i1() { +; CHECK-LABEL: fill_nxv16i1 +; CHECK-DAG: ldr p{{[01]}}, [sp, #8, mul vl] +; CHECK-DAG: ldr p{{[01]}}, [sp] + %local0 = alloca + %local1 = alloca + load volatile , * %local0 + load volatile , * %local1 + ret void +} + +define void @fill_nxv8i1() { +; CHECK-LABEL: fill_nxv8i1 +; CHECK-DAG: ldr p{{[01]}}, [sp, #4, mul vl] +; CHECK-DAG: ldr p{{[01]}}, [sp] + %local0 = alloca + %local1 = alloca + load volatile , * %local0 + load volatile , * %local1 + ret void +} + +define void @fill_nxv4i1() { +; CHECK-LABEL: fill_nxv4i1 +; CHECK-DAG: ldr p{{[01]}}, [sp, #6, mul vl] +; CHECK-DAG: ldr p{{[01]}}, [sp, #4, mul vl] + %local0 = alloca + %local1 = alloca + load volatile , * %local0 + load volatile , * %local1 + ret void +} + +define void @fill_nxv2i1() { +; CHECK-LABEL: fill_nxv2i1 +; CHECK-DAG: ldr p{{[01]}}, [sp, #7, mul vl] +; CHECK-DAG: ldr p{{[01]}}, [sp, #6, mul vl] + %local0 = alloca + %local1 = alloca + load volatile , * %local0 + load volatile , * %local1 + ret void +} + +; Predicate spills + +define void @spill_nxv16i1( %v0, %v1) { +; CHECK-LABEL: spill_nxv16i1 +; CHECK-DAG: str p{{[01]}}, [sp, #8, mul vl] +; CHECK-DAG: str p{{[01]}}, [sp] + %local0 = alloca + %local1 = alloca + store volatile %v0, * %local0 + store volatile %v1, * %local1 + ret void +} + +define void @spill_nxv8i1( %v0, %v1) { +; CHECK-LABEL: spill_nxv8i1 +; CHECK-DAG: str p{{[01]}}, [sp, #4, mul vl] +; CHECK-DAG: str p{{[01]}}, [sp] + %local0 = alloca + %local1 = alloca + store volatile %v0, * %local0 + store volatile %v1, * %local1 + ret void +} + +define void @spill_nxv4i1( %v0, %v1) { +; CHECK-LABEL: spill_nxv4i1 +; CHECK-DAG: str p{{[01]}}, [sp, #6, mul vl] +; CHECK-DAG: str p{{[01]}}, [sp, #4, mul vl] + %local0 = alloca + %local1 = alloca + store volatile %v0, * %local0 + store volatile %v1, * %local1 + ret void +} + +define void @spill_nxv2i1( %v0, %v1) { +; CHECK-LABEL: spill_nxv2i1 +; CHECK-DAG: str p{{[01]}}, [sp, #7, mul vl] +; CHECK-DAG: str p{{[01]}}, [sp, #6, mul vl] + %local0 = alloca + %local1 = alloca + store volatile %v0, * %local0 + store volatile %v1, * %local1 + ret void +}