diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -3196,7 +3196,7 @@ /// If one cannot be created using all the given inputs, nullptr should be /// returned. virtual Value *createComplexDeinterleavingIR( - Instruction *I, ComplexDeinterleavingOperation OperationType, + IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator = nullptr) const { return nullptr; diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -267,7 +267,7 @@ /// intrinsic (for both fixed and scalable vectors) NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); - Value *replaceNode(RawNodePtr Node); + Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); public: void dump() { dump(dbgs()); } @@ -1011,7 +1011,8 @@ return submitCompositeNode(PlaceholderNode); } -static Value *replaceSymmetricNode(ComplexDeinterleavingGraph::RawNodePtr Node, +static Value *replaceSymmetricNode(IRBuilderBase &B, + ComplexDeinterleavingGraph::RawNodePtr Node, Value *InputA, Value *InputB) { Instruction *I = Node->Real; if (I->isUnaryOp()) @@ -1021,8 +1022,6 @@ assert(InputB && "Binary symmetric operations need two inputs, only one " "was provided."); - IRBuilder<> B(I); - switch (I->getOpcode()) { case Instruction::FNeg: return B.CreateFNegFMF(InputA, I); @@ -1037,27 +1036,28 @@ return nullptr; } -Value *ComplexDeinterleavingGraph::replaceNode( - ComplexDeinterleavingGraph::RawNodePtr Node) { +Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, + RawNodePtr Node) { if (Node->ReplacementNode) return Node->ReplacementNode; - Value *Input0 = replaceNode(Node->Operands[0]); - Value *Input1 = - Node->Operands.size() > 1 ? replaceNode(Node->Operands[1]) : nullptr; - Value *Accumulator = - Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr; + Value *Input0 = replaceNode(Builder, Node->Operands[0]); + Value *Input1 = Node->Operands.size() > 1 + ? replaceNode(Builder, Node->Operands[1]) + : nullptr; + Value *Accumulator = Node->Operands.size() > 2 + ? replaceNode(Builder, Node->Operands[2]) + : nullptr; if (Input1) assert(Input0->getType() == Input1->getType() && "Node inputs need to be of the same type"); if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) - Node->ReplacementNode = replaceSymmetricNode(Node, Input0, Input1); + Node->ReplacementNode = replaceSymmetricNode(Builder, Node, Input0, Input1); else Node->ReplacementNode = TL->createComplexDeinterleavingIR( - Node->Real, Node->Operation, Node->Rotation, Input0, Input1, - Accumulator); + Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); assert(Node->ReplacementNode && "Target failed to create Intrinsic call."); NumComplexTransformations += 1; @@ -1074,7 +1074,7 @@ IRBuilder<> Builder(RootInstruction); auto RootNode = RootToNode[RootInstruction]; - Value *R = replaceNode(RootNode.get()); + Value *R = replaceNode(Builder, RootNode.get()); assert(R && "Unable to find replacement for RootInstruction"); DeadInstrRoots.push_back(RootInstruction); RootInstruction->replaceAllUsesWith(R); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -843,7 +843,7 @@ ComplexDeinterleavingOperation Operation, Type *Ty) const override; Value *createComplexDeinterleavingIR( - Instruction *I, ComplexDeinterleavingOperation OperationType, + IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator = nullptr) const override; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -25286,14 +25286,12 @@ } Value *AArch64TargetLowering::createComplexDeinterleavingIR( - Instruction *I, ComplexDeinterleavingOperation OperationType, + IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator) const { VectorType *Ty = cast(InputA->getType()); bool IsScalable = Ty->isScalableTy(); - IRBuilder<> B(I); - unsigned TyWidth = Ty->getScalarSizeInBits() * Ty->getElementCount().getKnownMinValue(); @@ -25317,9 +25315,9 @@ B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(Stride)); } auto *LowerSplitInt = createComplexDeinterleavingIR( - I, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc); + B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc); auto *UpperSplitInt = createComplexDeinterleavingIR( - I, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc); + B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc); auto *Result = B.CreateInsertVector(Ty, PoisonValue::get(Ty), LowerSplitInt, B.getInt64(0)); diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h --- a/llvm/lib/Target/ARM/ARMISelLowering.h +++ b/llvm/lib/Target/ARM/ARMISelLowering.h @@ -750,7 +750,7 @@ ComplexDeinterleavingOperation Operation, Type *Ty) const override; Value *createComplexDeinterleavingIR( - Instruction *I, ComplexDeinterleavingOperation OperationType, + IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator = nullptr) const override; diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -22060,14 +22060,12 @@ } Value *ARMTargetLowering::createComplexDeinterleavingIR( - Instruction *I, ComplexDeinterleavingOperation OperationType, + IRBuilderBase &B, ComplexDeinterleavingOperation OperationType, ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB, Value *Accumulator) const { FixedVectorType *Ty = cast(InputA->getType()); - IRBuilder<> B(I); - unsigned TyWidth = Ty->getScalarSizeInBits() * Ty->getNumElements(); assert(TyWidth >= 128 && "Width of vector type must be at least 128 bits"); @@ -22092,9 +22090,9 @@ } auto *LowerSplitInt = createComplexDeinterleavingIR( - I, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc); + B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc); auto *UpperSplitInt = createComplexDeinterleavingIR( - I, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc); + B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc); ArrayRef JoinMask(&SplitSeqVec[0], Ty->getNumElements()); return B.CreateShuffleVector(LowerSplitInt, UpperSplitInt, JoinMask); diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll --- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll +++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll @@ -220,11 +220,11 @@ ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: movi v3.2d, #0000000000000000 ; CHECK-NEXT: movi v4.2d, #0000000000000000 -; CHECK-NEXT: fcmla v3.4s, v1.4s, v0.4s, #0 -; CHECK-NEXT: fcmla v4.4s, v2.4s, v0.4s, #0 -; CHECK-NEXT: fcmla v3.4s, v1.4s, v0.4s, #90 -; CHECK-NEXT: fcmla v4.4s, v2.4s, v0.4s, #90 -; CHECK-NEXT: fcadd v0.4s, v4.4s, v3.4s, #90 +; CHECK-NEXT: fcmla v3.4s, v2.4s, v0.4s, #0 +; CHECK-NEXT: fcmla v4.4s, v1.4s, v0.4s, #0 +; CHECK-NEXT: fcmla v3.4s, v2.4s, v0.4s, #90 +; CHECK-NEXT: fcmla v4.4s, v1.4s, v0.4s, #90 +; CHECK-NEXT: fcadd v0.4s, v3.4s, v4.4s, #90 ; CHECK-NEXT: ret entry: %ar = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32>