diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -1257,6 +1257,14 @@ def int_aarch64_sve_stnt1 : AdvSIMD_1Vec_PredStore_Intrinsic; // +// Prefetch +// + +def int_aarch64_sve_prf : Intrinsic<[], [llvm_anyvector_ty, + llvm_ptr_ty, + llvm_i32_ty], [IntrArgMemOnly]>; + +// // Scalar to vector operations // diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -4411,6 +4411,51 @@ return new AArch64DAGToDAGISel(TM, OptLevel); } +namespace { +/// When \p PredVT is a scalable vector predicate in the form +/// MVT::nxxi1, it builds the correspondent scalable vector of +/// integers MVT::nxxi s.t. M x bits = 128. If the input +/// PredVT is not in the form MVT::nxxi1, it returns an invalid +/// EVT. +EVT getPackedVectorTypeFromPredicateType(LLVMContext &Ctx, EVT PredVT) { + if (!PredVT.isScalableVector()) + return EVT(); + + if (PredVT.getVectorElementType() != MVT::i1) + return EVT(); + + const unsigned NumElts = PredVT.getVectorNumElements(); + + if (NumElts != 2 && NumElts != 4 && NumElts != 8 && NumElts != 16) + return EVT(); + + EVT ScalarVT = EVT::getIntegerVT(Ctx, 128 / NumElts); + EVT MemVT = EVT::getVectorVT(Ctx, ScalarVT, NumElts, /*IsScalable=*/true); + return MemVT; +} + +/// Return the EVT of the data associated to a memory operation in \p +/// Root. If such EVT cannot be retrived, it returns an invalid EVT. +EVT getMemVTFromNode(LLVMContext &Ctx, SDNode *Root) { + if (isa(Root)) + return cast(Root)->getMemoryVT(); + + const unsigned Opcode = Root->getOpcode(); + if (Opcode != ISD::INTRINSIC_VOID) + return EVT(); + + const unsigned IntNo = + cast(Root->getOperand(1))->getZExtValue(); + if (IntNo != Intrinsic::aarch64_sve_prf) + return EVT(); + + // We are using an SVE prefetch intrinsic. Type must be inferred + // from the width of the predicate. + return getPackedVectorTypeFromPredicateType( + Ctx, Root->getOperand(2)->getValueType(0)); +} +} // end anonymous namespace + /// SelectAddrModeIndexedSVE - Attempt selection of the addressing mode: /// Base + OffImm * sizeof(MemVT) for Min >= OffImm <= Max /// where Root is the memory access using N for its address. @@ -4418,9 +4463,10 @@ bool AArch64DAGToDAGISel::SelectAddrModeIndexedSVE(SDNode *Root, SDValue N, SDValue &Base, SDValue &OffImm) { - assert(isa(Root) && "Invalid node."); + const EVT MemVT = getMemVTFromNode(*(CurDAG->getContext()), Root); - EVT MemVT = cast(Root)->getMemoryVT(); + if (MemVT == EVT()) + return false; if (N.getOpcode() != ISD::ADD) return false; diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -780,6 +780,29 @@ def PRFS_PRR : sve_mem_prfm_ss<0b101, "prfw", GPR64NoXZRshifted32>; def PRFD_PRR : sve_mem_prfm_ss<0b111, "prfd", GPR64NoXZRshifted64>; +multiclass sve_prefetch { + // reg + imm + let AddedComplexity = 2 in { + def _reg_imm : Pat<(prefetch (PredTy PPR_3b:$gp), (am_sve_indexed_s6 GPR64sp:$base, simm6s1:$offset), (i32 sve_prfop:$prfop)), + (RegImmInst sve_prfop:$prfop, PPR_3b:$gp, GPR64:$base, simm6s1:$offset)>; + } + + // reg + reg + let AddedComplexity = 1 in { + def _reg_reg : Pat<(prefetch (PredTy PPR_3b:$gp), (AddrCP GPR64sp:$base, GPR64:$index), (i32 sve_prfop:$prfop)), + (RegRegInst sve_prfop:$prfop, PPR_3b:$gp, GPR64:$base, GPR64:$index)>; + } + + // default fallback + def _default : Pat<(prefetch (PredTy PPR_3b:$gp), GPR64:$base, (i32 sve_prfop:$prfop)), + (RegImmInst sve_prfop:$prfop, PPR_3b:$gp, GPR64:$base, (i64 0))>; + } + + defm : sve_prefetch; + defm : sve_prefetch; + defm : sve_prefetch; + defm : sve_prefetch; + // Gather prefetch using scaled 32-bit offsets, e.g. // prfh pldl1keep, p0, [x0, z0.s, uxtw #1] defm PRFB_S : sve_mem_32b_prfm_sv_scaled<0b00, "prfb", ZPR32ExtSXTW8Only, ZPR32ExtUXTW8Only>; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -6592,7 +6592,6 @@ (!cast(NAME) sve_prfop:$prfop, PPR3bAny:$Pg, ZPR64:$Zn, 0), 1>; } - //===----------------------------------------------------------------------===// // SVE Compute Vector Address Group //===----------------------------------------------------------------------===// @@ -7057,6 +7056,7 @@ /// Addressing modes def am_sve_indexed_s4 :ComplexPattern", [], [SDNPWantRoot]>; +def am_sve_indexed_s6 :ComplexPattern", [], [SDNPWantRoot]>; def am_sve_regreg_lsl0 : ComplexPattern", []>; def am_sve_regreg_lsl1 : ComplexPattern", []>; diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-contiguous-prefetches.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-contiguous-prefetches.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-contiguous-prefetches.ll @@ -0,0 +1,252 @@ +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s + +; +; Testing prfop encodings +; +define void @test_svprf_pldl1strm( %pg, i8* %base) { +; CHECK-LABEL: test_svprf_pldl1strm +; CHECK: prfb pldl1strm, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %base, i32 1) + ret void +} + +define void @test_svprf_pldl2keep( %pg, i8* %base) { +; CHECK-LABEL: test_svprf_pldl2keep +; CHECK: prfb pldl2keep, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %base, i32 2) + ret void +} + +define void @test_svprf_pldl2strm( %pg, i8* %base) { +; CHECK-LABEL: test_svprf_pldl2strm +; CHECK: prfb pldl2strm, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %base, i32 3) + ret void +} + +define void @test_svprf_pldl3keep( %pg, i8* %base) { +; CHECK-LABEL: test_svprf_pldl3keep +; CHECK: prfb pldl3keep, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %base, i32 4) + ret void +} + +define void @test_svprf_pldl3strm( %pg, i8* %base) { +; CHECK-LABEL: test_svprf_pldl3strm +; CHECK: prfb pldl3strm, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %base, i32 5) + ret void +} + +define void @test_svprf_pstl1keep( %pg, i8* %base) { +; CHECK-LABEL: test_svprf_pstl1keep +; CHECK: prfb pstl1keep, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %base, i32 8) + ret void +} + +define void @test_svprf_pstl1strm( %pg, i8* %base) { +; CHECK-LABEL: test_svprf_pstl1strm +; CHECK: prfb pstl1strm, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %base, i32 9) + ret void +} + +define void @test_svprf_pstl2keep( %pg, i8* %base) { +; CHECK-LABEL: test_svprf_pstl2keep +; CHECK: prfb pstl2keep, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %base, i32 10) + ret void +} + +define void @test_svprf_pstl2strm( %pg, i8* %base) { +; CHECK-LABEL: test_svprf_pstl2strm +; CHECK: prfb pstl2strm, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %base, i32 11) + ret void +} + +define void @test_svprf_pstl3keep( %pg, i8* %base) { +; CHECK-LABEL: test_svprf_pstl3keep +; CHECK: prfb pstl3keep, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %base, i32 12) + ret void +} + +define void @test_svprf_pstl3strm( %pg, i8* %base) { +; CHECK-LABEL: test_svprf_pstl3strm +; CHECK: prfb pstl3strm, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %base, i32 13) + ret void +} + +; +; Testing imm limits of SI form +; + +define void @test_svprf_vnum_under( %pg, * %base) { +; CHECK-LABEL: test_svprf_vnum_under +; CHECK-NOT: prfb pstl3strm, p0, [x0, #-33, mul vl] +entry: + %gep = getelementptr inbounds , * %base, i64 -33, i64 0 + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %gep, i32 13) + ret void +} + +define void @test_svprf_vnum_min( %pg, * %base) { +; CHECK-LABEL: test_svprf_vnum_min +; CHECK: prfb pstl3strm, p0, [x0, #-32, mul vl] +entry: + %gep = getelementptr inbounds , * %base, i64 -32, i64 0 + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %gep, i32 13) + ret void +} + +define void @test_svprf_vnum_over( %pg, * %base) { +; CHECK-LABEL: test_svprf_vnum_over +; CHECK-NOT: prfb pstl3strm, p0, [x0, #32, mul vl] +entry: + %gep = getelementptr inbounds , * %base, i64 32, i64 0 + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %gep, i32 13) + ret void +} + +define void @test_svprf_vnum_max( %pg, * %base) { +; CHECK-LABEL: test_svprf_vnum_max +; CHECK: prfb pstl3strm, p0, [x0, #31, mul vl] +entry: + %gep = getelementptr inbounds , * %base, i64 31, i64 0 + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %gep, i32 13) + ret void +} + +; +; scalar contiguous +; + +define void @test_svprfb( %pg, i8* %base) { +; CHECK-LABEL: test_svprfb +; CHECK: prfb pldl1keep, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %base, i32 0) + ret void +} + +define void @test_svprfh( %pg, i8* %base) { +; CHECK-LABEL: test_svprfh +; CHECK: prfh pldl1keep, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv8i1( %pg, i8* %base, i32 0) + ret void +} + +define void @test_svprfw( %pg, i8* %base) { +; CHECK-LABEL: test_svprfw +; CHECK: prfw pldl1keep, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv4i1( %pg, i8* %base, i32 0) + ret void +} + +define void @test_svprfd( %pg, i8* %base) { +; CHECK-LABEL: test_svprfd +; CHECK: prfd pldl1keep, p0, [x0] +entry: + tail call void @llvm.aarch64.sve.prf.nxv2i1( %pg, i8* %base, i32 0) + ret void +} + +; +; scalar + imm contiguous +; +; imm form of prfb is tested above + +define void @test_svprfh_vnum( %pg, * %base) { +; CHECK-LABEL: test_svprfh_vnum +; CHECK: prfh pstl3strm, p0, [x0, #31, mul vl] +entry: + %gep = getelementptr , * %base, i64 31 + %addr = bitcast * %gep to i8* + tail call void @llvm.aarch64.sve.prf.nxv8i1( %pg, i8* %addr, i32 13) + ret void +} + +define void @test_svprfw_vnum( %pg, * %base) { +; CHECK-LABEL: test_svprfw_vnum +; CHECK: prfw pstl3strm, p0, [x0, #31, mul vl] +entry: + %gep = getelementptr , * %base, i64 31 + %addr = bitcast * %gep to i8* + tail call void @llvm.aarch64.sve.prf.nxv4i1( %pg, i8* %addr, i32 13) + ret void +} + +define void @test_svprfd_vnum( %pg, * %base) { +; CHECK-LABEL: test_svprfd_vnum +; CHECK: prfd pstl3strm, p0, [x0, #31, mul vl] +entry: + %gep = getelementptr , * %base, i64 31 + %addr = bitcast * %gep to i8* + tail call void @llvm.aarch64.sve.prf.nxv2i1( %pg, i8* %addr, i32 13) + ret void +} + +; +; scalar + scaled scalar contiguous +; + +define void @test_svprfb_ss( %pg, i8* %base, i64 %offset) { +; CHECK-LABEL: test_svprfb_ss +; CHECK: prfb pstl3strm, p0, [x0, x1] +entry: + %addr = getelementptr i8, i8* %base, i64 %offset + tail call void @llvm.aarch64.sve.prf.nxv16i1( %pg, i8* %addr, i32 13) + ret void +} + +define void @test_svprfh_ss( %pg, i16* %base, i64 %offset) { +; CHECK-LABEL: test_svprfh_ss +; CHECK: prfh pstl3strm, p0, [x0, x1, lsl #1] +entry: + %gep = getelementptr i16, i16* %base, i64 %offset + %addr = bitcast i16* %gep to i8* + tail call void @llvm.aarch64.sve.prf.nxv8i1( %pg, i8* %addr, i32 13) + ret void +} + +define void @test_svprfw_ss( %pg, i32* %base, i64 %offset) { +; CHECK-LABEL: test_svprfw_ss +; CHECK: prfw pstl3strm, p0, [x0, x1, lsl #2] +entry: + %gep = getelementptr i32, i32* %base, i64 %offset + %addr = bitcast i32* %gep to i8* + tail call void @llvm.aarch64.sve.prf.nxv4i1( %pg, i8* %addr, i32 13) + ret void +} + +define void @test_svprfd_ss( %pg, i64* %base, i64 %offset) { +; CHECK-LABEL: test_svprfd_ss +; CHECK: prfd pstl3strm, p0, [x0, x1, lsl #3] +entry: + %gep = getelementptr i64, i64* %base, i64 %offset + %addr = bitcast i64* %gep to i8* + tail call void @llvm.aarch64.sve.prf.nxv2i1( %pg, i8* %addr, i32 13) + ret void +} + + +declare void @llvm.aarch64.sve.prf.nxv16i1(, i8*, i32) +declare void @llvm.aarch64.sve.prf.nxv8i1(, i8*, i32) +declare void @llvm.aarch64.sve.prf.nxv4i1(, i8*, i32) +declare void @llvm.aarch64.sve.prf.nxv2i1(, i8*, i32)