diff --git a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h @@ -0,0 +1,66 @@ +//===- ComplexDeinterleavingPass.h - Complex Deinterleaving Pass *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements generation of target-specific intrinsics to support +// handling of complex number arithmetic and deinterleaving. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CODEGEN_COMPLEXDEINTERLEAVING_H +#define LLVM_CODEGEN_COMPLEXDEINTERLEAVING_H + +#include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" + +namespace llvm { + +class Function; +class TargetMachine; + +struct ComplexDeinterleavingPass + : public PassInfoMixin { +private: + TargetMachine *TM; + +public: + ComplexDeinterleavingPass(TargetMachine *TM) : TM(TM) {} + + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; + +enum class ComplexDeinterleavingOperation { + None, + CAdd, + CMul, + CMulPartial +}; + +/// Struct defining the level of support for complex arithmetic features supported by the target +struct ComplexDeinterleavingSupport { +public: + // Default assumption is that support exists unless otherwise stated + bool SupportedOnTarget = true; + + unsigned int MaxVectorWidth; + + // Floating point complex operations + bool FPAdd = false; + bool FPMul = false; + bool FPPartialMul = false; + + /// Helper function to define no complex arithmetic support inline + static ComplexDeinterleavingSupport noSupport() { + ComplexDeinterleavingSupport S; + S.SupportedOnTarget = false; + return S; + } +}; + +} // namespace llvm + +#endif // LLVM_CODEGEN_COMPLEXDEINTERLEAVING_H diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -82,6 +82,10 @@ /// matching during instruction selection. FunctionPass *createCodeGenPreparePass(); + /// This pass implements generation of target-specific intrinsics to support + /// handling of complex number arithmetic + FunctionPass *createComplexDeinterleavingPass(const TargetMachine *TM); + /// AtomicExpandID -- Lowers atomic operations in terms of either cmpxchg /// load-linked/store-conditional loops. extern char &AtomicExpandID; diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -22,6 +22,7 @@ #ifndef LLVM_CODEGEN_TARGETLOWERING_H #define LLVM_CODEGEN_TARGETLOWERING_H +#include "ComplexDeinterleavingPass.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -2970,6 +2971,16 @@ return isOperationLegalOrCustom(Op, VT); } + virtual ComplexDeinterleavingSupport getComplexDeinterleavingSupport() const { + return ComplexDeinterleavingSupport::noSupport(); + } + + virtual Value *createComplexDeinterleavingIR(Instruction *I, ComplexDeinterleavingOperation OperationType, + unsigned Rotation, + Value *InputA, Value *InputB) const { + return nullptr; + } + //===--------------------------------------------------------------------===// // Runtime Library hooks // diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -114,9 +114,10 @@ void initializeCallSiteSplittingLegacyPassPass(PassRegistry&); void initializeCalledValuePropagationLegacyPassPass(PassRegistry &); void initializeCheckDebugMachineModulePass(PassRegistry &); -void initializeCodeGenPreparePass(PassRegistry&); -void initializeConstantHoistingLegacyPassPass(PassRegistry&); -void initializeConstantMergeLegacyPassPass(PassRegistry&); +void initializeCodeGenPreparePass(PassRegistry &); +void initializeComplexDeinterleavingLegacyPassPass(PassRegistry &); +void initializeConstantHoistingLegacyPassPass(PassRegistry &); +void initializeConstantMergeLegacyPassPass(PassRegistry &); void initializeConstraintEliminationPass(PassRegistry &); void initializeControlHeightReductionLegacyPassPass(PassRegistry&); void initializeCorrelatedValuePropagationPass(PassRegistry&); diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -44,6 +44,7 @@ CodeGenPassBuilder.cpp CodeGenPrepare.cpp CommandFlags.cpp + ComplexDeinterleavingPass.cpp CriticalAntiDepBreaker.cpp DeadMachineInstructionElim.cpp DetectDeadLanes.cpp diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -0,0 +1,704 @@ +//===- ComplexDeinterleavingPass.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/ComplexDeinterleavingPass.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/InitializePasses.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "complex-deinterleaving" + +STATISTIC(NumComplexIntrinsics, "Number of complex intrinsics generated"); + +static cl::opt ComplexArithmeticEnabled( + "enable-complex-arithmetic", + cl::desc("Enable generation of complex arithmetic instructions"), + cl::init(true), cl::Hidden); + +namespace { + +/** + * Creates a contiguous mask of the given length, optionally with a base offset. + */ +static SmallVector createContiguousMask(int len, int offset = 0) { + SmallVector Arr(len); + for (int j = 0; j < len; j++) + Arr[j] = j + offset; + return Arr; +} + +/** + * Creates an integer array of length \p len, where each item is \p step more + * than the previous. An offset can be provided to specify the first element. + */ +static SmallVector createArrayWithStep(int len, int step, int offset = 0) { + SmallVector Arr(len); + for (int j = 0; j < len; j++) + Arr[j] = (j * step) + offset; + return Arr; +} + +/** + * Creates an interleaving mask of the given length. + * An interleaving mask looks like \<0, 2, 1, 3> or \<0, 4, 1, 5, 2, 6, 3, 7> + */ +static SmallVector createInterleavingMask(int len) { + int Step = len / 2; + SmallVector Arr(len); + int idx = 0; + for (int j = 0; j < len; j += 2) { + Arr[j] = idx; + Arr[j + 1] = idx + Step; + idx++; + } + return Arr; +} + +/** + * Creates a deinterleaving mask of the given length at the given offset. + * A deinterleaving mask looks like <0, 2, 4, 6> or <1, 3, 5, 7> + */ +static SmallVector createDeinterleavingMask(int len, int offset = 0) { + return createArrayWithStep(len, 2, offset); +} + +class ComplexDeinterleavingLegacyPass : public FunctionPass { +public: + static char ID; + + ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) + : FunctionPass(ID), TM(TM) { + initializeComplexDeinterleavingLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { return "Complex Arithmetic Pass"; } + + bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesCFG(); + } + +private: + const TargetMachine *TM; +}; + +namespace { +/** + * Annotated graph-like structure that enriches the existing Instruction graph, + * allowing for contextual clues relevant to complex arithmetic to be provided + * and given to TTI hooks as required. + */ +class ComplexDeinterleavingData { +public: + + SmallVector getNodes() { + return Nodes; + } + + void addNode(Value *V) { + LLVM_DEBUG(dbgs() << "Adding node: "; V->dump()); + if (std::find(Nodes.begin(), Nodes.end(), V) != Nodes.end()) { + LLVM_DEBUG(dbgs() << " - Already added.\n"); + return; + } + + Nodes.push_back(V); + } + + unsigned Rotation = 0; + ComplexDeinterleavingOperation Type = ComplexDeinterleavingOperation::None; +private: + SmallVector Nodes; +}; + +} // namespace + +class ComplexDeinterleaving { +public: + ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) : TL(tl), TLI(tli) {} + bool runOnFunction(Function &F); + +private: + bool evaluateComplexDeinterleavingBasicBlock( + BasicBlock *B, ComplexDeinterleavingSupport &Support); + + const TargetLowering *TL = nullptr; + const TargetLibraryInfo *TLI = nullptr; +}; + +} // namespace + +char ComplexDeinterleavingLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, + "Complex Deinterleaving", false, false) +INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, + "Complex Deinterleaving", false, false) + +PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, + FunctionAnalysisManager &AM) { + const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); + auto TLI = AM.getResult(F); + if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserve(); + return PA; +} + +FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { + return new ComplexDeinterleavingLegacyPass(TM); +} + +bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { + const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); + auto TLI = getAnalysis().getTLI(F); + return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); +} + +bool ComplexDeinterleaving::runOnFunction(Function &F) { + if (!ComplexArithmeticEnabled) { + LLVM_DEBUG(dbgs() << "Complex has been explicitly disabled.\n"); + return false; + } + + auto Support = TL->getComplexDeinterleavingSupport(); + if (!Support.SupportedOnTarget) { + LLVM_DEBUG(dbgs() << "Complex has been disabled, target does not support lowering of complex numbers.\n"); + return false; + } + + bool Changed = false; + for (auto &B : F) + Changed |= evaluateComplexDeinterleavingBasicBlock(&B, Support); + + return Changed; +} + +/** + * Checks the given mask, and determines whether said mask is interleaving. + * + * To be interleaving, a mask must alternate between `i` and `i + (Length / 2)`, + * and must contain all numbers within the range of `[0..Length)` + * (e.g. a 4x vector interleaving mask would be <0, 2, 1, 3>). + */ +static bool isInterleavingMask(ArrayRef Mask, unsigned NumElements) { + if (Mask.size() != NumElements * 2) + return false; + + for (unsigned Idx = 0; Idx < NumElements; ++Idx) { + if (Mask[(Idx * 2) + 1] != (Mask[Idx * 2] + NumElements)) + return false; + } + + return true; +} + +/** + * Checks the mask of the given ShuffleVectorInst, and determines whether said + * shuffle is interleaving. See isInterleavingMask. + */ +static bool isInterleaving(ShuffleVectorInst *SVI) { + auto *Ty = dyn_cast(SVI->getOperand(0)->getType()); + if (!Ty) + return false; + + unsigned NumElements = Ty->getNumElements(); + return isInterleavingMask(SVI->getShuffleMask(), NumElements); +} + +static bool matchComplexPartialMul(ShuffleVectorInst *SVI, + ComplexDeinterleavingData &G) { + auto InterleavingMask = createInterleavingMask(SVI->getShuffleMask().size()); + auto DeinterleavingLength = InterleavingMask.size() / 2; + auto DeinterleavingRealMask = + createDeinterleavingMask(DeinterleavingLength, 0); + auto DeinterleavingImagMask = + createDeinterleavingMask(DeinterleavingLength, 1); + + ArrayRef DeinterleavingRealMaskRef(DeinterleavingRealMask); + ArrayRef DeinterleavingImagMaskRef(DeinterleavingImagMask); + ArrayRef InterleavingMaskRef(InterleavingMask); + + Value *LoadA, *LoadB; + Value *AssertLoadA, *AssertLoadB; + + auto MulByRealPatternA = + m_Shuffle(m_Shuffle(m_FMul(m_Value(LoadB), m_Value(LoadA)), m_Poison(), + m_SpecificMask(DeinterleavingRealMaskRef)), + m_FMul(m_Shuffle(m_Value(AssertLoadB), m_Poison(), + m_SpecificMask(DeinterleavingRealMaskRef)), + m_Shuffle(m_Value(AssertLoadA), m_Poison(), + m_SpecificMask(DeinterleavingImagMaskRef))), + m_SpecificMask(InterleavingMaskRef)); + + if (!match(SVI, MulByRealPatternA)) { + LLVM_DEBUG(dbgs() << "Failed to match MulByReal pattern.\n"; SVI->dump(); + SVI->getParent()->dump()); + return false; + } + + if (LoadA != AssertLoadA || LoadB != AssertLoadB) { + LLVM_DEBUG(dbgs() << "Loads don't match expected pattern" + << ".\n"); + return false; + } + + G.addNode(LoadA); + G.addNode(LoadB); + + G.Rotation = 0; + G.Type = llvm::ComplexDeinterleavingOperation::CMulPartial; + + return true; +} + +static bool matchComplexMul(ShuffleVectorInst *SVI, + ComplexDeinterleavingData &G) { + unsigned LikelyRotation = 0; + + Value *LeftShuffle00; + Value *LeftShuffle01; + Value *LeftShuffle10; + Value *LeftShuffle11; + + Value *RightShuffle00; + Value *RightShuffle01; + Value *RightShuffle10; + Value *RightShuffle11; + + auto Mask = createInterleavingMask(SVI->getShuffleMask().size()); + ArrayRef MaskRef(Mask); + + auto InterleaveShuffleRot0Pattern = m_Shuffle( + m_FSub(m_FMul(m_Value(LeftShuffle00), m_Value(LeftShuffle01)), + m_FMul(m_Value(LeftShuffle10), m_Value(LeftShuffle11))), + m_FAdd(m_FMul(m_Value(RightShuffle00), m_Value(RightShuffle01)), + m_FMul(m_Value(RightShuffle10), m_Value(RightShuffle11))), + m_SpecificMask(MaskRef)); + + auto InterleaveShuffleRot180Pattern = m_Shuffle( + m_FSub(m_FMul(m_Value(LeftShuffle11), m_Value(LeftShuffle01)), + m_FMul(m_Value(LeftShuffle10), m_Value(LeftShuffle00))), + m_FSub(m_FMul(m_Value(RightShuffle10), m_FNeg(m_Value(RightShuffle01))), + m_FMul(m_Value(RightShuffle11), m_Value(RightShuffle00))), + m_SpecificMask(MaskRef)); + + if (match(SVI, InterleaveShuffleRot0Pattern)) + LikelyRotation = 0; + else if (match(SVI, InterleaveShuffleRot180Pattern)) + LikelyRotation = 180; + else { + LLVM_DEBUG(dbgs() << "SVI does not match expected patterns.\n"; SVI->dump()); + return false; + } + + LLVM_DEBUG(dbgs() << "Rotation: " << LikelyRotation << ".\n"); + + struct ShuffleScore { + Value* Shuffle; + unsigned Score; + + unsigned leftScore() { + return (Score & 0b110) >> 1; + } + + unsigned rightScore() { + return Score & 0b011; + } + + ShuffleScore(Value *Shuffle, unsigned Score) : Shuffle(Shuffle), Score(Score) {} + }; + + auto *Sub = cast(SVI->getOperand(0)); + auto *Add = cast(SVI->getOperand(1)); + + auto *SubMul0 = cast(Sub->getOperand(0)); + // SubMul1 not needed for comparisons + auto *AddMul0 = cast(Add->getOperand(0)); + auto *AddMul1 = cast(Add->getOperand(1)); + + // Evaluate the score of a given Value (Shuffle) in isolation. + // The score is made up of 3 bits that are then split into 2 2-bit numbers, + // representing the "left" and "right" scores, akin to a hash function. + // These scores should be distinct within their respective sets, otherwise + // this pattern is not valid for complex deinterleaving. + // + // How a score is evaluated is based on 3 conditions, each one relating to + // the respective bit in the final score; Crossover, MulSlot, MulOperand. + // + // Crossover: If Shuffle contributes to the first operand of the `sub` and the + // second operand of the `add`, then the `Crossover` score is 1. + // + // MulSlot: If Shuffle contributes to the second operand of the add, + // the `MulSlot` score is 1. + // + // MulOperand: If Shuffle is the second operand of either of `add`s + // operands, the `MulOperand` score is 1. + SmallVector DiscoveredNegs; + + auto CompareValues = [&](Value *A, Value *B) { + if(LikelyRotation == 180) { + if (auto *I = dyn_cast(A)) { + if(I->getOpcode() == Instruction::FNeg) { + if(std::find(DiscoveredNegs.begin(), DiscoveredNegs.end(), A) != DiscoveredNegs.end()) + DiscoveredNegs.push_back(A); + A = I->getOperand(0); + } + } + } + + return A == B; + }; + + auto EvaluateScore = [&](Value *Shuffle) { + bool Cross = false; + bool MulSlot = false; + bool MulOperand = false; + + // Evaluate crossover + { + bool IsSlot0Left = CompareValues(SubMul0->getOperand(0), Shuffle) || + CompareValues(SubMul0->getOperand(1), Shuffle); + bool IsSlot0Right = CompareValues(AddMul0->getOperand(0), Shuffle) || + CompareValues(AddMul0->getOperand(1), Shuffle); + Cross = IsSlot0Left != IsSlot0Right; + } + + // Evaluate add(mul) slot + MulSlot = CompareValues(AddMul1->getOperand(0), Shuffle) || + CompareValues(AddMul1->getOperand(1), Shuffle); + + // Evaluate operand slot + MulOperand = CompareValues(AddMul0->getOperand(1), Shuffle) || + CompareValues(AddMul1->getOperand(1), Shuffle); + + return (Cross << 2) | (MulSlot << 1) | MulOperand; + }; + + ShuffleScore Scores[4] = { + ShuffleScore(LeftShuffle00, EvaluateScore(LeftShuffle00)), + ShuffleScore(LeftShuffle01, EvaluateScore(LeftShuffle01)), + ShuffleScore(LeftShuffle10, EvaluateScore(LeftShuffle10)), + ShuffleScore(LeftShuffle11, EvaluateScore(LeftShuffle11)) + }; + + if(DiscoveredNegs.size() > 1) { + LLVM_DEBUG(dbgs() << "Too many negations in pattern to be confidently matched.\n"); + return false; + } + + // All of these must be true after the following loop, otherwise the pattern + // is not valid for complex deinterleaving. They must also be set only once, as + // a duplicate score on either side also means that the pattern is not valid. + bool LeftScores[4] = { false, false, false, false }; + bool RightScores[4] = { false, false, false, false }; + + for(unsigned i = 0; i < 4; i++) { + ShuffleScore S = Scores[i]; + unsigned LIdx = S.leftScore(); + unsigned RIdx = S.rightScore(); + + if(LeftScores[LIdx]) { + LLVM_DEBUG(dbgs() << "Bad left score for S: " << S.Score << ". "; S.Shuffle->dump()); + return false; + } + LeftScores[LIdx] = true; + if(RightScores[RIdx]) { + LLVM_DEBUG(dbgs() << "Bad right score for S: " << S.Score << ". "; S.Shuffle->dump()); + return false; + } + RightScores[RIdx] = true; + } + + for(unsigned i = 0; i < 4; i++) { + if(!LeftScores[i]) { + LLVM_DEBUG(dbgs() << "Left Score " << i << " was left false" << ".\n"); + return false; + } + if(!RightScores[i]) { + LLVM_DEBUG(dbgs() << "Right Score " << i << " was left false" << ".\n"); + return false; + } + } + + + Value *ShuffleAR = LeftShuffle01; + Value *ShuffleBR = LeftShuffle00; + Value *ShuffleAI = LeftShuffle11; + Value *ShuffleBI = LeftShuffle10; + + // Add the first operand of all 4 shuffles. 2 of these should be duplicate, + // but we don't know which 2. + // If there aren't 2 nodes after this, the pattern is not deinterleaved. + G.addNode(cast(ShuffleAR)->getOperand(0)); + G.addNode(cast(ShuffleBR)->getOperand(0)); + G.addNode(cast(ShuffleAI)->getOperand(0)); + G.addNode(cast(ShuffleBI)->getOperand(0)); + + G.Rotation = LikelyRotation; + G.Type = llvm::ComplexDeinterleavingOperation::CMul; + + return true; +} + +static bool matchComplexAdd(ShuffleVectorInst *SVI, + ComplexDeinterleavingData &G) { + Value *ShuffleAR; + Value *ShuffleAI; + Value *ShuffleBR; + Value *ShuffleBI; + + auto *Op0 = dyn_cast(SVI->getOperand(0)); + auto *Op1 = dyn_cast(SVI->getOperand(1)); + + if (!Op0 || !Op1) + return false; + + unsigned Rotation; + if (Op0->getOpcode() == Instruction::FSub && + Op1->getOpcode() == Instruction::FAdd) { + Rotation = 90; + } else if (Op0->getOpcode() == Instruction::FAdd && + Op1->getOpcode() == Instruction::FSub) { + Rotation = 270; + } else { + return false; + } + + auto ShuffleMask = createInterleavingMask(SVI->getShuffleMask().size()); + ArrayRef ShuffleMaskRef(ShuffleMask); + + if (Rotation == 90) { + if (!match(SVI, m_Shuffle(m_FSub(m_Value(ShuffleAR), m_Value(ShuffleBI)), + m_FAdd(m_Value(ShuffleAI), m_Value(ShuffleBR)), + m_SpecificMask(ShuffleMaskRef)))) { + LLVM_DEBUG( + dbgs() << "SVI does not match expected pattern for complex add rot " + << Rotation << ".\n"); + return false; + } + } else if (Rotation == 270) { + if (!match(SVI, m_Shuffle(m_FAdd(m_Value(ShuffleBI), m_Value(ShuffleAR)), + m_FSub(m_Value(ShuffleAI), m_Value(ShuffleBR)), + m_SpecificMask(ShuffleMaskRef)))) { + LLVM_DEBUG( + dbgs() << "SVI does not match expected pattern for complex add rot " + << Rotation << ".\n"); + return false; + } + } + + if (!isa(ShuffleAR) || + !isa(ShuffleAI) || + !isa(ShuffleBR) || + !isa(ShuffleAI)) { + LLVM_DEBUG(dbgs() << "SVI does not match expected pattern for complex add, " + "inputs aren't all shuffles.\n"); + return false; + } + + auto *InputA = cast(ShuffleAR)->getOperand(0); + auto *InputB = cast(ShuffleBR)->getOperand(0); + + G.addNode(InputA); + G.addNode(InputB); + + G.Rotation = Rotation; + G.Type = ComplexDeinterleavingOperation::CAdd; + + return true; +} + +static bool substituteGraph(const TargetLowering *TL, Instruction *I, + ComplexDeinterleavingData &G, + ComplexDeinterleavingSupport &Support) { + SmallVector Inputs = G.getNodes(); + + if(Inputs.size() != 2) { + LLVM_DEBUG(dbgs() << "Unexpected amount of inputs.\n"); + return false; + } + + auto *LoadA = Inputs[0]; + auto *LoadB = Inputs[1]; + + auto *TyA = cast(LoadA->getType()); + auto *TyB = cast(LoadB->getType()); + + FixedVectorType *WideType; + FixedVectorType *NarrowType; + if (TyA->getNumElements() >= TyB->getNumElements()) { + WideType = TyA; + NarrowType = TyB; + } else { + WideType = TyB; + NarrowType = TyA; + } + + if (NarrowType->getNumElements() != WideType->getNumElements() && + NarrowType->getNumElements() != WideType->getNumElements() / 2) { + LLVM_DEBUG( + dbgs() + << "Narrow type is not equal to or half the width of the wide type" + << ".\n"); + return false; + } + + unsigned WideStride; + unsigned NarrowStride; + + const unsigned MaxVectorWidth = Support.MaxVectorWidth; + + unsigned NumBits = + WideType->getScalarSizeInBits() * WideType->getNumElements(); + WideStride = MaxVectorWidth / WideType->getScalarSizeInBits(); + if (NarrowType->getNumElements() == WideType->getNumElements()) + NarrowStride = WideStride; + else + NarrowStride = WideType->getNumElements() / WideStride; + + if (NumBits > MaxVectorWidth) { + LLVM_DEBUG(dbgs() << "Split required, " << NumBits + << " is greater than the max vector width (" + << MaxVectorWidth << ")" + << ".\n"); + if (NumBits % MaxVectorWidth != 0) { + LLVM_DEBUG(dbgs() << "Vector can't be split evenly" + << ".\n"); + return false; + } + + IRBuilder<> B(I); + + unsigned SplitCount = NumBits / MaxVectorWidth; + + if (SplitCount > 2) { + LLVM_DEBUG(dbgs() << "Cannot split operation beyond 2" + << ".\n"); + return false; + } + + SmallVector CreatedInsts; + SmallVector ComplexIR; + for (unsigned i = 0; i < SplitCount; ++i) { + SmallVector WideMask = createContiguousMask(WideStride, WideStride * i); + SmallVector NarrowMask = createContiguousMask(NarrowStride, NarrowStride * i); + + ArrayRef WideMaskRef(WideMask); + ArrayRef NarrowMaskRef(NarrowMask); + + auto *Undef = UndefValue::get(LoadA->getType()); + auto *Undef2 = UndefValue::get(LoadB->getType()); + Value *ShuffleA, *ShuffleB; + if (TyA == WideType) { + ShuffleA = B.CreateShuffleVector( + LoadA, Undef, WideMaskRef.take_front(TyA->getNumElements() / 2)); + ShuffleB = B.CreateShuffleVector( + LoadB, Undef2, NarrowMaskRef.take_front(TyB->getNumElements() / 2)); + } else { + ShuffleA = B.CreateShuffleVector( + LoadB, Undef, WideMaskRef.take_front(TyB->getNumElements() / 2)); + ShuffleB = B.CreateShuffleVector( + LoadA, Undef2, NarrowMaskRef.take_front(TyA->getNumElements() / 2)); + } + + CreatedInsts.push_back(ShuffleA); + CreatedInsts.push_back(ShuffleB); + + auto *IR = TL->createComplexDeinterleavingIR( + I, G.Type, G.Rotation, ShuffleA, ShuffleB); + if (IR == nullptr) + return false; + NumComplexIntrinsics++; + ComplexIR.push_back(IR); + CreatedInsts.push_back(IR); + } + auto Mask = createContiguousMask(WideStride * 2); + ArrayRef MaskRef(Mask); + auto *Shuffle = B.CreateShuffleVector(ComplexIR[0], ComplexIR[1], MaskRef); + I->replaceAllUsesWith(Shuffle); + } else { + auto *Mla = TL->createComplexDeinterleavingIR( + I, G.Type, G.Rotation, LoadA, LoadB); + if (Mla == nullptr) + return false; + NumComplexIntrinsics++; + I->replaceAllUsesWith(Mla); + } + + return true; +} + +static bool traverseAndPopulateGraph(const TargetLowering *TLI, Instruction *I, + ComplexDeinterleavingData &G, + ComplexDeinterleavingSupport &Support) { + // Shuffle mask needs to interleave vectors + // e.g. + // <4 x i32> <0, 2, 1, 3> + // <8 x i32> <0, 4, 1, 5, 2, 6, 3, 7> + if (auto *SVI = dyn_cast(I)) { + if (!isInterleaving(SVI)) { + LLVM_DEBUG(dbgs() << "SVI doesn't appear to perform interleaving.\n"); + return false; + } + + if (Support.FPPartialMul && matchComplexPartialMul(SVI, G)) + return substituteGraph(TLI, I, G, Support); + + if (Support.FPMul && matchComplexMul(SVI, G)) + return substituteGraph(TLI, I, G, Support); + + if (Support.FPAdd && matchComplexAdd(SVI, G)) + return substituteGraph(TLI, I, G, Support); + } + + return false; +} + +bool ComplexDeinterleaving::evaluateComplexDeinterleavingBasicBlock( + BasicBlock *B, ComplexDeinterleavingSupport &Support) { + ComplexDeinterleavingData Graph; + + bool Changed = false; + + SmallVector DeadInstrRoots; + + for (auto &I : *B) { + if (auto *SVI = dyn_cast(&I)) { + if (isInterleaving(SVI)) { + if(traverseAndPopulateGraph(TL, SVI, Graph, Support)) { + Changed = true; + DeadInstrRoots.push_back(SVI); + } + } + } + } + + for (const auto &I : DeadInstrRoots) + llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI); + + return Changed; +} + diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h --- a/llvm/lib/Target/ARM/ARMISelLowering.h +++ b/llvm/lib/Target/ARM/ARMISelLowering.h @@ -740,6 +740,12 @@ bool shouldConvertFpToSat(unsigned Op, EVT FPVT, EVT VT) const override; + ComplexDeinterleavingSupport + getComplexDeinterleavingSupport() const override; + + Value *createComplexDeinterleavingIR(Instruction *I, ComplexDeinterleavingOperation OperationType, unsigned Rotation, + Value *InputA, Value *InputB) const override; + protected: std::pair findRepresentativeClass(const TargetRegisterInfo *TRI, diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -21714,3 +21714,90 @@ MF.getFrameInfo().computeMaxCallFrameSize(MF); TargetLoweringBase::finalizeLowering(MF); } + +ComplexDeinterleavingSupport +ARMTargetLowering::getComplexDeinterleavingSupport() const { + if(!Subtarget->hasMVEFloatOps()) + return ComplexDeinterleavingSupport::noSupport(); + + ComplexDeinterleavingSupport S; + S.MaxVectorWidth = 128; + S.FPAdd = true; + S.FPMul = true; + S.FPPartialMul = true; + return S; +} + +Value *ARMTargetLowering::createComplexDeinterleavingIR( + Instruction *I, ComplexDeinterleavingOperation OperationType, unsigned Rotation, + Value *InputA, Value *InputB) const { + auto *Ty = InputA->getType(); + if (!isa(Ty)) + return nullptr; + auto *VTy = cast(Ty); + + // Cannot widen complex intrinsics to fill vectors + if (VTy->getNumElements() * VTy->getScalarSizeInBits() != 128) + return nullptr; + + // MVE does not support double complex operations + if (VTy->getScalarType()->isDoubleTy()) + return nullptr; + + IRBuilder<> B(I); + auto *IntTy = Type::getInt32Ty(B.getContext()); + + if (OperationType == ComplexDeinterleavingOperation::CMulPartial) { + + auto *TyA = InputA->getType(); + auto *TyB = InputB->getType(); + + ConstantInt *ConstMulRot = nullptr; + + if (Rotation == 0) + ConstMulRot = ConstantInt::get(IntTy, 0); + else if (Rotation == 180) + ConstMulRot = ConstantInt::get(IntTy, 2); + + if (!ConstMulRot) + return nullptr; + + auto *Mul = B.CreateIntrinsic(Intrinsic::arm_mve_vcmulq, Ty, + {ConstMulRot, InputB, InputA}); + return Mul; + } + + if (OperationType == ComplexDeinterleavingOperation::CMul) { + + int RotIdx = Rotation / 90; + + auto *ConstMulRot = ConstantInt::get(IntTy, RotIdx); + auto *ConstMlaRot = ConstantInt::get(IntTy, (RotIdx + 1) % 4); + auto *Mul = B.CreateIntrinsic(Intrinsic::arm_mve_vcmulq, Ty, + {ConstMulRot, InputA, InputB}); + auto *Mla = B.CreateIntrinsic(Intrinsic::arm_mve_vcmlaq, Ty, + {ConstMlaRot, Mul, InputA, InputB}); + return Mla; + } + + if (OperationType == ComplexDeinterleavingOperation::CAdd) { + + // 1 means the value is not halved. + unsigned HalvingVal = 1; + auto *Halving = ConstantInt::get(IntTy, HalvingVal); + + unsigned RotKey; + if (Rotation == 90) + RotKey = 0; + else if (Rotation == 270) + RotKey = 1; + else + return nullptr; // Invalid rotation for arm_mve_vcaddq + + auto *RotVal = ConstantInt::get(IntTy, RotKey); + return B.CreateIntrinsic(Intrinsic::arm_mve_vcaddq, Ty, + {Halving, RotVal, InputA, InputB}); + } + + return nullptr; +} diff --git a/llvm/lib/Target/ARM/ARMTargetMachine.cpp b/llvm/lib/Target/ARM/ARMTargetMachine.cpp --- a/llvm/lib/Target/ARM/ARMTargetMachine.cpp +++ b/llvm/lib/Target/ARM/ARMTargetMachine.cpp @@ -425,12 +425,17 @@ TargetPassConfig::addIRPasses(); // Run the parallel DSP pass. - if (getOptLevel() == CodeGenOpt::Aggressive) + if (getOptLevel() == CodeGenOpt::Aggressive) addPass(createARMParallelDSPPass()); + // Match complex arithmetic patterns + if (TM->getOptLevel() >= CodeGenOpt::Default) + addPass(createComplexDeinterleavingPass(TM)); + // Match interleaved memory accesses to ldN/stN intrinsics. - if (TM->getOptLevel() != CodeGenOpt::None) + if (TM->getOptLevel() != CodeGenOpt::None) { addPass(createInterleavedAccessPass()); + } // Add Control Flow Guard checks. if (TM->getTargetTriple().isOSWindows()) diff --git a/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f16-add.ll b/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f16-add.ll --- a/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f16-add.ll +++ b/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f16-add.ll @@ -55,16 +55,11 @@ define <8 x half> @complex_add_v8f16(<8 x half> %a, <8 x half> %b) #0 { ; CHECK-LABEL: complex_add_v8f16: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: mov r12, sp -; CHECK-NEXT: vld1.64 {d0, d1}, [r12] -; CHECK-NEXT: vorr d18, d0, d0 -; CHECK-NEXT: vmov d16, r2, r3 -; CHECK-NEXT: vmov d17, r0, r1 -; CHECK-NEXT: vuzp.16 d17, d16 -; CHECK-NEXT: vuzp.16 d18, d1 -; CHECK-NEXT: vadd.f16 d1, d1, d17 -; CHECK-NEXT: vsub.f16 d0, d18, d16 -; CHECK-NEXT: vzip.16 d0, d1 +; CHECK-NEXT: vmov d0, r0, r1 +; CHECK-NEXT: mov r0, sp +; CHECK-NEXT: vld1.64 {d2, d3}, [r0] +; CHECK-NEXT: vmov d1, r2, r3 +; CHECK-NEXT: vcadd.f16 q0, q1, q0, #90 ; CHECK-NEXT: vmov r0, r1, d0 ; CHECK-NEXT: vmov r2, r3, d1 ; CHECK-NEXT: bx lr @@ -81,21 +76,18 @@ define <16 x half> @complex_add_v16f16(<16 x half> %a, <16 x half> %b) #0 { ; CHECK-LABEL: complex_add_v16f16: ; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: add r1, sp, #24 +; CHECK-NEXT: vldr d1, [sp] +; CHECK-NEXT: vld1.64 {d2, d3}, [r1] +; CHECK-NEXT: vmov d0, r2, r3 ; CHECK-NEXT: add r1, sp, #8 -; CHECK-NEXT: vldr d5, [sp] +; CHECK-NEXT: vcadd.f16 q0, q1, q0, #90 +; CHECK-NEXT: vst1.16 {d0, d1}, [r0:128]! ; CHECK-NEXT: vld1.64 {d0, d1}, [r1] ; CHECK-NEXT: add r1, sp, #40 -; CHECK-NEXT: vmov d4, r2, r3 ; CHECK-NEXT: vld1.64 {d2, d3}, [r1] -; CHECK-NEXT: add r1, sp, #24 -; CHECK-NEXT: vld1.64 {d6, d7}, [r1] -; CHECK-NEXT: vuzp.16 q2, q0 -; CHECK-NEXT: vuzp.16 q3, q1 -; CHECK-NEXT: vadd.f16 q1, q1, q2 -; CHECK-NEXT: vsub.f16 q0, q3, q0 -; CHECK-NEXT: vzip.16 q0, q1 -; CHECK-NEXT: vst1.16 {d0, d1}, [r0:128]! -; CHECK-NEXT: vst1.64 {d2, d3}, [r0:128] +; CHECK-NEXT: vcadd.f16 q0, q1, q0, #90 +; CHECK-NEXT: vst1.64 {d0, d1}, [r0:128] ; CHECK-NEXT: bx lr entry: %strided.vec = shufflevector <16 x half> %a, <16 x half> zeroinitializer, <8 x i32> diff --git a/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f16-mul.ll b/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f16-mul.ll --- a/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f16-mul.ll +++ b/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f16-mul.ll @@ -68,19 +68,13 @@ ; CHECK-LABEL: complex_mul_v8f16: ; CHECK: @ %bb.0: @ %entry ; CHECK-NEXT: mov r12, sp -; CHECK-NEXT: vld1.64 {d0, d1}, [r12] -; CHECK-NEXT: vmov d16, r2, r3 -; CHECK-NEXT: vmov d17, r0, r1 -; CHECK-NEXT: vorr d18, d0, d0 -; CHECK-NEXT: vuzp.16 d17, d16 -; CHECK-NEXT: vuzp.16 d18, d1 -; CHECK-NEXT: vmul.f16 d3, d1, d17 -; CHECK-NEXT: vmul.f16 d0, d18, d17 -; CHECK-NEXT: vmla.f16 d3, d18, d16 -; CHECK-NEXT: vmls.f16 d0, d1, d16 -; CHECK-NEXT: vzip.16 d0, d3 -; CHECK-NEXT: vmov r0, r1, d0 -; CHECK-NEXT: vmov r2, r3, d3 +; CHECK-NEXT: vld1.64 {d2, d3}, [r12] +; CHECK-NEXT: vmov d1, r2, r3 +; CHECK-NEXT: vmov d0, r0, r1 +; CHECK-NEXT: vcmul.f16 q2, q0, q1, #0 +; CHECK-NEXT: vcmla.f16 q2, q0, q1, #90 +; CHECK-NEXT: vmov r0, r1, d4 +; CHECK-NEXT: vmov r2, r3, d5 ; CHECK-NEXT: bx lr entry: %strided.vec = shufflevector <8 x half> %a, <8 x half> poison, <4 x i32> @@ -100,27 +94,20 @@ define <16 x half> @complex_mul_v16f16(<16 x half> %a, <16 x half> %b) #0 { ; CHECK-LABEL: complex_mul_v16f16: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: .vsave {d8, d9} -; CHECK-NEXT: vpush {d8, d9} +; CHECK-NEXT: add r1, sp, #40 +; CHECK-NEXT: vld1.64 {d0, d1}, [r1] +; CHECK-NEXT: add r1, sp, #8 +; CHECK-NEXT: vld1.64 {d2, d3}, [r1] ; CHECK-NEXT: add r1, sp, #24 -; CHECK-NEXT: vldr d1, [sp, #16] +; CHECK-NEXT: vcmul.f16 q2, q1, q0, #0 +; CHECK-NEXT: vcmla.f16 q2, q1, q0, #90 +; CHECK-NEXT: vldr d1, [sp] ; CHECK-NEXT: vld1.64 {d2, d3}, [r1] -; CHECK-NEXT: add r1, sp, #56 ; CHECK-NEXT: vmov d0, r2, r3 -; CHECK-NEXT: vld1.64 {d4, d5}, [r1] -; CHECK-NEXT: add r1, sp, #40 -; CHECK-NEXT: vld1.64 {d6, d7}, [r1] -; CHECK-NEXT: vuzp.16 q0, q1 -; CHECK-NEXT: vuzp.16 q3, q2 -; CHECK-NEXT: vmul.f16 q4, q3, q1 -; CHECK-NEXT: vmul.f16 q1, q2, q1 -; CHECK-NEXT: vfma.f16 q4, q2, q0 -; CHECK-NEXT: vneg.f16 q1, q1 -; CHECK-NEXT: vfma.f16 q1, q3, q0 -; CHECK-NEXT: vzip.16 q1, q4 -; CHECK-NEXT: vst1.16 {d2, d3}, [r0:128]! -; CHECK-NEXT: vst1.64 {d8, d9}, [r0:128] -; CHECK-NEXT: vpop {d8, d9} +; CHECK-NEXT: vcmul.f16 q3, q0, q1, #0 +; CHECK-NEXT: vcmla.f16 q3, q0, q1, #90 +; CHECK-NEXT: vst1.16 {d6, d7}, [r0:128]! +; CHECK-NEXT: vst1.64 {d4, d5}, [r0:128] ; CHECK-NEXT: bx lr entry: %strided.vec = shufflevector <16 x half> %a, <16 x half> poison, <8 x i32> diff --git a/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f32-add.ll b/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f32-add.ll --- a/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f32-add.ll +++ b/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f32-add.ll @@ -27,18 +27,13 @@ define <4 x float> @complex_add_v4f32(<4 x float> %a, <4 x float> %b) #0 { ; CHECK-LABEL: complex_add_v4f32: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: mov r12, sp -; CHECK-NEXT: vld1.64 {d0, d1}, [r12] -; CHECK-NEXT: vorr d18, d0, d0 -; CHECK-NEXT: vmov d16, r2, r3 -; CHECK-NEXT: vmov d17, r0, r1 -; CHECK-NEXT: vtrn.32 d17, d16 -; CHECK-NEXT: vtrn.32 d18, d1 -; CHECK-NEXT: vadd.f32 d1, d1, d17 -; CHECK-NEXT: vsub.f32 d0, d18, d16 -; CHECK-NEXT: vtrn.32 d0, d1 -; CHECK-NEXT: vmov r0, r1, d0 -; CHECK-NEXT: vmov r2, r3, d1 +; CHECK-NEXT: vmov d0, r0, r1 +; CHECK-NEXT: mov r0, sp +; CHECK-NEXT: vld1.64 {d2, d3}, [r0] +; CHECK-NEXT: vmov d1, r2, r3 +; CHECK-NEXT: vcadd.f32 q2, q1, q0, #90 +; CHECK-NEXT: vmov r0, r1, d4 +; CHECK-NEXT: vmov r2, r3, d5 ; CHECK-NEXT: bx lr entry: %strided.vec = shufflevector <4 x float> %a, <4 x float> zeroinitializer, <2 x i32> @@ -53,21 +48,21 @@ define <8 x float> @complex_add_v8f32(<8 x float> %a, <8 x float> %b) #0 { ; CHECK-LABEL: complex_add_v8f32: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: add r1, sp, #8 -; CHECK-NEXT: vldr d5, [sp] +; CHECK-NEXT: .vsave {d8, d9} +; CHECK-NEXT: vpush {d8, d9} +; CHECK-NEXT: add r1, sp, #24 +; CHECK-NEXT: vldr d3, [sp, #16] ; CHECK-NEXT: vld1.64 {d0, d1}, [r1] +; CHECK-NEXT: add r1, sp, #56 +; CHECK-NEXT: vmov d2, r2, r3 +; CHECK-NEXT: vld1.64 {d4, d5}, [r1] ; CHECK-NEXT: add r1, sp, #40 -; CHECK-NEXT: vmov d4, r2, r3 -; CHECK-NEXT: vld1.64 {d2, d3}, [r1] -; CHECK-NEXT: add r1, sp, #24 ; CHECK-NEXT: vld1.64 {d6, d7}, [r1] -; CHECK-NEXT: vuzp.32 q2, q0 -; CHECK-NEXT: vuzp.32 q3, q1 -; CHECK-NEXT: vadd.f32 q1, q1, q2 -; CHECK-NEXT: vsub.f32 q0, q3, q0 -; CHECK-NEXT: vzip.32 q0, q1 -; CHECK-NEXT: vst1.32 {d0, d1}, [r0:128]! +; CHECK-NEXT: vcadd.f32 q4, q3, q1, #90 +; CHECK-NEXT: vcadd.f32 q1, q2, q0, #90 +; CHECK-NEXT: vst1.32 {d8, d9}, [r0:128]! ; CHECK-NEXT: vst1.64 {d2, d3}, [r0:128] +; CHECK-NEXT: vpop {d8, d9} ; CHECK-NEXT: bx lr entry: %strided.vec = shufflevector <8 x float> %a, <8 x float> zeroinitializer, <4 x i32> diff --git a/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f32-mul.ll b/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f32-mul.ll --- a/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f32-mul.ll +++ b/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-f32-mul.ll @@ -33,19 +33,13 @@ ; CHECK-LABEL: complex_mul_v4f32: ; CHECK: @ %bb.0: @ %entry ; CHECK-NEXT: mov r12, sp -; CHECK-NEXT: vld1.64 {d0, d1}, [r12] -; CHECK-NEXT: vmov d16, r2, r3 -; CHECK-NEXT: vmov d17, r0, r1 -; CHECK-NEXT: vorr d18, d0, d0 -; CHECK-NEXT: vtrn.32 d17, d16 -; CHECK-NEXT: vtrn.32 d18, d1 -; CHECK-NEXT: vmul.f32 d3, d1, d17 -; CHECK-NEXT: vmul.f32 d0, d18, d17 -; CHECK-NEXT: vmla.f32 d3, d18, d16 -; CHECK-NEXT: vmls.f32 d0, d1, d16 -; CHECK-NEXT: vtrn.32 d0, d3 -; CHECK-NEXT: vmov r0, r1, d0 -; CHECK-NEXT: vmov r2, r3, d3 +; CHECK-NEXT: vld1.64 {d2, d3}, [r12] +; CHECK-NEXT: vmov d1, r2, r3 +; CHECK-NEXT: vmov d0, r0, r1 +; CHECK-NEXT: vcmul.f32 q2, q0, q1, #0 +; CHECK-NEXT: vcmla.f32 q2, q0, q1, #90 +; CHECK-NEXT: vmov r0, r1, d4 +; CHECK-NEXT: vmov r2, r3, d5 ; CHECK-NEXT: bx lr entry: %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> @@ -67,24 +61,20 @@ ; CHECK: @ %bb.0: @ %entry ; CHECK-NEXT: .vsave {d8, d9} ; CHECK-NEXT: vpush {d8, d9} -; CHECK-NEXT: add r1, sp, #24 -; CHECK-NEXT: vldr d1, [sp, #16] -; CHECK-NEXT: vld1.64 {d4, d5}, [r1] ; CHECK-NEXT: add r1, sp, #56 -; CHECK-NEXT: vmov d0, r2, r3 -; CHECK-NEXT: vld1.64 {d2, d3}, [r1] +; CHECK-NEXT: vldr d3, [sp, #16] +; CHECK-NEXT: vld1.64 {d0, d1}, [r1] ; CHECK-NEXT: add r1, sp, #40 +; CHECK-NEXT: vmov d2, r2, r3 +; CHECK-NEXT: vld1.64 {d4, d5}, [r1] +; CHECK-NEXT: add r1, sp, #24 +; CHECK-NEXT: vcmul.f32 q4, q1, q2, #0 ; CHECK-NEXT: vld1.64 {d6, d7}, [r1] -; CHECK-NEXT: vuzp.32 q0, q2 -; CHECK-NEXT: vuzp.32 q3, q1 -; CHECK-NEXT: vmul.f32 q4, q1, q2 -; CHECK-NEXT: vmul.f32 q2, q3, q2 -; CHECK-NEXT: vneg.f32 q4, q4 -; CHECK-NEXT: vfma.f32 q2, q1, q0 -; CHECK-NEXT: vfma.f32 q4, q3, q0 -; CHECK-NEXT: vzip.32 q4, q2 +; CHECK-NEXT: vcmla.f32 q4, q1, q2, #90 +; CHECK-NEXT: vcmul.f32 q1, q3, q0, #0 ; CHECK-NEXT: vst1.32 {d8, d9}, [r0:128]! -; CHECK-NEXT: vst1.64 {d4, d5}, [r0:128] +; CHECK-NEXT: vcmla.f32 q1, q3, q0, #90 +; CHECK-NEXT: vst1.64 {d2, d3}, [r0:128] ; CHECK-NEXT: vpop {d8, d9} ; CHECK-NEXT: bx lr entry: diff --git a/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-rotations-add.ll b/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-rotations-add.ll --- a/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-rotations-add.ll +++ b/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-rotations-add.ll @@ -12,10 +12,10 @@ ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: vld1.32 {d0, d1}, [r0] ; CHECK-NEXT: vld1.32 {d2, d3}, [r1] -; CHECK-NEXT: vuzp.32 q1, q0 -; CHECK-NEXT: vsub.f32 q8, q1, q0 -; CHECK-NEXT: vadd.f32 q9, q0, q1 -; CHECK-NEXT: vst2.32 {d16, d17, d18, d19}, [r1] +; CHECK-NEXT: vcadd.f32 q2, q0, q0, #90 +; CHECK-NEXT: vcadd.f32 q0, q1, q1, #90 +; CHECK-NEXT: vst1.32 {d4, d5}, [r0] +; CHECK-NEXT: vst1.32 {d0, d1}, [r1] ; CHECK-NEXT: b .LBB0_1 vector.ph: br label %vector.body @@ -44,10 +44,10 @@ ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: vld1.32 {d0, d1}, [r0] ; CHECK-NEXT: vld1.32 {d2, d3}, [r1] -; CHECK-NEXT: vuzp.32 q1, q0 -; CHECK-NEXT: vadd.f32 q8, q0, q1 -; CHECK-NEXT: vsub.f32 q9, q0, q1 -; CHECK-NEXT: vst2.32 {d16, d17, d18, d19}, [r1] +; CHECK-NEXT: vcadd.f32 q2, q0, q0, #270 +; CHECK-NEXT: vcadd.f32 q0, q1, q1, #270 +; CHECK-NEXT: vst1.32 {d4, d5}, [r0] +; CHECK-NEXT: vst1.32 {d0, d1}, [r1] ; CHECK-NEXT: b .LBB1_1 vector.ph: br label %vector.body diff --git a/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-rotations-mul.ll b/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-rotations-mul.ll --- a/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-rotations-mul.ll +++ b/llvm/test/CodeGen/ARM/ComplexArithmetic/complex-arithmetic-rotations-mul.ll @@ -12,13 +12,12 @@ ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: vld1.32 {d0, d1}, [r0] ; CHECK-NEXT: vld1.32 {d2, d3}, [r1] -; CHECK-NEXT: vuzp.32 q1, q0 -; CHECK-NEXT: vmul.f32 q2, q0, q0 -; CHECK-NEXT: vmul.f32 q9, q1, q0 -; CHECK-NEXT: vneg.f32 q8, q2 -; CHECK-NEXT: vfma.f32 q9, q0, q1 -; CHECK-NEXT: vfma.f32 q8, q1, q1 -; CHECK-NEXT: vst2.32 {d16, d17, d18, d19}, [r1] +; CHECK-NEXT: vcmul.f32 q2, q0, q0, #0 +; CHECK-NEXT: vcmla.f32 q2, q0, q0, #90 +; CHECK-NEXT: vcmul.f32 q0, q1, q1, #0 +; CHECK-NEXT: vst1.32 {d4, d5}, [r0] +; CHECK-NEXT: vcmla.f32 q0, q1, q1, #90 +; CHECK-NEXT: vst1.32 {d0, d1}, [r1] ; CHECK-NEXT: b .LBB0_1 vector.ph: br label %vector.body @@ -93,14 +92,12 @@ ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: vld1.32 {d0, d1}, [r0] ; CHECK-NEXT: vld1.32 {d2, d3}, [r1] -; CHECK-NEXT: vuzp.32 q1, q0 -; CHECK-NEXT: vmul.f32 q3, q1, q1 -; CHECK-NEXT: vmul.f32 q2, q0, q1 -; CHECK-NEXT: vneg.f32 q8, q3 -; CHECK-NEXT: vfma.f32 q2, q1, q0 -; CHECK-NEXT: vfma.f32 q8, q0, q0 -; CHECK-NEXT: vneg.f32 q9, q2 -; CHECK-NEXT: vst2.32 {d16, d17, d18, d19}, [r1] +; CHECK-NEXT: vcmul.f32 q2, q0, q0, #180 +; CHECK-NEXT: vcmla.f32 q2, q0, q0, #270 +; CHECK-NEXT: vcmul.f32 q0, q1, q1, #180 +; CHECK-NEXT: vst1.32 {d4, d5}, [r0] +; CHECK-NEXT: vcmla.f32 q0, q1, q1, #270 +; CHECK-NEXT: vst1.32 {d0, d1}, [r1] ; CHECK-NEXT: b .LBB2_1 vector.ph: br label %vector.body diff --git a/llvm/test/CodeGen/ARM/O3-pipeline.ll b/llvm/test/CodeGen/ARM/O3-pipeline.ll --- a/llvm/test/CodeGen/ARM/O3-pipeline.ll +++ b/llvm/test/CodeGen/ARM/O3-pipeline.ll @@ -46,6 +46,7 @@ ; CHECK-NEXT: Basic Alias Analysis (stateless AA impl) ; CHECK-NEXT: Function Alias Analysis Results ; CHECK-NEXT: Transform functions to use DSP intrinsics +; CHECK-NEXT: Complex Arithmetic Pass ; CHECK-NEXT: Interleaved Access Pass ; CHECK-NEXT: Type Promotion ; CHECK-NEXT: CodeGen Prepare