diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h --- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h +++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h @@ -1283,9 +1283,11 @@ /// \p BaseOps1 and \p BaseOps2 are memory operands of two memory operations. /// \p NumLoads is the number of loads that will be in the cluster if this /// hook returns true. + /// \p NumBytes is the number of bytes that will be loaded from all the + /// clustered loads if this hook returns true. virtual bool shouldClusterMemOps(ArrayRef BaseOps1, ArrayRef BaseOps2, - unsigned NumLoads) const { + unsigned NumLoads, unsigned NumBytes) const { llvm_unreachable("target did not implement shouldClusterMemOps()"); } diff --git a/llvm/lib/CodeGen/MachineScheduler.cpp b/llvm/lib/CodeGen/MachineScheduler.cpp --- a/llvm/lib/CodeGen/MachineScheduler.cpp +++ b/llvm/lib/CodeGen/MachineScheduler.cpp @@ -1473,10 +1473,12 @@ SUnit *SU; SmallVector BaseOps; int64_t Offset; + unsigned Width; MemOpInfo(SUnit *SU, ArrayRef BaseOps, - int64_t Offset) - : SU(SU), BaseOps(BaseOps.begin(), BaseOps.end()), Offset(Offset) {} + int64_t Offset, unsigned Width) + : SU(SU), BaseOps(BaseOps.begin(), BaseOps.end()), Offset(Offset), + Width(Width) {} static bool Compare(const MachineOperand *const &A, const MachineOperand *const &B) { @@ -1565,12 +1567,16 @@ ArrayRef MemOps, ScheduleDAGInstrs *DAG) { SmallVector MemOpRecords; for (SUnit *SU : MemOps) { + const MachineInstr &MI = *SU->getInstr(); SmallVector BaseOps; int64_t Offset; bool OffsetIsScalable; - if (TII->getMemOperandsWithOffset(*SU->getInstr(), BaseOps, Offset, - OffsetIsScalable, TRI)) - MemOpRecords.push_back(MemOpInfo(SU, BaseOps, Offset)); + if (TII->getMemOperandsWithOffset(MI, BaseOps, Offset, OffsetIsScalable, + TRI)) { + unsigned Width = + !MI.memoperands_empty() ? MI.memoperands().front()->getSize() : 0; + MemOpRecords.push_back(MemOpInfo(SU, BaseOps, Offset, Width)); + } #ifndef NDEBUG for (auto *Op : BaseOps) assert(Op); @@ -1584,16 +1590,19 @@ // At this point, `MemOpRecords` array must hold atleast two mem ops. Try to // cluster mem ops collected within `MemOpRecords` array. unsigned ClusterLength = 1; + unsigned CurrentClusterBytes = MemOpRecords[0].Width; for (unsigned Idx = 0, End = MemOpRecords.size(); Idx < (End - 1); ++Idx) { // Decision to cluster mem ops is taken based on target dependent logic auto MemOpa = MemOpRecords[Idx]; auto MemOpb = MemOpRecords[Idx + 1]; ++ClusterLength; - if (!TII->shouldClusterMemOps(MemOpa.BaseOps, MemOpb.BaseOps, - ClusterLength)) { + CurrentClusterBytes += MemOpb.Width; + if (!TII->shouldClusterMemOps(MemOpa.BaseOps, MemOpb.BaseOps, ClusterLength, + CurrentClusterBytes)) { // Current mem ops pair could not be clustered, reset cluster length, and // go to next pair ClusterLength = 1; + CurrentClusterBytes = MemOpb.Width; continue; } @@ -1605,6 +1614,7 @@ // FIXME: Is this check really required? if (!DAG->addEdge(SUb, SDep(SUa, SDep::Cluster))) { ClusterLength = 1; + CurrentClusterBytes = MemOpb.Width; continue; } diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h @@ -140,7 +140,7 @@ bool shouldClusterMemOps(ArrayRef BaseOps1, ArrayRef BaseOps2, - unsigned NumLoads) const override; + unsigned NumLoads, unsigned NumBytes) const override; void copyPhysRegTuple(MachineBasicBlock &MBB, MachineBasicBlock::iterator I, const DebugLoc &DL, MCRegister DestReg, diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -2492,7 +2492,8 @@ /// Only called for LdSt for which getMemOperandWithOffset returns true. bool AArch64InstrInfo::shouldClusterMemOps( ArrayRef BaseOps1, - ArrayRef BaseOps2, unsigned NumLoads) const { + ArrayRef BaseOps2, unsigned NumLoads, + unsigned NumBytes) const { assert(BaseOps1.size() == 1 && BaseOps2.size() == 1); const MachineOperand &BaseOp1 = *BaseOps1.front(); const MachineOperand &BaseOp2 = *BaseOps2.front(); diff --git a/llvm/lib/Target/AMDGPU/SIInsertHardClauses.cpp b/llvm/lib/Target/AMDGPU/SIInsertHardClauses.cpp --- a/llvm/lib/Target/AMDGPU/SIInsertHardClauses.cpp +++ b/llvm/lib/Target/AMDGPU/SIInsertHardClauses.cpp @@ -156,6 +156,15 @@ } } + unsigned WidthA = CI.Last + ? !CI.Last->memoperands_empty() + ? CI.Last->memoperands().front()->getSize() + : 0 + : 0; + unsigned WidthB = !MI.memoperands_empty() + ? MI.memoperands().front()->getSize() + : 0; + if (CI.Length == 64 || (CI.Length && Type != HARDCLAUSE_INTERNAL && (Type != CI.Type || @@ -164,7 +173,8 @@ // scheduler it limits the size of the cluster to avoid increasing // register pressure too much, but this pass runs after register // allocation so there is no need for that kind of limit. - !SII->shouldClusterMemOps(CI.BaseOps, BaseOps, 2)))) { + !SII->shouldClusterMemOps(CI.BaseOps, BaseOps, 2, + WidthA + WidthB)))) { // Finish the current clause. Changed |= emitClause(CI, SII); CI = ClauseInfo(); diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.h b/llvm/lib/Target/AMDGPU/SIInstrInfo.h --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.h +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.h @@ -189,7 +189,7 @@ bool shouldClusterMemOps(ArrayRef BaseOps1, ArrayRef BaseOps2, - unsigned NumLoads) const override; + unsigned NumLoads, unsigned NumBytes) const override; bool shouldScheduleLoadsNear(SDNode *Load0, SDNode *Load1, int64_t Offset0, int64_t Offset1, unsigned NumLoads) const override; diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp @@ -430,7 +430,8 @@ bool SIInstrInfo::shouldClusterMemOps(ArrayRef BaseOps1, ArrayRef BaseOps2, - unsigned NumLoads) const { + unsigned NumLoads, + unsigned NumBytes) const { assert(!BaseOps1.empty() && !BaseOps2.empty()); const MachineInstr &FirstLdSt = *BaseOps1.front()->getParent(); const MachineInstr &SecondLdSt = *BaseOps2.front()->getParent();