@@ -55,6 +55,7 @@ namespace {
55
55
using ReductionList = SmallVector<Reduction, 8 >;
56
56
using ValueList = SmallVector<Value*, 8 >;
57
57
using MemInstList = SmallVector<Instruction*, 8 >;
58
+ using LoadInstList = SmallVector<LoadInst*, 8 >;
58
59
using PMACPair = std::pair<BinOpChain*,BinOpChain*>;
59
60
using PMACPairList = SmallVector<PMACPair, 8 >;
60
61
using Instructions = SmallVector<Instruction*,16 >;
@@ -63,7 +64,8 @@ namespace {
63
64
struct OpChain {
64
65
Instruction *Root;
65
66
ValueList AllValues;
66
- MemInstList VecLd; // List of all load instructions.
67
+ MemInstList VecLd; // List of all sequential load instructions.
68
+ LoadInstList Loads; // List of all load instructions.
67
69
MemLocList MemLocs; // All memory locations read by this tree.
68
70
bool ReadOnly = true ;
69
71
@@ -76,8 +78,10 @@ namespace {
76
78
if (auto *I = dyn_cast<Instruction>(V)) {
77
79
if (I->mayWriteToMemory ())
78
80
ReadOnly = false ;
79
- if (auto *Ld = dyn_cast<LoadInst>(V))
81
+ if (auto *Ld = dyn_cast<LoadInst>(V)) {
80
82
MemLocs.push_back (MemoryLocation (Ld->getPointerOperand (), Size ));
83
+ Loads.push_back (Ld);
84
+ }
81
85
}
82
86
}
83
87
}
@@ -135,6 +139,7 @@ namespace {
135
139
// / exchange the halfwords of the second operand before performing the
136
140
// / arithmetic.
137
141
bool MatchSMLAD (Function &F);
142
+ bool MatchTopBottomMuls (BasicBlock *LoopBody);
138
143
139
144
public:
140
145
static char ID;
@@ -203,6 +208,8 @@ namespace {
203
208
LLVM_DEBUG (dbgs () << " \n == Parallel DSP pass ==\n " );
204
209
LLVM_DEBUG (dbgs () << " - " << F.getName () << " \n\n " );
205
210
Changes = MatchSMLAD (F);
211
+ if (!Changes)
212
+ Changes = MatchTopBottomMuls (Header);
206
213
return Changes;
207
214
}
208
215
};
@@ -496,10 +503,10 @@ static void MatchReductions(Function &F, Loop *TheLoop, BasicBlock *Header,
496
503
);
497
504
}
498
505
499
- static void AddMACCandidate (OpChainList &Candidates,
506
+ static void AddMulCandidate (OpChainList &Candidates,
500
507
Instruction *Mul,
501
508
Value *MulOp0, Value *MulOp1) {
502
- LLVM_DEBUG (dbgs () << " OK, found acc mul:\t " ; Mul->dump ());
509
+ LLVM_DEBUG (dbgs () << " OK, found mul:\t " ; Mul->dump ());
503
510
assert (Mul->getOpcode () == Instruction::Mul &&
504
511
" expected mul instruction" );
505
512
ValueList LHS;
@@ -533,14 +540,14 @@ static void MatchParallelMACSequences(Reduction &R,
533
540
break ;
534
541
case Instruction::Mul:
535
542
if (match (I, (m_Mul (m_Value (MulOp0), m_Value (MulOp1))))) {
536
- AddMACCandidate (Candidates, I, MulOp0, MulOp1);
543
+ AddMulCandidate (Candidates, I, MulOp0, MulOp1);
537
544
return false ;
538
545
}
539
546
break ;
540
547
case Instruction::SExt:
541
548
if (match (I, (m_SExt (m_Mul (m_Value (MulOp0), m_Value (MulOp1)))))) {
542
549
Instruction *Mul = cast<Instruction>(I->getOperand (0 ));
543
- AddMACCandidate (Candidates, Mul, MulOp0, MulOp1);
550
+ AddMulCandidate (Candidates, Mul, MulOp0, MulOp1);
544
551
return false ;
545
552
}
546
553
break ;
@@ -569,23 +576,24 @@ static void AliasCandidates(BasicBlock *Header, Instructions &Reads,
569
576
// the memory locations accessed by the MAC-chains.
570
577
// TODO: we need the read statements when we accept more complicated chains.
571
578
static bool AreAliased (AliasAnalysis *AA, Instructions &Reads,
572
- Instructions &Writes, OpChainList &MACCandidates ) {
579
+ Instructions &Writes, OpChainList &Candidates ) {
573
580
LLVM_DEBUG (dbgs () << " Alias checks:\n " );
574
- for (auto &MAC : MACCandidates) {
575
- LLVM_DEBUG (dbgs () << " mul: " ; MAC->Root ->dump ());
581
+ for (auto &Candidate : Candidates) {
582
+ LLVM_DEBUG (dbgs () << " mul: " ; Candidate->Root ->dump ());
583
+ Candidate->SetMemoryLocations ();
576
584
577
585
// At the moment, we allow only simple chains that only consist of reads,
578
586
// accumulate their result with an integer add, and thus that don't write
579
587
// memory, and simply bail if they do.
580
- if (!MAC ->ReadOnly )
588
+ if (!Candidate ->ReadOnly )
581
589
return true ;
582
590
583
591
// Now for all writes in the basic block, check that they don't alias with
584
592
// the memory locations accessed by our MAC-chain:
585
593
for (auto *I : Writes) {
586
594
LLVM_DEBUG (dbgs () << " - " ; I->dump ());
587
- assert (MAC ->MemLocs .size () >= 2 && " expecting at least 2 memlocs" );
588
- for (auto &MemLoc : MAC ->MemLocs ) {
595
+ assert (Candidate ->MemLocs .size () >= 2 && " expecting at least 2 memlocs" );
596
+ for (auto &MemLoc : Candidate ->MemLocs ) {
589
597
if (isModOrRefSet (intersectModRef (AA->getModRefInfo (I, MemLoc),
590
598
ModRefInfo::ModRef))) {
591
599
LLVM_DEBUG (dbgs () << " Yes, aliases found\n " );
@@ -599,15 +607,14 @@ static bool AreAliased(AliasAnalysis *AA, Instructions &Reads,
599
607
return false ;
600
608
}
601
609
602
- static bool CheckMACMemory (OpChainList &Candidates) {
610
+ static bool CheckMulMemory (OpChainList &Candidates) {
603
611
for (auto &C : Candidates) {
604
612
// A mul has 2 operands, and a narrow op consist of sext and a load; thus
605
613
// we expect at least 4 items in this operand value list.
606
614
if (C->size () < 4 ) {
607
615
LLVM_DEBUG (dbgs () << " Operand list too short.\n " );
608
616
return false ;
609
617
}
610
- C->SetMemoryLocations ();
611
618
ValueList &LHS = static_cast <BinOpChain*>(C.get ())->LHS ;
612
619
ValueList &RHS = static_cast <BinOpChain*>(C.get ())->RHS ;
613
620
@@ -620,6 +627,131 @@ static bool CheckMACMemory(OpChainList &Candidates) {
620
627
return true ;
621
628
}
622
629
630
+ static LoadInst *CreateLoadIns (IRBuilder<NoFolder> &IRB, LoadInst *BaseLoad,
631
+ const Type *LoadTy) {
632
+ const unsigned AddrSpace = BaseLoad->getPointerAddressSpace ();
633
+
634
+ Value *VecPtr = IRB.CreateBitCast (BaseLoad->getPointerOperand (),
635
+ LoadTy->getPointerTo (AddrSpace));
636
+ return IRB.CreateAlignedLoad (VecPtr, BaseLoad->getAlignment ());
637
+ }
638
+
639
+ // / Attempt to widen loads and use smulbb, smulbt, smultb and smultt muls.
640
+ // TODO: This, like smlad generation, expects the leave operands to be loads
641
+ // that are sign extended. We should be able to handle scalar values as well
642
+ // performing these muls on word x half types to generate smulwb and smulwt.
643
+ bool ARMParallelDSP::MatchTopBottomMuls (BasicBlock *LoopBody) {
644
+ LLVM_DEBUG (dbgs () << " Attempting to find BT|TB muls.\n " );
645
+
646
+ OpChainList Candidates;
647
+ for (auto &I : *LoopBody) {
648
+ if (I.getOpcode () == Instruction::Mul) {
649
+ if (I.getType ()->getScalarSizeInBits () == 32 ||
650
+ I.getType ()->getScalarSizeInBits () == 64 )
651
+ AddMulCandidate (Candidates, &I, I.getOperand (0 ), I.getOperand (1 ));
652
+ }
653
+ }
654
+
655
+ if (Candidates.empty ())
656
+ return false ;
657
+
658
+ Instructions Reads;
659
+ Instructions Writes;
660
+ AliasCandidates (LoopBody, Reads, Writes);
661
+
662
+ if (AreAliased (AA, Reads, Writes, Candidates))
663
+ return false ;
664
+
665
+ DenseMap<LoadInst*, Instruction*> LoadUsers;
666
+ DenseMap<LoadInst*, LoadInst*> SeqLoads;
667
+ SmallPtrSet<LoadInst*, 8 > OffsetLoads;
668
+
669
+ for (unsigned i = 0 ; i < Candidates.size (); ++i) {
670
+ for (unsigned j = 0 ; j < Candidates.size (); ++j) {
671
+ if (i == j)
672
+ continue ;
673
+
674
+ OpChain *MulChain0 = Candidates[i].get ();
675
+ OpChain *MulChain1 = Candidates[j].get ();
676
+
677
+ for (auto *Ld0 : MulChain0->Loads ) {
678
+ if (SeqLoads.count (Ld0) || OffsetLoads.count (Ld0))
679
+ continue ;
680
+
681
+ for (auto *Ld1 : MulChain1->Loads ) {
682
+ if (SeqLoads.count (Ld1) || OffsetLoads.count (Ld1))
683
+ continue ;
684
+
685
+ MemInstList VecMem;
686
+ if (AreSequentialLoads (Ld0, Ld1, VecMem)) {
687
+ SeqLoads[Ld0] = Ld1;
688
+ OffsetLoads.insert (Ld1);
689
+ LoadUsers[Ld0] = MulChain0->Root ;
690
+ LoadUsers[Ld1] = MulChain1->Root ;
691
+ }
692
+ }
693
+ }
694
+ }
695
+ }
696
+
697
+ if (SeqLoads.empty ())
698
+ return false ;
699
+
700
+ IRBuilder<NoFolder> IRB (LoopBody);
701
+ const Type *Ty = IntegerType::get (M->getContext (), 32 );
702
+
703
+ // We know that at least one of the operands is a SExt of Ld.
704
+ auto GetSExt = [](Instruction *I, LoadInst *Ld, unsigned OpIdx) -> Instruction* {
705
+ if (!isa<Instruction>(I->getOperand (OpIdx)))
706
+ return nullptr ;
707
+
708
+ Value *SExt = nullptr ;
709
+ if (cast<Instruction>(I->getOperand (OpIdx))->getOperand (0 ) == Ld)
710
+ SExt = I->getOperand (0 );
711
+ else
712
+ SExt = I->getOperand (1 );
713
+
714
+ return cast<Instruction>(SExt);
715
+ };
716
+
717
+ LLVM_DEBUG (dbgs () << " Found some sequential loads, now widening:\n " );
718
+ for (auto &Pair : SeqLoads) {
719
+ LoadInst *BaseLd = Pair.first ;
720
+ LoadInst *OffsetLd = Pair.second ;
721
+ IRB.SetInsertPoint (BaseLd);
722
+ LoadInst *WideLd = CreateLoadIns (IRB, BaseLd, Ty);
723
+ LLVM_DEBUG (dbgs () << " - with base load: " << *BaseLd << " \n " );
724
+ LLVM_DEBUG (dbgs () << " - created wide load: " << *WideLd << " \n " );
725
+ Instruction *BaseUser = LoadUsers[BaseLd];
726
+ Instruction *OffsetUser = LoadUsers[OffsetLd];
727
+
728
+ Instruction *BaseSExt = GetSExt (BaseUser, BaseLd, 0 );
729
+ if (!BaseSExt)
730
+ BaseSExt = GetSExt (BaseUser, BaseLd, 1 );
731
+ Instruction *OffsetSExt = GetSExt (OffsetUser, OffsetLd, 0 );
732
+ if (!OffsetSExt)
733
+ OffsetSExt = GetSExt (OffsetUser, OffsetLd, 1 );
734
+
735
+ assert ((BaseSExt && OffsetSExt) && " failed to find SExts" );
736
+
737
+ // BaseUser needs to: (asr (shl WideLoad, 16), 16)
738
+ // OffsetUser needs to: (asr WideLoad, 16)
739
+ auto *Shl = cast<Instruction>(IRB.CreateShl (WideLd, 16 ));
740
+ auto *Bottom = cast<Instruction>(IRB.CreateAShr (Shl, 16 ));
741
+ auto *Top = cast<Instruction>(IRB.CreateAShr (WideLd, 16 ));
742
+ BaseUser->replaceUsesOfWith (BaseSExt, Bottom);
743
+ OffsetUser->replaceUsesOfWith (OffsetSExt, Top);
744
+
745
+ BaseSExt->eraseFromParent ();
746
+ OffsetSExt->eraseFromParent ();
747
+ BaseLd->eraseFromParent ();
748
+ OffsetLd->eraseFromParent ();
749
+ }
750
+ LLVM_DEBUG (dbgs () << " Block after top bottom mul replacements:\n "
751
+ << *LoopBody << " \n " );
752
+ return true ;
753
+ }
754
+
623
755
// Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
624
756
// multiplications.
625
757
// To use SMLAD:
@@ -658,14 +790,15 @@ bool ARMParallelDSP::MatchSMLAD(Function &F) {
658
790
dbgs () << " Header block:\n " ; Header->dump ();
659
791
dbgs () << " Loop info:\n\n " ; L->dump ());
660
792
661
- bool Changed = false ;
662
793
ReductionList Reductions;
663
794
MatchReductions (F, L, Header, Reductions);
795
+ if (Reductions.empty ())
796
+ return false ;
664
797
665
798
for (auto &R : Reductions) {
666
799
OpChainList MACCandidates;
667
800
MatchParallelMACSequences (R, MACCandidates);
668
- if (!CheckMACMemory (MACCandidates))
801
+ if (!CheckMulMemory (MACCandidates))
669
802
continue ;
670
803
671
804
R.MACCandidates = std::move (MACCandidates);
@@ -682,6 +815,7 @@ bool ARMParallelDSP::MatchSMLAD(Function &F) {
682
815
Instructions Reads, Writes;
683
816
AliasCandidates (Header, Reads, Writes);
684
817
818
+ bool Changed = false ;
685
819
for (auto &R : Reductions) {
686
820
if (AreAliased (AA, Reads, Writes, R.MACCandidates ))
687
821
return false ;
@@ -693,15 +827,6 @@ bool ARMParallelDSP::MatchSMLAD(Function &F) {
693
827
return Changed;
694
828
}
695
829
696
- static LoadInst *CreateLoadIns (IRBuilder<NoFolder> &IRB, LoadInst &BaseLoad,
697
- const Type *LoadTy) {
698
- const unsigned AddrSpace = BaseLoad.getPointerAddressSpace ();
699
-
700
- Value *VecPtr = IRB.CreateBitCast (BaseLoad.getPointerOperand (),
701
- LoadTy->getPointerTo (AddrSpace));
702
- return IRB.CreateAlignedLoad (VecPtr, BaseLoad.getAlignment ());
703
- }
704
-
705
830
Instruction *ARMParallelDSP::CreateSMLADCall (LoadInst *VecLd0, LoadInst *VecLd1,
706
831
Instruction *Acc, bool Exchange,
707
832
Instruction *InsertAfter) {
@@ -716,8 +841,8 @@ Instruction *ARMParallelDSP::CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
716
841
717
842
// Replace the reduction chain with an intrinsic call
718
843
const Type *Ty = IntegerType::get (M->getContext (), 32 );
719
- LoadInst *NewLd0 = CreateLoadIns (Builder, VecLd0[0 ], Ty);
720
- LoadInst *NewLd1 = CreateLoadIns (Builder, VecLd1[0 ], Ty);
844
+ LoadInst *NewLd0 = CreateLoadIns (Builder, & VecLd0[0 ], Ty);
845
+ LoadInst *NewLd1 = CreateLoadIns (Builder, & VecLd1[0 ], Ty);
721
846
Value* Args[] = { NewLd0, NewLd1, Acc };
722
847
Function *SMLAD = nullptr ;
723
848
if (Exchange)
0 commit comments