Changeset View
Changeset View
Standalone View
Standalone View
llvm/lib/CodeGen/ComplexArithmeticPass.cpp
- This file was added.
//===- ComplexArithmeticPass.cpp ------------------------------------------===// | |||||
// | |||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||||
// See https://llvm.org/LICENSE.txt for license information. | |||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||||
// | |||||
//===----------------------------------------------------------------------===// | |||||
// | |||||
// | |||||
//===----------------------------------------------------------------------===// | |||||
#include "llvm/CodeGen/ComplexArithmeticPass.h" | |||||
#include "llvm/ADT/ArrayRef.h" | |||||
#include "llvm/ADT/DenseMap.h" | |||||
#include "llvm/ADT/Statistic.h" | |||||
#include "llvm/Analysis/TargetTransformInfo.h" | |||||
#include "llvm/CodeGen/TargetLowering.h" | |||||
#include "llvm/CodeGen/TargetPassConfig.h" | |||||
#include "llvm/CodeGen/TargetSubtargetInfo.h" | |||||
#include "llvm/IR/Dominators.h" | |||||
#include "llvm/IR/InstIterator.h" | |||||
#include "llvm/IR/PatternMatch.h" | |||||
#include "llvm/InitializePasses.h" | |||||
#include "llvm/Target/TargetMachine.h" | |||||
#include <utility> | |||||
using namespace llvm; | |||||
using namespace PatternMatch; | |||||
#define DEBUG_TYPE "complex-arithmetic" | |||||
STATISTIC(NumComplexIntrinsics, "Number of complex intrinsics generated"); | |||||
static cl::opt<bool> | |||||
ComplexArithmeticEnabled("enable-complex-arithmetic", | |||||
cl::desc("Enable complex arithmetic"), | |||||
cl::init(true), cl::Hidden); | |||||
static cl::opt<bool> | |||||
ComplexArithmeticForceEnabled("force-complex-arithmetic", | |||||
cl::desc("Force complex arithmetic"), | |||||
cl::init(false), cl::Hidden); | |||||
namespace { | |||||
class ComplexArithmeticLegacyPass : public FunctionPass { | |||||
public: | |||||
static char ID; | |||||
ComplexArithmeticLegacyPass() : FunctionPass(ID) { | |||||
initializeComplexArithmeticLegacyPassPass(*PassRegistry::getPassRegistry()); | |||||
} | |||||
StringRef getPassName() const override { return "Complex Arithmetic Pass"; } | |||||
bool runOnFunction(Function &F) override; | |||||
void getAnalysisUsage(AnalysisUsage &AU) const override { | |||||
AU.setPreservesCFG(); | |||||
} | |||||
private: | |||||
DominatorTree *DT = nullptr; | |||||
const TargetLowering *TLI = nullptr; | |||||
}; | |||||
class ComplexArithmetic { | |||||
public: | |||||
ComplexArithmetic(TargetTransformInfo *TTI) : TTI(TTI) {} | |||||
bool runOnFunction(Function &F); | |||||
private: | |||||
/// Performs a search through the given BasicBlock, looking for potential | |||||
/// complex arithmetic candidates and replacing them where applicable. | |||||
/// \return true if any change has been made, false otherwise. | |||||
bool evaluateComplexArithmetic(BasicBlock *B, | |||||
SmallVector<Instruction *, 32> &DeadInsts); | |||||
/// Performs a naive check for specific patterns within the given block, | |||||
/// and replaces the matched instructions with the complex arithmetic | |||||
/// equivalent. | |||||
/// \return true if any change has been made, false otherwise. | |||||
bool | |||||
evaluateComplexArithmeticNaive(BasicBlock *B, | |||||
SmallVector<Instruction *, 32> &DeadInsts); | |||||
const TargetTransformInfo *TTI = nullptr; | |||||
const TargetLowering *TLI = nullptr; | |||||
}; | |||||
} // end anonymous namespace. | |||||
char ComplexArithmeticLegacyPass::ID = 0; | |||||
INITIALIZE_PASS_BEGIN(ComplexArithmeticLegacyPass, DEBUG_TYPE, | |||||
"Complex Arithmetic", false, false) | |||||
INITIALIZE_PASS_END(ComplexArithmeticLegacyPass, DEBUG_TYPE, | |||||
"Complex Arithmetic", false, false) | |||||
PreservedAnalyses ComplexArithmeticPass::run(Function &F, | |||||
FunctionAnalysisManager &AM) { | |||||
auto &TTI = AM.getResult<TargetIRAnalysis>(F); | |||||
if (!ComplexArithmetic(&TTI).runOnFunction(F)) | |||||
return PreservedAnalyses::all(); | |||||
PreservedAnalyses PA; | |||||
PA.preserve<FunctionAnalysisManagerModuleProxy>(); | |||||
return PA; | |||||
} | |||||
FunctionPass *llvm::createComplexArithmeticPass() { | |||||
return new ComplexArithmeticLegacyPass(); | |||||
} | |||||
/// Checks the given instruction, and if it's a recognised pattern, returns a | |||||
/// complex arithmetic candidate. | |||||
/// \param I The instruction to check. | |||||
/// \return A pointer to a complex arithmetic candidate, or nullptr if I does | |||||
/// not match a recognised pattern. | |||||
static ComplexArithmeticCandidate *getCandidate(Instruction *I) { | |||||
if (match(I, m_FSub(m_FMul(m_Value(), m_Value()), | |||||
m_FMul(m_Value(), m_Value()))) || | |||||
match(I, m_FAdd(m_FMul(m_Value(), m_Value()), | |||||
m_FMul(m_Value(), m_Value())))) { | |||||
auto *C = new ComplexArithmeticCandidate( | |||||
I, llvm::ComplexArithmeticCandidate::Complex_Mul); | |||||
if (auto *Ty = dyn_cast<FixedVectorType>(I->getType())) | |||||
C->TypeWidth = Ty->getNumElements(); | |||||
C->addInstruction(I->getOperand(0)); | |||||
C->addInstruction(I->getOperand(1)); | |||||
return C; | |||||
} | |||||
return nullptr; | |||||
} | |||||
bool ComplexArithmetic::evaluateComplexArithmeticNaive( | |||||
BasicBlock *B, SmallVector<Instruction *, 32> &DeadInsts) { | |||||
/// Check a very specific case where the vector width is twice that of the | |||||
/// width of a vector register. | |||||
LLVM_DEBUG(dbgs() << "evaluateComplexArithmeticNaive" | |||||
<< ".\n"); | |||||
for (auto &I : *B) { | |||||
if (isa<StoreInst>(I)) { | |||||
LLVM_DEBUG(dbgs() << "Found potential root of naive pattern: "; I.dump()); | |||||
auto *StoreTy = I.getOperand(0)->getType(); | |||||
if (!StoreTy->isVectorTy()) { | |||||
LLVM_DEBUG(dbgs() << "Not vector type" | |||||
<< ".\n"); | |||||
continue; | |||||
} | |||||
auto *ScalarTy = StoreTy->getScalarType(); | |||||
if (!ScalarTy->isFloatTy() && !ScalarTy->isDoubleTy()) { | |||||
LLVM_DEBUG(dbgs() << "Vector element type is not float or double" | |||||
<< ".\n"); | |||||
continue; | |||||
} | |||||
unsigned AllowedElements = ScalarTy->isDoubleTy() ? 4 : 8; | |||||
auto ElemCount = cast<VectorType>(StoreTy)->getElementCount(); | |||||
if (ElemCount.isScalable() || | |||||
cast<FixedVectorType>(StoreTy)->getNumElements() != AllowedElements) { | |||||
LLVM_DEBUG(dbgs() << "Element count is scalable, or number of elements " | |||||
"is not as allowed" | |||||
<< ".\n"); | |||||
continue; | |||||
} | |||||
auto *StorePtrTy = PointerType::getUnqual(StoreTy); | |||||
ArrayRef<int> RealMask; | |||||
ArrayRef<int> ImagMask; | |||||
ArrayRef<int> InterleaveMask; | |||||
LLVM_DEBUG(dbgs() << "Allowed elements deemed to be " << AllowedElements | |||||
<< ".\n"); | |||||
if (AllowedElements == 8) { | |||||
std::array<int, 4> RealMaskArr = {0, 2, 4, 6}; | |||||
std::array<int, 4> ImagMaskArr = {1, 3, 5, 7}; | |||||
std::array<int, 8> InterleaveMaskArr = {0, 4, 1, 5, 2, 6, 3, 7}; | |||||
RealMask = ArrayRef<int>(RealMaskArr); | |||||
ImagMask = ArrayRef<int>(ImagMaskArr); | |||||
InterleaveMask = ArrayRef<int>(InterleaveMaskArr); | |||||
} else if (AllowedElements == 4) { | |||||
std::array<int, 2> RealMaskArr = {0, 2}; | |||||
std::array<int, 2> ImagMaskArr = {1, 3}; | |||||
std::array<int, 4> InterleaveMaskArr = {0, 2, 1, 3}; | |||||
RealMask = ArrayRef<int>(RealMaskArr); | |||||
ImagMask = ArrayRef<int>(ImagMaskArr); | |||||
InterleaveMask = ArrayRef<int>(InterleaveMaskArr); | |||||
} | |||||
// Check naive complex multiply pattern | |||||
Value *LoadInst0; | |||||
Value *LoadInst1; | |||||
Value *LoadInst2; | |||||
Value *LoadInst3; | |||||
Value *LoadInst4; | |||||
Value *LoadInst5; | |||||
Value *LoadInst6; | |||||
Value *LoadInst7; | |||||
Instruction *BitcastInstr; | |||||
bool M = match( | |||||
&I, | |||||
m_Store( | |||||
m_Shuffle(m_FSub(m_FMul(m_Shuffle(m_Value(LoadInst0), m_Poison(), | |||||
m_SpecificMask(RealMask)), | |||||
m_Shuffle(m_Value(LoadInst1), m_Poison(), | |||||
m_SpecificMask(RealMask))), | |||||
m_FMul(m_Shuffle(m_Value(LoadInst2), m_Poison(), | |||||
m_SpecificMask(ImagMask)), | |||||
m_Shuffle(m_Value(LoadInst3), m_Poison(), | |||||
m_SpecificMask(ImagMask)))), | |||||
m_FAdd(m_FMul(m_Shuffle(m_Value(LoadInst4), m_Poison(), | |||||
m_SpecificMask(ImagMask)), | |||||
m_Shuffle(m_Value(LoadInst5), m_Poison(), | |||||
m_SpecificMask(RealMask))), | |||||
m_FMul(m_Shuffle(m_Value(LoadInst6), m_Poison(), | |||||
m_SpecificMask(RealMask)), | |||||
m_Shuffle(m_Value(LoadInst7), m_Poison(), | |||||
m_SpecificMask(ImagMask)))), | |||||
m_SpecificMask(InterleaveMask)), | |||||
m_Instruction(BitcastInstr))); | |||||
if (!M) { | |||||
LLVM_DEBUG(dbgs() << "Failed to match full pattern" | |||||
<< ".\n"); | |||||
return false; | |||||
} | |||||
// Check that the 8 LoadInsts are made up of 2 distinct LoadInsts | |||||
bool LoadInstsAMatch = LoadInst1 == LoadInst3 && LoadInst1 == LoadInst5 && | |||||
LoadInst1 == LoadInst7; | |||||
bool LoadInstsBMatch = LoadInst0 == LoadInst2 && LoadInst0 == LoadInst4 && | |||||
LoadInst0 == LoadInst6; | |||||
if (!LoadInstsAMatch || !LoadInstsBMatch) { | |||||
LLVM_DEBUG(dbgs() << "Loads do not match" | |||||
<< ".\n"); | |||||
return false; | |||||
} | |||||
if (LoadInst0->getType() != StoreTy || LoadInst1->getType() != StoreTy) { | |||||
LLVM_DEBUG(dbgs() << "Load types do not match store type" | |||||
<< ".\n"); | |||||
return false; | |||||
} | |||||
auto *LoadInstA = LoadInst1; | |||||
auto *LoadInstB = LoadInst0; | |||||
std::array<Instruction *, 4> Shuffles; | |||||
auto *InterleavedVec = cast<ShuffleVectorInst>(I.getOperand(0)); | |||||
// Gather the shuffles from the FSub | |||||
auto *FSub = cast<Instruction>(InterleavedVec->getOperand(0)); | |||||
auto *FSub0 = cast<Instruction>(FSub->getOperand(0)); | |||||
auto *FSub1 = cast<Instruction>(FSub->getOperand(1)); | |||||
auto *FSub00 = cast<ShuffleVectorInst>(FSub0->getOperand(0)); | |||||
auto *FSub01 = cast<ShuffleVectorInst>(FSub0->getOperand(1)); | |||||
auto *FSub10 = cast<ShuffleVectorInst>(FSub1->getOperand(0)); | |||||
auto *FSub11 = cast<ShuffleVectorInst>(FSub1->getOperand(1)); | |||||
Shuffles[0] = FSub01; | |||||
Shuffles[1] = FSub11; | |||||
Shuffles[2] = FSub00; | |||||
Shuffles[3] = FSub10; | |||||
// Check that all shuffles are distinct | |||||
for (unsigned i = 0; i < 4; i++) { | |||||
for (unsigned j = 0; j < 4; j++) { | |||||
if (i == j) | |||||
continue; | |||||
if (Shuffles[i] == Shuffles[j]) { | |||||
LLVM_DEBUG(dbgs() << "Shuffles aren't distinct" | |||||
<< ".\n"); | |||||
return false; | |||||
} | |||||
} | |||||
} | |||||
// Compare the shuffles with the FAdd | |||||
auto *FAdd = cast<Instruction>(InterleavedVec->getOperand(1)); | |||||
auto *FAdd0 = cast<Instruction>(FAdd->getOperand(0)); | |||||
auto *FAdd1 = cast<Instruction>(FAdd->getOperand(1)); | |||||
auto *FAdd00 = cast<ShuffleVectorInst>(FAdd0->getOperand(0)); | |||||
auto *FAdd01 = cast<ShuffleVectorInst>(FAdd0->getOperand(1)); | |||||
auto *FAdd10 = cast<ShuffleVectorInst>(FAdd1->getOperand(0)); | |||||
auto *FAdd11 = cast<ShuffleVectorInst>(FAdd1->getOperand(1)); | |||||
if (FAdd00 != Shuffles[3] || FAdd01 != Shuffles[0] || | |||||
FAdd10 != Shuffles[2] || FAdd11 != Shuffles[1]) { | |||||
LLVM_DEBUG(dbgs() << "Shuffles do not match" | |||||
<< ".\n"); | |||||
return false; | |||||
} | |||||
// At this point, we can assume all is good to be replaced | |||||
IRBuilder<> Builder(&I); | |||||
std::array<int, 8> StorageShuffleMaskArr = {0, 1, 2, 3, 4, 5, 6, 7}; | |||||
std::array<int, 4> Shuffle1MaskArr = {0, 1, 2, 3}; | |||||
std::array<int, 4> Shuffle2MaskArr = {4, 5, 6, 7}; | |||||
ArrayRef<int> StorageShuffleMask(StorageShuffleMaskArr); | |||||
ArrayRef<int> Shuffle1Mask(Shuffle1MaskArr); | |||||
ArrayRef<int> Shuffle2Mask(Shuffle2MaskArr); | |||||
auto *ShuffleA1 = Builder.CreateShuffleVector(LoadInstA, Shuffle1Mask); | |||||
auto *ShuffleA2 = Builder.CreateShuffleVector(LoadInstA, Shuffle2Mask); | |||||
auto *ShuffleB1 = Builder.CreateShuffleVector(LoadInstB, Shuffle1Mask); | |||||
auto *ShuffleB2 = Builder.CreateShuffleVector(LoadInstB, Shuffle2Mask); | |||||
auto *C1 = getCandidate(FSub); | |||||
auto *C2 = getCandidate(FAdd); | |||||
ComplexArithmeticCandidateSet Set(TTI); | |||||
Set.addCandidate(C1); | |||||
Set.addCandidate(C2); | |||||
unsigned int OperandCount; | |||||
Intrinsic::ID IntC1 = | |||||
TTI->getComplexArithmeticIntrinsic(C1, OperandCount); | |||||
Intrinsic::ID IntC2 = | |||||
TTI->getComplexArithmeticIntrinsic(C2, OperandCount); | |||||
std::vector<Value *> Operands; | |||||
for (auto &item : C1->ExtraOperandsPre) | |||||
Operands.push_back(item); | |||||
if (C1->Type == llvm::ComplexArithmeticCandidate::Complex_Mla) | |||||
Operands.push_back(ConstantFP::get(C1->getDataType(), 0)); | |||||
Operands.push_back(ShuffleA1); | |||||
Operands.push_back(ShuffleB1); | |||||
auto *IntrinsicLow1 = | |||||
Builder.CreateIntrinsic(IntC1, C1->getDataType(), Operands); | |||||
Operands.clear(); | |||||
for (auto &item : C2->ExtraOperandsPre) | |||||
Operands.push_back(item); | |||||
Operands.push_back(IntrinsicLow1); | |||||
Operands.push_back(ShuffleA1); | |||||
Operands.push_back(ShuffleB1); | |||||
auto *IntrinsicLow2 = | |||||
Builder.CreateIntrinsic(IntC2, C2->getDataType(), Operands); | |||||
Operands.clear(); | |||||
for (auto &item : C1->ExtraOperandsPre) | |||||
Operands.push_back(item); | |||||
if (C1->Type == llvm::ComplexArithmeticCandidate::Complex_Mla) | |||||
Operands.push_back(ConstantFP::get(C1->getDataType(), 0)); | |||||
Operands.push_back(ShuffleA2); | |||||
Operands.push_back(ShuffleB2); | |||||
auto *IntrinsicHigh1 = | |||||
Builder.CreateIntrinsic(IntC1, C1->getDataType(), Operands); | |||||
Operands.clear(); | |||||
for (auto &item : C2->ExtraOperandsPre) | |||||
Operands.push_back(item); | |||||
Operands.push_back(IntrinsicHigh1); | |||||
Operands.push_back(ShuffleA2); | |||||
Operands.push_back(ShuffleB2); | |||||
auto *IntrinsicHigh2 = | |||||
Builder.CreateIntrinsic(IntC2, C2->getDataType(), Operands); | |||||
auto *StorageShuffle = Builder.CreateShuffleVector( | |||||
IntrinsicLow2, IntrinsicHigh2, StorageShuffleMask); | |||||
InterleavedVec->replaceAllUsesWith(StorageShuffle); | |||||
NumComplexIntrinsics += 4; | |||||
// Clear old instructions | |||||
DeadInsts.push_back(InterleavedVec); | |||||
DeadInsts.push_back(FSub); | |||||
DeadInsts.push_back(FSub0); | |||||
DeadInsts.push_back(FSub00); | |||||
DeadInsts.push_back(FSub01); | |||||
DeadInsts.push_back(FSub1); | |||||
DeadInsts.push_back(FSub10); | |||||
DeadInsts.push_back(FSub11); | |||||
DeadInsts.push_back(FAdd); | |||||
DeadInsts.push_back(FAdd0); | |||||
DeadInsts.push_back(FAdd00); | |||||
DeadInsts.push_back(FAdd01); | |||||
DeadInsts.push_back(FAdd1); | |||||
DeadInsts.push_back(FAdd10); | |||||
DeadInsts.push_back(FAdd11); | |||||
LLVM_DEBUG(dbgs() << "Naive pattern matching succeeded" | |||||
<< ".\n"); | |||||
return true; | |||||
} | |||||
} | |||||
LLVM_DEBUG(dbgs() << "Naive pattern matching failed" | |||||
<< ".\n"); | |||||
return false; | |||||
} | |||||
bool ComplexArithmetic::evaluateComplexArithmetic( | |||||
BasicBlock *B, SmallVector<Instruction *, 32> &DeadInsts) { | |||||
LLVM_DEBUG(dbgs() << "Evaluating complex arithmetic on block: "; B->dump()); | |||||
if (evaluateComplexArithmeticNaive(B, DeadInsts)) | |||||
return true; | |||||
ComplexArithmeticCandidateSet CandidateSet(TTI); | |||||
for (auto &I : *B) { | |||||
auto *C = getCandidate(&I); | |||||
if (C) | |||||
CandidateSet.addCandidate(C); | |||||
} | |||||
SmallVector<ComplexArithmeticCandidate *, 4> InvalidCandidates; | |||||
for (const auto &C : CandidateSet.get()) { | |||||
if (!CandidateSet.candidateHasValidDataFlow(C)) | |||||
InvalidCandidates.push_back(C); | |||||
} | |||||
for (auto &item : InvalidCandidates) | |||||
CandidateSet.remove(item); | |||||
InvalidCandidates.clear(); | |||||
bool Changed = false; | |||||
for (auto *C : CandidateSet.get()) { | |||||
Changed |= CandidateSet.tryLower(*TTI, C, DeadInsts); | |||||
} | |||||
// Clean up loose memory objects | |||||
CandidateSet.clear(); | |||||
return Changed; | |||||
} | |||||
bool ComplexArithmeticLegacyPass::runOnFunction(Function &F) { | |||||
auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); | |||||
auto &TM = TPC->getTM<TargetMachine>(); | |||||
auto TTI = TM.getTargetTransformInfo(F); | |||||
return ComplexArithmetic(&TTI).runOnFunction(F); | |||||
} | |||||
static bool HasBeenDisabled = false; | |||||
bool ComplexArithmetic::runOnFunction(Function &F) { | |||||
if (!ComplexArithmeticForceEnabled) { | |||||
if (!ComplexArithmeticEnabled) { | |||||
LLVM_DEBUG(if (!HasBeenDisabled) dbgs() | |||||
<< "Complex Arithmetic has been explicitly disabled.\n"); | |||||
HasBeenDisabled = true; | |||||
return false; | |||||
} | |||||
if (!TTI->supportsComplexNumberArithmetic()) { | |||||
LLVM_DEBUG(if (!HasBeenDisabled) dbgs() | |||||
<< "Complex Arithmetic has been disabled. " | |||||
<< "Target does not support lowering of complex numbers.\n"); | |||||
HasBeenDisabled = true; | |||||
return false; | |||||
} | |||||
} | |||||
SmallVector<Instruction *, 32> DeadInsts; | |||||
bool Changed = false; | |||||
SmallVector<BasicBlock *, 4> ChangedBlocks; | |||||
for (auto &B : F) { | |||||
if (evaluateComplexArithmetic(&B, DeadInsts)) { | |||||
Changed |= true; | |||||
ChangedBlocks.push_back(&B); | |||||
} | |||||
} | |||||
if (Changed) { | |||||
// TODO clean up the dead instructions better | |||||
unsigned iter = 0; | |||||
unsigned count = DeadInsts.size(); | |||||
unsigned remaining = DeadInsts.size(); | |||||
while (!DeadInsts.empty() && remaining > 0 && iter < count) { | |||||
++iter; | |||||
remaining = 0; | |||||
for (auto *It = DeadInsts.begin(); It != DeadInsts.end(); It++) { | |||||
auto *I = *It; | |||||
if (I->getParent()) | |||||
remaining++; | |||||
if (I->getNumUses() == 0 && I->getParent()) { | |||||
remaining--; | |||||
I->eraseFromParent(); | |||||
} | |||||
} | |||||
} | |||||
DeadInsts.clear(); | |||||
} | |||||
return Changed; | |||||
} | |||||
bool ComplexArithmeticCandidate::getOperands( | |||||
TargetTransformInfo *TTI, ComplexArithmeticCandidateSet *Set, | |||||
SmallVector<Instruction *, 32> &DeadInsts, | |||||
SmallVector<Value *, 4> &Operands) { | |||||
auto *PairedC = Set->getPaired(this); | |||||
if (!PairedC) | |||||
return false; | |||||
for (auto *V : ExtraOperandsPre) { | |||||
Operands.push_back(V); | |||||
} | |||||
if (Type == ComplexArithmeticCandidate::Complex_Mla) { | |||||
if (PairedC->getRotation() < this->getRotation()) { | |||||
Set->tryLower(*TTI, PairedC, DeadInsts); | |||||
Operands.push_back(PairedC->getOutput()); | |||||
} else { | |||||
auto *Zero = ConstantFP::get(PairedC->getDataType(), 0); | |||||
Operands.push_back(Zero); | |||||
} | |||||
} | |||||
auto Inputs = getInputs(); | |||||
for (auto *V : Inputs) { | |||||
if (auto *I = dyn_cast<Instruction>(V)) { | |||||
if (auto *C = Set->getCandidateHoldingInstr(I)) { | |||||
Set->tryLower(*TTI, C, DeadInsts); | |||||
if (C->getRotation() < C->PairedCandidate->getRotation()) | |||||
continue; | |||||
Operands.push_back(C->getOutput()); | |||||
continue; | |||||
} | |||||
if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) { | |||||
LoadInst *LI; | |||||
if (auto *IEI = dyn_cast<InsertElementInst>(SVI->getOperand(0))) { | |||||
LI = dyn_cast<LoadInst>(IEI->getOperand(1)); | |||||
if (LI) { | |||||
if (std::find(DeadInsts.begin(), DeadInsts.end(), IEI) == | |||||
DeadInsts.end()) { | |||||
DeadInsts.push_back(IEI); | |||||
} | |||||
if (std::find(DeadInsts.begin(), DeadInsts.end(), SVI) == | |||||
DeadInsts.end()) { | |||||
DeadInsts.push_back(SVI); | |||||
} | |||||
if (isa<GetElementPtrInst>(LI->getOperand(0))) | |||||
LI = nullptr; | |||||
} | |||||
} else { | |||||
LI = dyn_cast<LoadInst>(SVI->getOperand(0)); | |||||
} | |||||
auto *Ty = PointerType::get(getDataType(), 0); | |||||
if (LI) { | |||||
if (LI->getType() != getDataType()) { | |||||
LLVM_DEBUG(dbgs() << "LI->getType() "; LI->getType()->dump()); | |||||
LLVM_DEBUG(dbgs() << "getDataType() "; getDataType()->dump()); | |||||
IRBuilder<> B(LI); | |||||
if (std::find(DeadInsts.begin(), DeadInsts.end(), LI) == | |||||
DeadInsts.end()) { | |||||
DeadInsts.push_back(LI); | |||||
} | |||||
auto *BitCast = B.CreateBitOrPointerCast(LI->getOperand(0), Ty); | |||||
auto *NewLI = B.CreateLoad(Ty->getElementType(), BitCast); | |||||
LI = NewLI; | |||||
} | |||||
if (std::find(Operands.begin(), Operands.end(), LI) == | |||||
Operands.end()) { | |||||
Operands.push_back(LI); | |||||
} | |||||
if (std::find(DeadInsts.begin(), DeadInsts.end(), SVI) == | |||||
DeadInsts.end()) { | |||||
DeadInsts.push_back(SVI); | |||||
} | |||||
continue; | |||||
} | |||||
} | |||||
TTI->filterComplexArithmeticOperand(this, V, Operands, DeadInsts); | |||||
} | |||||
} | |||||
LLVM_DEBUG({ | |||||
dbgs() << "Operands: \n"; | |||||
for (auto *Op : Operands) { | |||||
Op->dump(); | |||||
} | |||||
}); | |||||
return true; | |||||
} | |||||
bool ComplexArithmeticCandidateSet::tryLower( | |||||
const TargetTransformInfo &TTI, ComplexArithmeticCandidate *C, | |||||
SmallVector<Instruction *, 32> &DeadInsts) { | |||||
LLVM_DEBUG(dbgs() << "Trying to lower candidate starting at "; | |||||
C->RootInstruction->dump()); | |||||
if (C->Intrinsic) { | |||||
LLVM_DEBUG(dbgs() << "Candidate already lowered" | |||||
<< ".\n"); | |||||
return false; | |||||
} | |||||
if (!C->RootInstruction->getType()->isVectorTy()) { | |||||
LLVM_DEBUG(dbgs() << "Candidate type is not vector type" | |||||
<< ".\n"); | |||||
return false; | |||||
} | |||||
getPaired(C); | |||||
// Is the lowered operator going to need splitting? If so, don't lower it | |||||
auto *VTy = cast<FixedVectorType>(C->getDataType()); | |||||
unsigned VTyWidth = VTy->getScalarSizeInBits() * VTy->getNumElements(); | |||||
if (VTyWidth > 128 /* Width of a single vector register */) { | |||||
LLVM_DEBUG(dbgs() << "Vector type is too wide" | |||||
<< ".\n"); | |||||
LLVM_DEBUG(dbgs() << " Width: " << VTyWidth << ". "; VTy->dump()); | |||||
return false; | |||||
} | |||||
unsigned IntArgCount = 0; | |||||
Intrinsic::ID IntId = TTI.getComplexArithmeticIntrinsic(C, IntArgCount); | |||||
if (IntId != Intrinsic::not_intrinsic) { | |||||
IRBuilder<> Builder(cast<Instruction>(*C->RootInstruction->user_begin())); | |||||
Type *DataType = C->getDataType(); | |||||
SmallVector<Value *, 4> OperandsArr; | |||||
if (!C->getOperands(const_cast<TargetTransformInfo *>(&TTI), this, | |||||
DeadInsts, OperandsArr)) { | |||||
LLVM_DEBUG(dbgs() << "Failed to get operands" | |||||
<< ".\n"); | |||||
return false; | |||||
} | |||||
if (OperandsArr.size() != IntArgCount) { | |||||
LLVM_DEBUG(dbgs() << "Got incorrect amount of operands" | |||||
<< ".\n"); | |||||
LLVM_DEBUG(dbgs() << "Expected " << IntArgCount << ", got " | |||||
<< OperandsArr.size() << ".\n"); | |||||
return false; | |||||
} | |||||
ArrayRef<Value *> Operands(OperandsArr); | |||||
C->setIntrinsic(Builder.CreateIntrinsic(IntId, {DataType}, Operands)); | |||||
NumComplexIntrinsics++; | |||||
if (C->isTerminalInPair()) { | |||||
if (auto *SVI = | |||||
dyn_cast<ShuffleVectorInst>(*C->RootInstruction->user_begin())) { | |||||
if (SVI->hasOneUse() && isa<CallInst>(*SVI->users().begin())) { | |||||
auto *CI = cast<CallInst>(*SVI->users().begin()); | |||||
if (CI->mayWriteToMemory()) { | |||||
// Convert a storing CallInst to a standard store, keeping it | |||||
// compatible with the complex intrinsic structure | |||||
// TODO make this better | |||||
DeadInsts.push_back(CI); | |||||
IRBuilder<> B(CI); | |||||
auto *BCI = cast<BitCastInst>(CI->getOperand(2)); | |||||
DeadInsts.push_back(BCI); | |||||
B.CreateStore(C->Intrinsic, BCI->getOperand(0)); | |||||
} | |||||
} else | |||||
SVI->replaceAllUsesWith(C->Intrinsic); | |||||
DeadInsts.push_back(SVI); | |||||
} | |||||
C->PairedCandidate->Intrinsic->moveBefore(C->Intrinsic); | |||||
} | |||||
for (auto &item : C->Instructions) | |||||
DeadInsts.push_back(item); | |||||
LLVM_DEBUG(dbgs() << "Success?" | |||||
<< ".\n"); | |||||
return true; | |||||
} | |||||
LLVM_DEBUG(dbgs() << "No intrinsic offered by TTI" | |||||
<< ".\n"); | |||||
return false; | |||||
} | |||||
bool ComplexArithmeticCandidateSet::candidateHasValidDataFlow( | |||||
ComplexArithmeticCandidate *C) { | |||||
// * Crawl through operands and evaluate each one | |||||
// * An operand within another candidate is counted as valid, if that | |||||
// candidate is also valid | |||||
// * A shuffle operand that matches a recognised pattern is valid | |||||
auto Inputs = C->getInputs(); | |||||
for (auto *V : Inputs) { | |||||
if (auto *I = dyn_cast<Instruction>(V)) { | |||||
if (auto *ParentC = getCandidateHoldingInstr(I)) | |||||
if (!candidateHasValidDataFlow(ParentC)) | |||||
return false; | |||||
if (isa<LoadInst>(I)) | |||||
continue; | |||||
if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) { | |||||
// Permuted/splatted access | |||||
std::array<int, 2> PermuteMaskArr = {1, 0}; | |||||
ArrayRef<int> PermuteMask(PermuteMaskArr); | |||||
if (match(SVI, m_Shuffle(m_Value(), m_Undef(), | |||||
m_SpecificMask(PermuteMask)))) { | |||||
C->TypeWidth = cast<FixedVectorType>(SVI->getOperand(0)->getType()) | |||||
->getNumElements(); | |||||
continue; | |||||
} | |||||
if (match(SVI, m_Shuffle(m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()), | |||||
m_Undef(), m_ZeroMask()))) { | |||||
auto *IEI = cast<InsertElementInst>(SVI->getOperand(0)); | |||||
C->TypeWidth = cast<FixedVectorType>(IEI->getOperand(0)->getType()) | |||||
->getNumElements(); | |||||
continue; | |||||
} | |||||
// Interleaved access | |||||
std::vector<ArrayRef<int>> ValidMasks; | |||||
std::array<int, 1> Even1 = {0}; | |||||
std::array<int, 1> Odd1 = {1}; | |||||
std::array<int, 2> Even2 = {0, 2}; | |||||
std::array<int, 2> Odd2 = {1, 3}; | |||||
std::array<int, 4> Even4 = {0, 2, 4, 6}; | |||||
std::array<int, 4> Odd4 = {1, 3, 5, 7}; | |||||
ValidMasks.emplace_back(Even1); | |||||
ValidMasks.emplace_back(Odd1); | |||||
ValidMasks.emplace_back(Even2); | |||||
ValidMasks.emplace_back(Odd2); | |||||
ValidMasks.emplace_back(Even4); | |||||
ValidMasks.emplace_back(Odd4); | |||||
bool ValidMatch = false; | |||||
for (auto M : ValidMasks) { | |||||
if (match(SVI, m_Shuffle(m_Value(), m_Undef(), m_SpecificMask(M)))) { | |||||
ValidMatch = true; | |||||
C->TypeWidth = M.size() * 2; | |||||
break; | |||||
} | |||||
} | |||||
if (ValidMatch) | |||||
continue; | |||||
return false; | |||||
} | |||||
if (TTI->validateComplexCandidateDataFlow(C, I)) | |||||
continue; | |||||
return false; | |||||
} // else if (!isa<Instruction>(I)) | |||||
return false; | |||||
} | |||||
return true; | |||||
} |