diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -2625,4 +2625,10 @@ def int_aarch64_sme_st1w_vert : SME_Load_Store_S_Intrinsic; def int_aarch64_sme_st1d_vert : SME_Load_Store_D_Intrinsic; def int_aarch64_sme_st1q_vert : SME_Load_Store_Q_Intrinsic; + + // Spill + fill + def int_aarch64_sme_ldr : DefaultAttrsIntrinsic< + [], [llvm_i32_ty, llvm_ptr_ty]>; + def int_aarch64_sme_str : DefaultAttrsIntrinsic< + [], [llvm_i32_ty, llvm_ptr_ty]>; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -5114,6 +5114,10 @@ const unsigned IntNo = cast(Root->getOperand(1))->getZExtValue(); + if (IntNo == Intrinsic::aarch64_sme_ldr || + IntNo == Intrinsic::aarch64_sme_str) + return MVT::nxv16i8; + if (IntNo != Intrinsic::aarch64_sve_prf) return EVT(); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -559,6 +559,7 @@ MachineBasicBlock *EmitTileLoad(unsigned Opc, unsigned BaseReg, MachineInstr &MI, MachineBasicBlock *BB) const; + MachineBasicBlock *EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const; MachineBasicBlock * EmitInstrWithCustomInserter(MachineInstr &MI, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2341,6 +2341,22 @@ return BB; } +MachineBasicBlock * +AArch64TargetLowering::EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const { + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + MachineInstrBuilder MIB = + BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::LDR_ZA)); + + MIB.addReg(AArch64::ZA, RegState::Define); + MIB.add(MI.getOperand(0)); // Vector select register + MIB.add(MI.getOperand(1)); // Vector select offset + MIB.add(MI.getOperand(2)); // Base + MIB.add(MI.getOperand(1)); // Offset, same as vector select offset + + MI.eraseFromParent(); // The pseudo is gone now. + return BB; +} + MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( MachineInstr &MI, MachineBasicBlock *BB) const { switch (MI.getOpcode()) { @@ -2391,6 +2407,8 @@ return EmitTileLoad(AArch64::LD1_MXIPXX_V_D, AArch64::ZAD0, MI, BB); case AArch64::LD1_MXIPXX_V_PSEUDO_Q: return EmitTileLoad(AArch64::LD1_MXIPXX_V_Q, AArch64::ZAQ0, MI, BB); + case AArch64::LDR_ZA_PSEUDO: + return EmitFill(MI, BB); } } diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td --- a/llvm/lib/Target/AArch64/SMEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td @@ -22,6 +22,8 @@ def tileslice64 : ComplexPattern", []>; def tileslice128 : ComplexPattern", []>; // nop +def am_sme_indexed_b4 :ComplexPattern", [], [SDNPWantRoot]>; + //===----------------------------------------------------------------------===// // SME Outer Products //===----------------------------------------------------------------------===// @@ -509,7 +511,7 @@ // SME Save and Restore Array //===----------------------------------------------------------------------===// -class sme_spill_fill_inst +class sme_spill_fill_base : I, Sched<[]> { @@ -524,33 +526,61 @@ let Inst{9-5} = Rn; let Inst{4} = 0b0; let Inst{3-0} = imm4; - - let mayLoad = !not(isStore); - let mayStore = isStore; } -multiclass sme_spill_fill { - def NAME : sme_spill_fill_inst; - +let mayStore = 1 in +class sme_spill_inst + : sme_spill_fill_base<0b1, (outs), + (ins MatrixOp:$ZAt, MatrixIndexGPR32Op12_15:$Rv, + sme_elm_idx0_15:$imm4, GPR64sp:$Rn, + imm0_15:$offset), + opcodestr>; +let mayLoad = 1 in +class sme_fill_inst + : sme_spill_fill_base<0b0, (outs MatrixOp:$ZAt), + (ins MatrixIndexGPR32Op12_15:$Rv, + sme_elm_idx0_15:$imm4, GPR64sp:$Rn, + imm0_15:$offset), + opcodestr>; +multiclass sme_spill { + def NAME : sme_spill_inst; def : InstAlias(NAME) MatrixOp:$ZAt, MatrixIndexGPR32Op12_15:$Rv, sme_elm_idx0_15:$imm4, GPR64sp:$Rn, 0), 1>; -} - -multiclass sme_spill { - defm NAME : sme_spill_fill<0b1, (outs), - (ins MatrixOp:$ZAt, MatrixIndexGPR32Op12_15:$Rv, - sme_elm_idx0_15:$imm4, GPR64sp:$Rn, - imm0_15:$offset), - opcodestr>; + // base + def : Pat<(int_aarch64_sme_str MatrixIndexGPR32Op12_15:$idx, GPR64sp:$base), + (!cast(NAME) ZA, $idx, 0, $base, 0)>; + // scalar + immediate (mul vl) + let AddedComplexity = 2 in { + def : Pat<(int_aarch64_sme_str MatrixIndexGPR32Op12_15:$idx, + (am_sme_indexed_b4 GPR64sp:$base, imm0_15:$imm4)), + (!cast(NAME) ZA, $idx, 0, $base, $imm4)>; + } } multiclass sme_fill { - defm NAME : sme_spill_fill<0b0, (outs MatrixOp:$ZAt), - (ins MatrixIndexGPR32Op12_15:$Rv, - sme_elm_idx0_15:$imm4, GPR64sp:$Rn, - imm0_15:$offset), - opcodestr>; + def NAME : sme_fill_inst; + def : InstAlias(NAME) MatrixOp:$ZAt, + MatrixIndexGPR32Op12_15:$Rv, sme_elm_idx0_15:$imm4, GPR64sp:$Rn, 0), 1>; + def NAME # _PSEUDO + : Pseudo<(outs), + (ins MatrixIndexGPR32Op12_15:$idx, imm0_15:$imm4, + GPR64sp:$base), []>, + Sched<[]> { + // Translated to actual instruction in AArch64ISelLowering.cpp + let usesCustomInserter = 1; + let mayLoad = 1; + } + // base + def : Pat<(int_aarch64_sme_ldr MatrixIndexGPR32Op12_15:$idx, GPR64sp:$base), + (!cast(NAME # _PSEUDO) $idx, 0, $base)>; + // scalar + immediate (mul vl) + let AddedComplexity = 2 in { + def : Pat<(int_aarch64_sme_ldr MatrixIndexGPR32Op12_15:$idx, + (am_sme_indexed_b4 GPR64sp:$base, imm0_15:$imm4)), + (!cast(NAME # _PSEUDO) $idx, $imm4, $base)>; + } } //===----------------------------------------------------------------------===// diff --git a/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-loads.ll b/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-loads.ll --- a/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-loads.ll +++ b/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-loads.ll @@ -246,6 +246,55 @@ ret void; } +define void @ldr(i8* %ptr) { +; CHECK-LABEL: ldr: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w12, wzr +; CHECK-NEXT: ldr za[w12, 0], [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sme.ldr(i32 0, i8* %ptr) + ret void; +} + +define void @ldr_with_off_15(i8* %ptr) { +; CHECK-LABEL: ldr_with_off_15: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w12, wzr +; CHECK-NEXT: add x8, x0, #15 +; CHECK-NEXT: ldr za[w12, 0], [x8] +; CHECK-NEXT: ret + %base = getelementptr i8, i8* %ptr, i64 15 + call void @llvm.aarch64.sme.ldr(i32 0, i8* %base) + ret void; +} + +define void @ldr_with_off_15mulvl(i8* %ptr) { +; CHECK-LABEL: ldr_with_off_15mulvl: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w12, wzr +; CHECK-NEXT: ldr za[w12, 15], [x0, #15, mul vl] +; CHECK-NEXT: ret + %vscale = call i64 @llvm.vscale.i64() + %mulvl = mul i64 %vscale, 240 + %base = getelementptr i8, i8* %ptr, i64 %mulvl + call void @llvm.aarch64.sme.ldr(i32 0, i8* %base) + ret void; +} + +define void @ldr_with_off_16mulvl(i8* %ptr) { +; CHECK-LABEL: ldr_with_off_16mulvl: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w12, wzr +; CHECK-NEXT: addvl x8, x0, #16 +; CHECK-NEXT: ldr za[w12, 0], [x8] +; CHECK-NEXT: ret + %vscale = call i64 @llvm.vscale.i64() + %mulvl = mul i64 %vscale, 256 + %base = getelementptr i8, i8* %ptr, i64 %mulvl + call void @llvm.aarch64.sme.ldr(i32 0, i8* %base) + ret void; +} + declare void @llvm.aarch64.sme.ld1b.horiz(, i8*, i64, i32) declare void @llvm.aarch64.sme.ld1h.horiz(, i16*, i64, i32) declare void @llvm.aarch64.sme.ld1w.horiz(, i32*, i64, i32) @@ -256,3 +305,6 @@ declare void @llvm.aarch64.sme.ld1w.vert(, i32*, i64, i32) declare void @llvm.aarch64.sme.ld1d.vert(, i64*, i64, i32) declare void @llvm.aarch64.sme.ld1q.vert(, i128*, i64, i32) + +declare void @llvm.aarch64.sme.ldr(i32, i8*) +declare i64 @llvm.vscale.i64() diff --git a/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-stores.ll b/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-stores.ll --- a/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-stores.ll +++ b/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-stores.ll @@ -246,6 +246,55 @@ ret void; } +define void @str(i8* %ptr) { +; CHECK-LABEL: str: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w12, wzr +; CHECK-NEXT: str za[w12, 0], [x0] +; CHECK-NEXT: ret + call void @llvm.aarch64.sme.str(i32 0, i8* %ptr) + ret void; +} + +define void @str_with_off_15(i8* %ptr) { +; CHECK-LABEL: str_with_off_15: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w12, wzr +; CHECK-NEXT: add x8, x0, #15 +; CHECK-NEXT: str za[w12, 0], [x8] +; CHECK-NEXT: ret + %base = getelementptr i8, i8* %ptr, i64 15 + call void @llvm.aarch64.sme.str(i32 0, i8* %base) + ret void; +} + +define void @str_with_off_15mulvl(i8* %ptr) { +; CHECK-LABEL: str_with_off_15mulvl: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w12, wzr +; CHECK-NEXT: str za[w12, 0], [x0, #15, mul vl] +; CHECK-NEXT: ret + %vscale = call i64 @llvm.vscale.i64() + %mulvl = mul i64 %vscale, 240 + %base = getelementptr i8, i8* %ptr, i64 %mulvl + call void @llvm.aarch64.sme.str(i32 0, i8* %base) + ret void; +} + +define void @str_with_off_16mulvl(i8* %ptr) { +; CHECK-LABEL: str_with_off_16mulvl: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w12, wzr +; CHECK-NEXT: addvl x8, x0, #16 +; CHECK-NEXT: str za[w12, 0], [x8] +; CHECK-NEXT: ret + %vscale = call i64 @llvm.vscale.i64() + %mulvl = mul i64 %vscale, 256 + %base = getelementptr i8, i8* %ptr, i64 %mulvl + call void @llvm.aarch64.sme.str(i32 0, i8* %base) + ret void; +} + declare void @llvm.aarch64.sme.st1b.horiz(, i8*, i64, i32) declare void @llvm.aarch64.sme.st1h.horiz(, i16*, i64, i32) declare void @llvm.aarch64.sme.st1w.horiz(, i32*, i64, i32) @@ -256,3 +305,6 @@ declare void @llvm.aarch64.sme.st1w.vert(, i32*, i64, i32) declare void @llvm.aarch64.sme.st1d.vert(, i64*, i64, i32) declare void @llvm.aarch64.sme.st1q.vert(, i128*, i64, i32) + +declare void @llvm.aarch64.sme.str(i32, i8*) +declare i64 @llvm.vscale.i64()