Changeset View
Changeset View
Standalone View
Standalone View
llvm/include/llvm/CodeGen/ComplexArithmeticPass.h
- This file was added.
//===- ComplexArithmeticPass.h - Complex Arithmetic Pass --------*- C++ -*-===// | |||||
// | |||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||||
// See https://llvm.org/LICENSE.txt for license information. | |||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||||
// | |||||
//===----------------------------------------------------------------------===// | |||||
// | |||||
// This pass implements generation of target-specific intrinsics to support | |||||
// handling of complex number arithmetic | |||||
// | |||||
//===----------------------------------------------------------------------===// | |||||
#ifndef LLVM_TRANSFORMS_SCALAR_COMPLEXARITHMETIC_H | |||||
#define LLVM_TRANSFORMS_SCALAR_COMPLEXARITHMETIC_H | |||||
#include "llvm/IR/PassManager.h" | |||||
#include "llvm/IR/PatternMatch.h" | |||||
namespace llvm { | |||||
class Function; | |||||
struct ComplexArithmeticPass : public PassInfoMixin<ComplexArithmeticPass> { | |||||
public: | |||||
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); | |||||
}; | |||||
class ComplexArithmeticCandidateSet; | |||||
class ComplexArithmeticCandidate { | |||||
public: | |||||
// Visible data | |||||
Instruction *RootInstruction; | |||||
SmallVector<Instruction *, 4> Instructions; | |||||
CallInst *Intrinsic = nullptr; | |||||
ComplexArithmeticCandidate *PairedCandidate = nullptr; | |||||
bool IsPairOwner = false; | |||||
Instruction *InstrToReplace = nullptr; | |||||
int TypeWidth = 4; | |||||
SmallVector<Value *, 4> ExtraOperandsPre; | |||||
enum Type { Complex_Add, Complex_Mul, Complex_Mla } Type; | |||||
Value *getOutput() { | |||||
if (Intrinsic) | |||||
return Intrinsic; | |||||
return RootInstruction; | |||||
} | |||||
LLVMContext &getContext() { return RootInstruction->getContext(); } | |||||
bool isTerminalInPair() { | |||||
if (PairedCandidate) | |||||
return PairedCandidate->getRotation() < getRotation(); | |||||
return true; | |||||
} | |||||
ComplexArithmeticCandidate(Instruction *rootInstruction, enum Type type) | |||||
: RootInstruction(rootInstruction), Type(type) { | |||||
addInstruction(RootInstruction); | |||||
} | |||||
llvm::Type *getDataType() { | |||||
return FixedVectorType::get(RootInstruction->getType()->getScalarType(), | |||||
TypeWidth); | |||||
} | |||||
bool getOperands(TargetTransformInfo *TTI, ComplexArithmeticCandidateSet *Set, | |||||
SmallVector<Instruction *, 32> &DeadInsts, | |||||
SmallVector<Value *, 4> &Operands); | |||||
SmallVector<Value *, 4> getInputs() { | |||||
SmallVector<Value *, 4> Inputs; | |||||
for (const auto &item : Instructions) { | |||||
for (unsigned i = 0; i < item->getNumOperands(); i++) { | |||||
auto *Op = item->getOperand(i); | |||||
if (auto *I = dyn_cast<Instruction>(Op)) { | |||||
if (!containsInstruction(I)) { | |||||
Inputs.push_back(Op); | |||||
} | |||||
} else { | |||||
Inputs.push_back(Op); | |||||
} | |||||
} | |||||
} | |||||
return Inputs; | |||||
} | |||||
// Analysis | |||||
/// Returns true if this candidate contains the given instruction | |||||
bool containsInstruction(Instruction *I) { | |||||
return std::find(Instructions.begin(), Instructions.end(), I) != | |||||
Instructions.end(); | |||||
} | |||||
/// Returns true if this candidate contains an instruction with the given | |||||
/// opcode | |||||
bool containsInstruction(unsigned int opcode) { | |||||
return std::find_if(Instructions.begin(), Instructions.end(), | |||||
[&](Instruction *I) { | |||||
return I->getOpcode() == opcode; | |||||
}) != Instructions.end(); | |||||
} | |||||
const char *typeName() { | |||||
switch (Type) { | |||||
case Type::Complex_Add: | |||||
return "Complex Add"; | |||||
case Complex_Mul: | |||||
return "Complex Mul"; | |||||
case Complex_Mla: | |||||
return "Complex Mla"; | |||||
} | |||||
return "Unknown type"; | |||||
} | |||||
void addInstruction(Value *V) { | |||||
if (auto *I = dyn_cast<Instruction>(V)) | |||||
addInstruction(I); | |||||
} | |||||
void addInstruction(Instruction *I) { Instructions.push_back(I); } | |||||
/// Calculates the rotation of this candidate, based on it's type | |||||
unsigned int getRotation() { | |||||
unsigned Rot = 0; | |||||
if (Type == Complex_Add) { | |||||
// TODO figure out what the add rotations look like in IR | |||||
} else { | |||||
if (RootInstruction->getOpcode() == Instruction::FSub) | |||||
Rot += 90; | |||||
if (containsInstruction(Instruction::FNeg)) | |||||
Rot += 180; | |||||
} | |||||
return Rot; | |||||
} | |||||
bool isSelfContained() { | |||||
for (const auto &item : Instructions) { | |||||
for (auto &U : item->uses()) { | |||||
if (!isUseSelfContained(item, U)) | |||||
return false; | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
void setIntrinsic(CallInst *Intrinsic) { this->Intrinsic = Intrinsic; } | |||||
private: | |||||
bool isUseSelfContained(Instruction *Def, Use &Use) { | |||||
if (Def == RootInstruction) | |||||
return true; | |||||
if (auto *I = dyn_cast<Instruction>(Use)) | |||||
return containsInstruction(I); | |||||
return true; | |||||
} | |||||
}; | |||||
using namespace PatternMatch; | |||||
class ComplexArithmeticCandidateSet { | |||||
private: | |||||
const TargetTransformInfo *TTI; | |||||
public: | |||||
ComplexArithmeticCandidateSet(const TargetTransformInfo *tti) : TTI(tti) {} | |||||
void clear() { | |||||
for (auto *C : Candidates) | |||||
delete C; | |||||
Candidates.clear(); | |||||
} | |||||
void remove(ComplexArithmeticCandidate *C) { | |||||
auto *It = std::find(Candidates.begin(), Candidates.end(), C); | |||||
if (It != Candidates.end()) { | |||||
Candidates.erase(It); | |||||
delete C; | |||||
} | |||||
} | |||||
ComplexArithmeticCandidate * | |||||
getCandidate(Instruction *I, enum ComplexArithmeticCandidate::Type T) { | |||||
auto *C = new ComplexArithmeticCandidate(I, T); | |||||
Candidates.push_back(C); | |||||
return C; | |||||
} | |||||
void addCandidate(ComplexArithmeticCandidate *C) { Candidates.push_back(C); } | |||||
SmallVector<ComplexArithmeticCandidate *, 4> &get() { return Candidates; } | |||||
bool isInstrPartOfCandidate(Instruction *I) { | |||||
for (const auto &C : Candidates) { | |||||
if (C->containsInstruction(I)) | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
ComplexArithmeticCandidate *getCandidateHoldingInstr(Instruction *I) { | |||||
for (auto &C : Candidates) { | |||||
if (C->containsInstruction(I)) | |||||
return C; | |||||
} | |||||
return nullptr; | |||||
} | |||||
ComplexArithmeticCandidate *getPaired(ComplexArithmeticCandidate *C) { | |||||
if (C->PairedCandidate) | |||||
return C->PairedCandidate; | |||||
auto CUsers = C->RootInstruction->users(); | |||||
auto SharesUse = [&](ComplexArithmeticCandidate *OtherC) { | |||||
for (auto *User : OtherC->RootInstruction->users()) { | |||||
if (std::find(CUsers.begin(), CUsers.end(), User) != CUsers.end()) | |||||
return true; | |||||
} | |||||
return false; | |||||
}; | |||||
auto StoreNeigbour = [&](ComplexArithmeticCandidate *OtherC) { | |||||
if (!C->RootInstruction->hasOneUser()) | |||||
return false; | |||||
if (!OtherC->RootInstruction->hasOneUser()) | |||||
return false; | |||||
auto *CUser = *C->RootInstruction->user_begin(); | |||||
auto *OtherCUser = *OtherC->RootInstruction->user_begin(); | |||||
auto *CStr = dyn_cast<StoreInst>(CUser); | |||||
auto *OtherCStr = dyn_cast<StoreInst>(OtherCUser); | |||||
if (!CStr || !OtherCStr) | |||||
return false; | |||||
auto *CPtr = CStr->getOperand(1); | |||||
auto *OtherCPtr = OtherCStr->getOperand(1); | |||||
auto *OtherCGep = dyn_cast<GetElementPtrInst>(OtherCPtr); | |||||
if (!OtherCGep) | |||||
return false; | |||||
return CPtr == OtherCGep->getOperand(0) && | |||||
cast<ConstantInt>((*OtherCGep->indices().begin()))->isOne(); | |||||
}; | |||||
for (auto *OtherC : Candidates) { | |||||
if ((SharesUse(OtherC) /*|| StoreNeigbour(OtherC)*/) && | |||||
C->getRotation() != OtherC->getRotation()) { | |||||
C->PairedCandidate = OtherC; | |||||
OtherC->PairedCandidate = C; | |||||
C->PairedCandidate->IsPairOwner = true; | |||||
return OtherC; | |||||
} | |||||
} | |||||
return nullptr; | |||||
} | |||||
bool candidateHasValidDataFlow(ComplexArithmeticCandidate *C); | |||||
bool tryLower(const TargetTransformInfo &TTI, ComplexArithmeticCandidate *C, | |||||
SmallVector<Instruction *, 32> &DeadInsts); | |||||
private: | |||||
SmallVector<ComplexArithmeticCandidate *, 4> Candidates; | |||||
}; | |||||
} // end namespace llvm | |||||
#endif // LLVM_TRANSFORMS_SCALAR_COMPLEXARITHMETIC_H |