diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7164,6 +7164,10 @@ SDValue AArch64TargetLowering::changeStreamingMode( SelectionDAG &DAG, SDLoc DL, bool Enable, SDValue Chain, SDValue InGlue, SDValue PStateSM, bool Entry) const { + MachineFunction &MF = DAG.getMachineFunction(); + AArch64FunctionInfo *FuncInfo = MF.getInfo(); + FuncInfo->setHasStreamingModeChanges(true); + const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask()); SDValue MSROp = diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h @@ -355,6 +355,8 @@ static void decomposeStackOffsetForDwarfOffsets(const StackOffset &Offset, int64_t &ByteSized, int64_t &VGSized); + + bool isReallyTriviallyReMaterializable(const MachineInstr &MI) const override; #define GET_INSTRINFO_HELPER_DECLS #include "AArch64GenInstrInfo.inc" diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -8524,6 +8524,54 @@ return AArch64::BLR; } +bool AArch64InstrInfo::isReallyTriviallyReMaterializable( + const MachineInstr &MI) const { + const MachineFunction &MF = *MI.getMF(); + const AArch64FunctionInfo &AFI = *MF.getInfo(); + + // If the function contains changes to streaming mode, then there + // is a danger that rematerialised instructions end up between + // instruction sequences (e.g. call sequences, or prolog/epilogue) + // where the streaming-SVE mode is temporarily changed. + if (AFI.hasStreamingModeChanges()) { + // Avoid rematerializing rematerializable instructions that use/define + // scalable values, such as 'pfalse' or 'ptrue', which result in different + // results when the runtime vector length is different. + const MachineRegisterInfo &MRI = MF.getRegInfo(); + if (any_of(MI.operands(), [&MRI](const MachineOperand &MO) { + if (!MO.isReg()) + return false; + + if (MO.getReg().isVirtual()) { + const TargetRegisterClass *RC = MRI.getRegClass(MO.getReg()); + return AArch64::ZPRRegClass.hasSubClassEq(RC) || + AArch64::PPRRegClass.hasSubClassEq(RC); + } + return AArch64::ZPRRegClass.contains(MO.getReg()) || + AArch64::PPRRegClass.contains(MO.getReg()); + })) + return false; + + // Avoid rematerializing instructions that return a value that is + // different depending on vector length, even when it is not returned + // in a scalable vector/predicate register. + switch (MI.getOpcode()) { + default: + break; + case AArch64::RDVLI_XI: + case AArch64::ADDVL_XXI: + case AArch64::ADDPL_XXI: + case AArch64::CNTB_XPiI: + case AArch64::CNTH_XPiI: + case AArch64::CNTW_XPiI: + case AArch64::CNTD_XPiI: + return false; + } + } + + return TargetInstrInfo::isReallyTriviallyReMaterializable(MI); +} + #define GET_INSTRINFO_HELPERS #define GET_INSTRMAP_INFO #include "AArch64GenInstrInfo.inc" diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h --- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h +++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h @@ -185,6 +185,8 @@ /// The frame-index for the TPIDR2 object used for lazy saves. Register LazySaveTPIDR2Obj = 0; + /// Whether this function changes streaming mode within the function. + bool HasStreamingModeChanges = false; /// True if the function need unwind information. mutable std::optional NeedsDwarfUnwindInfo; @@ -447,6 +449,11 @@ bool needsDwarfUnwindInfo(const MachineFunction &MF) const; bool needsAsyncDwarfUnwindInfo(const MachineFunction &MF) const; + bool hasStreamingModeChanges() const { return HasStreamingModeChanges; } + void setHasStreamingModeChanges(bool HasChanges) { + HasStreamingModeChanges = HasChanges; + } + private: // Hold the lists of LOHs. MILOHContainer LOHContainerSet; diff --git a/llvm/test/CodeGen/AArch64/sme-disable-rematerialize-with-streaming-mode-changes.ll b/llvm/test/CodeGen/AArch64/sme-disable-rematerialize-with-streaming-mode-changes.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sme-disable-rematerialize-with-streaming-mode-changes.ll @@ -0,0 +1,71 @@ +; RUN: llc < %s | FileCheck %s + +target triple = "aarch64" + + +define void @dont_rematerialize_cntd(i32 %N) #0 { +; CHECK-LABEL: dont_rematerialize_cntd: +; CHECK: cntd +; CHECK: smstop sm +; CHECK-NOT: cntd +; CHECK: bl foo +; CHECK: smstart sm +entry: + %cmp2 = icmp sgt i32 %N, 0 + br i1 %cmp2, label %for.body, label %for.cond.cleanup + +for.body: ; preds = %entry, %for.body + %index.03 = phi i32 [ %inc, %for.body ], [ 0, %entry ] + call void asm sideeffect "", "~{x19},~{x20},~{x21},~{x22},~{x23},~{x24},~{x25},~{x26},~{x27}"() nounwind + %.tr = call i32 @llvm.vscale.i32() + %conv = shl nuw nsw i32 %.tr, 4 + call void @foo(i32 %conv) + %inc = add nuw nsw i32 %index.03, 1 + %exitcond.not = icmp eq i32 %inc, %N + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body + +for.cond.cleanup: ; preds = %for.body, %entry + ret void +} + +; This test doesn't strictly make sense, because it passes a scalable predicate +; to a function, which makes little sense if the VL is not the same in/out of +; streaming-SVE mode. If the VL is known to be the same, then we could just as +; well rematerialize the `ptrue` inside the call sequence. However, the purpose +; of this test is more to ensure that the logic works, which may also trigger +; when the value is not being passed as argument (e.g. when it is hoisted from +; a loop and placed inside the call sequence). +; +; FIXME: This test also exposes another bug, where the 'mul vl' addressing mode +; is used before/after the smstop. This will be fixed in a future patch. +define void @dont_rematerialize_ptrue(i32 %N) #0 { +; CHECK-LABEL: dont_rematerialize_ptrue: +; CHECK: ptrue [[PTRUE:p[0-9]+]].b +; CHECK: str [[PTRUE]], [[[SPILL_ADDR:.*]]] +; CHECK: smstop sm +; CHECK: ldr p0, [[[SPILL_ADDR]]] +; CHECK-NOT: ptrue +; CHECK: bl bar +; CHECK: smstart sm +entry: + %cmp2 = icmp sgt i32 %N, 0 + br i1 %cmp2, label %for.body, label %for.cond.cleanup + +for.body: ; preds = %entry, %for.body + %index.03 = phi i32 [ %inc, %for.body ], [ 0, %entry ] + call void asm sideeffect "", "~{x19},~{x20},~{x21},~{x22},~{x23},~{x24},~{x25},~{x26},~{x27}"() nounwind + %ptrue.ins = insertelement poison, i1 1, i32 0 + %ptrue = shufflevector %ptrue.ins, poison, zeroinitializer + call void @bar( %ptrue) + %inc = add nuw nsw i32 %index.03, 1 + %exitcond.not = icmp eq i32 %inc, %N + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body + +for.cond.cleanup: ; preds = %for.body, %entry + ret void +} +declare void @foo(i32) +declare void @bar() +declare i32 @llvm.vscale.i32() + +attributes #0 = { "aarch64_pstate_sm_enabled" "frame-pointer"="non-leaf" "target-features"="+sme,+sve" }