@@ -55,7 +55,6 @@ namespace {
55
55
using ReductionList = SmallVector<Reduction, 8 >;
56
56
using ValueList = SmallVector<Value*, 8 >;
57
57
using MemInstList = SmallVector<Instruction*, 8 >;
58
- using LoadInstList = SmallVector<LoadInst*, 8 >;
59
58
using PMACPair = std::pair<BinOpChain*,BinOpChain*>;
60
59
using PMACPairList = SmallVector<PMACPair, 8 >;
61
60
using Instructions = SmallVector<Instruction*,16 >;
@@ -64,8 +63,7 @@ namespace {
64
63
struct OpChain {
65
64
Instruction *Root;
66
65
ValueList AllValues;
67
- MemInstList VecLd; // List of all sequential load instructions.
68
- LoadInstList Loads; // List of all load instructions.
66
+ MemInstList VecLd; // List of all load instructions.
69
67
MemLocList MemLocs; // All memory locations read by this tree.
70
68
bool ReadOnly = true ;
71
69
@@ -78,10 +76,8 @@ namespace {
78
76
if (auto *I = dyn_cast<Instruction>(V)) {
79
77
if (I->mayWriteToMemory ())
80
78
ReadOnly = false ;
81
- if (auto *Ld = dyn_cast<LoadInst>(V)) {
79
+ if (auto *Ld = dyn_cast<LoadInst>(V))
82
80
MemLocs.push_back (MemoryLocation (Ld->getPointerOperand (), Size ));
83
- Loads.push_back (Ld);
84
- }
85
81
}
86
82
}
87
83
}
@@ -139,7 +135,6 @@ namespace {
139
135
// / exchange the halfwords of the second operand before performing the
140
136
// / arithmetic.
141
137
bool MatchSMLAD (Function &F);
142
- bool MatchTopBottomMuls (BasicBlock *LoopBody);
143
138
144
139
public:
145
140
static char ID;
@@ -208,8 +203,6 @@ namespace {
208
203
LLVM_DEBUG (dbgs () << " \n == Parallel DSP pass ==\n " );
209
204
LLVM_DEBUG (dbgs () << " - " << F.getName () << " \n\n " );
210
205
Changes = MatchSMLAD (F);
211
- if (!Changes)
212
- Changes = MatchTopBottomMuls (Header);
213
206
return Changes;
214
207
}
215
208
};
@@ -503,10 +496,10 @@ static void MatchReductions(Function &F, Loop *TheLoop, BasicBlock *Header,
503
496
);
504
497
}
505
498
506
- static void AddMulCandidate (OpChainList &Candidates,
499
+ static void AddMACCandidate (OpChainList &Candidates,
507
500
Instruction *Mul,
508
501
Value *MulOp0, Value *MulOp1) {
509
- LLVM_DEBUG (dbgs () << " OK, found mul:\t " ; Mul->dump ());
502
+ LLVM_DEBUG (dbgs () << " OK, found acc mul:\t " ; Mul->dump ());
510
503
assert (Mul->getOpcode () == Instruction::Mul &&
511
504
" expected mul instruction" );
512
505
ValueList LHS;
@@ -540,14 +533,14 @@ static void MatchParallelMACSequences(Reduction &R,
540
533
break ;
541
534
case Instruction::Mul:
542
535
if (match (I, (m_Mul (m_Value (MulOp0), m_Value (MulOp1))))) {
543
- AddMulCandidate (Candidates, I, MulOp0, MulOp1);
536
+ AddMACCandidate (Candidates, I, MulOp0, MulOp1);
544
537
return false ;
545
538
}
546
539
break ;
547
540
case Instruction::SExt:
548
541
if (match (I, (m_SExt (m_Mul (m_Value (MulOp0), m_Value (MulOp1)))))) {
549
542
Instruction *Mul = cast<Instruction>(I->getOperand (0 ));
550
- AddMulCandidate (Candidates, Mul, MulOp0, MulOp1);
543
+ AddMACCandidate (Candidates, Mul, MulOp0, MulOp1);
551
544
return false ;
552
545
}
553
546
break ;
@@ -576,24 +569,23 @@ static void AliasCandidates(BasicBlock *Header, Instructions &Reads,
576
569
// the memory locations accessed by the MAC-chains.
577
570
// TODO: we need the read statements when we accept more complicated chains.
578
571
static bool AreAliased (AliasAnalysis *AA, Instructions &Reads,
579
- Instructions &Writes, OpChainList &Candidates ) {
572
+ Instructions &Writes, OpChainList &MACCandidates ) {
580
573
LLVM_DEBUG (dbgs () << " Alias checks:\n " );
581
- for (auto &Candidate : Candidates) {
582
- LLVM_DEBUG (dbgs () << " mul: " ; Candidate->Root ->dump ());
583
- Candidate->SetMemoryLocations ();
574
+ for (auto &MAC : MACCandidates) {
575
+ LLVM_DEBUG (dbgs () << " mul: " ; MAC->Root ->dump ());
584
576
585
577
// At the moment, we allow only simple chains that only consist of reads,
586
578
// accumulate their result with an integer add, and thus that don't write
587
579
// memory, and simply bail if they do.
588
- if (!Candidate ->ReadOnly )
580
+ if (!MAC ->ReadOnly )
589
581
return true ;
590
582
591
583
// Now for all writes in the basic block, check that they don't alias with
592
584
// the memory locations accessed by our MAC-chain:
593
585
for (auto *I : Writes) {
594
586
LLVM_DEBUG (dbgs () << " - " ; I->dump ());
595
- assert (Candidate ->MemLocs .size () >= 2 && " expecting at least 2 memlocs" );
596
- for (auto &MemLoc : Candidate ->MemLocs ) {
587
+ assert (MAC ->MemLocs .size () >= 2 && " expecting at least 2 memlocs" );
588
+ for (auto &MemLoc : MAC ->MemLocs ) {
597
589
if (isModOrRefSet (intersectModRef (AA->getModRefInfo (I, MemLoc),
598
590
ModRefInfo::ModRef))) {
599
591
LLVM_DEBUG (dbgs () << " Yes, aliases found\n " );
@@ -607,14 +599,15 @@ static bool AreAliased(AliasAnalysis *AA, Instructions &Reads,
607
599
return false ;
608
600
}
609
601
610
- static bool CheckMulMemory (OpChainList &Candidates) {
602
+ static bool CheckMACMemory (OpChainList &Candidates) {
611
603
for (auto &C : Candidates) {
612
604
// A mul has 2 operands, and a narrow op consist of sext and a load; thus
613
605
// we expect at least 4 items in this operand value list.
614
606
if (C->size () < 4 ) {
615
607
LLVM_DEBUG (dbgs () << " Operand list too short.\n " );
616
608
return false ;
617
609
}
610
+ C->SetMemoryLocations ();
618
611
ValueList &LHS = static_cast <BinOpChain*>(C.get ())->LHS ;
619
612
ValueList &RHS = static_cast <BinOpChain*>(C.get ())->RHS ;
620
613
@@ -627,131 +620,6 @@ static bool CheckMulMemory(OpChainList &Candidates) {
627
620
return true ;
628
621
}
629
622
630
- static LoadInst *CreateLoadIns (IRBuilder<NoFolder> &IRB, LoadInst *BaseLoad,
631
- const Type *LoadTy) {
632
- const unsigned AddrSpace = BaseLoad->getPointerAddressSpace ();
633
-
634
- Value *VecPtr = IRB.CreateBitCast (BaseLoad->getPointerOperand (),
635
- LoadTy->getPointerTo (AddrSpace));
636
- return IRB.CreateAlignedLoad (VecPtr, BaseLoad->getAlignment ());
637
- }
638
-
639
- // / Attempt to widen loads and use smulbb, smulbt, smultb and smultt muls.
640
- // TODO: This, like smlad generation, expects the leave operands to be loads
641
- // that are sign extended. We should be able to handle scalar values as well
642
- // performing these muls on word x half types to generate smulwb and smulwt.
643
- bool ARMParallelDSP::MatchTopBottomMuls (BasicBlock *LoopBody) {
644
- LLVM_DEBUG (dbgs () << " Attempting to find BT|TB muls.\n " );
645
-
646
- OpChainList Candidates;
647
- for (auto &I : *LoopBody) {
648
- if (I.getOpcode () == Instruction::Mul) {
649
- if (I.getType ()->getScalarSizeInBits () == 32 ||
650
- I.getType ()->getScalarSizeInBits () == 64 )
651
- AddMulCandidate (Candidates, &I, I.getOperand (0 ), I.getOperand (1 ));
652
- }
653
- }
654
-
655
- if (Candidates.empty ())
656
- return false ;
657
-
658
- Instructions Reads;
659
- Instructions Writes;
660
- AliasCandidates (LoopBody, Reads, Writes);
661
-
662
- if (AreAliased (AA, Reads, Writes, Candidates))
663
- return false ;
664
-
665
- DenseMap<LoadInst*, Instruction*> LoadUsers;
666
- DenseMap<LoadInst*, LoadInst*> SeqLoads;
667
- SmallPtrSet<LoadInst*, 8 > OffsetLoads;
668
-
669
- for (unsigned i = 0 ; i < Candidates.size (); ++i) {
670
- for (unsigned j = 0 ; j < Candidates.size (); ++j) {
671
- if (i == j)
672
- continue ;
673
-
674
- OpChain *MulChain0 = Candidates[i].get ();
675
- OpChain *MulChain1 = Candidates[j].get ();
676
-
677
- for (auto *Ld0 : MulChain0->Loads ) {
678
- if (SeqLoads.count (Ld0) || OffsetLoads.count (Ld0))
679
- continue ;
680
-
681
- for (auto *Ld1 : MulChain1->Loads ) {
682
- if (SeqLoads.count (Ld1) || OffsetLoads.count (Ld1))
683
- continue ;
684
-
685
- MemInstList VecMem;
686
- if (AreSequentialLoads (Ld0, Ld1, VecMem)) {
687
- SeqLoads[Ld0] = Ld1;
688
- OffsetLoads.insert (Ld1);
689
- LoadUsers[Ld0] = MulChain0->Root ;
690
- LoadUsers[Ld1] = MulChain1->Root ;
691
- }
692
- }
693
- }
694
- }
695
- }
696
-
697
- if (SeqLoads.empty ())
698
- return false ;
699
-
700
- IRBuilder<NoFolder> IRB (LoopBody);
701
- const Type *Ty = IntegerType::get (M->getContext (), 32 );
702
-
703
- // We know that at least one of the operands is a SExt of Ld.
704
- auto GetSExt = [](Instruction *I, LoadInst *Ld, unsigned OpIdx) -> Instruction* {
705
- if (!isa<Instruction>(I->getOperand (OpIdx)))
706
- return nullptr ;
707
-
708
- Value *SExt = nullptr ;
709
- if (cast<Instruction>(I->getOperand (OpIdx))->getOperand (0 ) == Ld)
710
- SExt = I->getOperand (0 );
711
- else
712
- SExt = I->getOperand (1 );
713
-
714
- return cast<Instruction>(SExt);
715
- };
716
-
717
- LLVM_DEBUG (dbgs () << " Found some sequential loads, now widening:\n " );
718
- for (auto &Pair : SeqLoads) {
719
- LoadInst *BaseLd = Pair.first ;
720
- LoadInst *OffsetLd = Pair.second ;
721
- IRB.SetInsertPoint (BaseLd);
722
- LoadInst *WideLd = CreateLoadIns (IRB, BaseLd, Ty);
723
- LLVM_DEBUG (dbgs () << " - with base load: " << *BaseLd << " \n " );
724
- LLVM_DEBUG (dbgs () << " - created wide load: " << *WideLd << " \n " );
725
- Instruction *BaseUser = LoadUsers[BaseLd];
726
- Instruction *OffsetUser = LoadUsers[OffsetLd];
727
-
728
- Instruction *BaseSExt = GetSExt (BaseUser, BaseLd, 0 );
729
- if (!BaseSExt)
730
- BaseSExt = GetSExt (BaseUser, BaseLd, 1 );
731
- Instruction *OffsetSExt = GetSExt (OffsetUser, OffsetLd, 0 );
732
- if (!OffsetSExt)
733
- OffsetSExt = GetSExt (OffsetUser, OffsetLd, 1 );
734
-
735
- assert ((BaseSExt && OffsetSExt) && " failed to find SExts" );
736
-
737
- // BaseUser needs to: (asr (shl WideLoad, 16), 16)
738
- // OffsetUser needs to: (asr WideLoad, 16)
739
- auto *Shl = cast<Instruction>(IRB.CreateShl (WideLd, 16 ));
740
- auto *Bottom = cast<Instruction>(IRB.CreateAShr (Shl, 16 ));
741
- auto *Top = cast<Instruction>(IRB.CreateAShr (WideLd, 16 ));
742
- BaseUser->replaceUsesOfWith (BaseSExt, Bottom);
743
- OffsetUser->replaceUsesOfWith (OffsetSExt, Top);
744
-
745
- BaseSExt->eraseFromParent ();
746
- OffsetSExt->eraseFromParent ();
747
- BaseLd->eraseFromParent ();
748
- OffsetLd->eraseFromParent ();
749
- }
750
- LLVM_DEBUG (dbgs () << " Block after top bottom mul replacements:\n "
751
- << *LoopBody << " \n " );
752
- return true ;
753
- }
754
-
755
623
// Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
756
624
// multiplications.
757
625
// To use SMLAD:
@@ -790,15 +658,14 @@ bool ARMParallelDSP::MatchSMLAD(Function &F) {
790
658
dbgs () << " Header block:\n " ; Header->dump ();
791
659
dbgs () << " Loop info:\n\n " ; L->dump ());
792
660
661
+ bool Changed = false ;
793
662
ReductionList Reductions;
794
663
MatchReductions (F, L, Header, Reductions);
795
- if (Reductions.empty ())
796
- return false ;
797
664
798
665
for (auto &R : Reductions) {
799
666
OpChainList MACCandidates;
800
667
MatchParallelMACSequences (R, MACCandidates);
801
- if (!CheckMulMemory (MACCandidates))
668
+ if (!CheckMACMemory (MACCandidates))
802
669
continue ;
803
670
804
671
R.MACCandidates = std::move (MACCandidates);
@@ -815,7 +682,6 @@ bool ARMParallelDSP::MatchSMLAD(Function &F) {
815
682
Instructions Reads, Writes;
816
683
AliasCandidates (Header, Reads, Writes);
817
684
818
- bool Changed = false ;
819
685
for (auto &R : Reductions) {
820
686
if (AreAliased (AA, Reads, Writes, R.MACCandidates ))
821
687
return false ;
@@ -827,6 +693,15 @@ bool ARMParallelDSP::MatchSMLAD(Function &F) {
827
693
return Changed;
828
694
}
829
695
696
+ static LoadInst *CreateLoadIns (IRBuilder<NoFolder> &IRB, LoadInst &BaseLoad,
697
+ const Type *LoadTy) {
698
+ const unsigned AddrSpace = BaseLoad.getPointerAddressSpace ();
699
+
700
+ Value *VecPtr = IRB.CreateBitCast (BaseLoad.getPointerOperand (),
701
+ LoadTy->getPointerTo (AddrSpace));
702
+ return IRB.CreateAlignedLoad (VecPtr, BaseLoad.getAlignment ());
703
+ }
704
+
830
705
Instruction *ARMParallelDSP::CreateSMLADCall (LoadInst *VecLd0, LoadInst *VecLd1,
831
706
Instruction *Acc, bool Exchange,
832
707
Instruction *InsertAfter) {
@@ -841,8 +716,8 @@ Instruction *ARMParallelDSP::CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
841
716
842
717
// Replace the reduction chain with an intrinsic call
843
718
const Type *Ty = IntegerType::get (M->getContext (), 32 );
844
- LoadInst *NewLd0 = CreateLoadIns (Builder, & VecLd0[0 ], Ty);
845
- LoadInst *NewLd1 = CreateLoadIns (Builder, & VecLd1[0 ], Ty);
719
+ LoadInst *NewLd0 = CreateLoadIns (Builder, VecLd0[0 ], Ty);
720
+ LoadInst *NewLd1 = CreateLoadIns (Builder, VecLd1[0 ], Ty);
846
721
Value* Args[] = { NewLd0, NewLd1, Acc };
847
722
Function *SMLAD = nullptr ;
848
723
if (Exchange)
0 commit comments