diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -601,7 +601,8 @@ MachineBasicBlock *EmitTileLoad(unsigned Opc, unsigned BaseReg, MachineInstr &MI, MachineBasicBlock *BB) const; - MachineBasicBlock *EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const; + MachineBasicBlock *EmitZASpillFill(MachineInstr &MI, MachineBasicBlock *BB, + bool IsSpill) const; MachineBasicBlock *EmitZAInstr(unsigned Opc, unsigned BaseReg, MachineInstr &MI, MachineBasicBlock *BB, bool HasTile) const; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2595,13 +2595,18 @@ return BB; } -MachineBasicBlock * -AArch64TargetLowering::EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const { +MachineBasicBlock *AArch64TargetLowering::EmitZASpillFill(MachineInstr &MI, + MachineBasicBlock *BB, + bool IsSpill) const { const TargetInstrInfo *TII = Subtarget->getInstrInfo(); MachineInstrBuilder MIB = - BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::LDR_ZA)); + BuildMI(*BB, MI, MI.getDebugLoc(), + TII->get(IsSpill ? AArch64::STR_ZA : AArch64::LDR_ZA)); - MIB.addReg(AArch64::ZA, RegState::Define); + if (IsSpill) + MIB.addReg(AArch64::ZA); + else + MIB.addReg(AArch64::ZA, RegState::Define); MIB.add(MI.getOperand(0)); // Vector select register MIB.add(MI.getOperand(1)); // Vector select offset MIB.add(MI.getOperand(2)); // Base @@ -2722,7 +2727,9 @@ case AArch64::LD1_MXIPXX_V_PSEUDO_Q: return EmitTileLoad(AArch64::LD1_MXIPXX_V_Q, AArch64::ZAQ0, MI, BB); case AArch64::LDR_ZA_PSEUDO: - return EmitFill(MI, BB); + return EmitZASpillFill(MI, BB, /*IsSpill=*/false); + case AArch64::STR_ZA_PSEUDO: + return EmitZASpillFill(MI, BB, /*IsSpill=*/true); case AArch64::ZERO_M_PSEUDO: return EmitZero(MI, BB); } diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td --- a/llvm/lib/Target/AArch64/SMEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td @@ -757,6 +757,15 @@ def : InstAlias(NAME) MatrixOp:$ZAt, MatrixIndexGPR32Op12_15:$Rv, sme_elm_idx0_15:$imm4, GPR64sp:$Rn, 0), 1>; + def NAME # _PSEUDO + : Pseudo<(outs), + (ins MatrixIndexGPR32Op12_15:$idx, imm0_15:$imm4, + GPR64sp:$base), []>, + Sched<[]> { + // Translated to actual instruction in AArch64ISelLowering.cpp + let usesCustomInserter = 1; + let mayStore = 1; + } // base def : Pat<(int_aarch64_sme_str MatrixIndexGPR32Op12_15:$idx, GPR64sp:$base), (!cast(NAME) ZA, $idx, 0, $base, 0)>; @@ -764,7 +773,7 @@ let AddedComplexity = 2 in { def : Pat<(int_aarch64_sme_str MatrixIndexGPR32Op12_15:$idx, (am_sme_indexed_b4 GPR64sp:$base, imm0_15:$imm4)), - (!cast(NAME) ZA, $idx, 0, $base, $imm4)>; + (!cast(NAME # _PSEUDO) $idx, $imm4, $base)>; } } diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll --- a/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll +++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll @@ -272,7 +272,7 @@ ; CHECK-LABEL: str_with_off_15mulvl: ; CHECK: // %bb.0: ; CHECK-NEXT: mov w12, wzr -; CHECK-NEXT: str za[w12, 0], [x0, #15, mul vl] +; CHECK-NEXT: str za[w12, 15], [x0, #15, mul vl] ; CHECK-NEXT: ret %vscale = call i64 @llvm.vscale.i64() %mulvl = mul i64 %vscale, 240