diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -21,6 +21,7 @@ #ifndef LLVM_ANALYSIS_TARGETTRANSFORMINFO_H #define LLVM_ANALYSIS_TARGETTRANSFORMINFO_H +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" @@ -63,6 +64,7 @@ class User; class Value; class VPIntrinsic; +class ComplexArithmeticCandidate; struct KnownBits; template class Optional; @@ -753,6 +755,16 @@ /// the scalarization cost of a load/store. bool supportsEfficientVectorElementLoadStore() const; + bool supportsComplexNumberArithmetic() const; + Intrinsic::ID getComplexArithmeticIntrinsic(ComplexArithmeticCandidate *C, + unsigned &IntArgCount) const; + bool validateComplexCandidateDataFlow(ComplexArithmeticCandidate *C, + Instruction *I) const; + void + filterComplexArithmeticOperand(ComplexArithmeticCandidate *C, Value *V, + SmallVector &Operands, + SmallVector &DeadInsts); + /// Don't restrict interleaved unrolling to small loops. bool enableAggressiveInterleaving(bool LoopHasReductions) const; @@ -1543,6 +1555,16 @@ getOperandsScalarizationOverhead(ArrayRef Args, ArrayRef Tys) = 0; virtual bool supportsEfficientVectorElementLoadStore() = 0; + virtual bool supportsComplexNumberArithmetic() = 0; + virtual Intrinsic::ID + getComplexArithmeticIntrinsic(ComplexArithmeticCandidate *C, + unsigned &IntArgCount) = 0; + virtual bool validateComplexCandidateDataFlow(ComplexArithmeticCandidate *C, + Instruction *I) = 0; + virtual void + filterComplexArithmeticOperand(ComplexArithmeticCandidate *C, Value *V, + SmallVector &Operands, + SmallVector &DeadInsts) = 0; virtual bool enableAggressiveInterleaving(bool LoopHasReductions) = 0; virtual MemCmpExpansionOptions enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const = 0; @@ -1964,6 +1986,26 @@ return Impl.supportsEfficientVectorElementLoadStore(); } + bool supportsComplexNumberArithmetic() override { + return Impl.supportsComplexNumberArithmetic(); + } + + Intrinsic::ID getComplexArithmeticIntrinsic(ComplexArithmeticCandidate *C, + unsigned &IntArgCount) override { + return Impl.getComplexArithmeticIntrinsic(C, IntArgCount); + } + + bool validateComplexCandidateDataFlow(ComplexArithmeticCandidate *C, + Instruction *I) override { + return Impl.validateComplexCandidateDataFlow(C, I); + } + void filterComplexArithmeticOperand( + ComplexArithmeticCandidate *C, Value *V, + SmallVector &Operands, + SmallVector &DeadInsts) override { + Impl.filterComplexArithmeticOperand(C, V, Operands, DeadInsts); + } + bool enableAggressiveInterleaving(bool LoopHasReductions) override { return Impl.enableAggressiveInterleaving(LoopHasReductions); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -319,6 +319,21 @@ bool supportsEfficientVectorElementLoadStore() const { return false; } + bool supportsComplexNumberArithmetic() const { return false; } + + Intrinsic::ID getComplexArithmeticIntrinsic(ComplexArithmeticCandidate *C, + unsigned &IntArgCount) const { + return Intrinsic::not_intrinsic; + } + bool validateComplexCandidateDataFlow(ComplexArithmeticCandidate *C, + Instruction *I) { + return false; + } + void + filterComplexArithmeticOperand(ComplexArithmeticCandidate *C, Value *V, + SmallVector &Operands, + SmallVector &DeadInsts) {} + bool enableAggressiveInterleaving(bool LoopHasReductions) const { return false; } diff --git a/llvm/include/llvm/CodeGen/ComplexArithmeticPass.h b/llvm/include/llvm/CodeGen/ComplexArithmeticPass.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/CodeGen/ComplexArithmeticPass.h @@ -0,0 +1,275 @@ +//===- ComplexArithmeticPass.h - Complex Arithmetic Pass --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements generation of target-specific intrinsics to support +// handling of complex number arithmetic +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_SCALAR_COMPLEXARITHMETIC_H +#define LLVM_TRANSFORMS_SCALAR_COMPLEXARITHMETIC_H + +#include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" + +namespace llvm { + +class Function; + +struct ComplexArithmeticPass : public PassInfoMixin { +public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; + +class ComplexArithmeticCandidateSet; + +class ComplexArithmeticCandidate { +public: + // Visible data + Instruction *RootInstruction; + SmallVector Instructions; + CallInst *Intrinsic = nullptr; + ComplexArithmeticCandidate *PairedCandidate = nullptr; + bool IsPairOwner = false; + Instruction *InstrToReplace = nullptr; + int TypeWidth = 4; + SmallVector ExtraOperandsPre; + enum Type { Complex_Add, Complex_Mul, Complex_Mla } Type; + + Value *getOutput() { + if (Intrinsic) + return Intrinsic; + return RootInstruction; + } + + LLVMContext &getContext() { return RootInstruction->getContext(); } + + bool isTerminalInPair() { + if (PairedCandidate) + return PairedCandidate->getRotation() < getRotation(); + return true; + } + + ComplexArithmeticCandidate(Instruction *rootInstruction, enum Type type) + : RootInstruction(rootInstruction), Type(type) { + addInstruction(RootInstruction); + } + + llvm::Type *getDataType() { + return FixedVectorType::get(RootInstruction->getType()->getScalarType(), + TypeWidth); + } + + bool getOperands(TargetTransformInfo *TTI, ComplexArithmeticCandidateSet *Set, + SmallVector &DeadInsts, + SmallVector &Operands); + + SmallVector getInputs() { + SmallVector Inputs; + for (const auto &item : Instructions) { + for (unsigned i = 0; i < item->getNumOperands(); i++) { + auto *Op = item->getOperand(i); + if (auto *I = dyn_cast(Op)) { + if (!containsInstruction(I)) { + Inputs.push_back(Op); + } + } else { + Inputs.push_back(Op); + } + } + } + return Inputs; + } + + // Analysis + /// Returns true if this candidate contains the given instruction + bool containsInstruction(Instruction *I) { + return std::find(Instructions.begin(), Instructions.end(), I) != + Instructions.end(); + } + + /// Returns true if this candidate contains an instruction with the given + /// opcode + bool containsInstruction(unsigned int opcode) { + return std::find_if(Instructions.begin(), Instructions.end(), + [&](Instruction *I) { + return I->getOpcode() == opcode; + }) != Instructions.end(); + } + + const char *typeName() { + switch (Type) { + case Type::Complex_Add: + return "Complex Add"; + case Complex_Mul: + return "Complex Mul"; + case Complex_Mla: + return "Complex Mla"; + } + return "Unknown type"; + } + + void addInstruction(Value *V) { + if (auto *I = dyn_cast(V)) + addInstruction(I); + } + + void addInstruction(Instruction *I) { Instructions.push_back(I); } + + /// Calculates the rotation of this candidate, based on it's type + unsigned int getRotation() { + unsigned Rot = 0; + + if (Type == Complex_Add) { + // TODO figure out what the add rotations look like in IR + } else { + if (RootInstruction->getOpcode() == Instruction::FSub) + Rot += 90; + if (containsInstruction(Instruction::FNeg)) + Rot += 180; + } + return Rot; + } + + bool isSelfContained() { + for (const auto &item : Instructions) { + for (auto &U : item->uses()) { + if (!isUseSelfContained(item, U)) + return false; + } + } + return true; + } + + void setIntrinsic(CallInst *Intrinsic) { this->Intrinsic = Intrinsic; } + +private: + bool isUseSelfContained(Instruction *Def, Use &Use) { + if (Def == RootInstruction) + return true; + if (auto *I = dyn_cast(Use)) + return containsInstruction(I); + return true; + } +}; + +using namespace PatternMatch; + +class ComplexArithmeticCandidateSet { +private: + const TargetTransformInfo *TTI; + +public: + ComplexArithmeticCandidateSet(const TargetTransformInfo *tti) : TTI(tti) {} + + void clear() { + for (auto *C : Candidates) + delete C; + Candidates.clear(); + } + + void remove(ComplexArithmeticCandidate *C) { + auto *It = std::find(Candidates.begin(), Candidates.end(), C); + if (It != Candidates.end()) { + Candidates.erase(It); + delete C; + } + } + + ComplexArithmeticCandidate * + getCandidate(Instruction *I, enum ComplexArithmeticCandidate::Type T) { + auto *C = new ComplexArithmeticCandidate(I, T); + Candidates.push_back(C); + return C; + } + + void addCandidate(ComplexArithmeticCandidate *C) { Candidates.push_back(C); } + + SmallVector &get() { return Candidates; } + + bool isInstrPartOfCandidate(Instruction *I) { + for (const auto &C : Candidates) { + if (C->containsInstruction(I)) + return true; + } + return false; + } + + ComplexArithmeticCandidate *getCandidateHoldingInstr(Instruction *I) { + for (auto &C : Candidates) { + if (C->containsInstruction(I)) + return C; + } + return nullptr; + } + + ComplexArithmeticCandidate *getPaired(ComplexArithmeticCandidate *C) { + if (C->PairedCandidate) + return C->PairedCandidate; + + auto CUsers = C->RootInstruction->users(); + auto SharesUse = [&](ComplexArithmeticCandidate *OtherC) { + for (auto *User : OtherC->RootInstruction->users()) { + if (std::find(CUsers.begin(), CUsers.end(), User) != CUsers.end()) + return true; + } + return false; + }; + + auto StoreNeigbour = [&](ComplexArithmeticCandidate *OtherC) { + if (!C->RootInstruction->hasOneUser()) + return false; + if (!OtherC->RootInstruction->hasOneUser()) + return false; + + auto *CUser = *C->RootInstruction->user_begin(); + auto *OtherCUser = *OtherC->RootInstruction->user_begin(); + + auto *CStr = dyn_cast(CUser); + auto *OtherCStr = dyn_cast(OtherCUser); + + if (!CStr || !OtherCStr) + return false; + + auto *CPtr = CStr->getOperand(1); + auto *OtherCPtr = OtherCStr->getOperand(1); + + auto *OtherCGep = dyn_cast(OtherCPtr); + + if (!OtherCGep) + return false; + + return CPtr == OtherCGep->getOperand(0) && + cast((*OtherCGep->indices().begin()))->isOne(); + }; + + for (auto *OtherC : Candidates) { + if ((SharesUse(OtherC) /*|| StoreNeigbour(OtherC)*/) && + C->getRotation() != OtherC->getRotation()) { + C->PairedCandidate = OtherC; + OtherC->PairedCandidate = C; + C->PairedCandidate->IsPairOwner = true; + return OtherC; + } + } + + return nullptr; + } + + bool candidateHasValidDataFlow(ComplexArithmeticCandidate *C); + + bool tryLower(const TargetTransformInfo &TTI, ComplexArithmeticCandidate *C, + SmallVector &DeadInsts); + +private: + SmallVector Candidates; +}; + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_SCALAR_COMPLEXARITHMETIC_H diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -416,6 +416,8 @@ /// createJumpInstrTables - This pass creates jump-instruction tables. ModulePass *createJumpInstrTablesPass(); + FunctionPass *createComplexArithmeticPass(); + /// InterleavedAccess Pass - This pass identifies and matches interleaved /// memory accesses to target specific intrinsics. /// diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -114,9 +114,10 @@ void initializeCallSiteSplittingLegacyPassPass(PassRegistry&); void initializeCalledValuePropagationLegacyPassPass(PassRegistry &); void initializeCheckDebugMachineModulePass(PassRegistry &); -void initializeCodeGenPreparePass(PassRegistry&); -void initializeConstantHoistingLegacyPassPass(PassRegistry&); -void initializeConstantMergeLegacyPassPass(PassRegistry&); +void initializeCodeGenPreparePass(PassRegistry &); +void initializeComplexArithmeticLegacyPassPass(PassRegistry &); +void initializeConstantHoistingLegacyPassPass(PassRegistry &); +void initializeConstantMergeLegacyPassPass(PassRegistry &); void initializeConstraintEliminationPass(PassRegistry &); void initializeControlHeightReductionLegacyPassPass(PassRegistry&); void initializeCorrelatedValuePropagationPass(PassRegistry&); diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -492,6 +492,26 @@ return TTIImpl->supportsEfficientVectorElementLoadStore(); } +bool TargetTransformInfo::supportsComplexNumberArithmetic() const { + return TTIImpl->supportsComplexNumberArithmetic(); +} + +Intrinsic::ID TargetTransformInfo::getComplexArithmeticIntrinsic( + ComplexArithmeticCandidate *C, unsigned &IntArgCount) const { + return TTIImpl->getComplexArithmeticIntrinsic(C, IntArgCount); +} + +bool TargetTransformInfo::validateComplexCandidateDataFlow( + ComplexArithmeticCandidate *C, Instruction *I) const { + return TTIImpl->validateComplexCandidateDataFlow(C, I); +} + +void TargetTransformInfo::filterComplexArithmeticOperand( + ComplexArithmeticCandidate *C, Value *V, SmallVector &Operands, + SmallVector &DeadInsts) { + TTIImpl->filterComplexArithmeticOperand(C, V, Operands, DeadInsts); +} + bool TargetTransformInfo::enableAggressiveInterleaving( bool LoopHasReductions) const { return TTIImpl->enableAggressiveInterleaving(LoopHasReductions); diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -17,6 +17,7 @@ CodeGenPassBuilder.cpp CodeGenPrepare.cpp CommandFlags.cpp + ComplexArithmeticPass.cpp CriticalAntiDepBreaker.cpp DeadMachineInstructionElim.cpp DetectDeadLanes.cpp diff --git a/llvm/lib/CodeGen/ComplexArithmeticPass.cpp b/llvm/lib/CodeGen/ComplexArithmeticPass.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/CodeGen/ComplexArithmeticPass.cpp @@ -0,0 +1,776 @@ +//===- ComplexArithmeticPass.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/ComplexArithmeticPass.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/InitializePasses.h" +#include "llvm/Target/TargetMachine.h" +#include + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "complex-arithmetic" + +STATISTIC(NumComplexIntrinsics, "Number of complex intrinsics generated"); + +static cl::opt + ComplexArithmeticEnabled("enable-complex-arithmetic", + cl::desc("Enable complex arithmetic"), + cl::init(true), cl::Hidden); + +static cl::opt + ComplexArithmeticForceEnabled("force-complex-arithmetic", + cl::desc("Force complex arithmetic"), + cl::init(false), cl::Hidden); + +namespace { + +class ComplexArithmeticLegacyPass : public FunctionPass { +public: + static char ID; + + ComplexArithmeticLegacyPass() : FunctionPass(ID) { + initializeComplexArithmeticLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { return "Complex Arithmetic Pass"; } + + bool runOnFunction(Function &F) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + } + +private: + DominatorTree *DT = nullptr; + const TargetLowering *TLI = nullptr; +}; + +class ComplexArithmetic { +public: + ComplexArithmetic(TargetTransformInfo *TTI) : TTI(TTI) {} + + bool runOnFunction(Function &F); + +private: + /// Performs a search through the given BasicBlock, looking for potential + /// complex arithmetic candidates and replacing them where applicable. + /// \return true if any change has been made, false otherwise. + bool evaluateComplexArithmetic(BasicBlock *B, + SmallVector &DeadInsts); + + /// Performs a naive check for specific patterns within the given block, + /// and replaces the matched instructions with the complex arithmetic + /// equivalent. + /// \return true if any change has been made, false otherwise. + bool + evaluateComplexArithmeticNaive(BasicBlock *B, + SmallVector &DeadInsts); + + const TargetTransformInfo *TTI = nullptr; + const TargetLowering *TLI = nullptr; +}; + +} // end anonymous namespace. + +char ComplexArithmeticLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(ComplexArithmeticLegacyPass, DEBUG_TYPE, + "Complex Arithmetic", false, false) +INITIALIZE_PASS_END(ComplexArithmeticLegacyPass, DEBUG_TYPE, + "Complex Arithmetic", false, false) + +PreservedAnalyses ComplexArithmeticPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &TTI = AM.getResult(F); + + if (!ComplexArithmetic(&TTI).runOnFunction(F)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserve(); + return PA; +} + +FunctionPass *llvm::createComplexArithmeticPass() { + return new ComplexArithmeticLegacyPass(); +} + +/// Checks the given instruction, and if it's a recognised pattern, returns a +/// complex arithmetic candidate. +/// \param I The instruction to check. +/// \return A pointer to a complex arithmetic candidate, or nullptr if I does +/// not match a recognised pattern. +static ComplexArithmeticCandidate *getCandidate(Instruction *I) { + if (match(I, m_FSub(m_FMul(m_Value(), m_Value()), + m_FMul(m_Value(), m_Value()))) || + match(I, m_FAdd(m_FMul(m_Value(), m_Value()), + m_FMul(m_Value(), m_Value())))) { + auto *C = new ComplexArithmeticCandidate( + I, llvm::ComplexArithmeticCandidate::Complex_Mul); + if (auto *Ty = dyn_cast(I->getType())) + C->TypeWidth = Ty->getNumElements(); + C->addInstruction(I->getOperand(0)); + C->addInstruction(I->getOperand(1)); + return C; + } + + return nullptr; +} + +bool ComplexArithmetic::evaluateComplexArithmeticNaive( + BasicBlock *B, SmallVector &DeadInsts) { + /// Check a very specific case where the vector width is twice that of the + /// width of a vector register. + LLVM_DEBUG(dbgs() << "evaluateComplexArithmeticNaive" + << ".\n"); + for (auto &I : *B) { + if (isa(I)) { + + LLVM_DEBUG(dbgs() << "Found potential root of naive pattern: "; I.dump()); + + auto *StoreTy = I.getOperand(0)->getType(); + + if (!StoreTy->isVectorTy()) { + LLVM_DEBUG(dbgs() << "Not vector type" + << ".\n"); + continue; + } + + auto *ScalarTy = StoreTy->getScalarType(); + if (!ScalarTy->isFloatTy() && !ScalarTy->isDoubleTy()) { + LLVM_DEBUG(dbgs() << "Vector element type is not float or double" + << ".\n"); + continue; + } + + unsigned AllowedElements = ScalarTy->isDoubleTy() ? 4 : 8; + auto ElemCount = cast(StoreTy)->getElementCount(); + if (ElemCount.isScalable() || + cast(StoreTy)->getNumElements() != AllowedElements) { + LLVM_DEBUG(dbgs() << "Element count is scalable, or number of elements " + "is not as allowed" + << ".\n"); + continue; + } + + auto *StorePtrTy = PointerType::getUnqual(StoreTy); + + ArrayRef RealMask; + ArrayRef ImagMask; + ArrayRef InterleaveMask; + + LLVM_DEBUG(dbgs() << "Allowed elements deemed to be " << AllowedElements + << ".\n"); + if (AllowedElements == 8) { + std::array RealMaskArr = {0, 2, 4, 6}; + std::array ImagMaskArr = {1, 3, 5, 7}; + std::array InterleaveMaskArr = {0, 4, 1, 5, 2, 6, 3, 7}; + RealMask = ArrayRef(RealMaskArr); + ImagMask = ArrayRef(ImagMaskArr); + InterleaveMask = ArrayRef(InterleaveMaskArr); + } else if (AllowedElements == 4) { + std::array RealMaskArr = {0, 2}; + std::array ImagMaskArr = {1, 3}; + std::array InterleaveMaskArr = {0, 2, 1, 3}; + RealMask = ArrayRef(RealMaskArr); + ImagMask = ArrayRef(ImagMaskArr); + InterleaveMask = ArrayRef(InterleaveMaskArr); + } + + // Check naive complex multiply pattern + Value *LoadInst0; + Value *LoadInst1; + Value *LoadInst2; + Value *LoadInst3; + Value *LoadInst4; + Value *LoadInst5; + Value *LoadInst6; + Value *LoadInst7; + + Instruction *BitcastInstr; + + bool M = match( + &I, + m_Store( + m_Shuffle(m_FSub(m_FMul(m_Shuffle(m_Value(LoadInst0), m_Poison(), + m_SpecificMask(RealMask)), + m_Shuffle(m_Value(LoadInst1), m_Poison(), + m_SpecificMask(RealMask))), + m_FMul(m_Shuffle(m_Value(LoadInst2), m_Poison(), + m_SpecificMask(ImagMask)), + m_Shuffle(m_Value(LoadInst3), m_Poison(), + m_SpecificMask(ImagMask)))), + m_FAdd(m_FMul(m_Shuffle(m_Value(LoadInst4), m_Poison(), + m_SpecificMask(ImagMask)), + m_Shuffle(m_Value(LoadInst5), m_Poison(), + m_SpecificMask(RealMask))), + m_FMul(m_Shuffle(m_Value(LoadInst6), m_Poison(), + m_SpecificMask(RealMask)), + m_Shuffle(m_Value(LoadInst7), m_Poison(), + m_SpecificMask(ImagMask)))), + m_SpecificMask(InterleaveMask)), + m_Instruction(BitcastInstr))); + + if (!M) { + LLVM_DEBUG(dbgs() << "Failed to match full pattern" + << ".\n"); + return false; + } + + // Check that the 8 LoadInsts are made up of 2 distinct LoadInsts + bool LoadInstsAMatch = LoadInst1 == LoadInst3 && LoadInst1 == LoadInst5 && + LoadInst1 == LoadInst7; + bool LoadInstsBMatch = LoadInst0 == LoadInst2 && LoadInst0 == LoadInst4 && + LoadInst0 == LoadInst6; + + if (!LoadInstsAMatch || !LoadInstsBMatch) { + LLVM_DEBUG(dbgs() << "Loads do not match" + << ".\n"); + return false; + } + + if (LoadInst0->getType() != StoreTy || LoadInst1->getType() != StoreTy) { + LLVM_DEBUG(dbgs() << "Load types do not match store type" + << ".\n"); + return false; + } + + auto *LoadInstA = LoadInst1; + auto *LoadInstB = LoadInst0; + + std::array Shuffles; + auto *InterleavedVec = cast(I.getOperand(0)); + + // Gather the shuffles from the FSub + auto *FSub = cast(InterleavedVec->getOperand(0)); + auto *FSub0 = cast(FSub->getOperand(0)); + auto *FSub1 = cast(FSub->getOperand(1)); + auto *FSub00 = cast(FSub0->getOperand(0)); + auto *FSub01 = cast(FSub0->getOperand(1)); + auto *FSub10 = cast(FSub1->getOperand(0)); + auto *FSub11 = cast(FSub1->getOperand(1)); + + Shuffles[0] = FSub01; + Shuffles[1] = FSub11; + Shuffles[2] = FSub00; + Shuffles[3] = FSub10; + + // Check that all shuffles are distinct + for (unsigned i = 0; i < 4; i++) { + for (unsigned j = 0; j < 4; j++) { + if (i == j) + continue; + + if (Shuffles[i] == Shuffles[j]) { + LLVM_DEBUG(dbgs() << "Shuffles aren't distinct" + << ".\n"); + return false; + } + } + } + + // Compare the shuffles with the FAdd + auto *FAdd = cast(InterleavedVec->getOperand(1)); + auto *FAdd0 = cast(FAdd->getOperand(0)); + auto *FAdd1 = cast(FAdd->getOperand(1)); + + auto *FAdd00 = cast(FAdd0->getOperand(0)); + auto *FAdd01 = cast(FAdd0->getOperand(1)); + auto *FAdd10 = cast(FAdd1->getOperand(0)); + auto *FAdd11 = cast(FAdd1->getOperand(1)); + + if (FAdd00 != Shuffles[3] || FAdd01 != Shuffles[0] || + FAdd10 != Shuffles[2] || FAdd11 != Shuffles[1]) { + LLVM_DEBUG(dbgs() << "Shuffles do not match" + << ".\n"); + return false; + } + + // At this point, we can assume all is good to be replaced + IRBuilder<> Builder(&I); + std::array StorageShuffleMaskArr = {0, 1, 2, 3, 4, 5, 6, 7}; + std::array Shuffle1MaskArr = {0, 1, 2, 3}; + std::array Shuffle2MaskArr = {4, 5, 6, 7}; + ArrayRef StorageShuffleMask(StorageShuffleMaskArr); + ArrayRef Shuffle1Mask(Shuffle1MaskArr); + ArrayRef Shuffle2Mask(Shuffle2MaskArr); + + auto *ShuffleA1 = Builder.CreateShuffleVector(LoadInstA, Shuffle1Mask); + auto *ShuffleA2 = Builder.CreateShuffleVector(LoadInstA, Shuffle2Mask); + + auto *ShuffleB1 = Builder.CreateShuffleVector(LoadInstB, Shuffle1Mask); + auto *ShuffleB2 = Builder.CreateShuffleVector(LoadInstB, Shuffle2Mask); + + auto *C1 = getCandidate(FSub); + auto *C2 = getCandidate(FAdd); + + ComplexArithmeticCandidateSet Set(TTI); + Set.addCandidate(C1); + Set.addCandidate(C2); + + unsigned int OperandCount; + Intrinsic::ID IntC1 = + TTI->getComplexArithmeticIntrinsic(C1, OperandCount); + Intrinsic::ID IntC2 = + TTI->getComplexArithmeticIntrinsic(C2, OperandCount); + + std::vector Operands; + for (auto &item : C1->ExtraOperandsPre) + Operands.push_back(item); + + if (C1->Type == llvm::ComplexArithmeticCandidate::Complex_Mla) + Operands.push_back(ConstantFP::get(C1->getDataType(), 0)); + + Operands.push_back(ShuffleA1); + Operands.push_back(ShuffleB1); + auto *IntrinsicLow1 = + Builder.CreateIntrinsic(IntC1, C1->getDataType(), Operands); + + Operands.clear(); + for (auto &item : C2->ExtraOperandsPre) + Operands.push_back(item); + Operands.push_back(IntrinsicLow1); + Operands.push_back(ShuffleA1); + Operands.push_back(ShuffleB1); + auto *IntrinsicLow2 = + Builder.CreateIntrinsic(IntC2, C2->getDataType(), Operands); + + Operands.clear(); + for (auto &item : C1->ExtraOperandsPre) + Operands.push_back(item); + + if (C1->Type == llvm::ComplexArithmeticCandidate::Complex_Mla) + Operands.push_back(ConstantFP::get(C1->getDataType(), 0)); + + Operands.push_back(ShuffleA2); + Operands.push_back(ShuffleB2); + auto *IntrinsicHigh1 = + Builder.CreateIntrinsic(IntC1, C1->getDataType(), Operands); + + Operands.clear(); + for (auto &item : C2->ExtraOperandsPre) + Operands.push_back(item); + Operands.push_back(IntrinsicHigh1); + Operands.push_back(ShuffleA2); + Operands.push_back(ShuffleB2); + auto *IntrinsicHigh2 = + Builder.CreateIntrinsic(IntC2, C2->getDataType(), Operands); + + auto *StorageShuffle = Builder.CreateShuffleVector( + IntrinsicLow2, IntrinsicHigh2, StorageShuffleMask); + InterleavedVec->replaceAllUsesWith(StorageShuffle); + + NumComplexIntrinsics += 4; + + // Clear old instructions + DeadInsts.push_back(InterleavedVec); + DeadInsts.push_back(FSub); + DeadInsts.push_back(FSub0); + DeadInsts.push_back(FSub00); + DeadInsts.push_back(FSub01); + DeadInsts.push_back(FSub1); + DeadInsts.push_back(FSub10); + DeadInsts.push_back(FSub11); + DeadInsts.push_back(FAdd); + DeadInsts.push_back(FAdd0); + DeadInsts.push_back(FAdd00); + DeadInsts.push_back(FAdd01); + DeadInsts.push_back(FAdd1); + DeadInsts.push_back(FAdd10); + DeadInsts.push_back(FAdd11); + + LLVM_DEBUG(dbgs() << "Naive pattern matching succeeded" + << ".\n"); + return true; + } + } + + LLVM_DEBUG(dbgs() << "Naive pattern matching failed" + << ".\n"); + return false; +} + +bool ComplexArithmetic::evaluateComplexArithmetic( + BasicBlock *B, SmallVector &DeadInsts) { + + LLVM_DEBUG(dbgs() << "Evaluating complex arithmetic on block: "; B->dump()); + + if (evaluateComplexArithmeticNaive(B, DeadInsts)) + return true; + + ComplexArithmeticCandidateSet CandidateSet(TTI); + for (auto &I : *B) { + auto *C = getCandidate(&I); + if (C) + CandidateSet.addCandidate(C); + } + + SmallVector InvalidCandidates; + for (const auto &C : CandidateSet.get()) { + if (!CandidateSet.candidateHasValidDataFlow(C)) + InvalidCandidates.push_back(C); + } + + for (auto &item : InvalidCandidates) + CandidateSet.remove(item); + InvalidCandidates.clear(); + + bool Changed = false; + + for (auto *C : CandidateSet.get()) { + Changed |= CandidateSet.tryLower(*TTI, C, DeadInsts); + } + + // Clean up loose memory objects + CandidateSet.clear(); + + return Changed; +} + +bool ComplexArithmeticLegacyPass::runOnFunction(Function &F) { + auto *TPC = getAnalysisIfAvailable(); + auto &TM = TPC->getTM(); + auto TTI = TM.getTargetTransformInfo(F); + return ComplexArithmetic(&TTI).runOnFunction(F); +} + +static bool HasBeenDisabled = false; +bool ComplexArithmetic::runOnFunction(Function &F) { + + if (!ComplexArithmeticForceEnabled) { + if (!ComplexArithmeticEnabled) { + LLVM_DEBUG(if (!HasBeenDisabled) dbgs() + << "Complex Arithmetic has been explicitly disabled.\n"); + HasBeenDisabled = true; + return false; + } + + if (!TTI->supportsComplexNumberArithmetic()) { + LLVM_DEBUG(if (!HasBeenDisabled) dbgs() + << "Complex Arithmetic has been disabled. " + << "Target does not support lowering of complex numbers.\n"); + HasBeenDisabled = true; + return false; + } + } + + SmallVector DeadInsts; + bool Changed = false; + + SmallVector ChangedBlocks; + + for (auto &B : F) { + if (evaluateComplexArithmetic(&B, DeadInsts)) { + Changed |= true; + ChangedBlocks.push_back(&B); + } + } + + if (Changed) { + // TODO clean up the dead instructions better + unsigned iter = 0; + unsigned count = DeadInsts.size(); + unsigned remaining = DeadInsts.size(); + while (!DeadInsts.empty() && remaining > 0 && iter < count) { + ++iter; + remaining = 0; + for (auto *It = DeadInsts.begin(); It != DeadInsts.end(); It++) { + auto *I = *It; + + if (I->getParent()) + remaining++; + + if (I->getNumUses() == 0 && I->getParent()) { + remaining--; + I->eraseFromParent(); + } + } + } + + DeadInsts.clear(); + } + + return Changed; +} + +bool ComplexArithmeticCandidate::getOperands( + TargetTransformInfo *TTI, ComplexArithmeticCandidateSet *Set, + SmallVector &DeadInsts, + SmallVector &Operands) { + + auto *PairedC = Set->getPaired(this); + if (!PairedC) + return false; + + for (auto *V : ExtraOperandsPre) { + Operands.push_back(V); + } + + if (Type == ComplexArithmeticCandidate::Complex_Mla) { + if (PairedC->getRotation() < this->getRotation()) { + Set->tryLower(*TTI, PairedC, DeadInsts); + Operands.push_back(PairedC->getOutput()); + } else { + auto *Zero = ConstantFP::get(PairedC->getDataType(), 0); + Operands.push_back(Zero); + } + } + + auto Inputs = getInputs(); + for (auto *V : Inputs) { + if (auto *I = dyn_cast(V)) { + if (auto *C = Set->getCandidateHoldingInstr(I)) { + Set->tryLower(*TTI, C, DeadInsts); + if (C->getRotation() < C->PairedCandidate->getRotation()) + continue; + Operands.push_back(C->getOutput()); + continue; + } + + if (auto *SVI = dyn_cast(I)) { + LoadInst *LI; + if (auto *IEI = dyn_cast(SVI->getOperand(0))) { + LI = dyn_cast(IEI->getOperand(1)); + if (LI) { + if (std::find(DeadInsts.begin(), DeadInsts.end(), IEI) == + DeadInsts.end()) { + DeadInsts.push_back(IEI); + } + if (std::find(DeadInsts.begin(), DeadInsts.end(), SVI) == + DeadInsts.end()) { + DeadInsts.push_back(SVI); + } + if (isa(LI->getOperand(0))) + LI = nullptr; + } + } else { + LI = dyn_cast(SVI->getOperand(0)); + } + + auto *Ty = PointerType::get(getDataType(), 0); + + if (LI) { + if (LI->getType() != getDataType()) { + LLVM_DEBUG(dbgs() << "LI->getType() "; LI->getType()->dump()); + LLVM_DEBUG(dbgs() << "getDataType() "; getDataType()->dump()); + IRBuilder<> B(LI); + if (std::find(DeadInsts.begin(), DeadInsts.end(), LI) == + DeadInsts.end()) { + DeadInsts.push_back(LI); + } + + auto *BitCast = B.CreateBitOrPointerCast(LI->getOperand(0), Ty); + auto *NewLI = B.CreateLoad(Ty->getElementType(), BitCast); + LI = NewLI; + } + + if (std::find(Operands.begin(), Operands.end(), LI) == + Operands.end()) { + Operands.push_back(LI); + } + if (std::find(DeadInsts.begin(), DeadInsts.end(), SVI) == + DeadInsts.end()) { + DeadInsts.push_back(SVI); + } + continue; + } + } + + TTI->filterComplexArithmeticOperand(this, V, Operands, DeadInsts); + } + } + + LLVM_DEBUG({ + dbgs() << "Operands: \n"; + for (auto *Op : Operands) { + Op->dump(); + } + }); + + return true; +} + +bool ComplexArithmeticCandidateSet::tryLower( + const TargetTransformInfo &TTI, ComplexArithmeticCandidate *C, + SmallVector &DeadInsts) { + + LLVM_DEBUG(dbgs() << "Trying to lower candidate starting at "; + C->RootInstruction->dump()); + + if (C->Intrinsic) { + LLVM_DEBUG(dbgs() << "Candidate already lowered" + << ".\n"); + return false; + } + + if (!C->RootInstruction->getType()->isVectorTy()) { + LLVM_DEBUG(dbgs() << "Candidate type is not vector type" + << ".\n"); + return false; + } + + getPaired(C); + + // Is the lowered operator going to need splitting? If so, don't lower it + auto *VTy = cast(C->getDataType()); + unsigned VTyWidth = VTy->getScalarSizeInBits() * VTy->getNumElements(); + if (VTyWidth > 128 /* Width of a single vector register */) { + LLVM_DEBUG(dbgs() << "Vector type is too wide" + << ".\n"); + LLVM_DEBUG(dbgs() << " Width: " << VTyWidth << ". "; VTy->dump()); + return false; + } + + unsigned IntArgCount = 0; + Intrinsic::ID IntId = TTI.getComplexArithmeticIntrinsic(C, IntArgCount); + + if (IntId != Intrinsic::not_intrinsic) { + IRBuilder<> Builder(cast(*C->RootInstruction->user_begin())); + Type *DataType = C->getDataType(); + SmallVector OperandsArr; + if (!C->getOperands(const_cast(&TTI), this, + DeadInsts, OperandsArr)) { + LLVM_DEBUG(dbgs() << "Failed to get operands" + << ".\n"); + return false; + } + + if (OperandsArr.size() != IntArgCount) { + LLVM_DEBUG(dbgs() << "Got incorrect amount of operands" + << ".\n"); + LLVM_DEBUG(dbgs() << "Expected " << IntArgCount << ", got " + << OperandsArr.size() << ".\n"); + return false; + } + + ArrayRef Operands(OperandsArr); + C->setIntrinsic(Builder.CreateIntrinsic(IntId, {DataType}, Operands)); + NumComplexIntrinsics++; + + if (C->isTerminalInPair()) { + if (auto *SVI = + dyn_cast(*C->RootInstruction->user_begin())) { + + if (SVI->hasOneUse() && isa(*SVI->users().begin())) { + auto *CI = cast(*SVI->users().begin()); + if (CI->mayWriteToMemory()) { + // Convert a storing CallInst to a standard store, keeping it + // compatible with the complex intrinsic structure + // TODO make this better + DeadInsts.push_back(CI); + IRBuilder<> B(CI); + auto *BCI = cast(CI->getOperand(2)); + DeadInsts.push_back(BCI); + B.CreateStore(C->Intrinsic, BCI->getOperand(0)); + } + } else + SVI->replaceAllUsesWith(C->Intrinsic); + DeadInsts.push_back(SVI); + } + C->PairedCandidate->Intrinsic->moveBefore(C->Intrinsic); + } + + for (auto &item : C->Instructions) + DeadInsts.push_back(item); + + LLVM_DEBUG(dbgs() << "Success?" + << ".\n"); + return true; + } + LLVM_DEBUG(dbgs() << "No intrinsic offered by TTI" + << ".\n"); + + return false; +} +bool ComplexArithmeticCandidateSet::candidateHasValidDataFlow( + ComplexArithmeticCandidate *C) { + // * Crawl through operands and evaluate each one + // * An operand within another candidate is counted as valid, if that + // candidate is also valid + // * A shuffle operand that matches a recognised pattern is valid + auto Inputs = C->getInputs(); + + for (auto *V : Inputs) { + if (auto *I = dyn_cast(V)) { + if (auto *ParentC = getCandidateHoldingInstr(I)) + if (!candidateHasValidDataFlow(ParentC)) + return false; + + if (isa(I)) + continue; + + if (auto *SVI = dyn_cast(I)) { + // Permuted/splatted access + std::array PermuteMaskArr = {1, 0}; + ArrayRef PermuteMask(PermuteMaskArr); + if (match(SVI, m_Shuffle(m_Value(), m_Undef(), + m_SpecificMask(PermuteMask)))) { + C->TypeWidth = cast(SVI->getOperand(0)->getType()) + ->getNumElements(); + continue; + } + if (match(SVI, m_Shuffle(m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()), + m_Undef(), m_ZeroMask()))) { + auto *IEI = cast(SVI->getOperand(0)); + C->TypeWidth = cast(IEI->getOperand(0)->getType()) + ->getNumElements(); + continue; + } + + // Interleaved access + std::vector> ValidMasks; + std::array Even1 = {0}; + std::array Odd1 = {1}; + std::array Even2 = {0, 2}; + std::array Odd2 = {1, 3}; + std::array Even4 = {0, 2, 4, 6}; + std::array Odd4 = {1, 3, 5, 7}; + ValidMasks.emplace_back(Even1); + ValidMasks.emplace_back(Odd1); + ValidMasks.emplace_back(Even2); + ValidMasks.emplace_back(Odd2); + ValidMasks.emplace_back(Even4); + ValidMasks.emplace_back(Odd4); + + bool ValidMatch = false; + for (auto M : ValidMasks) { + if (match(SVI, m_Shuffle(m_Value(), m_Undef(), m_SpecificMask(M)))) { + ValidMatch = true; + C->TypeWidth = M.size() * 2; + break; + } + } + if (ValidMatch) + continue; + return false; + } + + if (TTI->validateComplexCandidateDataFlow(C, I)) + continue; + + return false; + } // else if (!isa(I)) + return false; + } + + return true; +} diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -69,6 +69,7 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/TypeBasedAliasAnalysis.h" +#include "llvm/CodeGen/ComplexArithmeticPass.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRPrintingPasses.h" #include "llvm/IR/PassManager.h" diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -239,6 +239,7 @@ FUNCTION_PASS("bounds-checking", BoundsCheckingPass()) FUNCTION_PASS("break-crit-edges", BreakCriticalEdgesPass()) FUNCTION_PASS("callsite-splitting", CallSiteSplittingPass()) +//FUNCTION_PASS("complex-arithmetic", ComplexArithmeticPass()) FUNCTION_PASS("consthoist", ConstantHoistingPass()) FUNCTION_PASS("constraint-elimination", ConstraintEliminationPass()) FUNCTION_PASS("chr", ControlHeightReductionPass()) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -223,6 +223,9 @@ FCMLEz, FCMLTz, + // Vector complex arithmetic + FCADD, + // Vector across-lanes addition // Only the lower result lane is defined. SADDV, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1980,6 +1980,7 @@ MAKE_CASE(AArch64ISD::FCMP) MAKE_CASE(AArch64ISD::STRICT_FCMP) MAKE_CASE(AArch64ISD::STRICT_FCMPE) + MAKE_CASE(AArch64ISD::FCADD) MAKE_CASE(AArch64ISD::DUP) MAKE_CASE(AArch64ISD::DUPLANE8) MAKE_CASE(AArch64ISD::DUPLANE16) @@ -17631,6 +17632,8 @@ ConstantSDNode *CN = cast(N->getOperand(0)); Intrinsic::ID IntID = static_cast(CN->getZExtValue()); + dbgs() << "Trying to lower intrinsic with id " << IntID << "\n"; + N->dump(); switch (IntID) { default: return; diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -43,6 +43,7 @@ #include "llvm/Target/TargetLoweringObjectFile.h" #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/CFGuard.h" +#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Scalar.h" #include #include @@ -601,8 +602,10 @@ } void AArch64PassConfig::addCodeGenPrepare() { - if (getOptLevel() != CodeGenOpt::None) + if (getOptLevel() != CodeGenOpt::None) { + addPass(createComplexArithmeticPass()); addPass(createTypePromotionPass()); + } TargetPassConfig::addCodeGenPrepare(); } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -24,7 +24,10 @@ #include "llvm/CodeGen/BasicTTIImpl.h" #include "llvm/IR/Function.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/MC/SubtargetFeature.h" #include +#include +#include namespace llvm { @@ -317,6 +320,104 @@ bool supportsScalableVectors() const { return ST->hasSVE(); } + bool supportsComplexNumberArithmetic() const { return ST->hasComplxNum(); } + + Intrinsic::ID getComplexArithmeticIntrinsic(ComplexArithmeticCandidate *C, + unsigned &IntArgCount) const { + if (!ST->hasNEON() || !ST->hasComplxNum()) + return Intrinsic::not_intrinsic; + + unsigned Rot = C->getRotation(); + switch (C->Type) { + case ComplexArithmeticCandidate::Complex_Mul: + case ComplexArithmeticCandidate::Complex_Mla: { + // AArch64 doesn't support Complex Mul, so use Mla instead + C->Type = ComplexArithmeticCandidate::Complex_Mla; + + IntArgCount = 3; + + if (Rot == 0) + return Intrinsic::aarch64_neon_vcmla_rot0; + if (Rot == 90) + return Intrinsic::aarch64_neon_vcmla_rot90; + if (Rot == 180) + return Intrinsic::aarch64_neon_vcmla_rot180; + if (Rot == 270) + return Intrinsic::aarch64_neon_vcmla_rot270; + break; + } + case ComplexArithmeticCandidate::Complex_Add: { + IntArgCount = 2; + if (Rot == 90) + return Intrinsic::aarch64_neon_vcadd_rot90; + if (Rot == 270) + return Intrinsic::aarch64_neon_vcadd_rot270; + } + } + + return Intrinsic::not_intrinsic; + } + + bool validateComplexCandidateDataFlow(ComplexArithmeticCandidate *C, + Instruction *I) const { + if (auto *EXT = dyn_cast(I)) { + auto *Op = EXT->getOperand(0); + auto Idx = EXT->getIndices()[0]; + if (Idx != 0 && Idx != 1) + return false; + if (auto *Int = dyn_cast(Op)) { + if (Int->getIntrinsicID() != Intrinsic::aarch64_neon_ld2) + return false; + + if (auto *STy = dyn_cast(Int->getType())) { + if (STy->getNumElements() != 2) + return false; + Type *ExpectedTy = C->getDataType(); + return STy->getElementType(0) == ExpectedTy && + STy->getElementType(1) == ExpectedTy; + } + } + return false; + } + return false; + } + void + filterComplexArithmeticOperand(ComplexArithmeticCandidate *C, Value *V, + SmallVector &Operands, + SmallVector &DeadInsts) { + if (auto *EVI = dyn_cast(V)) { + if (auto *CI = dyn_cast(EVI->getOperand(0))) { + if (CI->getIntrinsicID() == Intrinsic::aarch64_neon_ld2) { + DeadInsts.push_back(CI); + DeadInsts.push_back(EVI); + IRBuilder<> B(CI); + auto *Ptr = CI->getOperand(0); + + if (auto *BC = dyn_cast(Ptr)) { + DeadInsts.push_back(BC); + Ptr = BC->getOperand(0); + } + + auto ContainsLoadForPtr = [Operands](Value *Ptr) { + for (auto *Op : Operands) { + if (auto *LOp = dyn_cast(Op)) { + if (LOp->getOperand(0) == Ptr) + return true; + } + } + return false; + }; + + if (!ContainsLoadForPtr(Ptr)) { + auto *Ty = C->getDataType(); + auto *LI = B.CreateLoad(Ty, Ptr); + Operands.push_back(LI); + } + } + } + } + } + bool isLegalToVectorizeReduction(const RecurrenceDescriptor &RdxDesc, ElementCount VF) const; diff --git a/llvm/lib/Target/ARM/ARMTargetMachine.cpp b/llvm/lib/Target/ARM/ARMTargetMachine.cpp --- a/llvm/lib/Target/ARM/ARMTargetMachine.cpp +++ b/llvm/lib/Target/ARM/ARMTargetMachine.cpp @@ -416,6 +416,9 @@ return ST.hasAnyDataBarrier() && !ST.isThumb1Only(); })); + if (TM->getOptLevel() != CodeGenOpt::None) + addPass(createComplexArithmeticPass()); + addPass(createMVEGatherScatterLoweringPass()); addPass(createMVELaneInterleavingPass()); diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h @@ -199,6 +199,13 @@ int getNumMemOps(const IntrinsicInst *I) const; + bool supportsComplexNumberArithmetic() const { + return ST->hasNEON() || ST->hasMVEFloatOps(); + } + + Intrinsic::ID getComplexArithmeticIntrinsic(ComplexArithmeticCandidate *C, + unsigned &IntArgCount) const; + InstructionCost getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, ArrayRef Mask, int Index, VectorType *SubTp); diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -12,6 +12,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/CodeGen/ComplexArithmeticPass.h" #include "llvm/CodeGen/CostTable.h" #include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/ValueTypes.h" @@ -20,8 +21,8 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsARM.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" @@ -2337,3 +2338,56 @@ return false; return true; } + +Intrinsic::ID +ARMTTIImpl::getComplexArithmeticIntrinsic(ComplexArithmeticCandidate *C, + unsigned &IntArgCount) const { + if (C->getDataType()->isDoubleTy() || + (C->getDataType()->isVectorTy() && + cast(C->getDataType())->getElementType()->isDoubleTy())) + return Intrinsic::not_intrinsic; + + unsigned Rotation = C->getRotation(); + if (ST->hasMVEFloatOps()) { + Value *RotVal = ConstantInt::get(Type::getInt32Ty(C->getContext()), + (Rotation == 90 ? 0 : 1)); + + if (C->Type == ComplexArithmeticCandidate::Complex_Add) { + Value *Halving = ConstantInt::get(Type::getInt32Ty(C->getContext()), 1); + C->ExtraOperandsPre.push_back(Halving); + C->ExtraOperandsPre.push_back(RotVal); + IntArgCount = 4; + return Intrinsic::arm_mve_vcaddq; + } + + RotVal = + ConstantInt::get(Type::getInt32Ty(C->getContext()), (Rotation / 90)); + + if (C->Type == ComplexArithmeticCandidate::Complex_Mla || + (C->Type == ComplexArithmeticCandidate::Complex_Mul && + C->isTerminalInPair())) { + C->ExtraOperandsPre.push_back(RotVal); + C->Type = ComplexArithmeticCandidate::Complex_Mla; + IntArgCount = 4; + return Intrinsic::arm_mve_vcmlaq; + } + + if (C->Type == ComplexArithmeticCandidate::Complex_Mul) { + C->ExtraOperandsPre.push_back(RotVal); + IntArgCount = 3; + return Intrinsic::arm_mve_vcmulq; + } + } + + if (ST->hasNEON()) { + if (C->Type == ComplexArithmeticCandidate::Complex_Add) { + IntArgCount = 2; + if (Rotation == 90) + return Intrinsic::arm_neon_vcadd_rot90; + if (Rotation == 270) + return Intrinsic::arm_neon_vcadd_rot270; + } + } + + return Intrinsic::not_intrinsic; +} \ No newline at end of file diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll --- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll @@ -76,6 +76,7 @@ ; CHECK-NEXT: Interleaved Load Combine Pass ; CHECK-NEXT: Dominator Tree Construction ; CHECK-NEXT: Interleaved Access Pass +; CHECK-NEXT: Complex Arithmetic Pass ; CHECK-NEXT: Type Promotion ; CHECK-NEXT: Natural Loop Information ; CHECK-NEXT: CodeGen Prepare diff --git a/llvm/test/CodeGen/ARM/complex-arithmetic-arm.ll b/llvm/test/CodeGen/ARM/complex-arithmetic-arm.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/ARM/complex-arithmetic-arm.ll @@ -0,0 +1,175 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -enable-complex-arithmetic=true -o - | FileCheck < %s + +target triple = "thumbv8.1m.main-arm-none-eabi" + +define void @complex_mul_v8f32(float* %a, float* %b, float* %c) local_unnamed_addr #0 { +; CHECK-LABEL: complex_mul_v8f32: +; CHECK: @ %bb.0: @ %vector.ph +; CHECK-NEXT: .vsave {d8, d9, d10, d11} +; CHECK-NEXT: vpush {d8, d9, d10, d11} +; CHECK-NEXT: .p2align 2 +; CHECK-NEXT: .LBB0_1: @ %vector.body +; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: vld20.32 {q0, q1}, [r0] +; CHECK-NEXT: vld20.32 {q2, q3}, [r1] +; CHECK-NEXT: vld21.32 {q0, q1}, [r0] +; CHECK-NEXT: vld21.32 {q2, q3}, [r1] +; CHECK-NEXT: vmul.f32 q4, q3, q1 +; CHECK-NEXT: vneg.f32 q4, q4 +; CHECK-NEXT: vmul.f32 q5, q2, q1 +; CHECK-NEXT: vfma.f32 q4, q2, q0 +; CHECK-NEXT: vfma.f32 q5, q3, q0 +; CHECK-NEXT: vst20.32 {q4, q5}, [r0] +; CHECK-NEXT: vst21.32 {q4, q5}, [r0] +; CHECK-NEXT: b .LBB0_1 +vector.ph: + br label %vector.body + +vector.body: ; preds = %vector.body, %vector.ph + %a.ptr = bitcast float* %a to <8 x float>* + %b.ptr = bitcast float* %b to <8 x float>* + %c.ptr = bitcast float* %c to <8 x float>* + %a.val = load <8 x float>, <8 x float>* %a.ptr + %b.val = load <8 x float>, <8 x float>* %b.ptr + %strided.vec = shufflevector <8 x float> %a.val, <8 x float> poison, <4 x i32> + %strided.vec46 = shufflevector <8 x float> %a.val, <8 x float> poison, <4 x i32> + %strided.vec48 = shufflevector <8 x float> %b.val, <8 x float> poison, <4 x i32> + %strided.vec49 = shufflevector <8 x float> %b.val, <8 x float> poison, <4 x i32> + %0 = fmul fast <4 x float> %strided.vec48, %strided.vec + %1 = fmul fast <4 x float> %strided.vec49, %strided.vec46 + %2 = fsub fast <4 x float> %0, %1 + %3 = fmul fast <4 x float> %strided.vec49, %strided.vec + %4 = fmul fast <4 x float> %strided.vec48, %strided.vec46 + %5 = fadd fast <4 x float> %3, %4 + %6 = bitcast float* undef to <8 x float>* + %interleaved.vec = shufflevector <4 x float> %2, <4 x float> %5, <8 x i32> + store <8 x float> %interleaved.vec, <8 x float>* %6, align 4 + br label %vector.body +} + +define void @complex_mul_v4f32(float* %a, float* %b, float* %c) local_unnamed_addr #0 { +; CHECK-LABEL: complex_mul_v4f32: +; CHECK: @ %bb.0: @ %vector.ph +; CHECK-NEXT: .p2align 2 +; CHECK-NEXT: .LBB1_1: @ %vector.body +; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: vldrw.u32 q0, [r0] +; CHECK-NEXT: vldrw.u32 q1, [r1] +; CHECK-NEXT: vcmul.f32 q2, q1, q0, #0 +; CHECK-NEXT: vcmla.f32 q2, q1, q0, #90 +; CHECK-NEXT: vstrw.32 q2, [r0] +; CHECK-NEXT: b .LBB1_1 +vector.ph: + br label %vector.body + +vector.body: ; preds = %vector.body, %vector.ph + %a.ptr = bitcast float* %a to <4 x float>* + %b.ptr = bitcast float* %b to <4 x float>* + %c.ptr = bitcast float* %c to <4 x float>* + %a.val = load <4 x float>, <4 x float>* %a.ptr + %b.val = load <4 x float>, <4 x float>* %b.ptr + %strided.vec = shufflevector <4 x float> %a.val, <4 x float> poison, <2 x i32> + %strided.vec46 = shufflevector <4 x float> %a.val, <4 x float> poison, <2 x i32> + %strided.vec48 = shufflevector <4 x float> %b.val, <4 x float> poison, <2 x i32> + %strided.vec49 = shufflevector <4 x float> %b.val, <4 x float> poison, <2 x i32> + %0 = fmul fast <2 x float> %strided.vec48, %strided.vec + %1 = fmul fast <2 x float> %strided.vec49, %strided.vec46 + %2 = fsub fast <2 x float> %0, %1 + %3 = fmul fast <2 x float> %strided.vec49, %strided.vec + %4 = fmul fast <2 x float> %strided.vec48, %strided.vec46 + %5 = fadd fast <2 x float> %3, %4 + %6 = bitcast float* undef to <4 x float>* + %interleaved.vec = shufflevector <2 x float> %2, <2 x float> %5, <4 x i32> + store <4 x float> %interleaved.vec, <4 x float>* %6, align 4 + br label %vector.body +} + +define void @complex_mul_v4f64(double* %a, double* %b, double* %c) local_unnamed_addr #0 { +; CHECK-LABEL: complex_mul_v4f64: +; CHECK: @ %bb.0: @ %vector.ph +; CHECK-NEXT: .vsave {d8, d9} +; CHECK-NEXT: vpush {d8, d9} +; CHECK-NEXT: .p2align 2 +; CHECK-NEXT: .LBB2_1: @ %vector.body +; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: vldrw.u32 q0, [r0, #16] +; CHECK-NEXT: vldrw.u32 q1, [r1, #16] +; CHECK-NEXT: vldrw.u32 q3, [r0] +; CHECK-NEXT: vldrw.u32 q4, [r1] +; CHECK-NEXT: vmul.f64 d4, d3, d1 +; CHECK-NEXT: vmul.f64 d5, d2, d1 +; CHECK-NEXT: vfnms.f64 d4, d2, d0 +; CHECK-NEXT: vfma.f64 d5, d3, d0 +; CHECK-NEXT: vmul.f64 d0, d9, d7 +; CHECK-NEXT: vmul.f64 d1, d8, d7 +; CHECK-NEXT: vfnms.f64 d0, d8, d6 +; CHECK-NEXT: vfma.f64 d1, d9, d6 +; CHECK-NEXT: vstrw.32 q2, [r0] +; CHECK-NEXT: vstrw.32 q0, [r0] +; CHECK-NEXT: b .LBB2_1 +vector.ph: + br label %vector.body + +vector.body: ; preds = %vector.body, %vector.ph + %a.ptr = bitcast double* %a to <4 x double>* + %b.ptr = bitcast double* %b to <4 x double>* + %c.ptr = bitcast double* %c to <4 x double>* + %a.val = load <4 x double>, <4 x double>* %a.ptr + %b.val = load <4 x double>, <4 x double>* %b.ptr + %strided.vec = shufflevector <4 x double> %a.val, <4 x double> poison, <2 x i32> + %strided.vec46 = shufflevector <4 x double> %a.val, <4 x double> poison, <2 x i32> + %strided.vec48 = shufflevector <4 x double> %b.val, <4 x double> poison, <2 x i32> + %strided.vec49 = shufflevector <4 x double> %b.val, <4 x double> poison, <2 x i32> + %0 = fmul fast <2 x double> %strided.vec48, %strided.vec + %1 = fmul fast <2 x double> %strided.vec49, %strided.vec46 + %2 = fsub fast <2 x double> %0, %1 + %3 = fmul fast <2 x double> %strided.vec49, %strided.vec + %4 = fmul fast <2 x double> %strided.vec48, %strided.vec46 + %5 = fadd fast <2 x double> %3, %4 + %6 = bitcast double* undef to <4 x double>* + %interleaved.vec = shufflevector <2 x double> %2, <2 x double> %5, <4 x i32> + store <4 x double> %interleaved.vec, <4 x double>* %6, align 4 + br label %vector.body +} + +define void @complex_mul_v2f64(double* %a, double* %b, double* %c) local_unnamed_addr #0 { +; CHECK-LABEL: complex_mul_v2f64: +; CHECK: @ %bb.0: @ %vector.ph +; CHECK-NEXT: .p2align 2 +; CHECK-NEXT: .LBB3_1: @ %vector.body +; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: vldrw.u32 q0, [r0] +; CHECK-NEXT: vldrw.u32 q1, [r1] +; CHECK-NEXT: vmul.f64 d4, d3, d1 +; CHECK-NEXT: vmul.f64 d5, d2, d1 +; CHECK-NEXT: vfnms.f64 d4, d2, d0 +; CHECK-NEXT: vfma.f64 d5, d3, d0 +; CHECK-NEXT: vstrw.32 q2, [r0] +; CHECK-NEXT: b .LBB3_1 +vector.ph: + br label %vector.body + +vector.body: ; preds = %vector.body, %vector.ph + %a.ptr = bitcast double* %a to <2 x double>* + %b.ptr = bitcast double* %b to <2 x double>* + %c.ptr = bitcast double* %c to <2 x double>* + %a.val = load <2 x double>, <2 x double>* %a.ptr + %b.val = load <2 x double>, <2 x double>* %b.ptr + %strided.vec = shufflevector <2 x double> %a.val, <2 x double> poison, <1 x i32> + %strided.vec46 = shufflevector <2 x double> %a.val, <2 x double> poison, <1 x i32> + %strided.vec48 = shufflevector <2 x double> %b.val, <2 x double> poison, <1 x i32> + %strided.vec49 = shufflevector <2 x double> %b.val, <2 x double> poison, <1 x i32> + %0 = fmul fast <1 x double> %strided.vec48, %strided.vec + %1 = fmul fast <1 x double> %strided.vec49, %strided.vec46 + %2 = fsub fast <1 x double> %0, %1 + %3 = fmul fast <1 x double> %strided.vec49, %strided.vec + %4 = fmul fast <1 x double> %strided.vec48, %strided.vec46 + %5 = fadd fast <1 x double> %3, %4 + %6 = bitcast double* undef to <2 x double>* + %interleaved.vec = shufflevector <1 x double> %2, <1 x double> %5, <2 x i32> + store <2 x double> %interleaved.vec, <2 x double>* %6, align 4 + br label %vector.body +} + +attributes #0 = { "target-cpu"="cortex-m55" }