diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp --- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp +++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp @@ -1377,6 +1377,17 @@ NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated. return true; } + case AArch64::OBSCURE_COPY: { + if (MI.getOperand(0).getReg() != MI.getOperand(1).getReg()) { + BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(AArch64::ORRXrs)) + .add(MI.getOperand(0)) + .addReg(AArch64::XZR) + .add(MI.getOperand(1)) + .addImm(0); + } + MI.eraseFromParent(); + return true; + } } return false; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -58,6 +58,13 @@ CALL_BTI, // Function call followed by a BTI instruction. + // Essentially like a normal COPY that works on GPRs, but cannot be + // rematerialised by passes like the simple register coalescer. It's + // required for SME when lowering calls because we cannot allow frame + // index calculations using addvl to slip in between the smstart/smstop + // and the bl instruction. The scalable vector length may change across + // the smstart/smstop boundary. + OBSCURE_COPY, SMSTART, SMSTOP, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2065,6 +2065,7 @@ switch ((AArch64ISD::NodeType)Opcode) { case AArch64ISD::FIRST_NUMBER: break; + MAKE_CASE(AArch64ISD::OBSCURE_COPY) MAKE_CASE(AArch64ISD::SMSTART) MAKE_CASE(AArch64ISD::SMSTOP) MAKE_CASE(AArch64ISD::CALL) @@ -7036,6 +7037,11 @@ return ArgReg.Reg == VA.getLocReg(); }); } else { + // Add an extra level of indirection for streaming mode changes by + // using a pseudo copy node that cannot be rematerialised between a + // smstart/smstop and the call by the simple register coalescer. + if (RequiresSMChange && isa(Arg)) + Arg = DAG.getNode(AArch64ISD::OBSCURE_COPY, DL, MVT::i64, Arg); RegsToPass.emplace_back(VA.getLocReg(), Arg); RegsUsed.insert(VA.getLocReg()); const TargetOptions &Options = DAG.getTarget().Options; diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -19,6 +19,8 @@ [SDNPHasChain, SDNPSideEffect, SDNPVariadic, SDNPOptInGlue, SDNPOutGlue]>; +def AArch64ObscureCopy : SDNode<"AArch64ISD::OBSCURE_COPY", SDTypeProfile<1, 1, []>, []>; + //===----------------------------------------------------------------------===// // Add vector elements horizontally or vertically to ZA tile. //===----------------------------------------------------------------------===// @@ -202,6 +204,10 @@ def : Pat<(i64 (int_aarch64_sme_get_tpidr2)), (MRS 0xde85)>; +def OBSCURE_COPY : Pseudo<(outs GPR64:$dst), (ins GPR64:$idx), []>, Sched<[]> { } +def : Pat<(i64 (AArch64ObscureCopy (i64 GPR64:$idx))), + (OBSCURE_COPY GPR64:$idx)>; + //===----------------------------------------------------------------------===// // SVE2 instructions //===----------------------------------------------------------------------===// diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll --- a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll +++ b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll @@ -360,4 +360,50 @@ ret void; } +define i8 @call_to_non_streaming_pass_sve_objects(ptr nocapture noundef readnone %ptr) #1 { +; CHECK-LABEL: call_to_non_streaming_pass_sve_objects: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill +; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill +; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill +; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill +; CHECK-NEXT: stp x29, x30, [sp, #64] // 16-byte Folded Spill +; CHECK-NEXT: addvl sp, sp, #-3 +; CHECK-NEXT: rdsvl x8, #1 +; CHECK-NEXT: addvl x9, sp, #2 +; CHECK-NEXT: addvl x10, sp, #1 +; CHECK-NEXT: mov x11, sp +; CHECK-NEXT: smstop sm +; CHECK-NEXT: mov x0, x9 +; CHECK-NEXT: mov x1, x10 +; CHECK-NEXT: mov x2, x11 +; CHECK-NEXT: mov x3, x8 +; CHECK-NEXT: bl foo +; CHECK-NEXT: smstart sm +; CHECK-NEXT: ptrue p0.b +; CHECK-NEXT: ld1b { z0.b }, p0/z, [sp, #2, mul vl] +; CHECK-NEXT: fmov w0, s0 +; CHECK-NEXT: addvl sp, sp, #3 +; CHECK-NEXT: ldp x29, x30, [sp, #64] // 16-byte Folded Reload +; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload +; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload +; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload +; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload +; CHECK-NEXT: ret +entry: + %Data1 = alloca , align 16 + %Data2 = alloca , align 16 + %Data3 = alloca , align 16 + %0 = tail call i64 @llvm.aarch64.sme.cntsb() + call void @foo(ptr noundef nonnull %Data1, ptr noundef nonnull %Data2, ptr noundef nonnull %Data3, i64 noundef %0) + %1 = load , ptr %Data1, align 16 + %vecext = extractelement %1, i64 0 + ret i8 %vecext +} + +declare i64 @llvm.aarch64.sme.cntsb() + +declare void @foo(ptr noundef, ptr noundef, ptr noundef, i64 noundef) + attributes #0 = { nounwind "target-features"="+sve" } +attributes #1 = { nounwind vscale_range(1,16) "aarch64_pstate_sm_enabled" }