Skip to content

Commit 7b84fd7

Browse files
committedSep 14, 2018
[ARM] bottom-top mul support in ARMParallelDSP
On failing to find sequences that can be converted into dual macs, try to find sequential 16-bit loads that are used by muls which we can then use smultb, smulbt, smultt with a wide load. Differential Revision: https://reviews.llvm.org/D51983 llvm-svn: 342210
1 parent 3afb974 commit 7b84fd7

File tree

3 files changed

+612
-27
lines changed

3 files changed

+612
-27
lines changed
 

‎llvm/lib/Target/ARM/ARMParallelDSP.cpp

+152-27
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ namespace {
5555
using ReductionList = SmallVector<Reduction, 8>;
5656
using ValueList = SmallVector<Value*, 8>;
5757
using MemInstList = SmallVector<Instruction*, 8>;
58+
using LoadInstList = SmallVector<LoadInst*, 8>;
5859
using PMACPair = std::pair<BinOpChain*,BinOpChain*>;
5960
using PMACPairList = SmallVector<PMACPair, 8>;
6061
using Instructions = SmallVector<Instruction*,16>;
@@ -63,7 +64,8 @@ namespace {
6364
struct OpChain {
6465
Instruction *Root;
6566
ValueList AllValues;
66-
MemInstList VecLd; // List of all load instructions.
67+
MemInstList VecLd; // List of all sequential load instructions.
68+
LoadInstList Loads; // List of all load instructions.
6769
MemLocList MemLocs; // All memory locations read by this tree.
6870
bool ReadOnly = true;
6971

@@ -76,8 +78,10 @@ namespace {
7678
if (auto *I = dyn_cast<Instruction>(V)) {
7779
if (I->mayWriteToMemory())
7880
ReadOnly = false;
79-
if (auto *Ld = dyn_cast<LoadInst>(V))
81+
if (auto *Ld = dyn_cast<LoadInst>(V)) {
8082
MemLocs.push_back(MemoryLocation(Ld->getPointerOperand(), Size));
83+
Loads.push_back(Ld);
84+
}
8185
}
8286
}
8387
}
@@ -135,6 +139,7 @@ namespace {
135139
/// exchange the halfwords of the second operand before performing the
136140
/// arithmetic.
137141
bool MatchSMLAD(Function &F);
142+
bool MatchTopBottomMuls(BasicBlock *LoopBody);
138143

139144
public:
140145
static char ID;
@@ -203,6 +208,8 @@ namespace {
203208
LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
204209
LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
205210
Changes = MatchSMLAD(F);
211+
if (!Changes)
212+
Changes = MatchTopBottomMuls(Header);
206213
return Changes;
207214
}
208215
};
@@ -496,10 +503,10 @@ static void MatchReductions(Function &F, Loop *TheLoop, BasicBlock *Header,
496503
);
497504
}
498505

499-
static void AddMACCandidate(OpChainList &Candidates,
506+
static void AddMulCandidate(OpChainList &Candidates,
500507
Instruction *Mul,
501508
Value *MulOp0, Value *MulOp1) {
502-
LLVM_DEBUG(dbgs() << "OK, found acc mul:\t"; Mul->dump());
509+
LLVM_DEBUG(dbgs() << "OK, found mul:\t"; Mul->dump());
503510
assert(Mul->getOpcode() == Instruction::Mul &&
504511
"expected mul instruction");
505512
ValueList LHS;
@@ -533,14 +540,14 @@ static void MatchParallelMACSequences(Reduction &R,
533540
break;
534541
case Instruction::Mul:
535542
if (match (I, (m_Mul(m_Value(MulOp0), m_Value(MulOp1))))) {
536-
AddMACCandidate(Candidates, I, MulOp0, MulOp1);
543+
AddMulCandidate(Candidates, I, MulOp0, MulOp1);
537544
return false;
538545
}
539546
break;
540547
case Instruction::SExt:
541548
if (match (I, (m_SExt(m_Mul(m_Value(MulOp0), m_Value(MulOp1)))))) {
542549
Instruction *Mul = cast<Instruction>(I->getOperand(0));
543-
AddMACCandidate(Candidates, Mul, MulOp0, MulOp1);
550+
AddMulCandidate(Candidates, Mul, MulOp0, MulOp1);
544551
return false;
545552
}
546553
break;
@@ -569,23 +576,24 @@ static void AliasCandidates(BasicBlock *Header, Instructions &Reads,
569576
// the memory locations accessed by the MAC-chains.
570577
// TODO: we need the read statements when we accept more complicated chains.
571578
static bool AreAliased(AliasAnalysis *AA, Instructions &Reads,
572-
Instructions &Writes, OpChainList &MACCandidates) {
579+
Instructions &Writes, OpChainList &Candidates) {
573580
LLVM_DEBUG(dbgs() << "Alias checks:\n");
574-
for (auto &MAC : MACCandidates) {
575-
LLVM_DEBUG(dbgs() << "mul: "; MAC->Root->dump());
581+
for (auto &Candidate : Candidates) {
582+
LLVM_DEBUG(dbgs() << "mul: "; Candidate->Root->dump());
583+
Candidate->SetMemoryLocations();
576584

577585
// At the moment, we allow only simple chains that only consist of reads,
578586
// accumulate their result with an integer add, and thus that don't write
579587
// memory, and simply bail if they do.
580-
if (!MAC->ReadOnly)
588+
if (!Candidate->ReadOnly)
581589
return true;
582590

583591
// Now for all writes in the basic block, check that they don't alias with
584592
// the memory locations accessed by our MAC-chain:
585593
for (auto *I : Writes) {
586594
LLVM_DEBUG(dbgs() << "- "; I->dump());
587-
assert(MAC->MemLocs.size() >= 2 && "expecting at least 2 memlocs");
588-
for (auto &MemLoc : MAC->MemLocs) {
595+
assert(Candidate->MemLocs.size() >= 2 && "expecting at least 2 memlocs");
596+
for (auto &MemLoc : Candidate->MemLocs) {
589597
if (isModOrRefSet(intersectModRef(AA->getModRefInfo(I, MemLoc),
590598
ModRefInfo::ModRef))) {
591599
LLVM_DEBUG(dbgs() << "Yes, aliases found\n");
@@ -599,15 +607,14 @@ static bool AreAliased(AliasAnalysis *AA, Instructions &Reads,
599607
return false;
600608
}
601609

602-
static bool CheckMACMemory(OpChainList &Candidates) {
610+
static bool CheckMulMemory(OpChainList &Candidates) {
603611
for (auto &C : Candidates) {
604612
// A mul has 2 operands, and a narrow op consist of sext and a load; thus
605613
// we expect at least 4 items in this operand value list.
606614
if (C->size() < 4) {
607615
LLVM_DEBUG(dbgs() << "Operand list too short.\n");
608616
return false;
609617
}
610-
C->SetMemoryLocations();
611618
ValueList &LHS = static_cast<BinOpChain*>(C.get())->LHS;
612619
ValueList &RHS = static_cast<BinOpChain*>(C.get())->RHS;
613620

@@ -620,6 +627,131 @@ static bool CheckMACMemory(OpChainList &Candidates) {
620627
return true;
621628
}
622629

630+
static LoadInst *CreateLoadIns(IRBuilder<NoFolder> &IRB, LoadInst *BaseLoad,
631+
const Type *LoadTy) {
632+
const unsigned AddrSpace = BaseLoad->getPointerAddressSpace();
633+
634+
Value *VecPtr = IRB.CreateBitCast(BaseLoad->getPointerOperand(),
635+
LoadTy->getPointerTo(AddrSpace));
636+
return IRB.CreateAlignedLoad(VecPtr, BaseLoad->getAlignment());
637+
}
638+
639+
/// Attempt to widen loads and use smulbb, smulbt, smultb and smultt muls.
640+
// TODO: This, like smlad generation, expects the leave operands to be loads
641+
// that are sign extended. We should be able to handle scalar values as well
642+
// performing these muls on word x half types to generate smulwb and smulwt.
643+
bool ARMParallelDSP::MatchTopBottomMuls(BasicBlock *LoopBody) {
644+
LLVM_DEBUG(dbgs() << "Attempting to find BT|TB muls.\n");
645+
646+
OpChainList Candidates;
647+
for (auto &I : *LoopBody) {
648+
if (I.getOpcode() == Instruction::Mul) {
649+
if (I.getType()->getScalarSizeInBits() == 32 ||
650+
I.getType()->getScalarSizeInBits() == 64)
651+
AddMulCandidate(Candidates, &I, I.getOperand(0), I.getOperand(1));
652+
}
653+
}
654+
655+
if (Candidates.empty())
656+
return false;
657+
658+
Instructions Reads;
659+
Instructions Writes;
660+
AliasCandidates(LoopBody, Reads, Writes);
661+
662+
if (AreAliased(AA, Reads, Writes, Candidates))
663+
return false;
664+
665+
DenseMap<LoadInst*, Instruction*> LoadUsers;
666+
DenseMap<LoadInst*, LoadInst*> SeqLoads;
667+
SmallPtrSet<LoadInst*, 8> OffsetLoads;
668+
669+
for (unsigned i = 0; i < Candidates.size(); ++i) {
670+
for (unsigned j = 0; j < Candidates.size(); ++j) {
671+
if (i == j)
672+
continue;
673+
674+
OpChain *MulChain0 = Candidates[i].get();
675+
OpChain *MulChain1 = Candidates[j].get();
676+
677+
for (auto *Ld0 : MulChain0->Loads) {
678+
if (SeqLoads.count(Ld0) || OffsetLoads.count(Ld0))
679+
continue;
680+
681+
for (auto *Ld1 : MulChain1->Loads) {
682+
if (SeqLoads.count(Ld1) || OffsetLoads.count(Ld1))
683+
continue;
684+
685+
MemInstList VecMem;
686+
if (AreSequentialLoads(Ld0, Ld1, VecMem)) {
687+
SeqLoads[Ld0] = Ld1;
688+
OffsetLoads.insert(Ld1);
689+
LoadUsers[Ld0] = MulChain0->Root;
690+
LoadUsers[Ld1] = MulChain1->Root;
691+
}
692+
}
693+
}
694+
}
695+
}
696+
697+
if (SeqLoads.empty())
698+
return false;
699+
700+
IRBuilder<NoFolder> IRB(LoopBody);
701+
const Type *Ty = IntegerType::get(M->getContext(), 32);
702+
703+
// We know that at least one of the operands is a SExt of Ld.
704+
auto GetSExt = [](Instruction *I, LoadInst *Ld, unsigned OpIdx) -> Instruction* {
705+
if (!isa<Instruction>(I->getOperand(OpIdx)))
706+
return nullptr;
707+
708+
Value *SExt = nullptr;
709+
if (cast<Instruction>(I->getOperand(OpIdx))->getOperand(0) == Ld)
710+
SExt = I->getOperand(0);
711+
else
712+
SExt = I->getOperand(1);
713+
714+
return cast<Instruction>(SExt);
715+
};
716+
717+
LLVM_DEBUG(dbgs() << "Found some sequential loads, now widening:\n");
718+
for (auto &Pair : SeqLoads) {
719+
LoadInst *BaseLd = Pair.first;
720+
LoadInst *OffsetLd = Pair.second;
721+
IRB.SetInsertPoint(BaseLd);
722+
LoadInst *WideLd = CreateLoadIns(IRB, BaseLd, Ty);
723+
LLVM_DEBUG(dbgs() << " - with base load: " << *BaseLd << "\n");
724+
LLVM_DEBUG(dbgs() << " - created wide load: " << *WideLd << "\n");
725+
Instruction *BaseUser = LoadUsers[BaseLd];
726+
Instruction *OffsetUser = LoadUsers[OffsetLd];
727+
728+
Instruction *BaseSExt = GetSExt(BaseUser, BaseLd, 0);
729+
if (!BaseSExt)
730+
BaseSExt = GetSExt(BaseUser, BaseLd, 1);
731+
Instruction *OffsetSExt = GetSExt(OffsetUser, OffsetLd, 0);
732+
if (!OffsetSExt)
733+
OffsetSExt = GetSExt(OffsetUser, OffsetLd, 1);
734+
735+
assert((BaseSExt && OffsetSExt) && "failed to find SExts");
736+
737+
// BaseUser needs to: (asr (shl WideLoad, 16), 16)
738+
// OffsetUser needs to: (asr WideLoad, 16)
739+
auto *Shl = cast<Instruction>(IRB.CreateShl(WideLd, 16));
740+
auto *Bottom = cast<Instruction>(IRB.CreateAShr(Shl, 16));
741+
auto *Top = cast<Instruction>(IRB.CreateAShr(WideLd, 16));
742+
BaseUser->replaceUsesOfWith(BaseSExt, Bottom);
743+
OffsetUser->replaceUsesOfWith(OffsetSExt, Top);
744+
745+
BaseSExt->eraseFromParent();
746+
OffsetSExt->eraseFromParent();
747+
BaseLd->eraseFromParent();
748+
OffsetLd->eraseFromParent();
749+
}
750+
LLVM_DEBUG(dbgs() << "Block after top bottom mul replacements:\n"
751+
<< *LoopBody << "\n");
752+
return true;
753+
}
754+
623755
// Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
624756
// multiplications.
625757
// To use SMLAD:
@@ -658,14 +790,15 @@ bool ARMParallelDSP::MatchSMLAD(Function &F) {
658790
dbgs() << "Header block:\n"; Header->dump();
659791
dbgs() << "Loop info:\n\n"; L->dump());
660792

661-
bool Changed = false;
662793
ReductionList Reductions;
663794
MatchReductions(F, L, Header, Reductions);
795+
if (Reductions.empty())
796+
return false;
664797

665798
for (auto &R : Reductions) {
666799
OpChainList MACCandidates;
667800
MatchParallelMACSequences(R, MACCandidates);
668-
if (!CheckMACMemory(MACCandidates))
801+
if (!CheckMulMemory(MACCandidates))
669802
continue;
670803

671804
R.MACCandidates = std::move(MACCandidates);
@@ -682,6 +815,7 @@ bool ARMParallelDSP::MatchSMLAD(Function &F) {
682815
Instructions Reads, Writes;
683816
AliasCandidates(Header, Reads, Writes);
684817

818+
bool Changed = false;
685819
for (auto &R : Reductions) {
686820
if (AreAliased(AA, Reads, Writes, R.MACCandidates))
687821
return false;
@@ -693,15 +827,6 @@ bool ARMParallelDSP::MatchSMLAD(Function &F) {
693827
return Changed;
694828
}
695829

696-
static LoadInst *CreateLoadIns(IRBuilder<NoFolder> &IRB, LoadInst &BaseLoad,
697-
const Type *LoadTy) {
698-
const unsigned AddrSpace = BaseLoad.getPointerAddressSpace();
699-
700-
Value *VecPtr = IRB.CreateBitCast(BaseLoad.getPointerOperand(),
701-
LoadTy->getPointerTo(AddrSpace));
702-
return IRB.CreateAlignedLoad(VecPtr, BaseLoad.getAlignment());
703-
}
704-
705830
Instruction *ARMParallelDSP::CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
706831
Instruction *Acc, bool Exchange,
707832
Instruction *InsertAfter) {
@@ -716,8 +841,8 @@ Instruction *ARMParallelDSP::CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
716841

717842
// Replace the reduction chain with an intrinsic call
718843
const Type *Ty = IntegerType::get(M->getContext(), 32);
719-
LoadInst *NewLd0 = CreateLoadIns(Builder, VecLd0[0], Ty);
720-
LoadInst *NewLd1 = CreateLoadIns(Builder, VecLd1[0], Ty);
844+
LoadInst *NewLd0 = CreateLoadIns(Builder, &VecLd0[0], Ty);
845+
LoadInst *NewLd1 = CreateLoadIns(Builder, &VecLd1[0], Ty);
721846
Value* Args[] = { NewLd0, NewLd1, Acc };
722847
Function *SMLAD = nullptr;
723848
if (Exchange)

0 commit comments

Comments
 (0)