diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -210,6 +210,10 @@ } NodePtr submitCompositeNode(NodePtr Node) { + LLVM_DEBUG({ + dbgs() << "Node Submitted:.\n"; + Node->dump(dbgs()); + }); CompositeNodes.push_back(Node); AllInstructions.insert(Node->Real); AllInstructions.insert(Node->Imag); @@ -460,6 +464,7 @@ return nullptr; } + LLVM_DEBUG(dbgs() << "Negs: " << Negs << ".\n"); ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; Instruction *CommonOperand; @@ -633,27 +638,42 @@ ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); - // Determine rotation - ComplexDeinterleavingRotation Rotation; - if ((Real->getOpcode() == Instruction::FSub && - Imag->getOpcode() == Instruction::FAdd) || - (Real->getOpcode() == Instruction::Sub && - Imag->getOpcode() == Instruction::Add)) - Rotation = ComplexDeinterleavingRotation::Rotation_90; - else if ((Real->getOpcode() == Instruction::FAdd && - Imag->getOpcode() == Instruction::FSub) || - (Real->getOpcode() == Instruction::Add && - Imag->getOpcode() == Instruction::Sub)) - Rotation = ComplexDeinterleavingRotation::Rotation_270; - else { - LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); + unsigned RotKey = 0; + RotKey |= Real->getOpcode() == Instruction::FSub || + Real->getOpcode() == Instruction::Sub; + RotKey |= ((Imag->getOpcode() == Instruction::FSub || + Imag->getOpcode() == Instruction::Sub) + << 1); + if ((RotKey & 2) == 2) + RotKey ^= 1; + + ComplexDeinterleavingRotation Rotation = + (ComplexDeinterleavingRotation)RotKey; + + LLVM_DEBUG(dbgs() << " - RotKey: " << RotKey << ".\n"); + + if(Rotation == llvm::ComplexDeinterleavingRotation::Rotation_180) { + LLVM_DEBUG(dbgs() << " - Unsupported rotation: 180.\n"); return nullptr; } - auto *AR = dyn_cast(Real->getOperand(0)); - auto *BI = dyn_cast(Real->getOperand(1)); - auto *AI = dyn_cast(Imag->getOperand(0)); - auto *BR = dyn_cast(Imag->getOperand(1)); + Instruction *AR; + Instruction *BI; + Instruction *AI; + Instruction *BR; + + if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_90 || + Rotation == llvm::ComplexDeinterleavingRotation::Rotation_270) { + AR = dyn_cast(Real->getOperand(0)); + BI = dyn_cast(Real->getOperand(1)); + AI = dyn_cast(Imag->getOperand(0)); + BR = dyn_cast(Imag->getOperand(1)); + }else{ + AR = dyn_cast(Real->getOperand(0)); + BR = dyn_cast(Real->getOperand(1)); + AI = dyn_cast(Imag->getOperand(0)); + BI = dyn_cast(Imag->getOperand(1)); + } if (!AR || !AI || !BR || !BI) { LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); @@ -683,10 +703,12 @@ unsigned OpcA = A->getOpcode(); unsigned OpcB = B->getOpcode(); - return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || - (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || - (OpcA == Instruction::Sub && OpcB == Instruction::Add) || - (OpcA == Instruction::Add && OpcB == Instruction::Sub); + auto CheckOpc = [&](unsigned Opc) -> bool { + return Opc == Instruction::FSub || Opc == Instruction::FAdd || + Opc == Instruction::Sub || Opc == Instruction::Add; + }; + + return CheckOpc(OpcA) && CheckOpc(OpcB); } static bool isInstructionPairMul(Instruction *A, Instruction *B) { @@ -788,8 +810,10 @@ PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); return submitCompositeNode(PlaceholderNode); } - if (RealShuffle || ImagShuffle) + if (RealShuffle || ImagShuffle) { + LLVM_DEBUG(dbgs() << " - Therre's a shuffle where there shouldn't be.\n"); return nullptr; + } auto *VTy = cast(Real->getType()); auto *NewVTy = @@ -807,6 +831,7 @@ return identifyAdd(Real, Imag); } + LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); return nullptr; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -23875,6 +23875,9 @@ } if (OperationType == ComplexDeinterleavingOperation::CAdd) { + if(Rotation == ComplexDeinterleavingRotation::Rotation_0) + return B.CreateFAdd(InputA, InputB); + Intrinsic::ID IntId = Intrinsic::not_intrinsic; if (Rotation == ComplexDeinterleavingRotation::Rotation_90) IntId = Intrinsic::aarch64_neon_vcadd_rot90; diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -21977,6 +21977,13 @@ } if (OperationType == ComplexDeinterleavingOperation::CAdd) { + if(Rotation == ComplexDeinterleavingRotation::Rotation_0) { + auto *ScalarTy = Ty->getScalarType(); + if(ScalarTy->isHalfTy() || ScalarTy->isFloatTy() || ScalarTy->isDoubleTy()) + return B.CreateFAdd(InputA, InputB); + return B.CreateAdd(InputA, InputB); + } + // 1 means the value is not halved. auto *ConstHalving = ConstantInt::get(IntTy, 1); diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll --- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll +++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll @@ -366,24 +366,10 @@ define <4 x float> @mul_add_rot0(<4 x float> %a, <4 x float> %b, <4 x float> %c) { ; CHECK-LABEL: mul_add_rot0: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8 -; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8 -; CHECK-NEXT: ext v7.16b, v2.16b, v2.16b, #8 -; CHECK-NEXT: zip2 v5.2s, v1.2s, v3.2s -; CHECK-NEXT: zip1 v1.2s, v1.2s, v3.2s -; CHECK-NEXT: zip2 v6.2s, v0.2s, v4.2s -; CHECK-NEXT: zip1 v0.2s, v0.2s, v4.2s -; CHECK-NEXT: zip2 v4.2s, v2.2s, v7.2s -; CHECK-NEXT: fmul v16.2s, v6.2s, v5.2s -; CHECK-NEXT: fmla v4.2s, v0.2s, v5.2s -; CHECK-NEXT: fneg v3.2s, v16.2s -; CHECK-NEXT: fmla v4.2s, v6.2s, v1.2s -; CHECK-NEXT: fmla v3.2s, v0.2s, v1.2s -; CHECK-NEXT: zip1 v0.2s, v2.2s, v7.2s -; CHECK-NEXT: fadd v0.2s, v3.2s, v0.2s -; CHECK-NEXT: zip2 v1.2s, v0.2s, v4.2s -; CHECK-NEXT: zip1 v0.2s, v0.2s, v4.2s -; CHECK-NEXT: mov v0.d[1], v1.d[0] +; CHECK-NEXT: movi v3.2d, #0000000000000000 +; CHECK-NEXT: fcmla v3.4s, v0.4s, v1.4s, #0 +; CHECK-NEXT: fcmla v3.4s, v0.4s, v1.4s, #90 +; CHECK-NEXT: fadd v0.4s, v3.4s, v2.4s ; CHECK-NEXT: ret entry: %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> diff --git a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll --- a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll +++ b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll @@ -390,32 +390,9 @@ define arm_aapcs_vfpcc <4 x float> @mul_add_rot0(<4 x float> %a, <4 x float> %b, <4 x float> %c) { ; CHECK-LABEL: mul_add_rot0: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: .vsave {d10} -; CHECK-NEXT: vpush {d10} -; CHECK-NEXT: .vsave {d8} -; CHECK-NEXT: vpush {d8} -; CHECK-NEXT: vmov.f32 s16, s0 -; CHECK-NEXT: vmov.f32 s20, s5 -; CHECK-NEXT: vmov.f32 s12, s9 -; CHECK-NEXT: vmov.f32 s17, s2 -; CHECK-NEXT: vmov.f32 s21, s7 -; CHECK-NEXT: vmov.f32 s13, s11 -; CHECK-NEXT: vmov.f32 s0, s1 -; CHECK-NEXT: vfma.f32 q3, q5, q4 -; CHECK-NEXT: vmov.f32 s1, s3 -; CHECK-NEXT: vmov.f32 s5, s6 -; CHECK-NEXT: vfma.f32 q3, q1, q0 -; CHECK-NEXT: vmul.f32 q0, q0, q5 -; CHECK-NEXT: vneg.f32 q0, q0 -; CHECK-NEXT: vmov.f32 s9, s10 -; CHECK-NEXT: vfma.f32 q0, q1, q4 -; CHECK-NEXT: vadd.f32 q1, q0, q2 -; CHECK-NEXT: vmov.f32 s1, s12 -; CHECK-NEXT: vmov.f32 s0, s4 -; CHECK-NEXT: vmov.f32 s2, s5 -; CHECK-NEXT: vmov.f32 s3, s13 -; CHECK-NEXT: vpop {d8} -; CHECK-NEXT: vpop {d10} +; CHECK-NEXT: vcmul.f32 q3, q0, q1, #0 +; CHECK-NEXT: vcmla.f32 q3, q0, q1, #90 +; CHECK-NEXT: vadd.f32 q0, q3, q2 ; CHECK-NEXT: bx lr entry: %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32>