diff --git a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h --- a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h +++ b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h @@ -38,7 +38,8 @@ CMulPartial, // The following 'operations' are used to represent internal states. Backends // are not expected to try and support these in any capacity. - Shuffle + Shuffle, + Passthrough }; enum class ComplexDeinterleavingRotation { diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -254,6 +254,7 @@ /// 270: r: ar + bi /// i: ai - br NodePtr identifyAdd(Instruction *Real, Instruction *Imag); + NodePtr identifyPassthrough(Instruction *Real, Instruction *Imag); NodePtr identifyNode(Instruction *I, Instruction *J); @@ -598,8 +599,16 @@ Rotation == ComplexDeinterleavingRotation::Rotation_270) ? CommonOperand : nullptr); - NodePtr CNode = identifyNodeWithImplicitAdd( - cast(CR), cast(CI), PartialMatch); + + auto *CRInst = dyn_cast(CR); + auto *CIInst = dyn_cast(CI); + + if (!CRInst || !CIInst) { + LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); + return nullptr; + } + + NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); if (!CNode) { LLVM_DEBUG(dbgs() << " - No cnode identified\n"); return nullptr; @@ -633,27 +642,30 @@ ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); - // Determine rotation - ComplexDeinterleavingRotation Rotation; - if ((Real->getOpcode() == Instruction::FSub && - Imag->getOpcode() == Instruction::FAdd) || - (Real->getOpcode() == Instruction::Sub && - Imag->getOpcode() == Instruction::Add)) - Rotation = ComplexDeinterleavingRotation::Rotation_90; - else if ((Real->getOpcode() == Instruction::FAdd && - Imag->getOpcode() == Instruction::FSub) || - (Real->getOpcode() == Instruction::Add && - Imag->getOpcode() == Instruction::Sub)) - Rotation = ComplexDeinterleavingRotation::Rotation_270; - else { - LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); + unsigned RotKey = 0; + RotKey |= Real->getOpcode() == Instruction::FSub || + Real->getOpcode() == Instruction::Sub; + RotKey |= ((Imag->getOpcode() == Instruction::FSub || + Imag->getOpcode() == Instruction::Sub) + << 1); + if ((RotKey & 2) == 2) + RotKey ^= 1; + + ComplexDeinterleavingRotation Rotation = + (ComplexDeinterleavingRotation)RotKey; + + LLVM_DEBUG(dbgs() << " - RotKey: " << RotKey << ".\n"); + + if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0 || + Rotation == llvm::ComplexDeinterleavingRotation::Rotation_180) { + LLVM_DEBUG(dbgs() << " - Unsupported rotation.\n"); return nullptr; } - auto *AR = dyn_cast(Real->getOperand(0)); - auto *BI = dyn_cast(Real->getOperand(1)); - auto *AI = dyn_cast(Imag->getOperand(0)); - auto *BR = dyn_cast(Imag->getOperand(1)); + Instruction *AR = dyn_cast(Real->getOperand(0)); + Instruction *BI = dyn_cast(Real->getOperand(1)); + Instruction *AI = dyn_cast(Imag->getOperand(0)); + Instruction *BR = dyn_cast(Imag->getOperand(1)); if (!AR || !AI || !BR || !BI) { LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); @@ -696,6 +708,70 @@ return match(A, Pattern) && match(B, Pattern); } +static bool isInstructionValidForPassthrough(Instruction *I) { + // Must be either a unary or binary op + if (!I->isUnaryOp() && !I->isBinaryOp()) + return false; + + // Must not have any side effects, or touch memory + if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory()) + return false; + + // Must not throw exceptions + if (I->mayThrow()) + return false; + + switch (I->getOpcode()) { + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + case Instruction::FNeg: + return true; + } + + // TODO ask target for valid passthrough instructions + + return false; +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyPassthrough(Instruction *Real, + Instruction *Imag) { + if (!isInstructionValidForPassthrough(Real) || + !isInstructionValidForPassthrough(Imag)) + return nullptr; + + auto *R0 = dyn_cast(Real->getOperand(0)); + auto *I0 = dyn_cast(Imag->getOperand(0)); + + if (!R0 || !I0) + return nullptr; + + NodePtr Op0 = identifyNode(R0, I0); + NodePtr Op1 = nullptr; + if (Op0 == nullptr) + return nullptr; + + if (Real->isBinaryOp()) { + auto *R1 = dyn_cast(Real->getOperand(1)); + auto *I1 = dyn_cast(Imag->getOperand(1)); + if (!R1 || !I1) + return nullptr; + + Op1 = identifyNode(R1, I1); + if (Op1 == nullptr) + return nullptr; + } + + auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Passthrough, + Real, Imag); + Node->addOperand(Op0); + if (Real->isBinaryOp()) + Node->addOperand(Op1); + + return submitCompositeNode(Node); +} + ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n"); @@ -788,8 +864,10 @@ PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); return submitCompositeNode(PlaceholderNode); } - if (RealShuffle || ImagShuffle) + if (RealShuffle || ImagShuffle) { + LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); return nullptr; + } auto *VTy = cast(Real->getType()); auto *NewVTy = @@ -807,7 +885,10 @@ return identifyAdd(Real, Imag); } - return nullptr; + auto Passthrough = identifyPassthrough(Real, Imag); + LLVM_DEBUG(if (Passthrough == nullptr) dbgs() + << " - Not recognised as a valid pattern.\n"); + return Passthrough; } bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { @@ -832,7 +913,8 @@ // Check all instructions have internal uses for (const auto &Node : CompositeNodes) { if (!Node->hasAllInternalUses(AllInstructions)) { - LLVM_DEBUG(dbgs() << " - Invalid internal uses\n"); + LLVM_DEBUG(dbgs() << " - Invalid internal uses in " << Node.get() + << "\n"); return false; } } @@ -845,12 +927,14 @@ return Node->ReplacementNode; Value *Input0 = replaceNode(Node->Operands[0]); - Value *Input1 = replaceNode(Node->Operands[1]); + Value *Input1 = + Node->Operands.size() > 1 ? replaceNode(Node->Operands[1]) : nullptr; Value *Accumulator = Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr; - assert(Input0->getType() == Input1->getType() && - "Node inputs need to be of the same type"); + if (Input1) + assert(Input0->getType() == Input1->getType() && + "Node inputs need to be of the same type"); Node->ReplacementNode = TL->createComplexDeinterleavingIR( Node->Real, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -23841,12 +23841,17 @@ ArrayRef UpperSplitMask(&SplitSeqVec[Stride], Stride); auto *LowerSplitA = B.CreateShuffleVector(InputA, LowerSplitMask); - auto *LowerSplitB = B.CreateShuffleVector(InputB, LowerSplitMask); auto *UpperSplitA = B.CreateShuffleVector(InputA, UpperSplitMask); - auto *UpperSplitB = B.CreateShuffleVector(InputB, UpperSplitMask); + Value *LowerSplitB = nullptr; + Value *UpperSplitB = nullptr; Value *LowerSplitAcc = nullptr; Value *UpperSplitAcc = nullptr; + if (InputB) { + LowerSplitB = B.CreateShuffleVector(InputB, LowerSplitMask); + UpperSplitB = B.CreateShuffleVector(InputB, UpperSplitMask); + } + if (Accumulator) { LowerSplitAcc = B.CreateShuffleVector(Accumulator, LowerSplitMask); UpperSplitAcc = B.CreateShuffleVector(Accumulator, UpperSplitMask); @@ -23862,6 +23867,8 @@ } if (OperationType == ComplexDeinterleavingOperation::CMulPartial) { + assert(InputB && "Complex multiplication is a binary operation, and " + "requires 2 inputs. Only one provided."); Intrinsic::ID IdMap[4] = {Intrinsic::aarch64_neon_vcmla_rot0, Intrinsic::aarch64_neon_vcmla_rot90, Intrinsic::aarch64_neon_vcmla_rot180, @@ -23875,6 +23882,11 @@ } if (OperationType == ComplexDeinterleavingOperation::CAdd) { + assert(InputB && "Complex addition is a binary operation, and requires 2 " + "inputs. Only one provided."); + if (Rotation == ComplexDeinterleavingRotation::Rotation_0) + return B.CreateFAdd(InputA, InputB); + Intrinsic::ID IntId = Intrinsic::not_intrinsic; if (Rotation == ComplexDeinterleavingRotation::Rotation_90) IntId = Intrinsic::aarch64_neon_vcadd_rot90; @@ -23887,5 +23899,24 @@ return B.CreateIntrinsic(IntId, Ty, {InputA, InputB}); } + if (OperationType == ComplexDeinterleavingOperation::Passthrough) { + if (I->isUnaryOp()) + assert(!InputB && + "Unary complex passthroughs need one input, but two were provided."); + else if (I->isBinaryOp()) + assert(InputB && "Binary complex passthroughs need two inputs, only one " + "was provided."); + switch (I->getOpcode()) { + case Instruction::FNeg: + return B.CreateFNeg(InputA); + case Instruction::FAdd: + return B.CreateFAdd(InputA, InputB); + case Instruction::FSub: + return B.CreateFSub(InputA, InputB); + case Instruction::FMul: + return B.CreateFMul(InputA, InputB); + } + } + return nullptr; } diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -21944,11 +21944,16 @@ auto *LowerSplitA = B.CreateShuffleVector(InputA, LowerSplitMask); auto *LowerSplitB = B.CreateShuffleVector(InputB, LowerSplitMask); - auto *UpperSplitA = B.CreateShuffleVector(InputA, UpperSplitMask); - auto *UpperSplitB = B.CreateShuffleVector(InputB, UpperSplitMask); + Value *UpperSplitA = nullptr; + Value *UpperSplitB = nullptr; Value *LowerSplitAcc = nullptr; Value *UpperSplitAcc = nullptr; + if (InputB) { + UpperSplitA = B.CreateShuffleVector(InputA, UpperSplitMask); + UpperSplitB = B.CreateShuffleVector(InputB, UpperSplitMask); + } + if (Accumulator) { LowerSplitAcc = B.CreateShuffleVector(Accumulator, LowerSplitMask); UpperSplitAcc = B.CreateShuffleVector(Accumulator, UpperSplitMask); @@ -21967,6 +21972,8 @@ ConstantInt *ConstRotation = nullptr; if (OperationType == ComplexDeinterleavingOperation::CMulPartial) { + assert(InputB && "Complex multiplication is a binary operation, and " + "requires 2 inputs. Only one provided."); ConstRotation = ConstantInt::get(IntTy, (int)Rotation); if (Accumulator) @@ -21977,6 +21984,16 @@ } if (OperationType == ComplexDeinterleavingOperation::CAdd) { + assert(InputB && "Complex addition is a binary operation, and requires 2 " + "inputs. Only one provided."); + if (Rotation == ComplexDeinterleavingRotation::Rotation_0) { + auto *ScalarTy = Ty->getScalarType(); + if (ScalarTy->isHalfTy() || ScalarTy->isFloatTy() || + ScalarTy->isDoubleTy()) + return B.CreateFAdd(InputA, InputB); + return B.CreateAdd(InputA, InputB); + } + // 1 means the value is not halved. auto *ConstHalving = ConstantInt::get(IntTy, 1); @@ -21992,5 +22009,24 @@ {ConstHalving, ConstRotation, InputA, InputB}); } + if (OperationType == ComplexDeinterleavingOperation::Passthrough) { + if (I->isUnaryOp()) + assert(!InputB && + "Unary complex passthroughs need one input, but two were provided."); + else if (I->isBinaryOp()) + assert(InputB && "Binary complex passthroughs need two inputs, only one " + "was provided."); + switch (I->getOpcode()) { + case Instruction::FNeg: + return B.CreateFNeg(InputA); + case Instruction::FAdd: + return B.CreateFAdd(InputA, InputB); + case Instruction::FSub: + return B.CreateFSub(InputA, InputB); + case Instruction::FMul: + return B.CreateFMul(InputA, InputB); + } + } + return nullptr; } diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll --- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll +++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll @@ -366,24 +366,10 @@ define <4 x float> @mul_add_rot0(<4 x float> %a, <4 x float> %b, <4 x float> %c) { ; CHECK-LABEL: mul_add_rot0: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8 -; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8 -; CHECK-NEXT: ext v7.16b, v2.16b, v2.16b, #8 -; CHECK-NEXT: zip2 v5.2s, v1.2s, v3.2s -; CHECK-NEXT: zip1 v1.2s, v1.2s, v3.2s -; CHECK-NEXT: zip2 v6.2s, v0.2s, v4.2s -; CHECK-NEXT: zip1 v0.2s, v0.2s, v4.2s -; CHECK-NEXT: zip2 v4.2s, v2.2s, v7.2s -; CHECK-NEXT: fmul v16.2s, v6.2s, v5.2s -; CHECK-NEXT: fmla v4.2s, v0.2s, v5.2s -; CHECK-NEXT: fneg v3.2s, v16.2s -; CHECK-NEXT: fmla v4.2s, v6.2s, v1.2s -; CHECK-NEXT: fmla v3.2s, v0.2s, v1.2s -; CHECK-NEXT: zip1 v0.2s, v2.2s, v7.2s -; CHECK-NEXT: fadd v0.2s, v3.2s, v0.2s -; CHECK-NEXT: zip2 v1.2s, v0.2s, v4.2s -; CHECK-NEXT: zip1 v0.2s, v0.2s, v4.2s -; CHECK-NEXT: mov v0.d[1], v1.d[0] +; CHECK-NEXT: movi v3.2d, #0000000000000000 +; CHECK-NEXT: fcmla v3.4s, v0.4s, v1.4s, #0 +; CHECK-NEXT: fcmla v3.4s, v0.4s, v1.4s, #90 +; CHECK-NEXT: fadd v0.4s, v3.4s, v2.4s ; CHECK-NEXT: ret entry: %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> diff --git a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll --- a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll +++ b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll @@ -390,32 +390,9 @@ define arm_aapcs_vfpcc <4 x float> @mul_add_rot0(<4 x float> %a, <4 x float> %b, <4 x float> %c) { ; CHECK-LABEL: mul_add_rot0: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: .vsave {d10} -; CHECK-NEXT: vpush {d10} -; CHECK-NEXT: .vsave {d8} -; CHECK-NEXT: vpush {d8} -; CHECK-NEXT: vmov.f32 s16, s0 -; CHECK-NEXT: vmov.f32 s20, s5 -; CHECK-NEXT: vmov.f32 s12, s9 -; CHECK-NEXT: vmov.f32 s17, s2 -; CHECK-NEXT: vmov.f32 s21, s7 -; CHECK-NEXT: vmov.f32 s13, s11 -; CHECK-NEXT: vmov.f32 s0, s1 -; CHECK-NEXT: vfma.f32 q3, q5, q4 -; CHECK-NEXT: vmov.f32 s1, s3 -; CHECK-NEXT: vmov.f32 s5, s6 -; CHECK-NEXT: vfma.f32 q3, q1, q0 -; CHECK-NEXT: vmul.f32 q0, q0, q5 -; CHECK-NEXT: vneg.f32 q0, q0 -; CHECK-NEXT: vmov.f32 s9, s10 -; CHECK-NEXT: vfma.f32 q0, q1, q4 -; CHECK-NEXT: vadd.f32 q1, q0, q2 -; CHECK-NEXT: vmov.f32 s1, s12 -; CHECK-NEXT: vmov.f32 s0, s4 -; CHECK-NEXT: vmov.f32 s2, s5 -; CHECK-NEXT: vmov.f32 s3, s13 -; CHECK-NEXT: vpop {d8} -; CHECK-NEXT: vpop {d10} +; CHECK-NEXT: vcmul.f32 q3, q0, q1, #0 +; CHECK-NEXT: vcmla.f32 q3, q0, q1, #90 +; CHECK-NEXT: vadd.f32 q0, q3, q2 ; CHECK-NEXT: bx lr entry: %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32>